03-从RNN到Transformer-序列建模的演进之路

4 阅读7分钟

从RNN到Transformer:序列建模的演进之路

理解序列建模技术的发展脉络,掌握Transformer这一革命性架构的核心原理。

前言

在自然语言处理、语音识别、时间序列预测等领域,数据往往是序列形式的——数据的顺序很重要,前后之间存在依赖关系。

如何让神经网络理解这种"顺序"和"依赖"?这就是序列建模要解决的问题。今天,我们来梳理从RNN到Transformer的技术演进之路。


一、什么是序列数据?

序列数据的特点

序列数据中,元素的顺序至关重要:

数据类型示例特点
文本"我喜欢学习AI"词序决定含义
语音声音波形时间顺序重要
股票价格时间序列历史影响未来
DNA序列ATCG...顺序编码信息

序列建模的核心问题

输入序列:x₁, x₂, x₃, ..., xₙ
输出序列:y₁, y₂, y₃, ..., yₘ

核心挑战:
1. 输入输出长度可能不同
2. 需要捕捉长距离依赖
3. 序列中位置信息很重要

二、RNN:循环神经网络

核心思想

RNN通过循环连接来保持对之前信息的"记忆":

时刻 t-1      时刻 t       时刻 t+1
   hₜ₋₁ ────→ hₜ ────→ hₜ₊₁
    ↓          ↓           ↓
   输出       输出         输出
    ↑          ↑           ↑
   输入       输入         输入
   xₜ₋₁       xₜ          xₜ₊₁

数学表达

隐藏状态:hₜ = tanh(Wₕₕ · hₜ₋₁ + Wₓₕ · xₜ + bₕ)
输出:yₜ = Wₕᵧ · hₜ + b

代码实现

import numpy as np

class SimpleRNN:
    def __init__(self, input_size, hidden_size, output_size):
        # 初始化权重
        self.W_xh = np.random.randn(input_size, hidden_size) * 0.01
        self.W_hh = np.random.randn(hidden_size, hidden_size) * 0.01
        self.W_hy = np.random.randn(hidden_size, output_size) * 0.01
        self.b_h = np.zeros((1, hidden_size))
        self.b_y = np.zeros((1, output_size))

    def forward(self, x_sequence):
        """
        x_sequence: shape (seq_len, input_size)
        """
        h = np.zeros((1, self.W_hh.shape[0]))  # 初始隐藏状态
        outputs = []

        for x in x_sequence:
            x = x.reshape(1, -1)
            # 计算新的隐藏状态
            h = np.tanh(np.dot(x, self.W_xh) + np.dot(h, self.W_hh) + self.b_h)
            # 计算输出
            y = np.dot(h, self.W_hy) + self.b_y
            outputs.append(y)

        return outputs, h

# 使用示例
rnn = SimpleRNN(input_size=3, hidden_size=4, output_size=2)
sequence = np.random.randn(5, 3)  # 序列长度5,每个元素3维
outputs, final_h = rnn.forward(sequence)
print(f"输出序列长度: {len(outputs)}")
print(f"最终隐藏状态: {final_h}")

RNN的问题:梯度消失

随着序列变长,RNN很难记住早期的信息:

序列:我 昨天 在 公园 里 看 到 一只 很 可爱 的 小猫
                        ↑                      ↑
                      信息源                  预测点

距离太远,信息在传递过程中逐渐"消散"

数学原因

梯度在反向传播时需要连乘:
∂hₜ/∂h₁ = ∏ᵢ W_hh · tanh'(hᵢ)

如果 |W_hh · tanh'(hᵢ)| < 1,连乘后趋近于0(梯度消失)
如果 |W_hh · tanh'(hᵢ)| > 1,连乘后趋近于∞(梯度爆炸)

三、LSTM:长短期记忆网络

核心创新

LSTM通过门控机制来控制信息的遗忘和保留:

┌─────────────────────────────────────┐
│              LSTM 单元               │
│                                     │
│  ┌─────┐                           │
│  │遗忘门│ ← 决定丢弃哪些信息         │
│  └──┬──┘                           │
│     ↓                              │
│  ┌─────┐                           │
│  │输入门│ ← 决定存储哪些新信息       │
│  └──┬──┘                           │
│     ↓                              │
│  ┌─────┐                           │
│  │输出门│ ← 决定输出哪些信息         │
│  └──┬──┘                           │
│     ↓                              │
│   hₜ → 输出                        │
└─────────────────────────────────────┘

三个门的作用

def lstm_cell(x_t, h_prev, c_prev, params):
    """
    x_t: 当前输入
    h_prev: 上一个隐藏状态
    c_prev: 上一个单元状态(长期记忆)
    """
    # 合并输入
    combined = np.concatenate([h_prev, x_t], axis=1)

    # 遗忘门:决定从长期记忆中丢弃什么
    f_t = sigmoid(np.dot(combined, params['W_f']) + params['b_f'])

    # 输入门:决定往长期记忆中添加什么
    i_t = sigmoid(np.dot(combined, params['W_i']) + params['b_i'])

    # 候选记忆
    c_tilde = np.tanh(np.dot(combined, params['W_c']) + params['b_c'])

    # 更新长期记忆
    c_t = f_t * c_prev + i_t * c_tilde

    # 输出门:决定输出什么
    o_t = sigmoid(np.dot(combined, params['W_o']) + params['b_o'])

    # 更新隐藏状态
    h_t = o_t * np.tanh(c_t)

    return h_t, c_t

LSTM vs RNN

特性RNNLSTM
记忆能力短期记忆长期+短期记忆
梯度问题严重梯度消失有效缓解
参数量较少较多(约4倍)
训练速度较快较慢
长序列表现

GRU:LSTM的简化版

GRU将遗忘门和输入门合并,参数更少:

def gru_cell(x_t, h_prev, params):
    # 重置门
    r_t = sigmoid(np.dot(np.concatenate([h_prev, x_t], axis=1), params['W_r']) + params['b_r'])

    # 更新门
    z_t = sigmoid(np.dot(np.concatenate([h_prev, x_t], axis=1), params['W_z']) + params['b_z'])

    # 候选隐藏状态
    h_tilde = np.tanh(np.dot(np.concatenate([r_t * h_prev, x_t], axis=1), params['W_h']) + params['b_h'])

    # 新隐藏状态
    h_t = (1 - z_t) * h_prev + z_t * h_tilde

    return h_t

四、注意力机制:让模型学会"聚焦"

核心思想

人类阅读时会关注重点,神经网络也可以:

句子:我 喜欢 在 咖啡馆 里 看 书
权重:0.1  0.3  0.1  0.4   0.05 0.3  0.2

在理解"看书"这个词时,"咖啡馆"获得较高权重(语境信息)

注意力计算过程

Query (查询): 当前关注的目标
Key (键): 每个位置的特征
Value (值): 每个位置的信息

注意力 = softmax(Q · Kᵀ / √d) · V
import torch
import torch.nn.functional as F

def attention(query, key, value, mask=None):
    """
    query: (batch, seq_len_q, d_k)
    key: (batch, seq_len_k, d_k)
    value: (batch, seq_len_v, d_v)
    """
    d_k = query.size(-1)

    # 计算注意力分数
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)

    # 应用mask(可选)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # softmax归一化
    attention_weights = F.softmax(scores, dim=-1)

    # 加权求和
    output = torch.matmul(attention_weights, value)

    return output, attention_weights

# 示例
batch_size = 2
seq_len = 5
d_model = 8

Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output, weights = attention(Q, K, V)
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")
print(f"权重和: {weights[0, 0].sum()}")  # 应该约等于1

五、Transformer:抛弃循环,拥抱注意力

革命性创新

2017年,Google发表论文《Attention Is All You Need》,提出Transformer架构:

核心改变

  • ❌ 不再使用循环结构
  • ✅ 完全基于注意力机制
  • ✅ 支持并行计算

Transformer架构

输入嵌入
    ↓
位置编码 ← 为序列添加位置信息
    ↓
┌─────────────────────┐
│   编码器(Nx层)     │
│  ┌───────────────┐  │
│  │ 多头自注意力   │  │
│  └───────┬───────┘  │
│          ↓          │
│  ┌───────────────┐  │
│  │ 前馈神经网络   │  │
│  └───────────────┘  │
└─────────────────────┘
    ↓
┌─────────────────────┐
│   解码器(Nx层)     │
│  ┌───────────────┐  │
│  │ 掩码多头注意力 │  │
│  └───────┬───────┘  │
│          ↓          │
│  ┌───────────────┐  │
│  │ 编码器-解码器  │  │
│  │   注意力      │  │
│  └───────┬───────┘  │
│          ↓          │
│  ┌───────────────┐  │
│  │ 前馈神经网络   │  │
│  └───────────────┘  │
└─────────────────────┘
    ↓
  输出层

关键组件详解

1. 位置编码

由于Transformer没有循环结构,需要显式注入位置信息:

def positional_encoding(seq_len, d_model):
    """
    使用正弦和余弦函数生成位置编码
    """
    position = torch.arange(seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))

    pe = torch.zeros(1, seq_len, d_model)
    pe[0, :, 0::2] = torch.sin(position * div_term)  # 偶数维度
    pe[0, :, 1::2] = torch.cos(position * div_term)  # 奇数维度

    return pe

# 可视化位置编码
pe = positional_encoding(50, 128)
print(f"位置编码形状: {pe.shape}")
2. 多头注意力

将注意力分成多个"头",分别关注不同的信息:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 线性变换
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        # 分成多个头
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # 注意力计算
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)

        # 合并多头
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.num_heads * self.d_k)

        # 输出投影
        output = self.W_o(attn_output)

        return output, attn_weights
3. 前馈网络
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

完整Transformer编码器层

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # 自注意力 + 残差连接 + LayerNorm
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))

        # 前馈网络 + 残差连接 + LayerNorm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))

        return x

六、Transformer vs RNN 对比

特性RNN/LSTMTransformer
计算方式串行并行
长距离依赖较弱很强
训练速度
参数量较少较多
可解释性较弱较强(注意力可视化)
内存占用高(O(n²))

七、Transformer的衍生模型

                    Transformer (2017)
                          │
            ┌─────────────┼─────────────┐
            ↓             ↓             ↓
         BERT          GPT           T5
       (双向编码)     (单向生成)    (编码-解码)
            │             │             │
            ↓             ↓             ↓
      RoBERTa,      GPT-2,3,4     T5-base
      ALBERT...     ChatGPT...    large...

小结

模型核心创新解决的问题
RNN循环结构序列建模的基础
LSTM门控机制长期记忆、梯度消失
GRU简化门控参数效率
Attention动态权重信息聚焦
Transformer自注意力并行计算、长距离依赖

思考与练习

  1. 思考题

    • 为什么Transformer需要位置编码?
    • 多头注意力比单头有什么优势?
  2. 动手练习

    • 用PyTorch实现一个简单的Transformer编码器
    • 尝试训练一个简单的序列分类任务
  3. 延伸阅读


下期预告

下一篇文章,我们将深入探讨:词向量到嵌入:让机器理解语言的奥秘

会解答这些问题:

  • Word2Vec是如何学习词义的?
  • 词向量为什么能捕捉语义关系?
  • 现代嵌入技术有哪些新发展?

关注专栏,不错过后续更新!


作者:ECH00O00 本文首发于掘金专栏《AI科普实验室》 欢迎评论区交流讨论,点赞收藏就是最大的鼓励