习惯上位置编码的 theta 都会设置为 10000。但针对不同的问题或是极端情况,例如 ViT 序列相对较短,若仍然保持 10000 会极大浪费通道数量:序列不论在哪个位置上,得到的位置编码有很多通道值恒为 1。
用以下代码绘图直观感受。
import numpy as np
import matplotlib.pyplot as plt
def generate_rope_matrix(seq_len, dim, theta=10000):
positions = np.arange(seq_len)[:, None] # 序列位置 shape: [seq_len, 1]
frequencies = np.power(theta, -2 * (np.arange(dim) // 2) / dim) # shape: [dim]
phase = positions * frequencies # shape: [seq_len, dim]
# RoPE 的 sin 和 cos 编码
rope_matrix = np.zeros((seq_len, dim))
rope_matrix[:, 0::2] = np.sin(phase[:, 0::2]) # 偶数列使用 sin
rope_matrix[:, 1::2] = np.cos(phase[:, 1::2]) # 奇数列使用 cos
return rope_matrix
def visualize_rope(seq_len, dim, theta=10000):
rope_matrix = generate_rope_matrix(seq_len, dim, theta)
# 只展示一个维度的信息(例如第 0 个维度)
plt.figure(figsize=(10, 6))
plt.imshow(rope_matrix.T, aspect='auto', cmap='coolwarm', extent=[0, seq_len, 0, dim])
plt.colorbar(label='Value of Position Encoding')
plt.xlabel('Along Length of Token')
plt.ylabel('Individual Tokens')
plt.title(f'RoPE Position Encoding Visualization (θ={theta})')
plt.show()
seq_len = 100 # 序列长度
dim = 160 # 编码维度
theta_values = [10000, 100] # 不同的 theta 值
for theta in theta_values:
visualize_rope(seq_len, dim, theta)
theta 设为 10000 的情况如下,横轴代表从 0 到 100 的位置,纵轴表示不同通道。可见相当多的通道不包含有意义的信息。
而 theta 设为 100,在序列长度不超过 100 的情况下就显得合理多了。