On N-dimensional Rotary Positional Embeddings
RoPE in one dimension
One of the simplest ways of encoding relative positional information in attention is to add a scalar to each of the attention logits, with a value somehow depending on the distance between the corresponding query and key (e.g. learned values in T5, fixed values decreasing with distance in ALiBi).
However, this makes it difficult for a query to attend to any specific (key, relative position) pair. In particular, the query must have a component pointing in the direction of the desired key, but this increases the attention scores with all tokens with that key, regardless of their position.
Rotary positional embeddings (RoPE) are an elegant solution to this problem. Essentially, the query and key vectors for every token are rotated by an angle proportional to the token's 1-d coordinate position.
To be specific, each attention head has its \(D\) channel dimensions divided into \(D / 2\) dimension pairs. For a given query or key input vector \(x \in \mathbb R^D\) located at position \(t\), the \(i\)th dimension pair \((x_{2i + 1}, x_{2i + 2})\) is rotated about the origin by an angle \(\omega_i \cdot t\), where \(\omega_i\) is the angular frequency corresponding to \(i\):
The composition of independent rotations on orthogonal 2d planes is itself just a higher-dimensional rotation (see this article or this post), so RoPE is implementing a single rotation on the vector as a whole.
Typically, a range of log-spaced frequency magnitudes are selected, where the \(i\)th frequency magnitude is \[\omega_i = \omega_\text{min} \cdot (\omega_\text{max} / \omega_\text{min}) ^{\frac{i}{D/2 - 1}}.\]
Intuitively, larger values of \(\omega_\text{min}\) and \(\omega_\text{max}\) enforce the prior that queries should be more specific about position, which can improve expressivity but hurt generalizability. Similarly, smaller values of \(\omega_\text{min}\) and \(\omega_\text{max}\) lead to more invariance w.r.t position, improving generalizability but hurting expressivity.
As more frequencies are added, their periodic oscillations cancel each other out, resulting in an attention map concentrated at a specific 1d coordinate position.
Extensions to 2 or more dimensions
The most common extension of RoPE to 2 dimensions used in vision transformer (ViT) implementations today is called axial RoPE, which applies 1d RoPE twice: rotating the first \(D/2\) dimensions of the query/key according to x-position, and the remaining \(D/2\) dimensions according to y-position.
Similar to 1d RoPE, 2d axial RoPE also encodes purely relative positional information. However, 2d axial RoPE does not enable attending solely to specific (key, relative position) pairs. The first half of a query, which rotates according to x-position, contributes the same amount to the attention score for a key regardless of the key's y-position. Similarly, the second half of a query contributes the same amount regardless of x-position. Attending to a token necessarily means attending to any tokens with similar keys located in the same row or column, with a cosine alignment at least half as large, roughly speaking.
The main insight is that, rather than rotating any particular dimension pair based on only x-position or based on only y-position, the rotations can instead be based on the tokens' positions measured along arbitrary 2d directions (Heo et al. 2024).
Although Heo et al. introduced this idea last year, various recent works which cite Heo et al., such as SAM 2, Infinity, and Perception Encoder are actually still using axial RoPE, based on their official implementations. Maybe it wasn't obvious from only reading the abstract of Heo et al. that they were introducing a novel approach differing from axial RoPE.
At initialization, unit directions \(\{\mathbf{u}_i\}_{i=1}^{D/2}, \mathbf{u}_i \in \mathbb{R}^2, \|\mathbf{u}_i\|_2 = 1\) are selected for each of the \(D/2\) dimension pairs. During inference, the angle of rotation for the \(i\)th pair is given by \(\omega_i\) times the inner product of the token's 2d position with \(\mathbf{u}_i\) (whereas for 1d RoPE, this angle was just \(\omega_i\) times the token's 1d coordinate position).
To be precise, the general N-dimensional RoPE for positions \(\mathbf{t} \in \mathbb R^N\) can be written as
This includes axial RoPE as a special case where each \(\mathbf{u}_i\) is in the standard basis.
In other words, the \(i\)th dimension pair is rotated by an angle proportional to "the position of the token measured according to the direction \(\mathbf{u}_i\)". By spacing out \(\{\mathbf{u}_i\}_{i=1}^{D / 2}\) uniformly on the unit circle, while using log-spaced frequency magnitudes \(\{\omega_i\}_{i=1}^{D/2}\), it turns out that RoPE can also produce concentrated attention maps in 2d!
Fig. 3. Cosine similarities between a fixed query and the 2d-RoPE-rotated version of itself over varying positions. Here, the similarities are evaluated on a prefix of the dimension pairs. The attention map becomes gradually more concentrated as the number of frequencies increases.
NOTE: since the sequence lengths relevant for language modeling are much longer than the side length of the patchified inputs for vision transformers, the minimum and maximum frequency magnitudes should be adjusted to compensate. I recommend using a coordinate system with positions normalized to [-1.0, 1.0] and using \(\omega_\text{min}\) between 0.2 and 1.0 and \(\omega_\text{max}\) between 20 and 100.
Selecting frequency directions
Mixed RoPE (Heo et al. 2024) initializes \(\{\mathbf{u}_i\}_{i=1}^{D / 2}\) by sampling uniformly random vectors from the unit circle, then treating the frequency vectors \(\mathbf{f}_i = \omega_i \mathbf{u}_i\) as learnable parameters. However, it's unclear a priori whether learnable frequency vectors are actually beneficial, especially if the frequencies are selected in such a way that it's already possible to query unique positions at initialization.
Intuitively, typical neural network parameters, e.g. the weights of a linear layer, have the property that a gradient step on some training input will generally only induce a large change in the layer's outputs for inputs similar to that training input. In contrast, changes to the frequency vectors of RoPE will nontrivially modify the attention scores between almost all query-key pairs, which could make these frequencies less amenable to gradient optimization.
If these frequency vectors are kept frozen instead, then ideally we would want to initialize them in a more principled, deterministic fashion. In the initial discussion on the EleutherAI discord, Kevin Yin suggested arranging the frequency vectors for each head in order of increasing magnitude, and rotating the \(i\)th vector to an angle of \(i\cdot 2\pi / \varphi\) where \(\varphi = (1 + \sqrt{5}) / 2\) is the golden ratio, resulting in uniformly distributed frequency vectors. (This is the approach used in the experiments below, though, as Kevin later pointed out, rotating by \(i \cdot \pi / \varphi\) is likely to perform better.)
For embedding positions in more than 2 dimensions, one option that is easy to implement and works well enough is to sample from \(U(0, 1)\) quasi-randomly, mapping them to Gaussian samples using the inverse CDF, and then normalizing to length one. An example implementation is provided here.
ViT experiments
Here, I compare:
- Learned absolute positional embeddings (APE, e.g. Dosovitskiy et al. 2020)
-
Fixed sinusoidal positional embeddings (SinCos)
- an absolute positional embedding approach without trainable parameters (Chen et al. 2021, Beyer et al. 2022)
- similar to the positional embedding from the original transformer paper (Vaswani et al. 2017), but with one half derived from x-positions and the other half from y.
- Axial RoPE
- RoPE with learned, randomly initialized frequencies (mixed RoPE, Heo et al. 2024)
-
LieRE (Ostmeier et al. 2024)
- a rotary positional embedding scheme where the rotation is parameterized by the matrix exponential of the weighted sum of learnable skew-symmetric matrices, with weights equal to token position
- RoPE with frozen, uniformly spaced directions (uniform RoPE)
CIFAR10
I trained some small (7M parameter) ViTs with 4x4 patches for 200 epochs on CIFAR10. I searched over \(\omega_\text{min} \in \{0.5, 1.0\}\) with \(\omega_\text{max} = 100 \cdot \omega_\text{min}\). Below, I'm reporting the best validation negative log-likelihood (NLL) and accuracy, mean ± std over 2 seeds, for each approach. See hyperparameters here and code here.
The official implementation of mixed RoPE uses \(\omega_\text{min}\) = 0.65, \(\omega_\text{max}\) = 6.5. (Well, more precisely, they used frequencies from 0.1 to 1.0 when working with integer row/column positions ranging from 0 to 13, but here I'm expressing the frequencies w.r.t positions normalized to [-1.0, 1.0].)
CIFAR10 ViT(dim=384, mlp_dim=768, depth=6) / patch 4 @ 200 epochs
Method | Learned | \(\omega_\text{min}\) | \(\omega_\text{max}\) | Valid NLL (↓) | Valid Accuracy (%) (↑) |
---|---|---|---|---|---|
APE | ✓ | N/A | N/A | 0.4287 ± 0.0031 | 89.70 ± 0.03 |
SinCos | 1.00 | 100.0 | 0.4144 ± 0.0124 | 89.93 ± 0.45 | |
Axial RoPE | 0.50 | 50.0 | 0.3535 ± 0.0018 | 91.95 ± 0.05 | |
Mixed RoPE, original freqs | ✓ | 0.65 | 6.5 | 0.3550 ± 0.0072 | 91.63 ± 0.32 |
Mixed RoPE, adjusted freqs | ✓ | 1.00 | 100.0 | 0.3394 ± 0.0015 | 92.43 ± 0.07 |
LieRE | ✓ | N/A | N/A | 0.3461 ± 0.0023 | 91.95 ± 0.04 |
Uniform RoPE | 1.00 | 100.0 | 0.3292 ± 0.0023 | 92.43 ± 0.05 |
Both absolute position embedding methods performed poorly. Mixed RoPE underperformed when using the frequencies from the official implementation, but after adjusting the frequencies, mixed RoPE outperformed axial RoPE and did about as well as uniform RoPE. During hyperparam tuning, axial RoPE was the only approach which benefited from a frequency range of (0.5-50) rather than (1.0-100.0) on CIFAR10. LieRE performed slightly better than axial RoPE when LieRE's skew symmetric parameters were not shared between layers or heads, though this resulted in high memory usage in my implementation compared to the other approaches, and still underperformed compared to mixed or uniform RoPE. Uniform RoPE had the best NLL.
ImageNet-1K
I trained ViT B/16 sized models (86M parameters) on ImageNet-1K for 90 epochs. I broadly used the same data augmentation and preprocessing scheme as Beyer et al. 2022, training at 224x224 with inception cropping and a small amount of RandAugment and MixUp, though with the same architecture and optimizer setup as the CIFAR10 experiments above. Due to resource limits, I only compared fixed sinusoidal positional embeddings, axial RoPE, mixed RoPE, and uniform RoPE. See hyperparameters here and code here.
ImageNet-1K ViT B/16 @ 90 epochs
Method | Learned | \(\omega_\text{min}\) | \(\omega_\text{max}\) | Zero freqs | Valid NLL (↓) | Valid Accuracy (%) (↑) |
---|---|---|---|---|---|---|
SinCos | 1.0 | 100.0 | N/A | 0.8444 | 78.71 | |
Axial RoPE | 0.2 | 20.0 | 0 / 16 | 0.8034 | 79.58 | |
" | " | " | 8 / 16 | 0.8055 | 79.61 | |
Mixed RoPE | ✓ | 0.2 | 20.0 | N/A | 0.8025 | 79.73 |
Uniform RoPE | 0.2 | 20.0 | 0 / 32 | 0.8064 | 79.67 | |
" | " | " | 8 / 32 | 0.7979 | 79.78 | |
" | " | " | 16 / 32 | 0.8002 | 79.68 |
I searched over \(\omega_\text{min} \in \{0.2, 0.5, 1.0\}\), and, in contrast to CIFAR10, I found that lower frequency magnitudes (0.2-20.0) performed better for the RoPE approaches on ImageNet. In my initial runs, axial RoPE, mixed RoPE, and uniform RoPE performed about the same, with mixed RoPE at a slight advantage. From analyzing the learned frequencies of mixed RoPE, I found that a good portion of the learned frequencies decreased to almost zero during training. Based on modded-nanogpt and Barbero et al. 2024, I set 8 / 32 of the frequencies of uniform RoPE to zero and it outperformed mixed RoPE by a slim margin. (Here, 8 / 32 means that 8 of the 32 unique frequency magnitudes were set to zero; axial RoPE repeats the frequency magnitudes for x and y, so there were only 16 unique frequency magnitudes total.)
Here, I evaluated how well each method generalized to different resolutions at inference time. The models were trained at 224x224 only, and to evaluate at a higher resolution (e.g. 384x384), the positions of each patch were scaled to still span [-1.0, 1.0] (so adjacent patches end up with coordinates which are closer together than before). I also tried scaling the temperature of the softmax to account for the increased token count (like e.g. here), using a temperature of \(\log(\text{new_res}^2 / p^2) / \log (\text{old_res}^2 / p^2)\) in this case, where \(p\) is the patch size.
ViT B/16 resolution generalization: validation accuracies (%) (vs in-dist @ 224x224)
Method | Learned | \(\omega_\text{min}\) | \(\omega_\text{max}\) | Zero freqs | 224x224 (in-dist) | 384x384 | 384x384 w/ temp | 512x512 w/ temp |
---|---|---|---|---|---|---|---|---|
SinCos | 1.0 | 100.0 | N/A | 78.71 | 75.12 (-3.59) | 77.29 (-1.42) | 74.95 (-3.76) | |
Axial RoPE | 0.2 | 20.0 | 0 / 16 | 79.58 | 78.31 (-1.27) | 79.57 (-0.01) | 77.18 (-2.40) | |
Mixed RoPE | ✓ | 0.2 | 20.0 | N/A | 79.73 | 78.61 (-1.12) | 79.83 (+0.10) | 77.55 (-2.18) |
Uniform RoPE | 0.2 | 20.0 | 8 / 32 | 79.78 | 79.19 (-0.59) | 80.41 (+0.63) | 79.15 (-0.63) |
Axial RoPE generalized better than fixed sinuisoidal positional embeddings, and mixed RoPE generalized better than axial. Surprisingly, uniform RoPE generalized better than mixed RoPE, and both mixed RoPE and uniform RoPE actually had higher validation accuracy at 384x384 when combined with temperature scaling than it had at the training resolution of 224x224.
As a sidenote, I found that thinner vision transformers (dim=384, mlp_dim=768, 15M params) with smaller 8x8 patches and the same depth=12 (due to increased # of patches, approximately equivalent in FLOPs to B/16) resulted in improved performance across the board:
ImageNet-1K ViT(dim=384, mlp_dim=768, depth=12) / patch 8 @ 90 epochs
Method | Learned | \(\omega_\text{min}\) | \(\omega_\text{max}\) | Valid NLL (↓) | Valid Accuracy (%) (↑) |
---|---|---|---|---|---|
SinCos | 1.0 | 100.0 | 0.7834 | 79.46 | |
Axial RoPE | 0.2 | 20.0 | 0.7393 | 80.48 | |
Mixed RoPE | ✓ | 0.2 | 20.0 | 0.7455 | 80.26 |
Uniform RoPE | 0.2 | 20.0 | 0.7456 | 80.42 |
(These runs used nonzero initialization frequencies only, though I'd expect better performance with some frequencies set to zero.)
Discussion
Overall, it seems like uniform RoPE was consistently among the best approaches when the frequency magnitudes were properly tuned. Mixed RoPE, with learnable frequencies, also performed well. However, mixed RoPE was fairly sensitive to the initialization magnitude of the frequencies, and can also require tuning the learning rate for the frequencies (which, under my setup, were optimized separately with AdamW rather than Muon). I would generally recommend defaulting to uniform RoPE, and performing at least a small amount of tuning on the RoPE frequency magnitudes regardless of the approach used.
Reference implementations
Uniform 2d RoPE (PyTorch)
class UniformRoPE2d(nn.Module):
def __init__(
self,
image_size: tuple[int, int],
n_heads: int,
head_dim: int,
min_freq: float,
max_freq: float,
p_zero_freqs: float = 0.0,
direction_spacing: float = math.pi * (math.sqrt(5) - 1) / 2,
):
"""
Args:
image_size: expected height and width of (patchified) input
n_heads: number of attention heads
head_dim: attention head dimensionality
min_freq, max_freq: lowest and highest nonzero frequency magnitudes
p_zero_freqs: proportion of frequencies set to 0
direction_spacing: difference in radians between adjacent directions along
which position is measured
Dimension key:
N: batch size
H: image_size[0]
W: image_size[1]
h: n_heads
d: head_dim
F: num_freqs == d // 2
"""
super().__init__()
assert head_dim % 2 == 0
assert 0 <= p_zero_freqs <= 1
n_freqs = head_dim // 2
n_zero_freqs = round(p_zero_freqs * n_freqs)
omega_F = torch.cat(
(
torch.zeros(n_zero_freqs),
min_freq
* (max_freq / min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
)
)
phi_hF = (
torch.arange(n_heads * n_freqs).reshape(n_heads, n_freqs)
* direction_spacing
)
directions_hF2 = torch.stack((torch.cos(phi_hF), torch.sin(phi_hF)), dim=-1)
freqs_hF2 = omega_F.unsqueeze(-1) * directions_hF2
H, W = image_size
xlim, ylim = math.sqrt(W / H), math.sqrt(H / W)
x_HW = torch.linspace(-xlim, xlim, W).reshape(1, W).expand(H, W)
y_HW = torch.linspace(-ylim, ylim, H).reshape(H, 1).expand(H, W)
positions_HW112 = torch.stack((x_HW, y_HW), dim=-1).reshape(H, W, 1, 1, 2)
theta_HWhF = (freqs_hF2 * positions_HW112).sum(dim=-1)
self.register_buffer("cos_HWhF", torch.cos(theta_HWhF))
self.register_buffer("sin_HWhF", torch.sin(theta_HWhF))
def forward(self, input_NHWhd: torch.Tensor) -> torch.Tensor:
x_NHWhF, y_NHWhF = input_NHWhd.float().chunk(2, dim=-1)
x_out_NHWhF = x_NHWhF * self.cos_HWhF - y_NHWhF * self.sin_HWhF
y_out_NHWhF = x_NHWhF * self.sin_HWhF + y_NHWhF * self.cos_HWhF
output_NHWhd = torch.cat((x_out_NHWhF, y_out_NHWhF), dim=-1)
return output_NHWhd.type_as(input_NHWhd)
direction_spacing
can be set to \(\pi / 2\) to
(approximately) emulate axial RoPE (though this will cause the y
frequencies to be slightly larger than the x frequencies on average,
since the frequencies are assigned in increasing order, alternating
between x and y, and x is assigned first). Note that
direction_spacing
was set to \(\pi (\sqrt{5} -
1)\) in the experiments above.
Uniform Nd RoPE (PyTorch)
def _phi(m: int) -> float:
x = 2.0
for _ in range(10):
x = (1 + x) ** (1.0 / (m + 1.0))
return x
def uniform_directions(n: int, d: int) -> torch.Tensor:
g = _phi(d)
alpha = (1.0 / g) ** torch.arange(1, d + 1, dtype=torch.float64)
i = torch.arange(1, n + 1, dtype=torch.float64).unsqueeze(1)
z = torch.fmod(i * alpha, 1.0)
directions = torch.erfinv(2.0 * z - 1.0)
directions = directions / directions.norm(dim=1, keepdim=True)
return directions.float()
class UniformRoPENd(nn.Module):
def __init__(
self,
pos_dim: int,
n_heads: int,
head_dim: int,
min_freq: float,
max_freq: float,
):
"""
Args:
pos_dim: dimensionality of the token positions
n_heads: number of attention heads
head_dim: attention head dimensionality
min_freq, max_freq: lowest and highest nonzero frequency magnitudes
p_zero_freqs: proportion of frequencies set to 0
Dimension key:
N: batch size
L: number of tokens per sample
P: pos_dim
h: n_heads
d: head_dim
F: num_freqs == head_dim // 2
"""
super().__init__()
n_freqs = head_dim // 2
n_zero_freqs = round(p_zero_freqs * n_freqs)
omega_F = torch.cat(
(
torch.zeros(n_zero_freqs),
min_freq
* (max_freq / min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
)
)
directions_hFP = uniform_directions(n_heads * n_freqs, pos_dim).reshape(
n_heads, n_freqs, pos_dim
)
self.register_buffer("freqs_hFP", directions_hFP * omega_F.reshape(n_freqs, 1))
def forward(self, input_NLhd: torch.Tensor, pos_NLP: torch.Tensor) -> torch.Tensor:
x_NLhF, y_NLhF = input_NLhd.float().chunk(2, dim=-1)
theta_NLhF = (self.freqs_hFP * pos_NLP[..., None, None, :].float()).sum(dim=-1)
cos_NLhF = torch.cos(theta_NLhF)
sin_NLhF = torch.sin(theta_NLhF)
x_out_NLhF = x_NLhF * cos_NLhF - y_NLhF * sin_NLhF
y_out_NLhF = x_NLhF * sin_NLhF + y_NLhF * cos_NLhF
output_NLhd = torch.cat((x_out_NLhF, y_out_NLhF), dim=-1)
return output_NLhd.type_as(input_NLhd)
Hyperparameters
CIFAR10
- Augmentations: ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05, hue=0.025), random padding by up to 4 pixels, and a random horizontal flip with p=0.5
- Architecture: RMSNorm on the input to each residual block, the queries, and the keys; ReLU^2 activations; learnable zero-init per-channel scaling on the last linear layer of each block; and no biases.
- Optimizers: Muon for linear layers. AdamW for per-channel scalings, and, when applicable, the learnable APE, the frequency vectors for mixed RoPE, or the raw parameters for LieRE. The schedule was constant learning rate with no warmup, followed by linear cooldown to 0.
- For the learnable APE baseline, I additionally tuned the initialization std and found that a relatively large value of 0.5 worked best, though learnable APE still performed the worst on CIFAR10 out of all tested positional embedding schemes.
Steps | 10,000 (200 epochs) |
Batch size | 1,000 |
Muon LR | 0.06 |
Muon momentum | 0.95 |
Muon weight decay | 0.01 |
AdamW LR | 0.003 |
AdamW betas | (0.9, 0.95) |
AdamW weight decay | 0.01 |
LR cooldown start | 7,500 |
Label smoothing | 0.1 |
Patch size | 4 |
dim | 384 |
MLP dim | 768 |
depth | 6 |
ImageNet-1K
- Augmentations: RandAugment(2, 10) and MixUp(alpha=0.2)
- Everything else same as the CIFAR10 experiments unless otherwise specified
Steps | 187,650 (90 epochs) |
Global batch size | 1,024 |
Muon LR | 0.03 |
Muon momentum | 0.95 |
Muon weight decay | 0.01 |
AdamW LR | 0.001 |
AdamW betas | (0.9, 0.95) |
AdamW weight decay | 0.01 |
LR cooldown start | 150,120 (70 epochs) |
Patch size | 16 |
dim | 768 |
MLP dim | 3072 |
depth | 12 |
Acknowledgements
Kevin Yin for suggesting the uniform rotations based on the golden ratio, as well as general feedback. Stephen Huan for miscellaneous suggestions.
How to cite
@misc{xiong2025ndrope
author = {Jerry Xiong},
title = {On N-dimensional rotary positional embeddings},
year = {2025},
url = {https://jerryxio.ng/posts/nd-rope/}
}