旋转式位置编码(RoPE)最早是论文[1]提出的一种能够将相对位置信息依赖集成到 self-attention 中并提升 transformer 架构性能的位置编码方式。而目前很火的 LLaMA 模型也是采用该位置编码方式。
基本概念
首先论文中定义一个长度为 N 的输入序列为
其中 wi 表示输入序列中第 i 个 token,而输入序列 SN 对应的 embedding 表示为:
其中 xi 表示第 i 个 token wi 对应的 d 维词嵌入向量。
接着在做 self-attention 之前,会用词嵌入向量计算 q, k, v 向量同时加入位置信息,函数公式表达如下:
旋转位置编码
关于旋转位置编码需要理解以下几个点:
1.用绝对位置编码表示相对位置
比如要在计算机里表达我喜欢你这四个字是可以用一个矩阵表示,我喜欢你这四个字,可以在计算机里用4x2的矩阵表示,但实际使用中必定会使用更高维度的向量,4x128,4x256.
RoPE的假设里,先将qm,kn看作是复数,比如我字的复数的表征为0.45 + i0.78,Rope的步骤就是在第一步里进行,即q,k进行内积的时候,先分别对做一次变换:
变换之后,这里就分别有了位置m,n的绝对位置编码信息。将两个变换之后的数据做内积(<>表示内积计算):
备注累积 两个复数的内积等于一个复数乘以另外一个复数的共轭 复数的共轭简单点说: 实部相同,虚部相反。
由此在attention的计算当中,通过 我们获得了
q中第m个 token和k中第n个token的相对位置信息。
理论上推导到此结束,剩下就是实际编码的过程,
其中 表示如下:
备注: 就是大名鼎鼎的欧拉(Euler)公式
则对公式1的变换的公式可以改为:
2.代码实现
比较常见的两种位置编码实现分别是llama和palm对ROPE的实现
llama 中 rope 实现
ref: LLaMA中ROPE位置编码实现源码解析
import torch
def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0):
'''
计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx
:param dim: q,k,v的最后一维,一般为emb_dim/head_num
:param end: 句长length
:param constant: 这里指10000
:return:
复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t))
'''
# freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta
# 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)]
freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2]
# 计算m
t = torch.arange(end, device=freqs.device) # [length]
# 计算m*theta
freqs = torch.outer(t, freqs).float() # [length, d/2]
# freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1
# 计算cos(m*theta)+j*sin(m*theta)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
# freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))]
# 其中j为虚数单位, m=0,1,...,length-1
return freqs_cis # [length, d/2]
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2)
return freqs_cis.view(*shape) # [1, length, 1, d/2]
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor,):
# 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数
# xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)]
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2]
# 同样的,xk_:[k0+j*k1, k2+j*k3, ..., k(d-2)+j*k(d-1)]
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2]
# 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下
# (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0))
# 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0)
# 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部
# 最后经flatten函数将维度拉平,即[bs, length, head, d]
# 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d]
# 即为新生成的q
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
if __name__=='__main__':
# (bs, length, head, d)
q = torch.randn((2, 10, 12, 32)) # q=[q0, q1, .., qd-1]
k = torch.randn((2, 10, 12, 32))
v = torch.randn((2, 10, 12, 32))
freqs_cis= precompute_freqs_cis(dim=32, end=10, constant= 10000.0)
# print(freqs_cis.detach().numpy())
q_new, k_new = apply_rotary_emb(xq=q, xk=k, freqs_cis=freqs_cis)
print()
palm 中 rope 实现
以下代码不是原版palm 的实现方式,来自于modeling_baichuan.py 但是大同小异。
_init_ 中的计算主要是为了缓存,当seq_len<max_position_embeding 的时候,就不需要重新计算了
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
# freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta
# 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)]
# shape:d/2
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
# 计算句子长度
#shape max_position_embeddings
self.max_seq_len_cached = max_position_embeddings
# 计算m
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
# freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_d/2],其中 m=0,1,...,length-1
# 计算m*theta [max_position_embeddings, d/2]
freqs = torch.outer(t, self.inv_freq)
# 结果形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1),m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1
#[max_position_embeddings, d]
emb = torch.cat((freqs, freqs), dim=-1)
# [1,1,max_position_embeddings, d]
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
# 如果实际长度(seq_len)大于预设的最大位置编码长度(max_position_embeddings),就重新缓存复数实部和虚部的值
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
elif self.cos_cached.device != x.device:
self.cos_cached = self.cos_cached.to(x.device)
self.sin_cached = self.sin_cached.to(x.device)
# 如说seq_len小于预设的最大位置编码长度,直接拿过去计算
return (
self.cos_cached[:, :, :seq_len, ...],
self.sin_cached[:, :, :seq_len, ...],
)
def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
return q_embed.to(q.dtype), k_embed.to(k.dtype)
if __name__=='__main__':
# (bs, head, length, d)
q = torch.randn((2, 12, 10, 32)) # q=[q0, q1, .., qd-1]
k = torch.randn((2, 12, 10, 32))
v = torch.randn((2, 12, 10, 32))
print('q:', q[0][0][0])
print('k:', k[0][0][0])
rotary_emb = RotaryEmbedding(dim=32)
cos, sin = rotary_emb(max_seq_len=10, device=torch.device('cpu')) # [length, d]
q_new, k_new = apply_rotary_pos_emb(q, k., cos, sin,position_ids)
print()
————————————————
版权声明:本文为CSDN博主「Bingoyear」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/angel_hben/article/details/132489588
验证就会发现同样的输入,这两者得到的输出是不一致的。
原因在于两者将输入拆成两个部分的方式不一致,llama将最后一维度dim拆成(dim//2, 2),而palm将最后一维度dim拆成(2, dim//2),熟悉torch的朋友知道这两者差别很大,苏神是将原始向量相邻的两两为一组,而palm则是将前1/2和后1/2分为两组。
那到底谁是对的呢?抱着这个问题我仔细读了一下苏神关于Positional Embeddings的一系列博客,结论是都对,llama的实现和苏神原始的切分方式一致,且实测运算更快(可能是复数运算算子融合的缘故)。