Transformer核心架构详解:自注意力与多头注意力

0 阅读3分钟

本文深度剖析Transformer的核心机制——自注意力(Self-Attention)和多头注意力(Multi-Head Attention)。通过数学推导、可视化图表和PyTorch代码实现,详细讲解QKV矩阵计算、注意力分数、缩放点积注意力等关键技术。涵盖Transformer Block完整结构、残差连接、层归一化等工程实践要点,是理解现代大语言模型架构的必读教程。

一、Transformer架构全景

1.1 Transformer诞生的背景

2017年Google发表的《Attention Is All You Need》彻底改变了NLP领域。其核心创新:完全抛弃RNN/CNN,只用注意力机制

graph TB
    subgraph 传统方法
        A["RNN/LSTM"] --> B["串行计算<br/>梯度消失<br/>难以并行"]
    end
    
    subgraph Transformer革命
        C["Self-Attention"] --> D["并行计算<br/>长距离依赖<br/>可扩展性强"]
    end
    
    B -.2017年突破.-> D
    
    style B fill:#ffcdd2
    style D fill:#a5d6a7

1.2 完整架构图

graph TB
    subgraph Encoder[编码器 N×]
        Input1[输入嵌入] --> PE1[位置编码]
        PE1 --> MHA1[多头注意力]
        MHA1 --> Add1[Add & Norm]
        Add1 --> FFN1[前馈网络]
        FFN1 --> Add2[Add & Norm]
    end
    
    subgraph Decoder[解码器 N×]
        Input2[输出嵌入] --> PE2[位置编码]
        PE2 --> MHA2[Masked<br/>多头注意力]
        MHA2 --> Add3[Add & Norm]
        Add3 --> MHA3[交叉注意力]
        MHA3 --> Add4[Add & Norm]
        Add4 --> FFN2[前馈网络]
        FFN2 --> Add5[Add & Norm]
    end
    
    Add2 --> MHA3
    Add5 --> Linear[线性层]
    Linear --> Softmax[Softmax]
    Softmax --> Output[输出概率]
    
    style MHA1 fill:#fff59d
    style MHA2 fill:#fff59d
    style MHA3 fill:#fff59d
    style Output fill:#a5d6a7

1.3 三大核心组件

组件作用关键技术
Self-Attention捕捉序列内部依赖QKV矩阵、缩放点积
Multi-Head Attention多视角信息提取多个注意力头并行
Feed-Forward Network非线性变换两层全连接+激活

二、自注意力机制(Self-Attention)详解

2.1 核心思想

Self-Attention的本质:让序列中的每个元素都能"看到"其他所有元素,并决定关注程度。

graph LR
    subgraph 输入句子
        W1["我(I)"]
        W2["爱(love)"]
        W3["AI"]
    end
    
    subgraph 自注意力计算
        W2 -->|0.2| W1
        W2 -->|0.5| W2
        W2 -->|0.3| W3
    end
    
    Result["增强表示:<br/>爱' = 0.2×我 + 0.5×爱 + 0.3×AI"]
    
    W2 --> Result
    
    style W2 fill:#fff59d
    style Result fill:#a5d6a7

2.2 QKV三剑客:Query、Key、Value

核心概念:

  • Query (Q): "我要找什么?"(查询向量)
  • Key (K): "我是什么?"(键向量)
  • Value (V): "我有什么信息?"(值向量)

类比理解:就像图书馆检索系统

sequenceDiagram
    participant 你的查询
    participant 图书索引
    participant 图书内容
    
    你的查询->>图书索引: Query: "机器学习"
    图书索引->>图书索引: 计算相关度(Key匹配)
    Note over 图书索引: 《深度学习》: 0.9<br/>《数据结构》: 0.2<br/>《Python》: 0.6
    图书索引->>图书内容: 按权重提取Value
    图书内容->>你的查询: 返回加权信息

2.3 数学公式推导

Step 1: 生成QKV矩阵

对于输入序列 XRn×dX \in \mathbb{R}^{n \times d} (n个词,每个d维):

Q=XWQ,WQRd×dkK=XWK,WKRd×dkV=XWV,WVRd×dv\begin{align} Q &= XW^Q, \quad W^Q \in \mathbb{R}^{d \times d_k} \\ K &= XW^K, \quad W^K \in \mathbb{R}^{d \times d_k} \\ V &= XW^V, \quad W^V \in \mathbb{R}^{d \times d_v} \end{align}
graph LR
    X["输入矩阵 X<br/>[n × d]"] --> WQ["权重 W^Q<br/>[d × d_k]"]
    X --> WK["权重 W^K<br/>[d × d_k]"]
    X --> WV["权重 W^V<br/>[d × d_v]"]
    
    WQ --> Q["Query<br/>[n × d_k]"]
    WK --> K["Key<br/>[n × d_k]"]
    WV --> V["Value<br/>[n × d_v]"]
    
    style X fill:#e3f2fd
    style Q fill:#fff9c4
    style K fill:#c5e1a5
    style V fill:#ce93d8

Step 2: 计算注意力分数

Score=QKTdk\text{Score} = \frac{QK^T}{\sqrt{d_k}}

为什么要除以 dk\sqrt{d_k}?

graph TD
    A["问题: 维度d_k增大时<br/>点积值会急剧增长"] --> B["导致Softmax梯度消失"]
    B --> C["解决: 缩放因子 √d_k<br/>保持数值稳定"]
    
    style A fill:#ffcdd2
    style C fill:#a5d6a7

Step 3: Softmax归一化

Attention Weights=softmax(Score)=softmax(QKTdk)\text{Attention Weights} = \text{softmax}(\text{Score}) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)

Step 4: 加权求和

Output=Attention Weights×V\text{Output} = \text{Attention Weights} \times V

完整公式:

Attention(Q,K,V)=softmax(QKTdk)V\boxed{\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V}

2.4 直观示例:计算"爱"的新表示

输入句子:"我 爱 学习 AI"

graph TB
    subgraph 输入[Step 1: 生成QKV]
        X["输入:<br/>我 爱 学习 AI"]
        X --> Q["Q矩阵<br/>(查询)"]
        X --> K["K矩阵<br/>(键)"]
        X --> V["V矩阵<br/>(值)"]
    end
    
    subgraph 计算[Step 2-3: 计算注意力]
        Q --> Score["QK^T / √d_k"]
        K --> Score
        Score --> Softmax["Softmax归一化"]
    end
    
    subgraph 输出[Step 4: 加权求和]
        Softmax --> Weight["权重:<br/>我:0.1 爱:0.2<br/>学习:0.4 AI:0.3"]
        V --> Output["新表示:<br/>爱' = Σ(权重×值)"]
        Weight --> Output
    end
    
    style Score fill:#fff9c4
    style Softmax fill:#c5e1a5
    style Output fill:#a5d6a7

假设计算"爱"对其他词的注意力:

Query·KeyScoreSoftmax最终贡献
2.12.1/√64=0.260.10.1×V(我)
3.50.440.20.2×V(爱)
学习7.20.900.40.4×V(学习)
AI5.80.730.30.3×V(AI)

最终:"爱"的新表示 = 0.1×V(我) + 0.2×V(爱) + 0.4×V(学习) + 0.3×V(AI)


三、多头注意力(Multi-Head Attention)

3.1 为什么需要多头?

单头的局限:只能捕捉一种关系模式

graph LR
    subgraph 单头注意力
        S["我爱学习AI"] --> H1["只关注<br/>语义相关性"]
    end
    
    subgraph 多头注意力
        M["我爱学习AI"] --> MH1["Head1<br/>语义关系"]
        M --> MH2["Head2<br/>语法关系"]
        M --> MH3["Head3<br/>位置关系"]
        M --> MH4["Head4<br/>..."]
    end
    
    style H1 fill:#ffcdd2
    style MH1 fill:#a5d6a7
    style MH2 fill:#90caf9
    style MH3 fill:#ce93d8
    style MH4 fill:#fff59d

3.2 多头注意力架构

graph TB
    X["输入 X"] --> Split["线性变换并分割"]
    
    Split --> H1["Head 1<br/>Attention(Q1,K1,V1)"]
    Split --> H2["Head 2<br/>Attention(Q2,K2,V2)"]
    Split --> H3["Head 3<br/>Attention(Q3,K3,V3)"]
    Split --> H4["Head h<br/>Attention(Qh,Kh,Vh)"]
    
    H1 --> Concat["拼接 Concat"]
    H2 --> Concat
    H3 --> Concat
    H4 --> Concat
    
    Concat --> Linear["线性变换 W^O"]
    Linear --> Output["输出"]
    
    style Split fill:#e3f2fd
    style Concat fill:#fff9c4
    style Output fill:#a5d6a7

3.3 数学公式

对于 hh 个注意力头:

headi=Attention(QWiQ,KWiK,VWiV)MultiHead(Q,K,V)=Concat(head1,...,headh)WO\begin{align} \text{head}_i &= \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) \\ \text{MultiHead}(Q,K,V) &= \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O \end{align}

其中:

  • WiQRdmodel×dkW_i^Q \in \mathbb{R}^{d_{model} \times d_k}
  • WiKRdmodel×dkW_i^K \in \mathbb{R}^{d_{model} \times d_k}
  • WiVRdmodel×dvW_i^V \in \mathbb{R}^{d_{model} \times d_v}
  • WORhdv×dmodelW^O \in \mathbb{R}^{hd_v \times d_{model}}

典型配置(如BERT):

  • dmodel=768d_{model} = 768 (模型维度)
  • h=12h = 12 (注意力头数)
  • dk=dv=dmodel/h=64d_k = d_v = d_{model}/h = 64 (每个头的维度)

3.4 多头的优势

graph TB
    subgraph 例子[句子: The animal didn't cross the street because it was too tired]
        Example["it 指代什么?"]
    end
    
    subgraph Heads[不同注意力头的关注点]
        H1["Head 1<br/>it → animal<br/>(主语指代)"]
        H2["Head 2<br/>it → street<br/>(位置关系)"]
        H3["Head 3<br/>didn't → tired<br/>(因果关系)"]
    end
    
    Example --> H1
    Example --> H2
    Example --> H3
    
    H1 --> Final["综合判断:<br/>it = animal"]
    H2 --> Final
    H3 --> Final
    
    style H1 fill:#a5d6a7
    style H2 fill:#ffcdd2
    style H3 fill:#90caf9
    style Final fill:#fff59d

四、PyTorch完整实现

4.1 Scaled Dot-Product Attention

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    参数:
        Q: [batch_size, n_heads, seq_len, d_k]
        K: [batch_size, n_heads, seq_len, d_k]
        V: [batch_size, n_heads, seq_len, d_v]
        mask: [batch_size, 1, 1, seq_len] 可选
    返回:
        output: [batch_size, n_heads, seq_len, d_v]
        attention_weights: [batch_size, n_heads, seq_len, seq_len]
    """
    d_k = Q.size(-1)
    
    # 1. 计算注意力分数: QK^T / √d_k
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    # scores: [batch_size, n_heads, seq_len, seq_len]
    
    # 2. 可选:应用mask(用于Decoder中的自回归)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # 3. Softmax归一化
    attention_weights = F.softmax(scores, dim=-1)
    
    # 4. 加权求和
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights


# 测试
batch_size, n_heads, seq_len, d_k = 2, 8, 10, 64
Q = torch.randn(batch_size, n_heads, seq_len, d_k)
K = torch.randn(batch_size, n_heads, seq_len, d_k)
V = torch.randn(batch_size, n_heads, seq_len, d_k)

output, attn_weights = scaled_dot_product_attention(Q, K, V)
print(f"输出形状: {output.shape}")  # torch.Size([2, 8, 10, 64])
print(f"注意力权重: {attn_weights.shape}")  # torch.Size([2, 8, 10, 10])

4.2 Multi-Head Attention完整实现

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        """
        参数:
            d_model: 模型维度(如768)
            n_heads: 注意力头数(如12)
        """
        super().__init__()
        assert d_model % n_heads == 0, "d_model必须能被n_heads整除"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # 每个头的维度
        
        # QKV的线性变换矩阵
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        
        # 输出的线性变换
        self.W_O = nn.Linear(d_model, d_model)
        
    def split_heads(self, x):
        """
        将输入分割成多个头
        x: [batch_size, seq_len, d_model]
        返回: [batch_size, n_heads, seq_len, d_k]
        """
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
    
    def forward(self, Q, K, V, mask=None):
        """
        参数:
            Q, K, V: [batch_size, seq_len, d_model]
            mask: [batch_size, 1, 1, seq_len]
        """
        batch_size = Q.size(0)
        
        # 1. 线性变换
        Q = self.W_Q(Q)  # [batch_size, seq_len, d_model]
        K = self.W_K(K)
        V = self.W_V(V)
        
        # 2. 分割成多个头
        Q = self.split_heads(Q)  # [batch_size, n_heads, seq_len, d_k]
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # 3. 缩放点积注意力
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        # attn_output: [batch_size, n_heads, seq_len, d_k]
        
        # 4. 合并多个头
        attn_output = attn_output.transpose(1, 2).contiguous()
        # [batch_size, seq_len, n_heads, d_k]
        attn_output = attn_output.view(batch_size, -1, self.d_model)
        # [batch_size, seq_len, d_model]
        
        # 5. 最终线性变换
        output = self.W_O(attn_output)
        
        return output, attn_weights


# 使用示例
d_model = 512
n_heads = 8
batch_size = 2
seq_len = 10

mha = MultiHeadAttention(d_model, n_heads)

# 输入
x = torch.randn(batch_size, seq_len, d_model)

# 自注意力:Q=K=V
output, attn_weights = mha(x, x, x)

print(f"输入形状: {x.shape}")              # torch.Size([2, 10, 512])
print(f"输出形状: {output.shape}")         # torch.Size([2, 10, 512])
print(f"注意力权重: {attn_weights.shape}") # torch.Size([2, 8, 10, 10])

4.3 可视化注意力权重

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attn_weights, tokens, head_idx=0):
    """
    可视化某个注意力头的权重
    attn_weights: [batch_size, n_heads, seq_len, seq_len]
    tokens: 词列表
    head_idx: 要可视化的头索引
    """
    # 提取第一个样本的指定头
    attn = attn_weights[0, head_idx].detach().cpu().numpy()
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(attn, 
                xticklabels=tokens,
                yticklabels=tokens,
                cmap='YlOrRd',
                annot=True,
                fmt='.2f',
                cbar_kws={'label': '注意力权重'})
    plt.title(f'Attention Head {head_idx}')
    plt.xlabel('Key')
    plt.ylabel('Query')
    plt.tight_layout()
    plt.show()


# 示例:可视化英译中的注意力
tokens = ['I', 'love', 'learning', 'AI', '<EOS>']
seq_len = len(tokens)

# 模拟注意力权重
mha = MultiHeadAttention(d_model=512, n_heads=8)
x = torch.randn(1, seq_len, 512)
_, attn_weights = mha(x, x, x)

# 可视化第0个头
visualize_attention(attn_weights, tokens, head_idx=0)

输出效果:

注意力矩阵 (Head 0)
         I    love  learning  AI   <EOS>
I      0.20  0.15    0.10   0.35  0.20
love   0.10  0.50    0.30   0.05  0.05
learning 0.05 0.30    0.40   0.20  0.05
AI     0.15  0.10    0.25   0.45  0.05
<EOS>  0.05  0.05    0.05   0.10  0.75

五、Transformer Block完整结构

5.1 单个Block的组成

graph TB
    Input[输入 X] --> MHA[多头注意力]
    Input -.残差连接.-> Add1
    MHA --> Add1[Add]
    Add1 --> Norm1[Layer Norm]
    
    Norm1 --> FFN[前馈网络<br/>2层全连接]
    Norm1 -.残差连接.-> Add2
    FFN --> Add2[Add]
    Add2 --> Norm2[Layer Norm]
    Norm2 --> Output[输出 Y]
    
    style MHA fill:#fff59d
    style FFN fill:#90caf9
    style Add1 fill:#a5d6a7
    style Add2 fill:#a5d6a7
    style Norm1 fill:#ce93d8
    style Norm2 fill:#ce93d8

5.2 残差连接(Residual Connection)

为什么需要?

graph LR
    subgraph 无残差[深层网络问题]
        A["梯度消失"] --> B["难以训练"]
    end
    
    subgraph 残差解决[ResNet思想]
        C["Y = X + F(X)"] --> D["梯度直通<br/>易于优化"]
    end
    
    B -.引入残差.-> D
    
    style A fill:#ffcdd2
    style D fill:#a5d6a7

数学表示:

Output=LayerNorm(X+MultiHeadAttention(X))\text{Output} = \text{LayerNorm}(X + \text{MultiHeadAttention}(X))

5.3 Layer Normalization

与Batch Norm的区别:

特性Batch NormLayer Norm
归一化维度跨batch维度跨特征维度
适用场景CNN(固定batch)NLP(变长序列)
依赖性依赖batch大小独立于batch
graph LR
    subgraph BatchNorm[Batch Normalization]
        B1["样本1<br/>[d1,d2,d3]"]
        B2["样本2<br/>[d1,d2,d3]"]
        B3["样本3<br/>[d1,d2,d3]"]
        B1 --> BN["对每个特征<br/>跨样本归一化"]
        B2 --> BN
        B3 --> BN
    end
    
    subgraph LayerNorm[Layer Normalization]
        L1["样本1<br/>[d1,d2,d3]"]
        L1 --> LN["对每个样本<br/>跨特征归一化"]
    end
    
    style BN fill:#ffccbc
    style LN fill:#a5d6a7

Layer Norm公式:

LayerNorm(x)=γxμσ2+ϵ+β\text{LayerNorm}(x) = \gamma \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

其中 μ,σ\mu, \sigma 是当前层所有特征的均值和标准差。

5.4 前馈网络(Feed-Forward Network)

结构:两层全连接+激活函数

FFN(x)=ReLU(xW1+b1)W2+b2\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2
graph LR
    Input["输入<br/>[seq_len, d_model]"] --> Linear1["全连接1<br/>d_model → d_ff"]
    Linear1 --> ReLU["ReLU激活"]
    ReLU --> Linear2["全连接2<br/>d_ff → d_model"]
    Linear2 --> Output["输出<br/>[seq_len, d_model]"]
    
    style Linear1 fill:#fff9c4
    style ReLU fill:#90caf9
    style Linear2 fill:#c5e1a5

典型配置:

  • BERT: dmodel=768d_{model}=768, dff=3072d_{ff}=3072 (4倍)
  • GPT-3: dmodel=12288d_{model}=12288, dff=49152d_{ff}=49152 (4倍)

5.5 完整Transformer Block实现

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        """
        参数:
            d_model: 模型维度
            n_heads: 注意力头数
            d_ff: 前馈网络隐藏层维度
            dropout: Dropout比率
        """
        super().__init__()
        
        # 多头注意力
        self.mha = MultiHeadAttention(d_model, n_heads)
        
        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        
        # Layer Normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        x: [batch_size, seq_len, d_model]
        """
        # 1. 多头自注意力 + 残差连接 + LayerNorm
        attn_output, _ = self.mha(x, x, x, mask)
        attn_output = self.dropout1(attn_output)
        x = self.norm1(x + attn_output)  # 残差连接
        
        # 2. 前馈网络 + 残差连接 + LayerNorm
        ffn_output = self.ffn(x)
        ffn_output = self.dropout2(ffn_output)
        x = self.norm2(x + ffn_output)   # 残差连接
        
        return x


# 测试
d_model = 512
n_heads = 8
d_ff = 2048
block = TransformerBlock(d_model, n_heads, d_ff)

x = torch.randn(2, 10, d_model)  # [batch, seq_len, d_model]
output = block(x)

print(f"输入形状: {x.shape}")      # torch.Size([2, 10, 512])
print(f"输出形状: {output.shape}")  # torch.Size([2, 10, 512])

六、位置编码(Positional Encoding)

6.1 为什么需要位置编码?

问题:自注意力是置换不变的(permutation-invariant)

sentence1 = "我 爱 AI"
sentence2 = "AI 爱 我"

# 如果没有位置编码,Self-Attention会给出相同的结果!
graph TB
    Problem["问题: 自注意力无法区分词序"] --> Solution["解决: 添加位置信息"]
    
    Solution --> PE1["方案1: 学习位置嵌入<br/>(BERT)"]
    Solution --> PE2["方案2: 固定位置编码<br/>(原始Transformer)"]
    
    style Problem fill:#ffcdd2
    style Solution fill:#fff9c4
    style PE1 fill:#c5e1a5
    style PE2 fill:#90caf9

6.2 正弦位置编码

公式:

PE(pos,2i)=sin(pos100002i/dmodel)PE(pos,2i+1)=cos(pos100002i/dmodel)\begin{align} PE_{(pos, 2i)} &= \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) \\ PE_{(pos, 2i+1)} &= \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) \end{align}

其中:

  • pospos: 词的位置(0, 1, 2, ...)
  • ii: 维度索引(0, 1, ..., d_model/2)

优点:

  1. ✅ 可以处理任意长度的序列
  2. ✅ 不需要训练参数
  3. ✅ 相对位置关系可以通过线性变换表达

6.3 实现代码

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        # 创建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # 计算分母
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                            (-math.log(10000.0) / d_model))
        
        # 计算正弦和余弦
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维度
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维度
        
        # 添加batch维度
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        x: [batch_size, seq_len, d_model]
        """
        # 添加位置编码
        x = x + self.pe[:, :x.size(1), :]
        return x


# 可视化位置编码
def visualize_positional_encoding(d_model=128, max_len=100):
    pe = PositionalEncoding(d_model, max_len)
    encoding = pe.pe[0, :max_len, :].numpy()
    
    plt.figure(figsize=(15, 5))
    plt.imshow(encoding.T, aspect='auto', cmap='RdBu', 
               interpolation='nearest')
    plt.colorbar(label='编码值')
    plt.xlabel('位置')
    plt.ylabel('维度')
    plt.title('正弦位置编码可视化')
    plt.tight_layout()
    plt.show()

visualize_positional_encoding()

输出效果:会看到规律的波浪状图案,不同频率编码不同维度。


七、完整Encoder实现

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, d_ff, 
                 n_layers, max_len=5000, dropout=0.1):
        """
        参数:
            vocab_size: 词汇表大小
            d_model: 模型维度
            n_heads: 注意力头数
            d_ff: 前馈网络维度
            n_layers: Transformer Block层数
            max_len: 最大序列长度
            dropout: Dropout比率
        """
        super().__init__()
        
        # 词嵌入
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # 位置编码
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        
        # 多层Transformer Block
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        x: [batch_size, seq_len] (token indices)
        """
        # 1. 词嵌入 + 位置编码
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # 2. 通过多层Transformer Block
        for layer in self.layers:
            x = layer(x, mask)
        
        return x


# 使用示例:构建一个小型BERT
vocab_size = 30000
d_model = 768
n_heads = 12
d_ff = 3072
n_layers = 12

encoder = TransformerEncoder(vocab_size, d_model, n_heads, d_ff, n_layers)

# 输入token indices
input_ids = torch.randint(0, vocab_size, (2, 20))  # [batch=2, seq_len=20]
output = encoder(input_ids)

print(f"输入形状: {input_ids.shape}")  # torch.Size([2, 20])
print(f"输出形状: {output.shape}")     # torch.Size([2, 20, 768])
print(f"参数量: {sum(p.numel() for p in encoder.parameters())/1e6:.1f}M")
# 输出: 参数量: 110.1M (接近BERT-Base的110M)

八、关键概念总结

8.1 公式总结

组件公式说明
Self-AttentionAttention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V缩放点积注意力
Multi-HeadConcat(head1,...,headh)WO\text{Concat}(\text{head}_1,...,\text{head}_h)W^O多视角融合
FFNReLU(xW1+b1)W2+b2\text{ReLU}(xW_1+b_1)W_2+b_2两层全连接
Layer Normγxμσ2+ϵ+β\gamma\frac{x-\mu}{\sqrt{\sigma^2+\epsilon}}+\beta特征归一化
Positionalsin(pos100002i/d)\sin(\frac{pos}{10000^{2i/d}})位置编码

8.2 配置对比

模型d_modeln_headsn_layersd_ff参数量
BERT-Base76812123072110M
BERT-Large102416244096340M
GPT-276812123072117M
GPT-312288969649152175B

8.3 架构流程图

graph TB
    Start[输入Token IDs] --> Embed[词嵌入层]
    Embed --> PE[位置编码]
    PE --> Block1[Transformer Block 1]
    
    subgraph Block[每个Block内部]
        MHA[多头注意力] --> Add1[Add & Norm]
        Add1 --> FFN[前馈网络]
        FFN --> Add2[Add & Norm]
    end
    
    Block1 --> Block2[Transformer Block 2]
    Block2 --> BlockN[... Block N]
    BlockN --> Output[输出表示]
    
    style Embed fill:#e3f2fd
    style PE fill:#fff9c4
    style MHA fill:#ffccbc
    style FFN fill:#c5e1a5
    style Output fill:#a5d6a7

九、实战练习

练习1:计算注意力权重

题目:给定QKV矩阵,手工计算注意力输出

# 已知(简化为2x2矩阵方便计算)
Q = torch.tensor([[1.0, 0.0],
                  [0.0, 1.0]])

K = torch.tensor([[1.0, 0.0],
                  [0.5, 0.5]])

V = torch.tensor([[2.0, 0.0],
                  [1.0, 1.0]])

# 步骤:
# 1. 计算 QK^T / √d_k
# 2. Softmax
# 3. 乘以V

# 你的答案:

练习2:实现Masked Self-Attention

任务:修改Self-Attention,实现Decoder中的mask机制(当前词不能看到未来词)

def masked_self_attention(Q, K, V):
    """
    TODO: 实现masked attention
    提示: 使用torch.tril创建下三角mask
    """
    pass

练习3:分析注意力头

任务:加载预训练BERT,可视化不同注意力头关注的模式

from transformers import BertModel, BertTokenizer

model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

text = "The animal didn't cross the street because it was too tired."
inputs = tokenizer(text, return_tensors='pt')

with torch.no_grad():
    outputs = model(**inputs)
    attentions = outputs.attentions  # 12层,每层12个头