[back to home]

[2025-04-26] MLP visualization (prototype)

The MLP block used in transformers is usually some variant of:

class MLPBlock(nn.Module):
    def __init__(self, dim: int, mlp_dim: int):
        super().__init__()
        self.norm = nn.RMSNorm(dim, elementwise_affine=False)
        self.wk = nn.Linear(dim, mlp_dim, bias=False)
        self.act = nn.ReLU()
        self.wv = nn.Linear(mlp_dim, dim, bias=False)

    def forward(self, x):
        scores = self.act(self.wk(self.norm(x)))
        return x + self.wv(scores)

of course, usually with a different activation---I'm using ReLU here for simplicity. Below is a visualization of `MLPBlock(dim=2, mlp_dim=3)`. The colored weight vectors in the visualization are draggable: the blue vectors correspond to the rows of the first linear layer (wk), and the red vectors correspond to the columns of the second linear layer (wv). I call these weight vectors the "key" and "value" vectors, respectively, since they play a similar role as the key and value activations of attention.

yes, i did use an llm to generate the code for this. i'm not exactly the kind of person who would know d3.js...