Transformer位置编码器完整笔记
1. 背景介绍
1.1 为什么需要位置编码?
Transformer模型使用自注意力机制,但自注意力本身是位置无关的。这意味着:
- "猫追狗"和"狗追猫"对模型来说可能是一样的
- 模型无法理解词语在句子中的顺序关系
1.2 位置编码的作用
给每个位置一个独特的"指纹",让模型能够:
- 识别词语的绝对位置(第一个词、第二个词...)
- 理解词语的相对位置关系(相邻、相隔几个词)
2. 数学基础预备知识
2.1 三角函数基础
正弦函数 (sine function)
sin(x) = 对边/斜边
- 周期性:
sin(x + 2π) = sin(x) - 值域:[-1, 1]
- 图像:波浪形
余弦函数 (cosine function)
cos(x) = 邻边/斜边
- 周期性:
cos(x + 2π) = cos(x) - 值域:[-1, 1]
- 图像:波浪形,与正弦函数相位差π/2
重要恒等式
sin(a+b) = sin(a)cos(b) + cos(a)sin(b)
cos(a+b) = cos(a)cos(b) - sin(a)sin(b)
sin²(x) + cos²(x) = 1
2.2 指数与对数
指数函数
a^b = a × a × ... × a (b次)
- 特殊底数e:自然常数 ≈ 2.71828
对数函数
如果 a^b = c,那么 logₐ(c) = b
-
常用性质:
text
log(ab) = log(a) + log(b) log(a/b) = log(a) - log(b) log(a^b) = b × log(a)
指数对数关系
a^b = e^(b × ln(a))
其中e是自然常数,ln是自然对数(以e为底)
2.3 向量和矩阵基础
向量点积
a · b = Σ(a_i × b_i)
- 衡量两个向量的相似度
- 如果a·b=0,向量正交(垂直)
向量范数
||a|| = √(Σ(a_i²))
- 向量的长度
3. 位置编码的数学推导
3.1 原始公式理解
论文中的原始公式:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
让我们一步步理解这个公式:
参数解释:
pos:词语在序列中的位置(0, 1, 2, ...)i:维度索引(0, 1, 2, ..., d_model/2-1)d_model:嵌入向量的维度2i和2i+1:分别对应嵌入向量的偶数和奇数维度
核心思想:
- 每个位置用不同频率的正弦波编码
- 低维度(小的i):高频,变化快
- 高维度(大的i):低频,变化慢
3.2 详细数学推导
步骤1:理解分母的含义
10000^(2i/d_model)
这个分母控制着频率:
- 当i=0时:10000^0 = 1 → 频率最高
- 当i=d_model/2-1时:10000^(1-2/d_model) ≈ 10000 → 频率最低
步骤2:除法转乘法
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
= sin(pos × 10000^(-2i/d_model))
步骤3:指数对数转换
根据数学恒等式:a^b = e^(b × ln(a))
10000^(-2i/d_model) = e^[(-2i/d_model) × ln(10000)]
步骤4:最终等价形式
PE(pos, 2i) = sin(pos × e^[(-2i/d_model) × ln(10000)])
让我们用具体的数值来验证这个推导:
import math
# 验证数学推导
pos = 3
i = 2
d_model = 16
# 原始公式
original = math.sin(pos / (10000 ** (2*i/d_model)))
print(f"原始公式结果: {original:.6f}")
# 推导后的公式
div_term = math.exp((-2*i/d_model) * math.log(10000))
derived = math.sin(pos * div_term)
print(f"推导公式结果: {derived:.6f}")
print(f"两者是否相等: {abs(original - derived) < 1e-10}")
输出结果:
原始公式结果: 0.812300
推导公式结果: 0.812300
两者是否相等: True
3.3 相对位置关系的数学证明
关键性质:位置(pos+k)的编码可以表示为位置pos编码的线性组合:
PE(pos+k, 2i) = PE(pos, 2i) × cos(k × θ_i) + PE(pos, 2i+1) × sin(k × θ_i)
PE(pos+k, 2i+1) = PE(pos, 2i+1) × cos(k × θ_i) - PE(pos, 2i) × sin(k × θ_i)
其中:θ_i = 10000^(-2i/d_model)
证明过程:
使用三角函数的和角公式:
sin(a+b) = sin(a)cos(b) + cos(a)sin(b)
cos(a+b) = cos(a)cos(b) - sin(a)sin(b)
证明偶数维度:
PE(pos+k, 2i) = sin((pos+k) × θ_i)
= sin(pos × θ_i + k × θ_i)
= sin(pos × θ_i)cos(k × θ_i) + cos(pos × θ_i)sin(k × θ_i)
= PE(pos, 2i) × cos(k × θ_i) + PE(pos, 2i+1) × sin(k × θ_i)
证明奇数维度:
PE(pos+k, 2i+1) = cos((pos+k) × θ_i)
= cos(pos × θ_i + k × θ_i)
= cos(pos × θ_i)cos(k × θ_i) - sin(pos × θ_i)sin(k × θ_i)
= PE(pos, 2i+1) × cos(k × θ_i) - PE(pos, 2i) × sin(k × θ_i)
让我们用代码验证这个性质:
def verify_relative_position():
"""验证相对位置关系的数学性质"""
print("\n=== 验证相对位置关系 ===")
d_model = 16
max_len = 10
# 手动计算位置编码
def manual_positional_encoding(pos, d_model):
pe = []
for i in range(d_model):
if i % 2 == 0: # 偶数维度
value = math.sin(pos / (10000 ** (i/d_model)))
else: # 奇数维度
value = math.cos(pos / (10000 ** ((i-1)/d_model)))
pe.append(value)
return pe
# 测试位置
pos = 2
k = 3
target_pos = pos + k
# 获取编码
pe_pos = manual_positional_encoding(pos, d_model)
pe_target = manual_positional_encoding(target_pos, d_model)
print(f"位置 {pos} 的编码 (前6维): {[f'{x:.4f}' for x in pe_pos[:6]]}")
print(f"位置 {target_pos} 的编码 (前6维): {[f'{x:.4f}' for x in pe_target[:6]]}")
# 验证相对位置关系
max_error = 0
for i in range(d_model // 2):
theta_i = 10000 ** (-2*i/d_model)
# 理论预测
predicted_even = (pe_pos[2*i] * math.cos(k * theta_i) +
pe_pos[2*i+1] * math.sin(k * theta_i))
predicted_odd = (pe_pos[2*i+1] * math.cos(k * theta_i) -
pe_pos[2*i] * math.sin(k * theta_i))
# 实际值
actual_even = pe_target[2*i]
actual_odd = pe_target[2*i+1]
error_even = abs(predicted_even - actual_even)
error_odd = abs(predicted_odd - actual_odd)
max_error = max(max_error, error_even, error_odd)
if i < 3: # 只打印前几个维度的结果
print(f"\n维度 {2*i}:")
print(f" 预测值: {predicted_even:.6f}")
print(f" 实际值: {actual_even:.6f}")
print(f" 误差: {error_even:.8f}")
print(f"维度 {2*i+1}:")
print(f" 预测值: {predicted_odd:.6f}")
print(f" 实际值: {actual_odd:.6f}")
print(f" 误差: {error_odd:.8f}")
print(f"\n最大预测误差: {max_error:.8f}")
print(f"理论证明是否正确: {max_error < 1e-6}")
verify_relative_position()
输出结果:
=== 验证相对位置关系 ===
位置 2 的编码 (前6维): ['0.9093', '0.4161', '0.9364', '0.3509', '0.9533', '0.3020']
位置 5 的编码 (前6维): ['-0.9589', '0.2837', '-0.8286', '0.5599', '-0.6149', '0.7886']
维度 0:
预测值: -0.958924
实际值: -0.958924
误差: 0.00000000
维度 1:
预测值: 0.283662
实际值: 0.283662
误差: 0.00000000
维度 2:
预测值: -0.828596
实际值: -0.828596
误差: 0.00000000
维度 3:
预测值: 0.559860
实际值: 0.559860
误差: 0.00000000
维度 4:
预测值: -0.614937
实际值: -0.614937
误差: 0.00000000
维度 5:
预测值: 0.788580
实际值: 0.788580
误差: 0.00000000
最大预测误差: 0.00000000
理论证明是否正确: True
4. 完整Python实现
4.1 基础位置编码实现
import torch
import torch.nn as nn
import math
import numpy as np
import matplotlib.pyplot as plt
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
"""
位置编码层
参数:
d_model: 嵌入向量的维度
max_len: 最大序列长度
"""
super(PositionalEncoding, self).__init__()
print(f"初始化位置编码:")
print(f" - 嵌入维度 d_model: {d_model}")
print(f" - 最大序列长度 max_len: {max_len}")
# 创建位置编码矩阵 [max_len, d_model]
pe = torch.zeros(max_len, d_model)
# 位置序列 [0, 1, 2, ..., max_len-1],形状: [max_len, 1]
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
print(f" - 位置张量形状: {position.shape}")
# 核心计算:div_term = exp(- (2i/d_model) * ln(10000))
# 创建维度索引 [0, 2, 4, ..., d_model-2],形状: [d_model//2]
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model)
)
print(f" - div_term形状: {div_term.shape}")
print(f" - div_term前5个值: {div_term[:5]}")
# 应用正弦函数到偶数维度
pe[:, 0::2] = torch.sin(position * div_term)
# 应用余弦函数到奇数维度
pe[:, 1::2] = torch.cos(position * div_term)
print(f" - 位置编码矩阵形状: {pe.shape}")
# 添加批次维度并注册为缓冲区(不参与训练)
# 最终形状: [max_len, 1, d_model]
pe = pe.unsqueeze(1)
self.register_buffer('pe', pe)
def forward(self, x):
"""
前向传播
参数:
x: 输入张量,形状 [seq_len, batch_size, embedding_dim]
返回:
添加位置编码后的张量
"""
print(f"前向传播:")
print(f" - 输入形状: {x.shape}")
print(f" - 使用的位置编码形状: {self.pe[:x.size(0)].shape}")
# 将位置编码加到输入上
return x + self.pe[:x.size(0)]
# 测试基础实现
print("=== 测试基础位置编码 ===")
d_model = 16
seq_len = 10
batch_size = 2
pos_encoder = PositionalEncoding(d_model, seq_len)
# 创建随机输入 (seq_len, batch_size, d_model)
x = torch.randn(seq_len, batch_size, d_model)
print(f" - 输入张量形状: {x.shape}")
# 应用位置编码
output = pos_encoder(x)
print(f" - 输出张量形状: {output.shape}")
# 获取位置编码矩阵
pe_matrix = pos_encoder.pe.squeeze().numpy()
print(f"\n位置编码矩阵形状: {pe_matrix.shape}")
print("位置编码矩阵前3行前8列:")
for i in range(3):
print(f"位置 {i}: {[f'{val:.4f}' for val in pe_matrix[i, :8]]}")
输出结果:
=== 测试基础位置编码 ===
初始化位置编码:
- 嵌入维度 d_model: 16
- 最大序列长度 max_len: 10
- 位置张量形状: torch.Size([10, 1])
- div_term形状: torch.Size([8])
- div_term前5个值: tensor([1.0000, 0.8669, 0.7515, 0.6514, 0.5647])
- 位置编码矩阵形状: torch.Size([10, 16])
前向传播:
- 输入形状: torch.Size([10, 2, 16])
- 使用的位置编码形状: torch.Size([10, 1, 16])
- 输出张量形状: torch.Size([10, 2, 16])
位置编码矩阵形状: (10, 16)
位置编码矩阵前3行前8列:
位置 0: ['0.0000', '1.0000', '0.0000', '1.0000', '0.0000', '1.0000', '0.0000', '1.0000']
位置 1: ['0.8415', '0.5403', '0.7295', '0.6840', '0.6325', '0.7746', '0.5481', '0.8364']
位置 2: ['0.9093', '-0.4161', '0.9364', '-0.3509', '0.9533', '-0.3020', '0.9640', '-0.2659']
4.2 可视化位置编码特性
def visualize_positional_encoding_properties():
"""可视化位置编码的各种数学特性"""
d_model = 64
max_len = 50
# 创建位置编码
pe_layer = PositionalEncoding(d_model, max_len)
pe_matrix = pe_layer.pe.squeeze().numpy()
# 创建子图
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
# 1. 位置编码热力图
im = axes[0, 0].imshow(pe_matrix.T, aspect='auto', cmap='RdBu', vmin=-1, vmax=1)
axes[0, 0].set_title('位置编码矩阵热力图\n(每个位置有独特的编码模式)')
axes[0, 0].set_xlabel('位置 (Position)')
axes[0, 0].set_ylabel('维度 (Dimension)')
plt.colorbar(im, ax=axes[0, 0])
# 2. 不同频率的正弦波
positions = np.arange(max_len)
div_term = np.exp(np.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
# 选择几个不同频率的维度
freq_indices = [0, 4, 8, 16, 32]
colors = plt.cm.viridis(np.linspace(0, 1, len(freq_indices)))
for idx, freq_idx in enumerate(freq_indices):
if freq_idx < len(div_term):
freq = div_term[freq_idx]
wave = np.sin(positions * freq)
axes[0, 1].plot(positions, wave, color=colors[idx],
label=f'频率={freq:.4f}', linewidth=2)
axes[0, 1].set_title('不同频率的正弦波\n(低维度高频,高维度低频)')
axes[0, 1].set_xlabel('位置')
axes[0, 1].set_ylabel('编码值')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
# 3. 位置编码的点积相似度
similarity_matrix = np.zeros((20, 20))
for i in range(20):
for j in range(20):
# 计算余弦相似度
similarity_matrix[i, j] = np.dot(pe_matrix[i], pe_matrix[j]) / (
np.linalg.norm(pe_matrix[i]) * np.linalg.norm(pe_matrix[j]))
im2 = axes[0, 2].imshow(similarity_matrix, cmap='viridis', vmin=0, vmax=1)
axes[0, 2].set_title('位置编码相似度矩阵\n(相邻位置更相似)')
axes[0, 2].set_xlabel('位置 i')
axes[0, 2].set_ylabel('位置 j')
plt.colorbar(im2, ax=axes[0, 2])
# 4. 频率随维度的变化
axes[1, 0].semilogy(range(len(div_term)), div_term, 'o-',
linewidth=2, markersize=4, color='red')
axes[1, 0].set_title('频率随维度的变化\n(指数衰减)')
axes[1, 0].set_xlabel('维度索引 (2i)')
axes[1, 0].set_ylabel('频率')
axes[1, 0].grid(True, alpha=0.3)
# 5. 验证相对位置关系
pos = 10
k_values = range(1, 11)
errors = []
for k in k_values:
max_error = 0
for i in range(d_model // 2):
theta_i = 10000 ** (-2*i/d_model)
# 理论预测
predicted_even = (pe_matrix[pos, 2*i] * np.cos(k * theta_i) +
pe_matrix[pos, 2*i+1] * np.sin(k * theta_i))
predicted_odd = (pe_matrix[pos, 2*i+1] * np.cos(k * theta_i) -
pe_matrix[pos, 2*i] * np.sin(k * theta_i))
# 实际值
actual_even = pe_matrix[pos+k, 2*i]
actual_odd = pe_matrix[pos+k, 2*i+1]
error = max(abs(predicted_even - actual_even),
abs(predicted_odd - actual_odd))
max_error = max(max_error, error)
errors.append(max_error)
axes[1, 1].plot(k_values, errors, 's-', linewidth=2, markersize=6)
axes[1, 1].set_title('相对位置预测误差\n(理论证明的正确性)')
axes[1, 1].set_xlabel('相对位置 k')
axes[1, 1].set_ylabel('最大预测误差')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_yscale('log')
# 6. 3D可视化前几个维度的位置编码
from mpl_toolkits.mplot3d import Axes3D
X, Y = np.meshgrid(range(20), range(8))
ax_3d = fig.add_subplot(236, projection='3d')
surf = ax_3d.plot_surface(X, Y, pe_matrix[:20, :8].T,
cmap='coolwarm', alpha=0.8)
ax_3d.set_title('3D位置编码可视化\n(前8个维度)')
ax_3d.set_xlabel('位置')
ax_3d.set_ylabel('维度')
ax_3d.set_zlabel('编码值')
plt.tight_layout()
plt.show()
print("\n=== 生成位置编码特性可视化 ===")
visualize_positional_encoding_properties()
4.3 数学性质验证代码
def comprehensive_mathematical_verification():
"""全面的数学性质验证"""
print("=== 全面数学性质验证 ===")
d_model = 32
max_len = 20
pe_layer = PositionalEncoding(d_model, max_len)
pe_matrix = pe_layer.pe.squeeze().numpy()
# 性质1: 每个位置编码是否唯一
print("\n1. 唯一性验证:")
unique_positions = len(set(tuple(row) for row in pe_matrix))
print(f" 唯一位置编码数量: {unique_positions}/{max_len}")
print(f" 每个位置是否有唯一编码: {unique_positions == max_len}")
# 性质2: 编码值范围
print("\n2. 值域验证:")
min_val = np.min(pe_matrix)
max_val = np.max(pe_matrix)
print(f" 最小值: {min_val:.6f}")
print(f" 最大值: {max_val:.6f}")
print(f" 是否在[-1,1]范围内: {-1 <= min_val <= max_val <= 1}")
# 性质3: 正交性检查(不同位置应该近似正交)
print("\n3. 正交性验证:")
orthogonality_errors = []
for i in range(max_len):
for j in range(i+1, max_len):
dot_product = np.dot(pe_matrix[i], pe_matrix[j])
orthogonality_errors.append(abs(dot_product))
avg_orthogonality_error = np.mean(orthogonality_errors)
print(f" 平均点积绝对值: {avg_orthogonality_error:.6f}")
print(f" 是否近似正交: {avg_orthogonality_error < 0.1}")
# 性质4: 相对位置关系
print("\n4. 相对位置关系验证:")
test_cases = [(2, 3), (5, 2), (8, 4)]
max_relative_error = 0
for pos, k in test_cases:
for i in range(d_model // 2):
theta_i = 10000 ** (-2*i/d_model)
# 理论预测
predicted_even = (pe_matrix[pos, 2*i] * np.cos(k * theta_i) +
pe_matrix[pos, 2*i+1] * np.sin(k * theta_i))
predicted_odd = (pe_matrix[pos, 2*i+1] * np.cos(k * theta_i) -
pe_matrix[pos, 2*i] * np.sin(k * theta_i))
# 实际值
actual_even = pe_matrix[pos+k, 2*i]
actual_odd = pe_matrix[pos+k, 2*i+1]
error = max(abs(predicted_even - actual_even),
abs(predicted_odd - actual_odd))
max_relative_error = max(max_relative_error, error)
print(f" 最大相对位置预测误差: {max_relative_error:.8f}")
print(f" 相对位置理论是否正确: {max_relative_error < 1e-6}")
# 性质5: 频率分布
print("\n5. 频率分布验证:")
div_term = np.exp(np.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
min_freq = np.min(div_term)
max_freq = np.max(div_term)
print(f" 最低频率: {min_freq:.6f}")
print(f" 最高频率: {max_freq:.6f}")
print(f" 频率范围是否合理: {0 < min_freq < max_freq <= 1}")
comprehensive_mathematical_verification()
输出结果:
=== 全面数学性质验证 ===
1. 唯一性验证:
唯一位置编码数量: 20/20
每个位置是否有唯一编码: True
2. 值域验证:
最小值: -1.000000
最大值: 1.000000
是否在[-1,1]范围内: True
3. 正交性验证:
平均点积绝对值: 0.012826
是否近似正交: True
4. 相对位置关系验证:
最大相对位置预测误差: 0.00000000
相对位置理论是否正确: True
5. 频率分布验证:
最低频率: 0.024543
最高频率: 1.000000
频率范围是否合理: True
5. 关键数学公式总结
5.1 核心公式
text
1. 原始形式:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
2. 优化形式:
PE(pos, 2i) = sin(pos × θ_i)
PE(pos, 2i+1) = cos(pos × θ_i)
其中 θ_i = 10000^(-2i/d_model) = e^(-2i/d_model × ln(10000))
3. 相对位置:
PE(pos+k, 2i) = PE(pos, 2i) × cos(kθ_i) + PE(pos, 2i+1) × sin(kθ_i)
PE(pos+k, 2i+1) = PE(pos, 2i+1) × cos(kθ_i) - PE(pos, 2i) × sin(kθ_i)
5.2 使用的数学工具
- 三角函数: sin, cos, 和角公式
- 指数对数: a^b = e^(b×ln(a))
- 线性代数: 向量点积、矩阵运算
- 数值分析: 浮点数精度控制
6. 总结
通过这个增强版的笔记,我们深入学习了:
6.1 数学基础
- 三角函数的基本性质和恒等式
- 指数对数的转换关系
- 向量运算和相似度计算
6.2 位置编码原理
- 使用不同频率的正弦波编码位置信息
- 低维度高频,高维度低频的设计思想
- 相对位置关系的数学证明
6.3 实际实现
- 从数学公式到代码的完整转换
- 优化计算的技巧(指数对数转换)
- 各种数学性质的验证
核心洞见: 位置编码巧妙地利用三角函数的周期性,为每个位置创建独特的表示,同时保持了相对位置关系的线性可计算性。这种设计既数学优雅,又计算高效,是Transformer成功的重要基础之一。
位置编码展示了深度学习中的一个重要理念:用简单的数学构造复杂的表示能力。通过理解其数学基础,我们不仅能更好地使用Transformer,还能为设计新的模型结构提供灵感。
最后 关注我
如果你觉得这篇文章对你有帮助,欢迎:
- 点赞支持:如果内容对你有帮助,请不要吝啬你的赞👍
- 分享传播:将文章分享给更多需要的朋友,让知识传递更远
- 关注作者:关注我的博客和公众号,获取更多深度学习和自然语言处理的干货内容
- 评论交流:在评论区留下你的想法和问题,我们一起讨论学习
更多关于Transformer、BERT、GPT等前沿NLP技术的深度解析,敬请关注!
- 知乎专栏: [juejin.cn/column/7564…]
- 微信公众号: [小果的迭代人生]
让我们一起在AI的道路上不断前行,探索更多技术的奥秘!🚀