从零学习大模型--旋转位置编码

0 阅读8分钟

1、为什么要位置编码

​transformer结构中在计算不同token的attention时不包含位置信息,即在计算某个token和其他token的相关性时与其位置无关,例如“天天向上”计算出来的attention score和“天向天上”计算出来的是一样的,因此需要再attention计算时添加上位置信息

2、常见位置编码

2.1 直接使用每个token的位置信息

​ ## 1、为什么要位置编码

transformer结构中在计算不同token的attention时不包含位置信息,即在计算某个token和其他token的相关性时与其位置无关,例如“天天向上”计算出来的attention score和“天向天上”计算出来的是一样的,因此需要在attention计算时添加上位置信息

2、常见位置编码

2.1 直接使用每个token的位置信息

第一个token的位置为0,第二个token的位置为1,以此类推,如下所示,其中,m表示token的位置,i表示每个位置的token在embedding后的第i个元素

xm,i=x+posm,ix_{m,i} = x + pos_{m,i}

问题:

  • 数值过大,需要归一化,而且embedding后,位置信息更复杂
  • 无法泛化到训练时未见的长度,例如位置n是个未知长度,无法根据已有的位置信息计算出n的位置信息

2.2 三角函数位置编码

m位置的token,其位置信息为:

posm=[sin(mθ)cos(mθ)]pos_m= \begin{bmatrix} sin(mθ) \\ cos(mθ) \end{bmatrix}

同理,n位置的信息也可按照上述类似表达,假如

n=m+Δtn = m + Δt

由于

sin(a+b)=sinacosb+cosasinbcos(a+b)=cosacosbsinasinbsin(a + b) = sina * cosb + cosa * sinb \\ cos(a + b) = cosa * cosb - sina * sinb

则,n位置的信息也可经过m位置的三角变换可获得,即:

posn=[sin((m+Δt)θ)cos((m+Δt)θ)]=[cosΔtsinΔtsinΔtcosΔt][sin(mθ)cos(mθ)]=RΔt[sin(mθ)cos(mθ)]pos_n = \begin{bmatrix} sin((m + Δt)θ) \\ cos((m + Δt)θ) \end{bmatrix} = \begin{bmatrix} cosΔt &sinΔt \\ -sinΔt &cosΔt \end{bmatrix} * \begin{bmatrix} sin(mθ) \\ cos(mθ) \end{bmatrix} = \mathcal{R}_{Δt} * \begin{bmatrix} sin(mθ) \\ cos(mθ) \end{bmatrix}

m位置的token经过embedding到hidden_size(hidden_size=d)后,其位置信息表达如下所示:

posm,2i=sin(mθi)posm,2i+1=cos(mθi)θi=100002i/d,i=0,1,2...,d/21pos_{m,2i} = sin(mθ_i) \\ pos_{m,2i+1} = cos(mθ_i) \\ θ_i = 10000^{-2i/d}, i = 0,1,2...,d/2-1

从上述公式可以看出

  • 三角函数的值域[-1,1],不存在数值过大的问题,不需要额外归一化

  • 具有外推性,可以泛化到未见的长度,任意位置的位置信息都可以根据已有的位置信息来计算

但是有个问题,如果将上述的位置编码经过Attention计算,即:

qmknT=[(xm+posm)WQ][(xn+posn)WK]T=[xmWQ+posmWQ][WKTxnT+WKTposnT]q_{m} * k_{n}^{T} = [(x_{m} + pos_{m}) * W_{Q} ] * [(x_{n} + pos_{n}) * W_{K}]^{T} \\ = [x_{m}W_{Q} + pos_{m}W_{Q}] *[W_{K}^{T}x_{n}^{T} + W_{K}^{T}pos_{n}^{T}]

其中位置信息m,n相关的部分为:

posmWQWKTposnTpos_{m}W_{Q}W_{K}^{T}pos_{n}^{T}

从这来看,位置信息不再单独受m、n处的位置编码影响了,而是引入了Q、K矩阵的线性变化

3、旋转位置编码

基于上述问题,是否可以在attention计算时,直接将位置信息融进去,这就引入了旋转位置编码

传统位置编码是加法式:

Attention(q,kT)=[(xm+posm)WQ][(xn+posn)WK]TAttention(q,k^T) = [(x_{m} + pos_{m}) * W_{Q} ] * [(x_{n} + pos_{n}) * W_{K}]^{T}

而旋转位置编码是乘法式:

Attention(q,kT)=[xmWQRm,i][xnWKRn,i]TAttention(q,k^T) = [x_{m} * W_{Q} * \mathcal{R}_{m, i}] * [x_{n} * W_{K} * \mathcal{R}_{n, i}]^{T}

其中,旋转矩阵对应为:

Rm,i=[[cos(mθ0)sin(mθ0)sin(mθ0)cos(mθ0)]...0.........0...[cos(mθd/21)sin(mθd/21)sin(mθd/21)cos(mθd/21)]]\mathcal{R}_{m, i} = \begin{bmatrix} \begin{bmatrix} cos(mθ_{0}) &-sin(mθ_{0}) \\ sin(mθ_{0}) &cos(mθ_{0}) \end{bmatrix} &... &0 \\ ... &... &... \\ 0 &... &\begin{bmatrix} cos(mθ_{d/2-1}) &-sin(mθ_{d/2-1}) \\ sin(mθ_{d/2-1}) &cos(mθ_{d/2-1}) \end{bmatrix} \end{bmatrix}

则旋转位置编码的Attention计算可以表示为:

Attention(q,k)=[qmRm,i][knRn,i]T=qmRm,iRn,iTknT=qmRnm,iknTAttention(q,k) = [q_{m} * \mathcal{R}_{m, i}] * [k_{n} * \mathcal{R}_{n, i}]^{T} = q_{m} * \mathcal{R}_{m, i} * \mathcal{R}_{n, i}^{T} * k_{n}^{T} = q_{m} * \mathcal{R}_{n-m, i} * k_{n}^{T}

从上述公式可以看出,位置信息仅跟m、n相关

4、网络中旋转位置编码的代码实现

def apply_rotary_emb_simple(x, freqs):
    """
    旋转矩阵公式:
    [cos(θ), -sin(θ)]   [x1]   [x1*cos(θ) - x2*sin(θ)]
    [sin(θ),  cos(θ)] * [x2] = [x1*sin(θ) + x2*cos(θ)]

    Args:
        x: 输入张量, 形状 (batch, seq_len, n_heads, head_dim)
        freqs: 旋转角度, 形状 (seq_len, head_dim//2)

    Returns:
        旋转后的张量
    """

    # 将向量分成两半
    x1 = x[..., :x.shape[-1]//2]   # 前半部分
    x2 = x[..., x.shape[-1]//2:]   # 后半部分

    # cos 和 sin
    cos = freqs.cos().unsqueeze(0).unsqueeze(2)  # (1, seq, 1, dim//2)
    sin = freqs.sin().unsqueeze(0).unsqueeze(2)

    # 2D旋转: [x1*cos - x2*sin, x1*sin + x2*cos]
    x_rotated = torch.cat([
        x1 * cos - x2 * sin,
        x1 * sin + x2 * cos
    ], dim=-1)

    return x_rotated

5、FAQ

1、三角函数本身存在周期性,会不会导致位置信息重复的问题

对于第i维度(i = 0,1...,d/2-1),

θi=100002i/dθ_i = 10000^{-2i/d}

Ti=2π/θ=2π100002i/dT_{i} = 2π/θ = 2π*10000^{2i/d}
对于 d = 512:
- i = 0:   T ≈ 6.28      ← 高频,很快重复
- i = 128: T ≈ 316.23
- i = 256: T ≈ 15811.39  ← 低频,周期很长

虽然单个维度会重复,但所有维度的组合几乎是唯一的。

第一个token的位置为0,第二个token的位置为1,以此类推,如下所示,其中,m表示token的位置,i表示每个位置的token在embedding后的第i个元素

xm=x+posmx_m = x + pos_m

问题:

  • 数值过大,需要归一化,而且embedding后,位置信息更复杂
  • 无法泛化到训练时未见的长度,例如位置n是个未知长度,无法根据已有的位置信息计算出n的位置信息

2.2 三角函数位置编码

m位置的token,其位置信息为:

posm=[sin(mθ)cos(mθ)]pos_m= \begin{bmatrix} sin(mθ) \\ cos(mθ) \end{bmatrix}

同理,n位置的信息也可按照上述类似表达,假如

n=m+Δtn = m + Δt

由于

sin(a+b)=sinacosb+cosasinbcos(a+b)=cosacosbsinasinbsin(a + b) = sina * cosb + cosa * sinb \\ cos(a + b) = cosa * cosb - sina * sinb

则,n位置的信息也可经过m位置的三角变换可获得,即:

posn=[sin((m+Δt)θ)cos((m+Δt)θ)]=[cosΔtsinΔtsinΔtcosΔt][sin(mθ)cos(mθ)]=RΔt[sin(mθ)cos(mθ)]pos_n = \begin{bmatrix} sin((m + Δt)θ) \\ cos((m + Δt)θ) \end{bmatrix} = \begin{bmatrix} cosΔt &sinΔt \\ -sinΔt &cosΔt \end{bmatrix} * \begin{bmatrix} sin(mθ) \\ cos(mθ) \end{bmatrix} = \mathcal{R}_{Δt} * \begin{bmatrix} sin(mθ) \\ cos(mθ) \end{bmatrix}

m位置的token经过embedding到hidden_size(hidden_size=d)后,其位置信息表达如下所示:

posm,2i=sin(mθi)posm,2i+1=cos(mθi)θi=100002i/d,i=0,1,2...,d/21pos_{m,2i} = sin(mθ_i) \\ pos_{m,2i+1} = cos(mθ_i) \\ θ_i = 10000^{-2i/d}, i = 0,1,2...,d/2-1

从上述公式可以看出

  • 三角函数的值域[-1,1],不存在数值过大的问题,不需要额外归一化

  • 具有外推性,可以泛化到未见的长度,任意位置的位置信息都可以根据已有的位置信息来计算

如果将上述的位置编码经过Attention计算,即:

qmknT=[(xm+posm)WQ][(xn+posn)WK]T=[xmWQ+posmWQ][WKTxnT+WKTposnT]q_{m} * k_{n}^{T} = [(x_{m} + pos_{m}) * W_{Q} ] * [(x_{n} + pos_{n}) * W_{K}]^{T} \\ = [x_{m}W_{Q} + pos_{m}W_{Q}] *[W_{K}^{T}x_{n}^{T} + W_{K}^{T}pos_{n}^{T}]

其中位置信息m,n相关的部分为:

posmWQWKTposnTpos_{m}W_{Q}W_{K}^{T}pos_{n}^{T}

从这来看,位置信息不再单独受m、n处的位置编码影响了,而是引入了线性变化

3、旋转位置编码

基于上述问题,是否可以在attention计算时,直接将位置信息融进去,这就引入了旋转位置编码

传统位置编码是加法式:

Attention(q,k)=[(xm+posm)WQ][(xn+posn)WK]TAttention(q,k) = [(x_{m} + pos_{m}) * W_{Q} ] * [(x_{n} + pos_{n}) * W_{K}]^{T}

而旋转位置编码是乘法式:

Attention(q,k)=[xmWQRm,i][xnWKRn,i]TAttention(q,k) = [x_{m} * W_{Q} * \mathcal{R}_{m, i}] * [x_{n} * W_{K} * \mathcal{R}_{n, i}]^{T}

其中,旋转矩阵对应为:

Rm,i=[[cos(mθ0)sin(mθ0)sin(mθ0)cos(mθ0)]...0.........0...[cos(mθd/21)sin(mθd/21)sin(mθd/21)cos(mθd/21)]]\mathcal{R}_{m, i} = \begin{bmatrix} \begin{bmatrix} cos(mθ_{0}) &-sin(mθ_{0}) \\ sin(mθ_{0}) &cos(mθ_{0}) \end{bmatrix} &... &0 \\ ... &... &... \\ 0 &... &\begin{bmatrix} cos(mθ_{d/2-1}) &-sin(mθ_{d/2-1}) \\ sin(mθ_{d/2-1}) &cos(mθ_{d/2-1}) \end{bmatrix} \end{bmatrix}

则旋转位置编码的Attention计算可以表示为:

Attention(q,k)=[qmRm,i][knRn,i]T=qmRm,iRn,iTknT=qmRnm,iknTAttention(q,k) = [q_{m} * \mathcal{R}_{m, i}] * [k_{n} * \mathcal{R}_{n, i}]^{T} = q_{m} * \mathcal{R}_{m, i} * \mathcal{R}_{n, i}^{T} * k_{n}^{T} = q_{m} * \mathcal{R}_{n-m, i} * k_{n}^{T}

从上述公式可以看出,位置信息仅跟m、n相关

4、网络中旋转位置编码的代码实现

def apply_rotary_emb_simple(x, freqs):
    """
    旋转矩阵公式:
    [cos(θ), -sin(θ)]   [x1]   [x1*cos(θ) - x2*sin(θ)]
    [sin(θ),  cos(θ)] * [x2] = [x1*sin(θ) + x2*cos(θ)]

    Args:
        x: 输入张量, 形状 (batch, seq_len, n_heads, head_dim)
        freqs: 旋转角度, 形状 (seq_len, head_dim//2)

    Returns:
        旋转后的张量
    """

    # 将向量分成两半
    x1 = x[..., :x.shape[-1]//2]   # 前半部分
    x2 = x[..., x.shape[-1]//2:]   # 后半部分

    # cos 和 sin
    cos = freqs.cos().unsqueeze(0).unsqueeze(2)  # (1, seq, 1, dim//2)
    sin = freqs.sin().unsqueeze(0).unsqueeze(2)

    # 2D旋转: [x1*cos - x2*sin, x1*sin + x2*cos]
    x_rotated = torch.cat([
        x1 * cos - x2 * sin,
        x1 * sin + x2 * cos
    ], dim=-1)

    return x_rotated

5、FAQ

1、三角函数本身存在周期性,会不会导致位置信息重复的问题

对于第i维度(i = 0,1...,d/2-1),

θi=100002i/dθ_i = 10000^{-2i/d}

Ti=2π/θ=2π100002i/dT_{i} = 2π/θ = 2π*10000^{2i/d}
对于 d = 512:
- i = 0:   T ≈ 6.28      ← 高频,很快重复
- i = 128: T ≈ 316.23
- i = 256: T ≈ 15811.39  ← 低频,周期很长

虽然单个维度会重复,但所有维度的组合几乎是唯一的。

1、为什么要位置编码

​ transformer结构中在计算不同token的attention时不包含位置信息,即在计算某个token和其他token的相关性时与其位置无关,例如“天天向上”计算出来的attention score和“天向天上”计算出来的是一样的,因此需要再attention计算时添加上位置信息

2、常见位置编码

2.1 直接使用每个token的位置信息

​ 第一个token的位置为0,第二个token的位置为1,以此类推,如下所示,其中,m表示token的位置,i表示每个位置的token在embedding后的第i个元素

xm=x+posmx_m = x + pos_m

问题:

  • 数值过大,需要归一化,而且embedding后,位置信息更复杂
  • 无法泛化到训练时未见的长度,例如位置n是个未知长度,无法根据已有的位置信息计算出n的位置信息

2.2 三角函数位置编码

m位置的token,其位置信息为:

posm=[sin(mθ)cos(mθ)]pos_m= \begin{bmatrix} sin(mθ) \\ cos(mθ) \end{bmatrix}

同理,n位置的信息也可按照上述类似表达,假如

n=m+Δtn = m + Δt

由于

sin(a+b)=sinacosb+cosasinbcos(a+b)=cosacosbsinasinbsin(a + b) = sina * cosb + cosa * sinb \\ cos(a + b) = cosa * cosb - sina * sinb

则,n位置的信息也可经过m位置的三角变换可获得,即:

posn=[cos((m+Δt)θ)sin((m+Δt)θ)]=[cosΔtsinΔtsinΔtcosΔt][cos(mθ)sin(mθ)]=RΔt[cos(mθ)sin(mθ)]pos_n = \begin{bmatrix} cos((m + Δt)θ) \\ sin((m + Δt)θ) \end{bmatrix} = \begin{bmatrix} cosΔt &sinΔt \\ -sinΔt &cosΔt \end{bmatrix} * \begin{bmatrix} cos(mθ) \\ sin(mθ) \end{bmatrix} = \mathcal{R}_{Δt} * \begin{bmatrix} cos(mθ) \\ sin(mθ) \end{bmatrix}

m位置的token经过embedding到hidden_size(hidden_size=d)后,其位置信息表达如下所示:

posm,2i=sin(mθi)posm,2i+1=cos(mθi)θi=100002i/d,i=0,1,2...,d/21pos_{m,2i} = sin(mθ_i) \\ pos_{m,2i+1} = cos(mθ_i) \\ θ_i = 10000^{-2i/d}, i = 0,1,2...,d/2-1

从上述公式可以看出

  • 三角函数的值域[-1,1],不存在数值过大的问题,不需要额外归一化

  • 具有外推性,可以泛化到未见的长度,任意位置的位置信息都可以根据已有的位置信息来计算

如果将上述的位置编码经过Attention计算,即:

qmknT=[(xm+posm)WQ][(xn+posn)WK]T=[xmWQ+posmWQ][WKTxnT+WKTposnT]q_{m} * k_{n}^{T} = [(x_{m} + pos_{m}) * W_{Q} ] * [(x_{n} + pos_{n}) * W_{K}]^{T} \\ = [x_{m}W_{Q} + pos_{m}W_{Q}] *[W_{K}^{T}x_{n}^{T} + W_{K}^{T}pos_{n}^{T}]

其中位置信息m,n相关的部分为:

posmWQWKTposnTpos_{m}W_{Q}W_{K}^{T}pos_{n}^{T}

从这来看,位置信息不再单独受m、n处的位置编码影响了,而是引入了线性变化

3、旋转位置编码

基于上述问题,是否可以在attention计算时,直接将位置信息融进去,这就引入了旋转位置编码

传统位置编码是加法式:

Attention(q,k)=[(xm+posm)WQ][(xn+posn)WK]TAttention(q,k) = [(x_{m} + pos_{m}) * W_{Q} ] * [(x_{n} + pos_{n}) * W_{K}]^{T}

而旋转位置编码是乘法式:

Attention(q,k)=[xmWQRm,i][xnWKRn,i]TAttention(q,k) = [x_{m} * W_{Q} * \mathcal{R}_{m, i}] * [x_{n} * W_{K} * \mathcal{R}_{n, i}]^{T}

其中,旋转矩阵对应为:

Rm,i=[[cos(mθ0)sin(mθ0)sin(mθ0)cos(mθ0)]...0.........0...[cos(mθd/21)sin(mθd/21)sin(mθd/21)cos(mθd/21)]]\mathcal{R}_{m, i} = \begin{bmatrix} \begin{bmatrix} cos(mθ_{0}) &-sin(mθ_{0}) \\ sin(mθ_{0}) &cos(mθ_{0}) \end{bmatrix} &... &0 \\ ... &... &... \\ 0 &... &\begin{bmatrix} cos(mθ_{d/2-1}) &-sin(mθ_{d/2-1}) \\ sin(mθ_{d/2-1}) &cos(mθ_{d/2-1}) \end{bmatrix} \end{bmatrix}

则旋转位置编码的Attention计算可以表示为:

Attention(q,k)=[qmRm,i][knRn,i]T=qmRm,iRn,iTknT=qmRnm,iknTAttention(q,k) = [q_{m} * \mathcal{R}_{m, i}] * [k_{n} * \mathcal{R}_{n, i}]^{T} = q_{m} * \mathcal{R}_{m, i} * \mathcal{R}_{n, i}^{T} * k_{n}^{T} = q_{m} * \mathcal{R}_{n-m, i} * k_{n}^{T}

从上述公式可以看出,位置信息仅跟m、n相关

4、网络中旋转位置编码的代码实现

def apply_rope(x, pos, dim, theta_base=10000):
    """应用 RoPE 旋转"""
    theta = pos / (theta_base ** (2 * np.arange(dim // 2) / dim))
    cos_t = np.cos(theta)
    sin_t = np.sin(theta)

    x_rot = np.zeros_like(x)
    for i in range(dim // 2):
        x_rot[2*i]   = cos_t[i] * x[2*i] - sin_t[i] * x[2*i+1]
        x_rot[2*i+1] = sin_t[i] * x[2*i] + cos_t[i] * x[2*i+1]

    return x_rot

5、FAQ

1、三角函数本身存在周期性,会不会导致位置信息重复的问题

对于第i维度(i = 0,1...,d/2-1),

θi=100002i/dθ_i = 10000^{-2i/d}

Ti=2π/θ=2π100002i/dT_{i} = 2π/θ = 2π*10000^{2i/d}
对于 d = 512:
- i = 0:   T ≈ 6.28      ← 高频,很快重复
- i = 128: T ≈ 316.23
- i = 256: T ≈ 15811.39  ← 低频,周期很长

虽然单个维度会重复,但所有维度的组合几乎是唯一的。