[2024-06-15] runtime patching of a nn.Module
so i was trying to patch a torch Module to change its behavior on forward()
my requirements were:
- i want to be able to create a patched object from an existing instance of the parent module
- i want to preserve the parent object's behavior aside from forward (in particular, the patched object should still return True for isinstance() of the parent class)
- i don't want to create a new object, to avoid expensive copies of the parent module's parameters
- i want to be reasonably invariant to the implementation of the parent module
- i want to be able to tell that whether an object has been patched
- i want to add additional attributes / parameters to the module
- i don't want my ide to complain
subclassing the parent object and overriding the init and forward methods is one option, but this runs into issues with requirements (1) and (4), since I would need to write some sort of class method to initialize a patched object and transfer all of the parent object's existing attributes over, which adds a dependency on exactly what parameters the parent object has. and of course doing a simple deepcopy violates (3)
something closer to satisfying the reqs would be using MethodType to bind a newly defined forward function to the existing object's forward, along with adding an additional attribute at runtime like _is_patched. but, checking _is_patched with hasattr is a bit messy, and some type checkers would complain when the newly added parameters are accessed, since they still think that the object is of the parent type, violating (7).
a possible fix could be to use Protocols, specifically the inheritance syntax shown here. unfortunately, protocols inheriting from non-protocols is not actually supported in python at the time of writing, and also, who the heck knows what a protocol is anyway?
the method i ended up settling with does satisfy all seven requirements, but is a bit cursed: i realized that i could just directly override the __class__ attribute of an existing object, which uhh, is probably fine if you're overriding it with a subclass, i think... as a plus, it looks like python figures out the method overrides automatically, so no mucking around with MethodTypes either.
here's a simple example of this method in action. we're patching an existing nn.Embedding object to finetune the embeddings of only a few specific tokens:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MostlyFrozenEmbedding(nn.Embedding):
@classmethod
def from_existing(
cls,
embedding: nn.Embedding,
trainable_token_ids: list[int],
) -> "MostlyFrozenEmbedding":
with torch.no_grad():
for param in embedding.parameters():
param.requires_grad_(False)
embedding.delta = nn.Parameter(
torch.zeros(len(trainable_token_ids), embedding.embedding_dim)
)
embedding.register_buffer(
"trainable_token_ids",
torch.tensor(trainable_token_ids)
.reshape(-1, 1)
.repeat(1, embedding.embedding_dim),
)
embedding.__class__ = cls
return embedding
def forward(self, input: torch.Tensor):
updated_weight = torch.scatter_add(
self.weight, dim=0, index=self.trainable_token_ids, src=self.delta
)
# taken from nn.Embedding forward implementation
return F.embedding(
input,
updated_weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
the patched object is created by simply calling the from_existing method:
original = nn.Embedding(16, 32)
patched = MostlyFrozenEmbedding.from_existing(
original, trainable_token_ids=[1, 2]
)
the patched object works as expected with torch.compile too.