Minimind项目源码解析(2) RoPE旋转位置编码代码解析

0 阅读5分钟

RoPE 旋转位置编码详细代码解析

一、位置编码的作用

既然要讲旋转位置编码,那么我们就要讲一下为什么要采用位置编码。假设我们有一句话"一把把把手把住了",这句话中的每个“把”都有着自己不同的含义,但是如果我们不采用一个位置编码,那么这句话中的每个“把”通过词嵌入模块后就会变成一个相同的向量,从而丢失了很多的语义。所以我们可以通过叠加位置编码给予每个“把”一个不同的位置信息,那么就可以解决这个问题,补上这个缺失的语义。

二、常用位置编码

现在有非常多种的位置编码,在此处我只介绍两种比较有名的位置编码。
第一个就是transformer中原版的位置编码--正余弦位置编码,第二个就是本文标题中的旋转位置编码

正余弦位置编码

Transformer 原版正弦余弦位置编码公式定义如下:

{PE(pos,2i)=sin(pos100002idmodel)PE(pos,2i+1)=cos(pos100002idmodel)\begin{cases} PE_{(\text{pos}, 2i)} = \sin\left(\frac{\text{pos}}{10000^{\frac{2i}{d_{\text{model}}}}}\right) \\ PE_{(\text{pos}, 2i+1)} = \cos\left(\frac{\text{pos}}{10000^{\frac{2i}{d_{\text{model}}}}}\right) \end{cases}

其中:

  • pos\text{pos}:token 在序列中的位置(如第 3 个 token,pos=3\text{pos}=3);
  • ii:编码向量的维度索引(取值范围:0idmodel210 \le i \le \frac{d_{\text{model}}}{2}-1);
  • dmodeld_{\text{model}}:token 嵌入向量的维度(如 512、768 等)。

正余弦位置编码的作用位置是embedding层之后,直接与经过embedding层的词嵌入进行相加。

其有三个致命的短板:

1.相对位置隐式、难学习:内积包含绝对位置项,模型需费力从叠加向量中分离语义与位置,长距离依赖建模弱。

2.长度外推能力差:预定义最大长度,超出后高频分量几乎无差异,位置区分失效。

3.语义信息被干扰:直接相加改变向量模长与方向,破坏语义空间结构,相似度计算被污染。

RoPE旋转位置编码

原理推荐在知乎查找,此处仅对MINImind的RoPE的源码进行解析,同时这里实现了外推算法YARN。

预计算sin矩阵和cos矩阵

import math
import torch
from typing import Optional

def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
                         rope_scaling: Optional[dict] = None):
    """
    预计算RoPE(旋转位置编码)的cos和sin矩阵,支持YaRN外推优化
    Args:
        dim: 模型特征维度(d_model)
        end: 要预计算的最大位置数(比如32768)
        rope_base: RoPE基数(对应公式中的10000/1e6)
        rope_scaling: YaRN外推配置字典(None则用原生RoPE)
    Returns:
        freqs_cos: 所有位置的cos值矩阵,尺寸[end, dim]
        freqs_sin: 所有位置的sin值矩阵,尺寸[end, dim]
    """
    # 1. 计算原生RoPE的基础频率(每两个维度为一组)
    # torch.arange(0, dim, 2)[:dim//2] → 生成维度对索引,尺寸[dim//2,]
    # 公式:freqs_i = 1 / (rope_base^(2i/dim)) (对应θ_i = 10000^(-2i/dim))
    freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # 注意力缩放因子,默认1.0(不缩放)
    attn_factor = 1.0

    # 2. YaRN外推优化逻辑(仅当配置了rope_scaling时生效)
    if rope_scaling is not None:
        # 从配置中取出YaRN关键参数
        # orig_max: 模型原生支持的最大位置(比如2048)
        # factor: 外推倍数(比如16 → 支持2048×16=32768)
        # beta_fast/beta_slow: 控制不同维度的旋转速度调整范围
        # attn_factor: 注意力缩放因子
        orig_max = rope_scaling.get("original_max_position_embeddings", 2048)
        factor = rope_scaling.get("factor", 16)
        beta_fast = rope_scaling.get("beta_fast", 32.0)
        beta_slow = rope_scaling.get("beta_slow", 1.0)
        attn_factor = rope_scaling.get("attention_factor", 1.0)

        # 仅当预计算位置超过原生最大位置时,才调整频率
        if end / orig_max > 1.0:
            # YaRN辅助函数:计算需要调整频率的维度索引边界
            # 作用:划分"快旋转维度"和"慢旋转维度",避免全维度过度调整
            inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
            
            # 计算调整范围的上下界(确保在0~dim//2-1之间)
            low = max(math.floor(inv_dim(beta_fast)), 0)
            high = min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
            
            # 生成线性斜坡因子(0~1之间),尺寸[dim//2,]
            # 作用:不同维度按不同比例调整频率
            ramp = torch.clamp(
                (torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001),
                0, 1  # 限制值在0~1,避免异常
            )
            
            # YaRN核心:调整频率(外推越长,频率越低),尺寸保持[dim//2,]
            freqs = freqs * (1 - ramp + ramp / factor)

    # 3. 生成位置索引(0~end-1),尺寸[end,]
    t = torch.arange(end, device=freqs.device)
    
    # 4. 计算位置×维度的频率矩阵(外积),尺寸[end, dim//2]
    # 每个位置的每个维度对,都有对应的旋转角度
    freqs = torch.outer(t, freqs).float()

    # 5. 生成cos矩阵并扩展到全维度(拼接两次,匹配模型dim)
    # torch.cos(freqs) → 尺寸[end, dim//2]
    # torch.cat拼接后 → 尺寸[end, dim],乘以注意力缩放因子
    freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
    
    # 6. 生成sin矩阵(逻辑同cos),尺寸[end, dim]
    freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor

    # 返回cos/sin矩阵,供后续query/key旋转使用
    return freqs_cos, freqs_sin

RoPE的应用

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    def rotate_half(x):
        return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)

    q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
    k_embed = (k * cos.u
    unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
    return q_embed, k_embed
    

原理可以直接看手写,其实就是一个计算的优化,利用矩阵计算的特性,加速计算。其中还有一些广播方面的内容,可以再进行一下详细理解。 e1af1ffe0245139cec3f0b285e619a07.jpg