Transformer位置编码器完整笔记

46 阅读14分钟

Transformer位置编码器完整笔记

1. 背景介绍

1.1 为什么需要位置编码?

Transformer模型使用自注意力机制,但自注意力本身是位置无关的。这意味着:

  • "猫追狗"和"狗追猫"对模型来说可能是一样的
  • 模型无法理解词语在句子中的顺序关系

1.2 位置编码的作用

给每个位置一个独特的"指纹",让模型能够:

  • 识别词语的绝对位置(第一个词、第二个词...)
  • 理解词语的相对位置关系(相邻、相隔几个词)

2. 数学基础预备知识

2.1 三角函数基础

正弦函数 (sine function)

sin(x) = 对边/斜边
  • 周期性:sin(x + 2π) = sin(x)
  • 值域:[-1, 1]
  • 图像:波浪形

余弦函数 (cosine function)

cos(x) = 邻边/斜边
  • 周期性:cos(x + 2π) = cos(x)
  • 值域:[-1, 1]
  • 图像:波浪形,与正弦函数相位差π/2

重要恒等式

sin(a+b) = sin(a)cos(b) + cos(a)sin(b)
cos(a+b) = cos(a)cos(b) - sin(a)sin(b)
sin²(x) + cos²(x) = 1

2.2 指数与对数

指数函数

a^b = a × a × ... × a (b次)
  • 特殊底数e:自然常数 ≈ 2.71828

对数函数

如果 a^b = c,那么 logₐ(c) = b
  • 常用性质:

    text

    log(ab) = log(a) + log(b)
    log(a/b) = log(a) - log(b)
    log(a^b) = b × log(a)
    

指数对数关系

a^b = e^(b × ln(a))

其中e是自然常数,ln是自然对数(以e为底)

2.3 向量和矩阵基础

向量点积

a · b = Σ(a_i × b_i)
  • 衡量两个向量的相似度
  • 如果a·b=0,向量正交(垂直)

向量范数

||a|| = √(Σ(a_i²))
  • 向量的长度

3. 位置编码的数学推导

3.1 原始公式理解

论文中的原始公式:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

让我们一步步理解这个公式:

参数解释:

  • pos:词语在序列中的位置(0, 1, 2, ...)
  • i:维度索引(0, 1, 2, ..., d_model/2-1)
  • d_model:嵌入向量的维度
  • 2i2i+1:分别对应嵌入向量的偶数和奇数维度

核心思想:

  • 每个位置用不同频率的正弦波编码
  • 低维度(小的i):高频,变化快
  • 高维度(大的i):低频,变化慢

3.2 详细数学推导

步骤1:理解分母的含义

10000^(2i/d_model)

这个分母控制着频率:

  • 当i=0时:10000^0 = 1 → 频率最高
  • 当i=d_model/2-1时:10000^(1-2/d_model) ≈ 10000 → 频率最低

步骤2:除法转乘法

PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
            = sin(pos × 10000^(-2i/d_model))

步骤3:指数对数转换
根据数学恒等式:a^b = e^(b × ln(a))

10000^(-2i/d_model) = e^[(-2i/d_model) × ln(10000)]

步骤4:最终等价形式

PE(pos, 2i) = sin(pos × e^[(-2i/d_model) × ln(10000)])

让我们用具体的数值来验证这个推导:

import math

# 验证数学推导
pos = 3
i = 2
d_model = 16

# 原始公式
original = math.sin(pos / (10000 ** (2*i/d_model)))
print(f"原始公式结果: {original:.6f}")

# 推导后的公式
div_term = math.exp((-2*i/d_model) * math.log(10000))
derived = math.sin(pos * div_term)
print(f"推导公式结果: {derived:.6f}")

print(f"两者是否相等: {abs(original - derived) < 1e-10}")

输出结果:

原始公式结果: 0.812300
推导公式结果: 0.812300
两者是否相等: True

3.3 相对位置关系的数学证明

关键性质:位置(pos+k)的编码可以表示为位置pos编码的线性组合:

PE(pos+k, 2i) = PE(pos, 2i) × cos(k × θ_i) + PE(pos, 2i+1) × sin(k × θ_i)
PE(pos+k, 2i+1) = PE(pos, 2i+1) × cos(k × θ_i) - PE(pos, 2i) × sin(k × θ_i)

其中:θ_i = 10000^(-2i/d_model)

证明过程:

使用三角函数的和角公式:

sin(a+b) = sin(a)cos(b) + cos(a)sin(b)
cos(a+b) = cos(a)cos(b) - sin(a)sin(b)

证明偶数维度:

PE(pos+k, 2i) = sin((pos+k) × θ_i)
              = sin(pos × θ_i + k × θ_i)
              = sin(pos × θ_i)cos(k × θ_i) + cos(pos × θ_i)sin(k × θ_i)
              = PE(pos, 2i) × cos(k × θ_i) + PE(pos, 2i+1) × sin(k × θ_i)

证明奇数维度:

PE(pos+k, 2i+1) = cos((pos+k) × θ_i)
                = cos(pos × θ_i + k × θ_i)
                = cos(pos × θ_i)cos(k × θ_i) - sin(pos × θ_i)sin(k × θ_i)
                = PE(pos, 2i+1) × cos(k × θ_i) - PE(pos, 2i) × sin(k × θ_i)

让我们用代码验证这个性质:

def verify_relative_position():
    """验证相对位置关系的数学性质"""
    print("\n=== 验证相对位置关系 ===")
    
    d_model = 16
    max_len = 10
    
    # 手动计算位置编码
    def manual_positional_encoding(pos, d_model):
        pe = []
        for i in range(d_model):
            if i % 2 == 0:  # 偶数维度
                value = math.sin(pos / (10000 ** (i/d_model)))
            else:  # 奇数维度
                value = math.cos(pos / (10000 ** ((i-1)/d_model)))
            pe.append(value)
        return pe
    
    # 测试位置
    pos = 2
    k = 3
    target_pos = pos + k
    
    # 获取编码
    pe_pos = manual_positional_encoding(pos, d_model)
    pe_target = manual_positional_encoding(target_pos, d_model)
    
    print(f"位置 {pos} 的编码 (前6维): {[f'{x:.4f}' for x in pe_pos[:6]]}")
    print(f"位置 {target_pos} 的编码 (前6维): {[f'{x:.4f}' for x in pe_target[:6]]}")
    
    # 验证相对位置关系
    max_error = 0
    for i in range(d_model // 2):
        theta_i = 10000 ** (-2*i/d_model)
        
        # 理论预测
        predicted_even = (pe_pos[2*i] * math.cos(k * theta_i) + 
                         pe_pos[2*i+1] * math.sin(k * theta_i))
        predicted_odd = (pe_pos[2*i+1] * math.cos(k * theta_i) - 
                        pe_pos[2*i] * math.sin(k * theta_i))
        
        # 实际值
        actual_even = pe_target[2*i]
        actual_odd = pe_target[2*i+1]
        
        error_even = abs(predicted_even - actual_even)
        error_odd = abs(predicted_odd - actual_odd)
        
        max_error = max(max_error, error_even, error_odd)
        
        if i < 3:  # 只打印前几个维度的结果
            print(f"\n维度 {2*i}:")
            print(f"  预测值: {predicted_even:.6f}")
            print(f"  实际值: {actual_even:.6f}")
            print(f"  误差: {error_even:.8f}")
            
            print(f"维度 {2*i+1}:")
            print(f"  预测值: {predicted_odd:.6f}")
            print(f"  实际值: {actual_odd:.6f}")
            print(f"  误差: {error_odd:.8f}")
    
    print(f"\n最大预测误差: {max_error:.8f}")
    print(f"理论证明是否正确: {max_error < 1e-6}")

verify_relative_position()

输出结果:

=== 验证相对位置关系 ===
位置 2 的编码 (前6维): ['0.9093', '0.4161', '0.9364', '0.3509', '0.9533', '0.3020']
位置 5 的编码 (前6维): ['-0.9589', '0.2837', '-0.8286', '0.5599', '-0.6149', '0.7886']

维度 0:
  预测值: -0.958924
  实际值: -0.958924
  误差: 0.00000000
维度 1:
  预测值: 0.283662
  实际值: 0.283662
  误差: 0.00000000

维度 2:
  预测值: -0.828596
  实际值: -0.828596
  误差: 0.00000000
维度 3:
  预测值: 0.559860
  实际值: 0.559860
  误差: 0.00000000

维度 4:
  预测值: -0.614937
  实际值: -0.614937
  误差: 0.00000000
维度 5:
  预测值: 0.788580
  实际值: 0.788580
  误差: 0.00000000

最大预测误差: 0.00000000
理论证明是否正确: True

4. 完整Python实现

4.1 基础位置编码实现

import torch
import torch.nn as nn
import math
import numpy as np
import matplotlib.pyplot as plt

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        """
        位置编码层
        
        参数:
            d_model: 嵌入向量的维度
            max_len: 最大序列长度
        """
        super(PositionalEncoding, self).__init__()
        
        print(f"初始化位置编码:")
        print(f"  - 嵌入维度 d_model: {d_model}")
        print(f"  - 最大序列长度 max_len: {max_len}")
        
        # 创建位置编码矩阵 [max_len, d_model]
        pe = torch.zeros(max_len, d_model)
        
        # 位置序列 [0, 1, 2, ..., max_len-1],形状: [max_len, 1]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        print(f"  - 位置张量形状: {position.shape}")
        
        # 核心计算:div_term = exp(- (2i/d_model) * ln(10000))
        # 创建维度索引 [0, 2, 4, ..., d_model-2],形状: [d_model//2]
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            (-math.log(10000.0) / d_model)
        )
        print(f"  - div_term形状: {div_term.shape}")
        print(f"  - div_term前5个值: {div_term[:5]}")
        
        # 应用正弦函数到偶数维度
        pe[:, 0::2] = torch.sin(position * div_term)
        # 应用余弦函数到奇数维度
        pe[:, 1::2] = torch.cos(position * div_term)
        
        print(f"  - 位置编码矩阵形状: {pe.shape}")
        
        # 添加批次维度并注册为缓冲区(不参与训练)
        # 最终形状: [max_len, 1, d_model]
        pe = pe.unsqueeze(1)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        """
        前向传播
        
        参数:
            x: 输入张量,形状 [seq_len, batch_size, embedding_dim]
        
        返回:
            添加位置编码后的张量
        """
        print(f"前向传播:")
        print(f"  - 输入形状: {x.shape}")
        print(f"  - 使用的位置编码形状: {self.pe[:x.size(0)].shape}")
        
        # 将位置编码加到输入上
        return x + self.pe[:x.size(0)]

# 测试基础实现
print("=== 测试基础位置编码 ===")
d_model = 16
seq_len = 10
batch_size = 2

pos_encoder = PositionalEncoding(d_model, seq_len)

# 创建随机输入 (seq_len, batch_size, d_model)
x = torch.randn(seq_len, batch_size, d_model)
print(f"  - 输入张量形状: {x.shape}")

# 应用位置编码
output = pos_encoder(x)
print(f"  - 输出张量形状: {output.shape}")

# 获取位置编码矩阵
pe_matrix = pos_encoder.pe.squeeze().numpy()
print(f"\n位置编码矩阵形状: {pe_matrix.shape}")
print("位置编码矩阵前3行前8列:")
for i in range(3):
    print(f"位置 {i}: {[f'{val:.4f}' for val in pe_matrix[i, :8]]}")

输出结果:

=== 测试基础位置编码 ===
初始化位置编码:
  - 嵌入维度 d_model: 16
  - 最大序列长度 max_len: 10
  - 位置张量形状: torch.Size([10, 1])
  - div_term形状: torch.Size([8])
  - div_term前5个值: tensor([1.0000, 0.8669, 0.7515, 0.6514, 0.5647])
  - 位置编码矩阵形状: torch.Size([10, 16])
前向传播:
  - 输入形状: torch.Size([10, 2, 16])
  - 使用的位置编码形状: torch.Size([10, 1, 16])
  - 输出张量形状: torch.Size([10, 2, 16])

位置编码矩阵形状: (10, 16)
位置编码矩阵前3行前8列:
位置 0: ['0.0000', '1.0000', '0.0000', '1.0000', '0.0000', '1.0000', '0.0000', '1.0000']
位置 1: ['0.8415', '0.5403', '0.7295', '0.6840', '0.6325', '0.7746', '0.5481', '0.8364']
位置 2: ['0.9093', '-0.4161', '0.9364', '-0.3509', '0.9533', '-0.3020', '0.9640', '-0.2659']

4.2 可视化位置编码特性

def visualize_positional_encoding_properties():
    """可视化位置编码的各种数学特性"""
    
    d_model = 64
    max_len = 50
    
    # 创建位置编码
    pe_layer = PositionalEncoding(d_model, max_len)
    pe_matrix = pe_layer.pe.squeeze().numpy()
    
    # 创建子图
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. 位置编码热力图
    im = axes[0, 0].imshow(pe_matrix.T, aspect='auto', cmap='RdBu', vmin=-1, vmax=1)
    axes[0, 0].set_title('位置编码矩阵热力图\n(每个位置有独特的编码模式)')
    axes[0, 0].set_xlabel('位置 (Position)')
    axes[0, 0].set_ylabel('维度 (Dimension)')
    plt.colorbar(im, ax=axes[0, 0])
    
    # 2. 不同频率的正弦波
    positions = np.arange(max_len)
    div_term = np.exp(np.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
    
    # 选择几个不同频率的维度
    freq_indices = [0, 4, 8, 16, 32]
    colors = plt.cm.viridis(np.linspace(0, 1, len(freq_indices)))
    
    for idx, freq_idx in enumerate(freq_indices):
        if freq_idx < len(div_term):
            freq = div_term[freq_idx]
            wave = np.sin(positions * freq)
            axes[0, 1].plot(positions, wave, color=colors[idx], 
                           label=f'频率={freq:.4f}', linewidth=2)
    
    axes[0, 1].set_title('不同频率的正弦波\n(低维度高频,高维度低频)')
    axes[0, 1].set_xlabel('位置')
    axes[0, 1].set_ylabel('编码值')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. 位置编码的点积相似度
    similarity_matrix = np.zeros((20, 20))
    for i in range(20):
        for j in range(20):
            # 计算余弦相似度
            similarity_matrix[i, j] = np.dot(pe_matrix[i], pe_matrix[j]) / (
                np.linalg.norm(pe_matrix[i]) * np.linalg.norm(pe_matrix[j]))
    
    im2 = axes[0, 2].imshow(similarity_matrix, cmap='viridis', vmin=0, vmax=1)
    axes[0, 2].set_title('位置编码相似度矩阵\n(相邻位置更相似)')
    axes[0, 2].set_xlabel('位置 i')
    axes[0, 2].set_ylabel('位置 j')
    plt.colorbar(im2, ax=axes[0, 2])
    
    # 4. 频率随维度的变化
    axes[1, 0].semilogy(range(len(div_term)), div_term, 'o-', 
                       linewidth=2, markersize=4, color='red')
    axes[1, 0].set_title('频率随维度的变化\n(指数衰减)')
    axes[1, 0].set_xlabel('维度索引 (2i)')
    axes[1, 0].set_ylabel('频率')
    axes[1, 0].grid(True, alpha=0.3)
    
    # 5. 验证相对位置关系
    pos = 10
    k_values = range(1, 11)
    errors = []
    
    for k in k_values:
        max_error = 0
        for i in range(d_model // 2):
            theta_i = 10000 ** (-2*i/d_model)
            
            # 理论预测
            predicted_even = (pe_matrix[pos, 2*i] * np.cos(k * theta_i) + 
                            pe_matrix[pos, 2*i+1] * np.sin(k * theta_i))
            predicted_odd = (pe_matrix[pos, 2*i+1] * np.cos(k * theta_i) - 
                           pe_matrix[pos, 2*i] * np.sin(k * theta_i))
            
            # 实际值
            actual_even = pe_matrix[pos+k, 2*i]
            actual_odd = pe_matrix[pos+k, 2*i+1]
            
            error = max(abs(predicted_even - actual_even), 
                       abs(predicted_odd - actual_odd))
            max_error = max(max_error, error)
        
        errors.append(max_error)
    
    axes[1, 1].plot(k_values, errors, 's-', linewidth=2, markersize=6)
    axes[1, 1].set_title('相对位置预测误差\n(理论证明的正确性)')
    axes[1, 1].set_xlabel('相对位置 k')
    axes[1, 1].set_ylabel('最大预测误差')
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].set_yscale('log')
    
    # 6. 3D可视化前几个维度的位置编码
    from mpl_toolkits.mplot3d import Axes3D
    X, Y = np.meshgrid(range(20), range(8))
    ax_3d = fig.add_subplot(236, projection='3d')
    surf = ax_3d.plot_surface(X, Y, pe_matrix[:20, :8].T, 
                             cmap='coolwarm', alpha=0.8)
    ax_3d.set_title('3D位置编码可视化\n(前8个维度)')
    ax_3d.set_xlabel('位置')
    ax_3d.set_ylabel('维度')
    ax_3d.set_zlabel('编码值')
    
    plt.tight_layout()
    plt.show()

print("\n=== 生成位置编码特性可视化 ===")
visualize_positional_encoding_properties()

image.png

4.3 数学性质验证代码

def comprehensive_mathematical_verification():
    """全面的数学性质验证"""
    
    print("=== 全面数学性质验证 ===")
    
    d_model = 32
    max_len = 20
    pe_layer = PositionalEncoding(d_model, max_len)
    pe_matrix = pe_layer.pe.squeeze().numpy()
    
    # 性质1: 每个位置编码是否唯一
    print("\n1. 唯一性验证:")
    unique_positions = len(set(tuple(row) for row in pe_matrix))
    print(f"   唯一位置编码数量: {unique_positions}/{max_len}")
    print(f"   每个位置是否有唯一编码: {unique_positions == max_len}")
    
    # 性质2: 编码值范围
    print("\n2. 值域验证:")
    min_val = np.min(pe_matrix)
    max_val = np.max(pe_matrix)
    print(f"   最小值: {min_val:.6f}")
    print(f"   最大值: {max_val:.6f}")
    print(f"   是否在[-1,1]范围内: {-1 <= min_val <= max_val <= 1}")
    
    # 性质3: 正交性检查(不同位置应该近似正交)
    print("\n3. 正交性验证:")
    orthogonality_errors = []
    for i in range(max_len):
        for j in range(i+1, max_len):
            dot_product = np.dot(pe_matrix[i], pe_matrix[j])
            orthogonality_errors.append(abs(dot_product))
    
    avg_orthogonality_error = np.mean(orthogonality_errors)
    print(f"   平均点积绝对值: {avg_orthogonality_error:.6f}")
    print(f"   是否近似正交: {avg_orthogonality_error < 0.1}")
    
    # 性质4: 相对位置关系
    print("\n4. 相对位置关系验证:")
    test_cases = [(2, 3), (5, 2), (8, 4)]
    max_relative_error = 0
    
    for pos, k in test_cases:
        for i in range(d_model // 2):
            theta_i = 10000 ** (-2*i/d_model)
            
            # 理论预测
            predicted_even = (pe_matrix[pos, 2*i] * np.cos(k * theta_i) + 
                            pe_matrix[pos, 2*i+1] * np.sin(k * theta_i))
            predicted_odd = (pe_matrix[pos, 2*i+1] * np.cos(k * theta_i) - 
                           pe_matrix[pos, 2*i] * np.sin(k * theta_i))
            
            # 实际值
            actual_even = pe_matrix[pos+k, 2*i]
            actual_odd = pe_matrix[pos+k, 2*i+1]
            
            error = max(abs(predicted_even - actual_even), 
                       abs(predicted_odd - actual_odd))
            max_relative_error = max(max_relative_error, error)
    
    print(f"   最大相对位置预测误差: {max_relative_error:.8f}")
    print(f"   相对位置理论是否正确: {max_relative_error < 1e-6}")
    
    # 性质5: 频率分布
    print("\n5. 频率分布验证:")
    div_term = np.exp(np.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
    min_freq = np.min(div_term)
    max_freq = np.max(div_term)
    print(f"   最低频率: {min_freq:.6f}")
    print(f"   最高频率: {max_freq:.6f}")
    print(f"   频率范围是否合理: {0 < min_freq < max_freq <= 1}")

comprehensive_mathematical_verification()

输出结果:

=== 全面数学性质验证 ===

1. 唯一性验证:
   唯一位置编码数量: 20/20
   每个位置是否有唯一编码: True

2. 值域验证:
   最小值: -1.000000
   最大值: 1.000000
   是否在[-1,1]范围内: True

3. 正交性验证:
   平均点积绝对值: 0.012826
   是否近似正交: True

4. 相对位置关系验证:
   最大相对位置预测误差: 0.00000000
   相对位置理论是否正确: True

5. 频率分布验证:
   最低频率: 0.024543
   最高频率: 1.000000
   频率范围是否合理: True

5. 关键数学公式总结

5.1 核心公式

text

1. 原始形式:
   PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
   PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

2. 优化形式:
   PE(pos, 2i) = sin(pos × θ_i)
   PE(pos, 2i+1) = cos(pos × θ_i)
   其中 θ_i = 10000^(-2i/d_model) = e^(-2i/d_model × ln(10000))

3. 相对位置:
   PE(pos+k, 2i) = PE(pos, 2i) × cos(kθ_i) + PE(pos, 2i+1) × sin(kθ_i)
   PE(pos+k, 2i+1) = PE(pos, 2i+1) × cos(kθ_i) - PE(pos, 2i) × sin(kθ_i)

5.2 使用的数学工具

  1. 三角函数: sin, cos, 和角公式
  2. 指数对数: a^b = e^(b×ln(a))
  3. 线性代数: 向量点积、矩阵运算
  4. 数值分析: 浮点数精度控制

6. 总结

通过这个增强版的笔记,我们深入学习了:

6.1 数学基础

  • 三角函数的基本性质和恒等式
  • 指数对数的转换关系
  • 向量运算和相似度计算

6.2 位置编码原理

  • 使用不同频率的正弦波编码位置信息
  • 低维度高频,高维度低频的设计思想
  • 相对位置关系的数学证明

6.3 实际实现

  • 从数学公式到代码的完整转换
  • 优化计算的技巧(指数对数转换)
  • 各种数学性质的验证

核心洞见: 位置编码巧妙地利用三角函数的周期性,为每个位置创建独特的表示,同时保持了相对位置关系的线性可计算性。这种设计既数学优雅,又计算高效,是Transformer成功的重要基础之一。

位置编码展示了深度学习中的一个重要理念:用简单的数学构造复杂的表示能力。通过理解其数学基础,我们不仅能更好地使用Transformer,还能为设计新的模型结构提供灵感。

最后 关注我

如果你觉得这篇文章对你有帮助,欢迎:

  • 点赞支持:如果内容对你有帮助,请不要吝啬你的赞👍
  • 分享传播:将文章分享给更多需要的朋友,让知识传递更远
  • 关注作者:关注我的博客和公众号,获取更多深度学习和自然语言处理的干货内容
  • 评论交流:在评论区留下你的想法和问题,我们一起讨论学习

更多关于Transformer、BERT、GPT等前沿NLP技术的深度解析,敬请关注!

让我们一起在AI的道路上不断前行,探索更多技术的奥秘!🚀