[back to home]

[2025-06-12] RoPE update distances visualization

Here's a minimal RoPE implementation:

class Rotary(nn.Module):
    def __init__(self, seq_len: int, head_dim: int, min_freq: float, max_mult: float):
        super().__init__()
        self.seq_len = seq_len

        half = head_dim // 2
        self.freqs = nn.Buffer(min_freq * (max_mult ** torch.linspace(0, 1, half)))
        theta_1T1d = (torch.arange(seq_len).unsqueeze(-1) * self.freqs).reshape(
            1, seq_len, 1, half
        )
        self.cos_1T1d = nn.Buffer(torch.cos(theta_1T1d), persistent=False)
        self.sin_1T1d = nn.Buffer(torch.sin(theta_1T1d), persistent=False)

    def forward(self, x_NThd):
        """forward with known sequence length, using cached rotation matrix"""
        assert x_NThd.size(1) == self.seq_len
        x1_NThd, x2_NThd = x_NThd.float().chunk(2, dim=-1)
        y1_NThd = x1_NThd * self.cos_1T1d - x2_NThd * self.sin_1T1d
        y2_NThd = x1_NThd * self.sin_1T1d + x2_NThd * self.cos_1T1d
        return torch.cat([y1_NThd, y2_NThd], dim=-1).type_as(x_NThd)
        

...and for any particular gradient update applied to the key vector at position 0, here's a plot of the relative magnitude of the change in output for the same input, if it was located at position \(x\) instead: