RoPE 旋转位置编码详细代码解析
一、位置编码的作用
既然要讲旋转位置编码,那么我们就要讲一下为什么要采用位置编码。假设我们有一句话"一把把把手把住了",这句话中的每个“把”都有着自己不同的含义,但是如果我们不采用一个位置编码,那么这句话中的每个“把”通过词嵌入模块后就会变成一个相同的向量,从而丢失了很多的语义。所以我们可以通过叠加位置编码给予每个“把”一个不同的位置信息,那么就可以解决这个问题,补上这个缺失的语义。
二、常用位置编码
现在有非常多种的位置编码,在此处我只介绍两种比较有名的位置编码。
第一个就是transformer中原版的位置编码--正余弦位置编码,第二个就是本文标题中的旋转位置编码
正余弦位置编码
Transformer 原版正弦余弦位置编码公式定义如下:
其中:
- :token 在序列中的位置(如第 3 个 token,);
- :编码向量的维度索引(取值范围:);
- :token 嵌入向量的维度(如 512、768 等)。
正余弦位置编码的作用位置是embedding层之后,直接与经过embedding层的词嵌入进行相加。
其有三个致命的短板:
1.相对位置隐式、难学习:内积包含绝对位置项,模型需费力从叠加向量中分离语义与位置,长距离依赖建模弱。
2.长度外推能力差:预定义最大长度,超出后高频分量几乎无差异,位置区分失效。
3.语义信息被干扰:直接相加改变向量模长与方向,破坏语义空间结构,相似度计算被污染。
RoPE旋转位置编码
原理推荐在知乎查找,此处仅对MINImind的RoPE的源码进行解析,同时这里实现了外推算法YARN。
预计算sin矩阵和cos矩阵
import math
import torch
from typing import Optional
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
rope_scaling: Optional[dict] = None):
"""
预计算RoPE(旋转位置编码)的cos和sin矩阵,支持YaRN外推优化
Args:
dim: 模型特征维度(d_model)
end: 要预计算的最大位置数(比如32768)
rope_base: RoPE基数(对应公式中的10000/1e6)
rope_scaling: YaRN外推配置字典(None则用原生RoPE)
Returns:
freqs_cos: 所有位置的cos值矩阵,尺寸[end, dim]
freqs_sin: 所有位置的sin值矩阵,尺寸[end, dim]
"""
# 1. 计算原生RoPE的基础频率(每两个维度为一组)
# torch.arange(0, dim, 2)[:dim//2] → 生成维度对索引,尺寸[dim//2,]
# 公式:freqs_i = 1 / (rope_base^(2i/dim)) (对应θ_i = 10000^(-2i/dim))
freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 注意力缩放因子,默认1.0(不缩放)
attn_factor = 1.0
# 2. YaRN外推优化逻辑(仅当配置了rope_scaling时生效)
if rope_scaling is not None:
# 从配置中取出YaRN关键参数
# orig_max: 模型原生支持的最大位置(比如2048)
# factor: 外推倍数(比如16 → 支持2048×16=32768)
# beta_fast/beta_slow: 控制不同维度的旋转速度调整范围
# attn_factor: 注意力缩放因子
orig_max = rope_scaling.get("original_max_position_embeddings", 2048)
factor = rope_scaling.get("factor", 16)
beta_fast = rope_scaling.get("beta_fast", 32.0)
beta_slow = rope_scaling.get("beta_slow", 1.0)
attn_factor = rope_scaling.get("attention_factor", 1.0)
# 仅当预计算位置超过原生最大位置时,才调整频率
if end / orig_max > 1.0:
# YaRN辅助函数:计算需要调整频率的维度索引边界
# 作用:划分"快旋转维度"和"慢旋转维度",避免全维度过度调整
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
# 计算调整范围的上下界(确保在0~dim//2-1之间)
low = max(math.floor(inv_dim(beta_fast)), 0)
high = min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
# 生成线性斜坡因子(0~1之间),尺寸[dim//2,]
# 作用:不同维度按不同比例调整频率
ramp = torch.clamp(
(torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001),
0, 1 # 限制值在0~1,避免异常
)
# YaRN核心:调整频率(外推越长,频率越低),尺寸保持[dim//2,]
freqs = freqs * (1 - ramp + ramp / factor)
# 3. 生成位置索引(0~end-1),尺寸[end,]
t = torch.arange(end, device=freqs.device)
# 4. 计算位置×维度的频率矩阵(外积),尺寸[end, dim//2]
# 每个位置的每个维度对,都有对应的旋转角度
freqs = torch.outer(t, freqs).float()
# 5. 生成cos矩阵并扩展到全维度(拼接两次,匹配模型dim)
# torch.cos(freqs) → 尺寸[end, dim//2]
# torch.cat拼接后 → 尺寸[end, dim],乘以注意力缩放因子
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
# 6. 生成sin矩阵(逻辑同cos),尺寸[end, dim]
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
# 返回cos/sin矩阵,供后续query/key旋转使用
return freqs_cos, freqs_sin
RoPE的应用
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def rotate_half(x):
return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
k_embed = (k * cos.u
unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
return q_embed, k_embed
原理可以直接看手写,其实就是一个计算的优化,利用矩阵计算的特性,加速计算。其中还有一些广播方面的内容,可以再进行一下详细理解。