05. Rotary Position Embedding
Rotary Position Embedding (RoPE) encodes absolute position into the rotation angle of query and key vectors. Unlike learned position embeddings, RoPE is position-agnostic and extends to arbitrary sequence lengths.
RoPE applies a rotation matrix to adjacent pairs of dimensions:
RoPE(q_i, i) = R_i @ q_i
RoPE(k_j, j) = R_j @ k_j
The dot product of position-encoded vectors reduces to a function of relative position:
⟨RoPE(q_i), RoPE(k_j)⟩ = ⟨q_i, k_j⟩ rotated by (i-j)θ
import torch
import torch.nn as nn
import math
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""Precompute complex frequency factors for RoPE."""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
positions = torch.arange(end)
# Complex numbers via polar representation: e^(i * position * theta)
freqs = torch.outer(positions, freqs)
return torch.polar(torch.ones_like(freqs), freqs * 2 * math.pi)
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
"""Apply rotary embedding to tensor.
x: (batch, heads, seq_len, head_dim)
freqs_cis: (seq_len, head_dim/2) - complex numbers
"""
x_complex = torch.view_as_complex(x.float())
# (batch, heads, seq_len, head_dim/2) - reshape for complex operations
# Reshape freqs_cis to broadcast with x
freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(0) # (1, 1, seq, head/2)
# Apply rotation
x_rotated = x_complex * freqs_cis
# Convert back to real
x_out = torch.view_as_real(x_rotated)
return x_out.flatten(-2).type_as(x)
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=2048, base=10000.0):
super().__init__()
self.dim = dim
self.base = base
self.max_seq_len = max_seq_len
self._freqs_cis = None
def forward(self, seq_len, device):
if self._freqs_cis is None:
self._freqs_cis = precompute_freqs_cis(
self.dim, seq_len, self.base
)
return self._freqs_cis.to(device)
def clear_cache(self):
self._freqs_cis = None
Failure mode: Dimension mismatch. RoPE operates on adjacent pairs of dimensions. If head_dim is not even, the operation fails. Common error: setting d_model=3072 with n_heads=32 gives d_k=96 — divisible by 2, but verify before implementation.
Failure mode: Cache invalidation. If caching frequency computations and then processing longer sequences, you get shape mismatches. Always check that cached freqs match current sequence length.
Implement a Llama-style attention module using RoPE. Apply RoPE to Q and K before attention computation, then verify that the positional encoding is absent from learned embeddings ( embeddings should be plain token embeddings with no positional information).