大模型原理理解-位置编码

870 阅读7分钟

旋转式位置编码(RoPE)最早是论文[1]提出的一种能够将相对位置信息依赖集成到 self-attention 中并提升 transformer 架构性能的位置编码方式。而目前很火的 LLaMA 模型也是采用该位置编码方式。

基本概念

首先论文中定义一个长度为 N 的输入序列为

image.png

其中 wi 表示输入序列中第 i 个 token,而输入序列 SN 对应的 embedding 表示为:

image.png

其中 xi 表示第 i 个 token wi 对应的 d 维词嵌入向量。

接着在做 self-attention 之前,会用词嵌入向量计算 q, k, v 向量同时加入位置信息,函数公式表达如下: image.png

旋转位置编码

关于旋转位置编码需要理解以下几个点:

1.用绝对位置编码表示相对位置

比如要在计算机里表达我喜欢你这四个字是可以用一个矩阵表示,我喜欢你这四个字,可以在计算机里用4x2的矩阵表示,但实际使用中必定会使用更高维度的向量,4x128,4x256.

image.png

RoPE的假设里,先将qm,kn看作是复数,比如字的复数的表征为0.45 + i0.78,Rope的步骤就是在第一步里进行,即qk进行内积的时候,先分别对做一次变换:

image.png 变换之后,这里就分别有了位置m,n的绝对位置编码信息。将两个变换之后的数据做内积(<>表示内积计算):

image.png

备注累积 两个复数的内积等于一个复数乘以另外一个复数的共轭 复数的共轭简单点说: 实部相同,虚部相反

image.png 由此在attention的计算当中,通过 j(mn)θj(m-n)\theta 我们获得了q中第m个 token和k中第n个token的相对位置信息。

理论上推导到此结束,剩下就是实际编码的过程,

其中ejmθe^{jm\theta} 表示如下:

image.png

备注: eiθ=cosθ+isinθe^{iθ}=cosθ+isinθ就是大名鼎鼎的欧拉(Euler)公式

则对公式1的变换的公式可以改为:

image.png

ref: NLP升级打怪之路:Rope 旋转位置编码(1)

2.代码实现

image.png

image.png 比较常见的两种位置编码实现分别是llama和palm对ROPE的实现

llama 中 rope 实现

ref: LLaMA中ROPE位置编码实现源码解析

import torch

def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0):
    '''
    计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx
    :param dim: q,k,v的最后一维,一般为emb_dim/head_num
    :param end: 句长length
    :param constant: 这里指10000
    :return:
    复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t))
    '''
    # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta
    # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)]
    freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2]

    # 计算m
    t = torch.arange(end, device=freqs.device)  # [length]
    # 计算m*theta
    freqs = torch.outer(t, freqs).float()  # [length, d/2]
    # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1

    # 计算cos(m*theta)+j*sin(m*theta)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0),  cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))]
    # 其中j为虚数单位, m=0,1,...,length-1
    return freqs_cis # [length, d/2]

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2)
    return freqs_cis.view(*shape) # [1, length, 1, d/2]

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor,):
    # 先将xq维度变为[bs, length, head,  d/2, 2], 利用torch.view_as_complex转变为复数
    # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)]
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2]
    # 同样的,xk_:[k0+j*k1, k2+j*k3, ..., k(d-2)+j*k(d-1)]
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2]
    # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下
    # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0))
    # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0)
    # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部
    # 最后经flatten函数将维度拉平,即[bs, length, head, d]
    # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d]
    # 即为新生成的q

    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

if __name__=='__main__':
    # (bs, length, head, d)
    q = torch.randn((2, 10, 12, 32))  # q=[q0, q1, .., qd-1]
    k = torch.randn((2, 10, 12, 32))
    v = torch.randn((2, 10, 12, 32))
    freqs_cis= precompute_freqs_cis(dim=32, end=10, constant= 10000.0)
    # print(freqs_cis.detach().numpy())

    q_new, k_new = apply_rotary_emb(xq=q, xk=k, freqs_cis=freqs_cis)
    print()

image.png

palm 中 rope 实现

以下代码不是原版palm 的实现方式,来自于modeling_baichuan.py 但是大同小异。

_init_ 中的计算主要是为了缓存,当seq_len<max_position_embeding 的时候,就不需要重新计算了

class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta 
        # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)]
        # shape:d/2
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        
        # 计算句子长度
        #shape max_position_embeddings
        self.max_seq_len_cached = max_position_embeddings
        # 计算m
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
        
        # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_d/2],其中 m=0,1,...,length-1
        # 计算m*theta [max_position_embeddings, d/2]
        freqs = torch.outer(t, self.inv_freq)
        
        # 结果形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1),m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1
         #[max_position_embeddings, d]
        emb = torch.cat((freqs, freqs), dim=-1)
        # [1,1,max_position_embeddings, d]
        self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
        self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        # 如果实际长度(seq_len)大于预设的最大位置编码长度(max_position_embeddings),就重新缓存复数实部和虚部的值
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
            freqs = torch.outer(t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
            self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
        elif self.cos_cached.device != x.device:
            self.cos_cached = self.cos_cached.to(x.device)
            self.sin_cached = self.sin_cached.to(x.device)  
        # 如说seq_len小于预设的最大位置编码长度,直接拿过去计算    
        return (
            self.cos_cached[:, :, :seq_len, ...],
            self.sin_cached[:, :, :seq_len, ...],
        )
def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
    cos = cos_.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin_.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
    k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
    return q_embed.to(q.dtype), k_embed.to(k.dtype)
        
if __name__=='__main__':
    # (bs, head, length, d)
    q = torch.randn((2, 12, 10, 32))  # q=[q0, q1, .., qd-1]
    k = torch.randn((2, 12, 10, 32))
    v = torch.randn((2, 12, 10, 32))
    print('q:', q[0][0][0])
    print('k:', k[0][0][0])
    rotary_emb = RotaryEmbedding(dim=32)
    cos, sin  = rotary_emb(max_seq_len=10, device=torch.device('cpu'))  # [length, d]
    q_new, k_new = apply_rotary_pos_emb(q, k., cos, sin,position_ids)
    print()
————————————————
版权声明:本文为CSDN博主「Bingoyear」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/angel_hben/article/details/132489588        

image.png

验证就会发现同样的输入,这两者得到的输出是不一致的。

原因在于两者将输入拆成两个部分的方式不一致,llama将最后一维度dim拆成(dim//2, 2),而palm将最后一维度dim拆成(2, dim//2),熟悉torch的朋友知道这两者差别很大,苏神是将原始向量相邻的两两为一组,而palm则是将前1/2和后1/2分为两组。

那到底谁是对的呢?抱着这个问题我仔细读了一下苏神关于Positional Embeddings的一系列博客,结论是都对,llama的实现和苏神原始的切分方式一致,且实测运算更快(可能是复数运算算子融合的缘故)。

备注:[分析] ROPE的不同实现:llama&palm