transformer 中的 RoPE 位置编码

404 阅读3分钟

总览

为了让 Transformer 模型获知输入序列的位置关系,需要想办法把位置信息嵌入到序列中。比较经典的方法有正余弦位置编码和可学习位置编码。

旋转位置编码 RoPE(Rotary Position Embedding)是目前最为流行的位置编码方案,广泛应用于各种 transformer 模型,例如 Llama 和 Qwen。效果好,有较强外推能力,是位置编码中的豪杰。

绝对位置编码与相对位置编码

Transformer 不能直接感知到字符所在位置,也不能感受到字符之间的位置关系。

绝对位置编码,可以辅助 Transformer 感知字符所在的绝对位置。每个位置分配一个唯一的编码向量,模型可以学习到与绝对位置相关的知识。主要包括训练式位置编码与 Sinusoidal 位置编码。

相对位置编码,可以辅助 Transformer 感知各个字符间的相对距离。感知相对位置关系对模型的外推和长文理解能力尤为重要。

Sinusoidal 位置编码其实理论上有一定的相对位置编码作用:两个相同的词向量,其内积结果大致满足距离越近则越高,距离越远则越低(远程衰减性质)。

但通过数学推导可获知,Sinusoidal 位置编码无法区分相对的前后位置关系。这一点太致命了。

RoPE 从设计之初就同时考虑到绝对位置编码与相对位置编码。利用巧妙的思路,通过绝对位置编码的方式实现了相对位置编码。

利用旋转矩阵

假设词向量 embedding 维度为 2,则一个词向量可表示为 q=[q0 q1]Tq=[q_0\ q_1]^T

以下这个 M(θ)M(\theta) 是旋转矩阵,左乘到 qq 可以让 qq 在二维空间逆时针旋转 mθm\theta 角度,且模长不变。

M(θ)=[cosmθsinmθsinmθcosmθ]M(\theta)= \left [ \begin{matrix} \cos m\theta& -\sin m\theta \\ \sin m\theta& \cos m\theta \\ \end{matrix} \right ]

其中 mm 代表词向量 qq 的所在位置。MqM·q 就是对 qq 逆时针旋转 mθm\theta 角度。

Mq=[cosmθsinmθsinmθcosmθ][q0q1]M·q= \left [ \begin{matrix} \cos m\theta& -\sin m\theta \\ \sin m\theta& \cos m\theta \\ \end{matrix} \right ] \left [ \begin{matrix} q_0\\ q_1 \end{matrix} \right]

MqM·q 通过对 qq 进行 “旋转” 添加了绝对位置编码。而不同位置的两个 MqM·q 的内积结果显然与 “旋转角度” 相关,角度之差与相对位置有关系,这就有了相对位置编码的作用。

接下会提到如何扩展到 n 维,以及 θ\theta 的选取。

多维推广

词向量的维度可比 2 大多了。要将上一节的旋转矩阵进行推广,最直接的思路是,把维度两两分组。假设词向量维度为 dddd 为 2 的倍数):

Mq=[cosmθ0sinmθ000sinmθ0cosmθ00000cosmθd/21sinmθd/2100sinmθd/21cosmθd/21][q0q1qd2qd1]M·q= \left [ \begin{matrix} \cos m\theta_0 & -\sin m\theta_0 & \cdots & 0 & 0 \\ \sin m\theta_0 & \cos m\theta_0 & \cdots & 0 & 0 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & \cdots & \cos m\theta_{d/2-1} & -\sin m\theta_{d/2-1} \\ 0 & 0 & \cdots & \sin m\theta_{d/2-1} & \cos m\theta_{d/2-1} \\ \end{matrix} \right ] \left [ \begin{matrix} q_0\\ q_1\\\vdots\\q_{d-2}\\q_{d-1} \end{matrix} \right]

如此,dd 个维度被分为了 d/2d/2 组,θ\theta 也有了分别 d/2d/2 个取值。

在代码实现上可以用以下思路来等效计算 MqM·q,节省资源:

Mq=[q0q1q2q3qd2qd1][cosmθ0cosmθ0cosmθ1cosmθ1cosmθd/21cosmθd/21]+[q1q0q3q2qd1qd2][sinmθ0sinmθ0sinmθ1sinmθ1sinmθd/21sinmθd/21]M·q= \left [ \begin{matrix} q_0\\ q_1\\q_2\\q_3\\\vdots\\q_{d-2}\\q_{d-1} \end{matrix} \right] \odot \left [ \begin{matrix} \cos m\theta_0\\ \cos m\theta_0\\ \cos m\theta_1\\ \cos m\theta_1\\ \vdots\\ \cos m\theta_{d/2 - 1}\\ \cos m\theta_{d/2 - 1}\\ \end{matrix} \right] + \left [ \begin{matrix} -q_1\\ q_0\\-q_3\\q_2\\\vdots\\-q_{d-1}\\q_{d-2} \end{matrix} \right] \odot \left [ \begin{matrix} \sin m\theta_0\\ \sin m\theta_0\\ \sin m\theta_1\\ \sin m\theta_1\\ \vdots\\ \sin m\theta_{d/2 - 1}\\ \sin m\theta_{d/2 - 1}\\ \end{matrix} \right]

那么,这 d/2d/2θi\theta_i 的值该如何选取?

θ\theta 的选取

RoPE 使用了类似 Sinusoidal 位置编码的思路:

θi=1100002i/d\theta_i=\frac{1}{10000^{2i/d}}

这样会带来一个效果:词向量的低维度被施加了短周期的编码,更擅长捕捉较近相对位置的词的关系信息;而高维度被施加了长周期的编码,更适合发现全局或长距离的依赖关系。

公式中的 1000010000 被称为 base。可以通过增大 base 来实现更长的外推能力,例如 LLaMa 2 Long 的 base 取在了 500000500000,支持最大上下文长度为 32k。但更大的 base 会降低注意力的远程衰减,让模型学习难度提升,需要注意。

总结

通过一段时间学习后对 RoPE 有了比较全面的认识。之前调试 Gemma 看到的奇怪的位置编码计算方式总算有了解答,也对 base 参数有了更好的理解。

RoPE 真妙啊。刚开始学习的时候还在想为什么 2022 年就提出来的 Alibi 鲜有人用,各个开源大模型不约而同都是 RoPE。现在看起来是 RoPE 足够优秀,也能在长上下文方面卷出新高度吧。

参考来源