一、为什么需要LSTM?
想象你要记住一个重要的电话号码。普通循环神经网络(RNN)就像一个容易分心的人:当新信息不断输入时,旧的号码很快会被遗忘。这种现象称为长期依赖问题。长短期记忆网络(LSTM)的设计灵感来自带便签本的聪明人:它可以选择性地记录重要信息,还能随时擦除无用内容。
二、LSTM的核心组件:记忆本与三道门
2.1 记忆本的结构
每个LSTM单元都携带两个关键信息:
- 隐状态():对外传递的短期记忆,类似「当前要说的话」
- 记忆元():内部保存的长期记忆,类似「随身携带的笔记本」
2.2 控制记忆的三道门
LSTM 的核心在于对记忆信息的“读写控制”,这一机制借鉴了计算机逻辑门的思想。主要有三个门:
注:
- 表示当前时间步的输入;
- 表示前一时刻的隐状态;
- 是 sigmoid 激活函数;
- 、、 等均为各门的权重和偏置参数。
(1)输入门:决定写什么
决定当前输入中有多少信息需要写入记忆元。
公式:
示例:当输入是重要名字时,输入门会完全打开(值接近1),确保信息被记录。
(2)遗忘门:决定擦除什么
控制保留多少来自过去记忆元的信息。
公式:
示例:遇到「但是」等转折词时,遗忘门可能关闭(值接近0),清空之前的状态。
(3)输出门:决定读什么
决定记忆元中有多少信息通过处理后参与到最终输出的隐状态中。
公式:
2.3 记忆更新过程
(1)候选记忆内容
除了上述三个门控之外,LSTM 还引入了一个候选记忆元,用于生成可供更新记忆元的新信息。候选记忆元的计算与门控类似,但采用了 激活函数,其值域在 内:
(2)更新记忆本
记忆元的更新由遗忘门和输入门共同控制,其更新公式为:
其中:
- 为上一时刻的记忆元;
- 表示按元素乘法(Hadamard 乘积)。
这意味着当遗忘门输出接近 且输入门输出接近 时,过去的记忆会被大部分保留;反之,当输入门输出较高时,新信息会更多地写入记忆元。
(3)生成新隐状态
最终,LSTM 利用经过门控调制后的记忆元来计算当前的隐状态 。隐状态不仅作为下一时刻计算的输入,还会参与最终的预测输出,其计算公式为:
这样设计既保证了隐状态的数值稳定性(值域为 ),又确保输出层能够根据输出门的控制,灵活地选择传递多少记忆信息。
三、动手实现LSTM
3.1 初始化参数(PyTorch示例)
import torch
from torch import nn
import d2l
batch_size, num_steps = 32, 25
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
def get_lstm_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return torch.rand(shape, device=device) * 0.01
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
torch.zeros(num_hiddens, device=device))
W_xi, W_hi, b_i = three() # 输入门参数
W_xf, W_hf, b_f = three() # 遗忘门参数
W_xo, W_ho, b_o = three() # 输出门参数
W_xc, W_hc, b_c = three() # 候选记忆元参数
# 输出层参数
W_hq = normal((num_hiddens, num_outputs))
b_q = torch.zeros(num_outputs, device=device)
# 附加梯度
params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
b_c, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
3.2 前向传播过程
def init_lstm_state(batch_size, num_hiddens, device):
"""初始化长短期记忆网络的隐状态"""
return (torch.zeros(size=(batch_size, num_hiddens), device=device),
torch.zeros(size=(batch_size, num_hiddens), device=device))
def lstm(inputs, state, params):
W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params
H, C = state
outputs = []
for X in inputs:
I = torch.sigmoid(X @ W_xi + H @ W_hi + b_i) # 输入门
F = torch.sigmoid(X @ W_xf + H @ W_hf + b_f) # 遗忘门
O = torch.sigmoid(X @ W_xo + H @ W_ho + b_o) # 输出门
C_tilda = torch.tanh(X @ W_xc + H @ W_hc + b_c) # 候选记忆元
C = F * C + I * C_tilda # 记忆元
H = O * torch.tanh(C) # 隐状态
Y = H @ W_hq + b_q # 输出
outputs.append(Y)
return torch.cat(outputs, dim=0), (H, C)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(vocab_size, num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
3.3 使用高级API快速搭建
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, vocab_size)
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
四、总结
本文详细介绍了长短期记忆网络(LSTM)的基本机制及其实现方法,主要内容包括:
- 门控记忆元:通过输入门、遗忘门和输出门,LSTM 控制了信息的写入、保留和输出。 例如,输入门的计算公式为:
- 候选记忆元与记忆元更新:通过生成候选记忆元 并结合遗忘门和输入门,LSTM 实现了对过去和新信息的平衡更新。更新公式为:
- 隐状态的计算:依靠输出门对记忆元的调控,隐状态 的计算公式为:
- 模型实现:不仅展示了从零实现 LSTM 的详细过程,还可以利用 PyTorch 内置的高级 API 轻松构建 LSTM 模型,以便在实际项目中快速部署和训练。
长短期记忆网络作为解决长距离依赖问题的重要模型,其思想在自然语言处理、时间序列分析等领域都有广泛应用。希望本篇博客能够帮助你轻松理解 LSTM 的核心原理,并激发你进一步探索深度学习技术的兴趣!