RoPE 位置编码选择合适的 theta

173 阅读1分钟

习惯上位置编码的 theta 都会设置为 10000。但针对不同的问题或是极端情况,例如 ViT 序列相对较短,若仍然保持 10000 会极大浪费通道数量:序列不论在哪个位置上,得到的位置编码有很多通道值恒为 1。

用以下代码绘图直观感受。

import numpy as np
import matplotlib.pyplot as plt


def generate_rope_matrix(seq_len, dim, theta=10000):
    positions = np.arange(seq_len)[:, None]  # 序列位置 shape: [seq_len, 1]
    frequencies = np.power(theta, -2 * (np.arange(dim) // 2) / dim)  # shape: [dim]
    phase = positions * frequencies  # shape: [seq_len, dim]

    # RoPE 的 sin 和 cos 编码
    rope_matrix = np.zeros((seq_len, dim))
    rope_matrix[:, 0::2] = np.sin(phase[:, 0::2])  # 偶数列使用 sin
    rope_matrix[:, 1::2] = np.cos(phase[:, 1::2])  # 奇数列使用 cos

    return rope_matrix


def visualize_rope(seq_len, dim, theta=10000):
    rope_matrix = generate_rope_matrix(seq_len, dim, theta)

    # 只展示一个维度的信息(例如第 0 个维度)
    plt.figure(figsize=(10, 6))
    plt.imshow(rope_matrix.T, aspect='auto', cmap='coolwarm', extent=[0, seq_len, 0, dim])
    plt.colorbar(label='Value of Position Encoding')
    plt.xlabel('Along Length of Token')
    plt.ylabel('Individual Tokens')
    plt.title(f'RoPE Position Encoding Visualization (θ={theta})')
    plt.show()


seq_len = 100  # 序列长度
dim = 160  # 编码维度
theta_values = [10000, 100]  # 不同的 theta 值

for theta in theta_values:
    visualize_rope(seq_len, dim, theta)

theta 设为 10000 的情况如下,横轴代表从 0 到 100 的位置,纵轴表示不同通道。可见相当多的通道不包含有意义的信息。

image.png

而 theta 设为 100,在序列长度不超过 100 的情况下就显得合理多了。

image.png