[2025-04-26] MLP visualization (prototype)
The MLP block used in transformers is usually some variant of:
class MLPBlock(nn.Module): def __init__(self, dim: int, mlp_dim: int): super().__init__() self.norm = nn.RMSNorm(dim, elementwise_affine=False) self.wk = nn.Linear(dim, mlp_dim, bias=False) self.act = nn.ReLU() self.wv = nn.Linear(mlp_dim, dim, bias=False) def forward(self, x): scores = self.act(self.wk(self.norm(x))) return x + self.wv(scores)
of course, usually with a different activation---I'm using ReLU here for simplicity. Below is a visualization of `MLPBlock(dim=2, mlp_dim=3)`. The colored weight vectors in the visualization are draggable: the blue vectors correspond to the rows of the first linear layer (wk), and the red vectors correspond to the columns of the second linear layer (wv). I call these weight vectors the "key" and "value" vectors, respectively, since they play a similar role as the key and value activations of attention.
yes, i did use an llm to generate the code for this. i'm not exactly the kind of person who would know d3.js...