[2025-06-12] RoPE update distances visualization
Here's a minimal RoPE implementation:
class Rotary(nn.Module): def __init__(self, seq_len: int, head_dim: int, min_freq: float, max_mult: float): super().__init__() self.seq_len = seq_len half = head_dim // 2 self.freqs = nn.Buffer(min_freq * (max_mult ** torch.linspace(0, 1, half))) theta_1T1d = (torch.arange(seq_len).unsqueeze(-1) * self.freqs).reshape( 1, seq_len, 1, half ) self.cos_1T1d = nn.Buffer(torch.cos(theta_1T1d), persistent=False) self.sin_1T1d = nn.Buffer(torch.sin(theta_1T1d), persistent=False) def forward(self, x_NThd): """forward with known sequence length, using cached rotation matrix""" assert x_NThd.size(1) == self.seq_len x1_NThd, x2_NThd = x_NThd.float().chunk(2, dim=-1) y1_NThd = x1_NThd * self.cos_1T1d - x2_NThd * self.sin_1T1d y2_NThd = x1_NThd * self.sin_1T1d + x2_NThd * self.cos_1T1d return torch.cat([y1_NThd, y2_NThd], dim=-1).type_as(x_NThd)
...and for any particular gradient update applied to the key vector at position 0, here's a plot of the relative magnitude of the change in output for the same input, if it was located at position \(x\) instead: