[2025-06-12] RoPE visualization
Edit (2025-06-25): Elaborated more.
Here's an implementation of RoPE:
class Rotary(nn.Module): def __init__(self, seq_len: int, head_dim: int, min_freq: float, max_mult: float): """ Args: head_dim: the dimensionality of the attention head, equal to n_freqs * 2 min_freq: the smallest angular frequency of the rotation max_mult: the ratio of the largest to the smallest angular frequencies Dimension key: N: batch size T: sequence length H: num attention heads D: head_dim F: n_freqs == head_dim // 2 """ super().__init__() n_freqs = head_dim // 2 self.freqs_F = nn.Buffer(min_freq * (max_mult ** torch.linspace(0, 1, n_freqs))) theta_T1F = torch.arange(seq_len)[:, None, None] * self.freqs_F self.cos_T1F = nn.Buffer(torch.cos(theta_T1F)) self.sin_T1F = nn.Buffer(torch.sin(theta_T1F)) def forward(self, input_NTHD: torch.Tensor): """forward with known sequence length, using cached rotation matrices""" x_NTHF, y_NTHF = input_NTHD.float().chunk(2, dim=-1) x_out_NTHF = x_NTHF * self.cos_T1F - y_NTHF * self.sin_T1F y_out_NTHF = x_NTHF * self.sin_T1F + y_NTHF * self.cos_T1F output_NTHD = torch.cat([x_out_NTHF, y_out_NTHF], dim=-1).type_as(input_NTHD) return output_NTHD
For language modeling, typical values for min_freq and max_mult would be 1e-4 and 1e4, respectively. (Actually, max_mult would be slightly lower, since typically the upper bound of the exponents would be exclusive.) For applications outside of language, I tend to normalize positional coordinates to the interval [-1, 1], in which case a reasonable default min_freq would be 0.1 or so.
Below, I'm plotting the cosine alignment of an unrotated query vector with the rotated version of itself over different positions. In general, attention maps will be linear combinations of the (softmax of) the plotted function. This is, in a sense, the "most specific" a query for a specific key can be, if you bound the norm of the query. Increasing min_freq results in more specific attention maps, which can improve expressivity, but also results on a weaker prior in terms of the smoothness of the output over small positional translations of the keys.
This plot also equivalently shows, given the gradient update that gets applied to a key at position zero, the relative change in the attention score for the same query and key, across different positions of the key.