从RNN到Transformer:序列建模的演进之路
理解序列建模技术的发展脉络,掌握Transformer这一革命性架构的核心原理。
前言
在自然语言处理、语音识别、时间序列预测等领域,数据往往是序列形式的——数据的顺序很重要,前后之间存在依赖关系。
如何让神经网络理解这种"顺序"和"依赖"?这就是序列建模要解决的问题。今天,我们来梳理从RNN到Transformer的技术演进之路。
一、什么是序列数据?
序列数据的特点
序列数据中,元素的顺序至关重要:
| 数据类型 | 示例 | 特点 |
|---|---|---|
| 文本 | "我喜欢学习AI" | 词序决定含义 |
| 语音 | 声音波形 | 时间顺序重要 |
| 股票价格 | 时间序列 | 历史影响未来 |
| DNA序列 | ATCG... | 顺序编码信息 |
序列建模的核心问题
输入序列:x₁, x₂, x₃, ..., xₙ
输出序列:y₁, y₂, y₃, ..., yₘ
核心挑战:
1. 输入输出长度可能不同
2. 需要捕捉长距离依赖
3. 序列中位置信息很重要
二、RNN:循环神经网络
核心思想
RNN通过循环连接来保持对之前信息的"记忆":
时刻 t-1 时刻 t 时刻 t+1
hₜ₋₁ ────→ hₜ ────→ hₜ₊₁
↓ ↓ ↓
输出 输出 输出
↑ ↑ ↑
输入 输入 输入
xₜ₋₁ xₜ xₜ₊₁
数学表达
隐藏状态:hₜ = tanh(Wₕₕ · hₜ₋₁ + Wₓₕ · xₜ + bₕ)
输出:yₜ = Wₕᵧ · hₜ + bᵧ
代码实现
import numpy as np
class SimpleRNN:
def __init__(self, input_size, hidden_size, output_size):
# 初始化权重
self.W_xh = np.random.randn(input_size, hidden_size) * 0.01
self.W_hh = np.random.randn(hidden_size, hidden_size) * 0.01
self.W_hy = np.random.randn(hidden_size, output_size) * 0.01
self.b_h = np.zeros((1, hidden_size))
self.b_y = np.zeros((1, output_size))
def forward(self, x_sequence):
"""
x_sequence: shape (seq_len, input_size)
"""
h = np.zeros((1, self.W_hh.shape[0])) # 初始隐藏状态
outputs = []
for x in x_sequence:
x = x.reshape(1, -1)
# 计算新的隐藏状态
h = np.tanh(np.dot(x, self.W_xh) + np.dot(h, self.W_hh) + self.b_h)
# 计算输出
y = np.dot(h, self.W_hy) + self.b_y
outputs.append(y)
return outputs, h
# 使用示例
rnn = SimpleRNN(input_size=3, hidden_size=4, output_size=2)
sequence = np.random.randn(5, 3) # 序列长度5,每个元素3维
outputs, final_h = rnn.forward(sequence)
print(f"输出序列长度: {len(outputs)}")
print(f"最终隐藏状态: {final_h}")
RNN的问题:梯度消失
随着序列变长,RNN很难记住早期的信息:
序列:我 昨天 在 公园 里 看 到 一只 很 可爱 的 小猫
↑ ↑
信息源 预测点
距离太远,信息在传递过程中逐渐"消散"
数学原因:
梯度在反向传播时需要连乘:
∂hₜ/∂h₁ = ∏ᵢ W_hh · tanh'(hᵢ)
如果 |W_hh · tanh'(hᵢ)| < 1,连乘后趋近于0(梯度消失)
如果 |W_hh · tanh'(hᵢ)| > 1,连乘后趋近于∞(梯度爆炸)
三、LSTM:长短期记忆网络
核心创新
LSTM通过门控机制来控制信息的遗忘和保留:
┌─────────────────────────────────────┐
│ LSTM 单元 │
│ │
│ ┌─────┐ │
│ │遗忘门│ ← 决定丢弃哪些信息 │
│ └──┬──┘ │
│ ↓ │
│ ┌─────┐ │
│ │输入门│ ← 决定存储哪些新信息 │
│ └──┬──┘ │
│ ↓ │
│ ┌─────┐ │
│ │输出门│ ← 决定输出哪些信息 │
│ └──┬──┘ │
│ ↓ │
│ hₜ → 输出 │
└─────────────────────────────────────┘
三个门的作用
def lstm_cell(x_t, h_prev, c_prev, params):
"""
x_t: 当前输入
h_prev: 上一个隐藏状态
c_prev: 上一个单元状态(长期记忆)
"""
# 合并输入
combined = np.concatenate([h_prev, x_t], axis=1)
# 遗忘门:决定从长期记忆中丢弃什么
f_t = sigmoid(np.dot(combined, params['W_f']) + params['b_f'])
# 输入门:决定往长期记忆中添加什么
i_t = sigmoid(np.dot(combined, params['W_i']) + params['b_i'])
# 候选记忆
c_tilde = np.tanh(np.dot(combined, params['W_c']) + params['b_c'])
# 更新长期记忆
c_t = f_t * c_prev + i_t * c_tilde
# 输出门:决定输出什么
o_t = sigmoid(np.dot(combined, params['W_o']) + params['b_o'])
# 更新隐藏状态
h_t = o_t * np.tanh(c_t)
return h_t, c_t
LSTM vs RNN
| 特性 | RNN | LSTM |
|---|---|---|
| 记忆能力 | 短期记忆 | 长期+短期记忆 |
| 梯度问题 | 严重梯度消失 | 有效缓解 |
| 参数量 | 较少 | 较多(约4倍) |
| 训练速度 | 较快 | 较慢 |
| 长序列表现 | 差 | 好 |
GRU:LSTM的简化版
GRU将遗忘门和输入门合并,参数更少:
def gru_cell(x_t, h_prev, params):
# 重置门
r_t = sigmoid(np.dot(np.concatenate([h_prev, x_t], axis=1), params['W_r']) + params['b_r'])
# 更新门
z_t = sigmoid(np.dot(np.concatenate([h_prev, x_t], axis=1), params['W_z']) + params['b_z'])
# 候选隐藏状态
h_tilde = np.tanh(np.dot(np.concatenate([r_t * h_prev, x_t], axis=1), params['W_h']) + params['b_h'])
# 新隐藏状态
h_t = (1 - z_t) * h_prev + z_t * h_tilde
return h_t
四、注意力机制:让模型学会"聚焦"
核心思想
人类阅读时会关注重点,神经网络也可以:
句子:我 喜欢 在 咖啡馆 里 看 书
权重:0.1 0.3 0.1 0.4 0.05 0.3 0.2
在理解"看书"这个词时,"咖啡馆"获得较高权重(语境信息)
注意力计算过程
Query (查询): 当前关注的目标
Key (键): 每个位置的特征
Value (值): 每个位置的信息
注意力 = softmax(Q · Kᵀ / √d) · V
import torch
import torch.nn.functional as F
def attention(query, key, value, mask=None):
"""
query: (batch, seq_len_q, d_k)
key: (batch, seq_len_k, d_k)
value: (batch, seq_len_v, d_v)
"""
d_k = query.size(-1)
# 计算注意力分数
scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)
# 应用mask(可选)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# softmax归一化
attention_weights = F.softmax(scores, dim=-1)
# 加权求和
output = torch.matmul(attention_weights, value)
return output, attention_weights
# 示例
batch_size = 2
seq_len = 5
d_model = 8
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)
output, weights = attention(Q, K, V)
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")
print(f"权重和: {weights[0, 0].sum()}") # 应该约等于1
五、Transformer:抛弃循环,拥抱注意力
革命性创新
2017年,Google发表论文《Attention Is All You Need》,提出Transformer架构:
核心改变:
- ❌ 不再使用循环结构
- ✅ 完全基于注意力机制
- ✅ 支持并行计算
Transformer架构
输入嵌入
↓
位置编码 ← 为序列添加位置信息
↓
┌─────────────────────┐
│ 编码器(Nx层) │
│ ┌───────────────┐ │
│ │ 多头自注意力 │ │
│ └───────┬───────┘ │
│ ↓ │
│ ┌───────────────┐ │
│ │ 前馈神经网络 │ │
│ └───────────────┘ │
└─────────────────────┘
↓
┌─────────────────────┐
│ 解码器(Nx层) │
│ ┌───────────────┐ │
│ │ 掩码多头注意力 │ │
│ └───────┬───────┘ │
│ ↓ │
│ ┌───────────────┐ │
│ │ 编码器-解码器 │ │
│ │ 注意力 │ │
│ └───────┬───────┘ │
│ ↓ │
│ ┌───────────────┐ │
│ │ 前馈神经网络 │ │
│ └───────────────┘ │
└─────────────────────┘
↓
输出层
关键组件详解
1. 位置编码
由于Transformer没有循环结构,需要显式注入位置信息:
def positional_encoding(seq_len, d_model):
"""
使用正弦和余弦函数生成位置编码
"""
position = torch.arange(seq_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
pe = torch.zeros(1, seq_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[0, :, 1::2] = torch.cos(position * div_term) # 奇数维度
return pe
# 可视化位置编码
pe = positional_encoding(50, 128)
print(f"位置编码形状: {pe.shape}")
2. 多头注意力
将注意力分成多个"头",分别关注不同的信息:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 线性变换
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# 分成多个头
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 注意力计算
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.num_heads * self.d_k)
# 输出投影
output = self.W_o(attn_output)
return output, attn_weights
3. 前馈网络
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.relu(self.linear1(x))))
完整Transformer编码器层
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力 + 残差连接 + LayerNorm
attn_output, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# 前馈网络 + 残差连接 + LayerNorm
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
六、Transformer vs RNN 对比
| 特性 | RNN/LSTM | Transformer |
|---|---|---|
| 计算方式 | 串行 | 并行 |
| 长距离依赖 | 较弱 | 很强 |
| 训练速度 | 慢 | 快 |
| 参数量 | 较少 | 较多 |
| 可解释性 | 较弱 | 较强(注意力可视化) |
| 内存占用 | 低 | 高(O(n²)) |
七、Transformer的衍生模型
Transformer (2017)
│
┌─────────────┼─────────────┐
↓ ↓ ↓
BERT GPT T5
(双向编码) (单向生成) (编码-解码)
│ │ │
↓ ↓ ↓
RoBERTa, GPT-2,3,4 T5-base
ALBERT... ChatGPT... large...
小结
| 模型 | 核心创新 | 解决的问题 |
|---|---|---|
| RNN | 循环结构 | 序列建模的基础 |
| LSTM | 门控机制 | 长期记忆、梯度消失 |
| GRU | 简化门控 | 参数效率 |
| Attention | 动态权重 | 信息聚焦 |
| Transformer | 自注意力 | 并行计算、长距离依赖 |
思考与练习
-
思考题:
- 为什么Transformer需要位置编码?
- 多头注意力比单头有什么优势?
-
动手练习:
- 用PyTorch实现一个简单的Transformer编码器
- 尝试训练一个简单的序列分类任务
-
延伸阅读:
下期预告
下一篇文章,我们将深入探讨:词向量到嵌入:让机器理解语言的奥秘
会解答这些问题:
- Word2Vec是如何学习词义的?
- 词向量为什么能捕捉语义关系?
- 现代嵌入技术有哪些新发展?
关注专栏,不错过后续更新!
作者:ECH00O00 本文首发于掘金专栏《AI科普实验室》 欢迎评论区交流讨论,点赞收藏就是最大的鼓励