大模型之旋转位置编码RoPE

124 阅读5分钟

大模型之旋转位置编码RoPE

一、产生背景

1.计算attention是本身是和位置无关的,所以需要加上位置的信息

传统attention计算时候。乱序的句子,和不乱序的句子,最后计算出的各个字的attention值是一样的。

2.transformer中绝对位置编码不好用,因为只是在input的时候加入了位置编码,后续层度堆上去后,可能会忽略掉前面的信息。其次这里的位置编码并不能满足具体的语义。

3.加入位置信息的后,需要满足一定条件,才能使用

  • f(m-n) = f(m) * f(n),m和n是句子中不同的位置的词。因为计算相关性的时候要计算内积,但是计算距离的时候是减法。

eg:今天上海天气的温度如何?请告诉我上海具体的温度。其中分词为【今天/上海/天气/的/温度/是多少/?/请/告诉/我/上海/具体/的/温度/。】,此时前半句【天气】和【温度】和后半句里【天气】和【温度】的相关性需要是一样的,但是前半句里的【天气】和后半句里的【天气】求出来的attention值需要是不一样的,【温度】也是同理。

二、解决方案

  • 向量旋转

    A向量旋转a角度后和B向量旋转b角度后,他们进行求相关性后会出现a-b。可以体现a和b之间的距离。

三、公式推导

1)旋转公式

一个二维向量(a, b)逆时针旋转θ度后,具体计算如下:

(ab)(cosθsinθsinθcosθ)=(acosθbsinθasinθbcosθ) \begin{pmatrix} a & b \end{pmatrix} \begin{pmatrix} cosθ & sinθ \\ -sinθ & cosθ \end{pmatrix} = \begin{pmatrix} acosθ - bsinθ & asinθ - bcosθ \end{pmatrix}

2)性质

  • 旋转矩阵转置后,再进行计算相当于顺时针旋转θ

    R(θ)T=R(θ)R(θ)^T = R(-θ)

  • 旋转两次,等于旋转一次

    R(θ1)R(θ2)=R(θ1+θ2)R(θ_1)R(θ_2)=R(θ1+θ2)

3)推导

token的向量为二维的,有两个token,他们在句子中的位置分别是i和j,此时对位置i和j的token对应的向量分别逆时针旋转iθ度和jθ度,然后再进行求相关性,则有如下公式:

QiKjT=XiWQR(iθ)(XjWKR(jθ))T           =XiWQR(iθ)R(jθ)TWKTXjT           =XiWQR(iθ)R(jθ)WKTXjT           =XiWQR((ij)θ)WKTXjT           =g(Xi,Xj,ij)Q_iK_j^T = X_iW_QR(iθ)(X_jW_KR(jθ))^T\\ \ \ \ \ \ \ \ \ \ \ \ =X_iW_QR(iθ)R(jθ)^TW_K^TX_j^T\\ \ \ \ \ \ \ \ \ \ \ \ = X_iW_QR(iθ)R(-jθ)W_K^TX_j^T\\ \ \ \ \ \ \ \ \ \ \ \ = X_iW_QR((i-j)θ)W_K^TX_j^T\\ \ \ \ \ \ \ \ \ \ \ \ = g(X_i,X_j,i-j)

4)拓展到多维度

高维向量的旋转可以分解为多个子二维向量的旋转。

eg:向量(a,b,c,d)在句子中的位置为i,此时对向量进行分解成(a b)和(c d)两个二维向量,此时进行分别逆时针旋转θ1和θ2。

公式为

(abcd)(cosiθ1siniθ100siniθ1cosiθ10000cosiθ2siniθ200siniθ2cosiθ2)=(acosθ1bsinθ1asinθ1bcosθ1ccosθ2dsinθ2csinθ2dcosθ2) \begin{pmatrix} a & b & c & d \end{pmatrix} \begin{pmatrix} cosiθ_1 & siniθ_1 & 0 & 0 \\ -siniθ_1 & cosiθ_1 & 0 & 0 \\ 0 & 0 & cosiθ_2 & siniθ_2\\ 0 & 0 & -siniθ_2 & cosiθ_2 \end{pmatrix} = \begin{pmatrix} acosθ_1 - bsinθ_1 & asinθ_1 - bcosθ_1 & ccosθ_2 - dsinθ_2 & csinθ_2 - dcosθ_2 \end{pmatrix}

其中旋转矩阵可以看成一个更高维度的向量逆时针旋转了iΘ角度。

R(iΘ)=(R(iθ1)00R(iθ2))R(iΘ) = \begin{pmatrix} R(iθ_1) & 0 \\ 0 & R(iθ_2) \end{pmatrix}

此时可以把第3步中的推导公式中的θ替换成Θ,然后我们可以证明R(iΘ)R(jΘ)T=R((ij)Θ)R(iΘ)R(jΘ)^T=R((i-j)Θ),此时高维向量的旋转后,进行求相似度后,也会产生距离之间的关系,即(i-j)。

四、代码实现

import torch.nn as nn
import torch


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=1024):
        super(RotaryEmbedding, self).__init__()
        self.dim = dim  # 向量的维度
        self.max_seq_len = max_seq_len  # 序列最大长度
        inv_freq = 1.0 / (
                    10000 ** (torch.arange(0, dim, 2).float() / dim))  # [dim/2,] # 各个分量的旋转角度(高维度向量的旋转只能依靠不同二维分量的旋转)
                                                                        # 这个值表示了多维向量的旋转角度,不过是用向量表示的。

        t = torch.arange(max_seq_len).float().unsqueeze(1)  # [max_seq_len,1] # 表示序列内token的位置索引,从0到max_seq_len-1
        freqs = t @ inv_freq.unsqueeze(
            0)  # [max_seq_len,1] 和 [1,dim/2] 矩阵相乘 ==> [max_seq_len,dim/2] 表示当前位置上各个二维分量上应该旋转的角度。这里是以二维分量为粒度。
                # 不同位置的token,他们的旋转角度不一样,是乘倍数的概念。
        freqs = torch.cat((freqs, freqs),
                          dim=-1)  # [max_seq_len,dim/2] ---> [max_seq_len,dim]  表示当前位置上各个二维分量上应该旋转的角度,这里是以每个值为粒度。
        # freqs从这里就可以看出来是如何划分各个二维向量的。
        # eg:dim=4时,四维向量为:(a1, a2, b1, b2),(a1,b1)是一对,他们对应一个旋转角度。  (a2,b2)是一对,他们对应一个旋转角度。
        self.register_buffer("cos_cached", freqs.cos())  # 注册cos值。[max_seq_len,dim]
        self.register_buffer("sin_cached", freqs.sin())  # 注册sin值。[max_seq_len,dim]

    def forward(self, q, k):
        cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)  # cos_cached是[max_seq_len,dim] ---->cos为[1,T,E], E=dim  拿到T个位置上各个分量的旋转角度。
        sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)  # [1,T,E]
        return apply_rotate_pos_emb(q, k, cos, sin)


'''
一分为2。eg,二维向量为(a,b), 则此时x1为各个二维向量的前一半,也就是a。 x2为各个向量的后一半,也就是b
这里没有a,b,a,b,a,b.......这样分,而是先a后b的分法
'''


def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


'''
旋转矩阵R为
    cosθ    sinθ
    -sinθ   cosθ
这里对每个二维向量进行旋转,也就是对每个(a,b)二维向量 乘以矩阵R,使得旋转对应角度θ
根据矩阵相乘得到旋转后的向量(acosθ-bsinθ,asinθ+bcosθ)为(A,B)

(q * cos) + (rotate_half(q) * sin) 的解释如下:
eg:q的维度为4时,也就是(a,a,b,b) 此时1 3 的a b为一对, 2 4的a b为一对,他们对应的θ一样的。
        (q * cos)为:(acosθ1,acosθ2,bcosθ1,bcosθ2)
        rotate_half(q)为:(-b,-b,a,a)
        (rotate_half(q) * sin) 为:(-bsinθ1,-bsinθ2,asinθ1,asinθ2)
    结果(q * cos) + (rotate_half(q) * sin) 为:
        (acosθ1-bsinθ1, acosθ2-bsinθ2, bcosθ1+asinθ1, bcosθ2+asinθ2)
        即:
        (A1, A2, B1, B2)
      

'''


def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
    cos = cos.unsqueeze(unsqueeze_dim)  # [1,T,1,E]
    sin = sin.unsqueeze(unsqueeze_dim)  # [1,T,1,E]

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed


if __name__ == '__main__':
    q = torch.randn(2, 3, 4, 6)  # N T H E
    k = torch.randn(2, 3, 4, 6)  # N T H E
    rope = RotaryEmbedding(dim=6)
    res = rope(q, k)  # q和k经过旋转位置编码后,就带了位置信息了
    print(res[0], res[0].shape)
    print(res[1], res[1].shape)