1、为什么要位置编码
transformer结构中在计算不同token的attention时不包含位置信息,即在计算某个token和其他token的相关性时与其位置无关,例如“天天向上”计算出来的attention score和“天向天上”计算出来的是一样的,因此需要再attention计算时添加上位置信息
2、常见位置编码
2.1 直接使用每个token的位置信息
## 1、为什么要位置编码
transformer结构中在计算不同token的attention时不包含位置信息,即在计算某个token和其他token的相关性时与其位置无关,例如“天天向上”计算出来的attention score和“天向天上”计算出来的是一样的,因此需要在attention计算时添加上位置信息
2、常见位置编码
2.1 直接使用每个token的位置信息
第一个token的位置为0,第二个token的位置为1,以此类推,如下所示,其中,m表示token的位置,i表示每个位置的token在embedding后的第i个元素
xm,i=x+posm,i
问题:
- 数值过大,需要归一化,而且embedding后,位置信息更复杂
- 无法泛化到训练时未见的长度,例如位置n是个未知长度,无法根据已有的位置信息计算出n的位置信息
2.2 三角函数位置编码
m位置的token,其位置信息为:
posm=[sin(mθ)cos(mθ)]
同理,n位置的信息也可按照上述类似表达,假如
由于
sin(a+b)=sina∗cosb+cosa∗sinbcos(a+b)=cosa∗cosb−sina∗sinb
则,n位置的信息也可经过m位置的三角变换可获得,即:
posn=[sin((m+Δt)θ)cos((m+Δt)θ)]=[cosΔt−sinΔtsinΔtcosΔt]∗[sin(mθ)cos(mθ)]=RΔt∗[sin(mθ)cos(mθ)]
m位置的token经过embedding到hidden_size(hidden_size=d)后,其位置信息表达如下所示:
posm,2i=sin(mθi)posm,2i+1=cos(mθi)θi=10000−2i/d,i=0,1,2...,d/2−1
从上述公式可以看出
但是有个问题,如果将上述的位置编码经过Attention计算,即:
qm∗knT=[(xm+posm)∗WQ]∗[(xn+posn)∗WK]T=[xmWQ+posmWQ]∗[WKTxnT+WKTposnT]
其中位置信息m,n相关的部分为:
posmWQWKTposnT
从这来看,位置信息不再单独受m、n处的位置编码影响了,而是引入了Q、K矩阵的线性变化
3、旋转位置编码
基于上述问题,是否可以在attention计算时,直接将位置信息融进去,这就引入了旋转位置编码
传统位置编码是加法式:
Attention(q,kT)=[(xm+posm)∗WQ]∗[(xn+posn)∗WK]T
而旋转位置编码是乘法式:
Attention(q,kT)=[xm∗WQ∗Rm,i]∗[xn∗WK∗Rn,i]T
其中,旋转矩阵对应为:
Rm,i=[cos(mθ0)sin(mθ0)−sin(mθ0)cos(mθ0)]...0.........0...[cos(mθd/2−1)sin(mθd/2−1)−sin(mθd/2−1)cos(mθd/2−1)]
则旋转位置编码的Attention计算可以表示为:
Attention(q,k)=[qm∗Rm,i]∗[kn∗Rn,i]T=qm∗Rm,i∗Rn,iT∗knT=qm∗Rn−m,i∗knT
从上述公式可以看出,位置信息仅跟m、n相关
4、网络中旋转位置编码的代码实现
def apply_rotary_emb_simple(x, freqs):
"""
旋转矩阵公式:
[cos(θ), -sin(θ)] [x1] [x1*cos(θ) - x2*sin(θ)]
[sin(θ), cos(θ)] * [x2] = [x1*sin(θ) + x2*cos(θ)]
Args:
x: 输入张量, 形状 (batch, seq_len, n_heads, head_dim)
freqs: 旋转角度, 形状 (seq_len, head_dim//2)
Returns:
旋转后的张量
"""
x1 = x[..., :x.shape[-1]//2]
x2 = x[..., x.shape[-1]//2:]
cos = freqs.cos().unsqueeze(0).unsqueeze(2)
sin = freqs.sin().unsqueeze(0).unsqueeze(2)
x_rotated = torch.cat([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos
], dim=-1)
return x_rotated
5、FAQ
1、三角函数本身存在周期性,会不会导致位置信息重复的问题
对于第i维度(i = 0,1...,d/2-1),
θi=10000−2i/d
则
Ti=2π/θ=2π∗100002i/d
对于 d = 512:
- i = 0: T ≈ 6.28 ← 高频,很快重复
- i = 128: T ≈ 316.23
- i = 256: T ≈ 15811.39 ← 低频,周期很长
虽然单个维度会重复,但所有维度的组合几乎是唯一的。
第一个token的位置为0,第二个token的位置为1,以此类推,如下所示,其中,m表示token的位置,i表示每个位置的token在embedding后的第i个元素
xm=x+posm
问题:
- 数值过大,需要归一化,而且embedding后,位置信息更复杂
- 无法泛化到训练时未见的长度,例如位置n是个未知长度,无法根据已有的位置信息计算出n的位置信息
2.2 三角函数位置编码
m位置的token,其位置信息为:
posm=[sin(mθ)cos(mθ)]
同理,n位置的信息也可按照上述类似表达,假如
由于
sin(a+b)=sina∗cosb+cosa∗sinbcos(a+b)=cosa∗cosb−sina∗sinb
则,n位置的信息也可经过m位置的三角变换可获得,即:
posn=[sin((m+Δt)θ)cos((m+Δt)θ)]=[cosΔt−sinΔtsinΔtcosΔt]∗[sin(mθ)cos(mθ)]=RΔt∗[sin(mθ)cos(mθ)]
m位置的token经过embedding到hidden_size(hidden_size=d)后,其位置信息表达如下所示:
posm,2i=sin(mθi)posm,2i+1=cos(mθi)θi=10000−2i/d,i=0,1,2...,d/2−1
从上述公式可以看出
如果将上述的位置编码经过Attention计算,即:
qm∗knT=[(xm+posm)∗WQ]∗[(xn+posn)∗WK]T=[xmWQ+posmWQ]∗[WKTxnT+WKTposnT]
其中位置信息m,n相关的部分为:
posmWQWKTposnT
从这来看,位置信息不再单独受m、n处的位置编码影响了,而是引入了线性变化
3、旋转位置编码
基于上述问题,是否可以在attention计算时,直接将位置信息融进去,这就引入了旋转位置编码
传统位置编码是加法式:
Attention(q,k)=[(xm+posm)∗WQ]∗[(xn+posn)∗WK]T
而旋转位置编码是乘法式:
Attention(q,k)=[xm∗WQ∗Rm,i]∗[xn∗WK∗Rn,i]T
其中,旋转矩阵对应为:
Rm,i=[cos(mθ0)sin(mθ0)−sin(mθ0)cos(mθ0)]...0.........0...[cos(mθd/2−1)sin(mθd/2−1)−sin(mθd/2−1)cos(mθd/2−1)]
则旋转位置编码的Attention计算可以表示为:
Attention(q,k)=[qm∗Rm,i]∗[kn∗Rn,i]T=qm∗Rm,i∗Rn,iT∗knT=qm∗Rn−m,i∗knT
从上述公式可以看出,位置信息仅跟m、n相关
4、网络中旋转位置编码的代码实现
def apply_rotary_emb_simple(x, freqs):
"""
旋转矩阵公式:
[cos(θ), -sin(θ)] [x1] [x1*cos(θ) - x2*sin(θ)]
[sin(θ), cos(θ)] * [x2] = [x1*sin(θ) + x2*cos(θ)]
Args:
x: 输入张量, 形状 (batch, seq_len, n_heads, head_dim)
freqs: 旋转角度, 形状 (seq_len, head_dim//2)
Returns:
旋转后的张量
"""
x1 = x[..., :x.shape[-1]//2]
x2 = x[..., x.shape[-1]//2:]
cos = freqs.cos().unsqueeze(0).unsqueeze(2)
sin = freqs.sin().unsqueeze(0).unsqueeze(2)
x_rotated = torch.cat([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos
], dim=-1)
return x_rotated
5、FAQ
1、三角函数本身存在周期性,会不会导致位置信息重复的问题
对于第i维度(i = 0,1...,d/2-1),
θi=10000−2i/d
则
Ti=2π/θ=2π∗100002i/d
对于 d = 512:
- i = 0: T ≈ 6.28 ← 高频,很快重复
- i = 128: T ≈ 316.23
- i = 256: T ≈ 15811.39 ← 低频,周期很长
虽然单个维度会重复,但所有维度的组合几乎是唯一的。
1、为什么要位置编码
transformer结构中在计算不同token的attention时不包含位置信息,即在计算某个token和其他token的相关性时与其位置无关,例如“天天向上”计算出来的attention score和“天向天上”计算出来的是一样的,因此需要再attention计算时添加上位置信息
2、常见位置编码
2.1 直接使用每个token的位置信息
第一个token的位置为0,第二个token的位置为1,以此类推,如下所示,其中,m表示token的位置,i表示每个位置的token在embedding后的第i个元素
xm=x+posm
问题:
- 数值过大,需要归一化,而且embedding后,位置信息更复杂
- 无法泛化到训练时未见的长度,例如位置n是个未知长度,无法根据已有的位置信息计算出n的位置信息
2.2 三角函数位置编码
m位置的token,其位置信息为:
posm=[sin(mθ)cos(mθ)]
同理,n位置的信息也可按照上述类似表达,假如
由于
sin(a+b)=sina∗cosb+cosa∗sinbcos(a+b)=cosa∗cosb−sina∗sinb
则,n位置的信息也可经过m位置的三角变换可获得,即:
posn=[cos((m+Δt)θ)sin((m+Δt)θ)]=[cosΔt−sinΔtsinΔtcosΔt]∗[cos(mθ)sin(mθ)]=RΔt∗[cos(mθ)sin(mθ)]
m位置的token经过embedding到hidden_size(hidden_size=d)后,其位置信息表达如下所示:
posm,2i=sin(mθi)posm,2i+1=cos(mθi)θi=10000−2i/d,i=0,1,2...,d/2−1
从上述公式可以看出
如果将上述的位置编码经过Attention计算,即:
qm∗knT=[(xm+posm)∗WQ]∗[(xn+posn)∗WK]T=[xmWQ+posmWQ]∗[WKTxnT+WKTposnT]
其中位置信息m,n相关的部分为:
posmWQWKTposnT
从这来看,位置信息不再单独受m、n处的位置编码影响了,而是引入了线性变化
3、旋转位置编码
基于上述问题,是否可以在attention计算时,直接将位置信息融进去,这就引入了旋转位置编码
传统位置编码是加法式:
Attention(q,k)=[(xm+posm)∗WQ]∗[(xn+posn)∗WK]T
而旋转位置编码是乘法式:
Attention(q,k)=[xm∗WQ∗Rm,i]∗[xn∗WK∗Rn,i]T
其中,旋转矩阵对应为:
Rm,i=[cos(mθ0)sin(mθ0)−sin(mθ0)cos(mθ0)]...0.........0...[cos(mθd/2−1)sin(mθd/2−1)−sin(mθd/2−1)cos(mθd/2−1)]
则旋转位置编码的Attention计算可以表示为:
Attention(q,k)=[qm∗Rm,i]∗[kn∗Rn,i]T=qm∗Rm,i∗Rn,iT∗knT=qm∗Rn−m,i∗knT
从上述公式可以看出,位置信息仅跟m、n相关
4、网络中旋转位置编码的代码实现
def apply_rope(x, pos, dim, theta_base=10000):
"""应用 RoPE 旋转"""
theta = pos / (theta_base ** (2 * np.arange(dim // 2) / dim))
cos_t = np.cos(theta)
sin_t = np.sin(theta)
x_rot = np.zeros_like(x)
for i in range(dim // 2):
x_rot[2*i] = cos_t[i] * x[2*i] - sin_t[i] * x[2*i+1]
x_rot[2*i+1] = sin_t[i] * x[2*i] + cos_t[i] * x[2*i+1]
return x_rot
5、FAQ
1、三角函数本身存在周期性,会不会导致位置信息重复的问题
对于第i维度(i = 0,1...,d/2-1),
θi=10000−2i/d
则
Ti=2π/θ=2π∗100002i/d
对于 d = 512:
- i = 0: T ≈ 6.28 ← 高频,很快重复
- i = 128: T ≈ 316.23
- i = 256: T ≈ 15811.39 ← 低频,周期很长
虽然单个维度会重复,但所有维度的组合几乎是唯一的。