现代循环神经网络2-长短期记忆网络(LSTM)

213 阅读5分钟

一、为什么需要LSTM?

想象你要记住一个重要的电话号码。普通循环神经网络(RNN)就像一个容易分心的人:当新信息不断输入时,旧的号码很快会被遗忘。这种现象称为长期依赖问题。长短期记忆网络(LSTM)的设计灵感来自带便签本的聪明人:它可以选择性地记录重要信息,还能随时擦除无用内容

二、LSTM的核心组件:记忆本与三道门

2.1 记忆本的结构

每个LSTM单元都携带两个关键信息:

  • 隐状态HtH_t):对外传递的短期记忆,类似「当前要说的话」
  • 记忆元CtC_t):内部保存的长期记忆,类似「随身携带的笔记本」

2.2 控制记忆的三道门

LSTM 的核心在于对记忆信息的“读写控制”,这一机制借鉴了计算机逻辑门的思想。主要有三个门:

lstm-0.svg

  • XtX_t 表示当前时间步的输入;
  • Ht1H_{t-1} 表示前一时刻的隐状态;
  • σ\sigma 是 sigmoid 激活函数;
  • WxiW_{xi}WhiW_{hi}bib_i 等均为各门的权重和偏置参数。

(1)输入门:决定写什么

决定当前输入中有多少信息需要写入记忆元。

公式:It=σ(XtWxi+Ht1Whi+bi)\boxed{I_t = \sigma(X_t W_{xi} + H_{t-1} W_{hi} + b_i)}

示例:当输入是重要名字时,输入门会完全打开(值接近1),确保信息被记录。

(2)遗忘门:决定擦除什么

控制保留多少来自过去记忆元的信息。

公式:Ft=σ(XtWxf+Ht1Whf+bf)\boxed{F_t = \sigma(X_t W_{xf} + H_{t-1} W_{hf} + b_f)}

示例:遇到「但是」等转折词时,遗忘门可能关闭(值接近0),清空之前的状态。

(3)输出门:决定读什么

决定记忆元中有多少信息通过处理后参与到最终输出的隐状态中。

公式:Ot=σ(XtWxo+Ht1Who+bo)\boxed{O_t = \sigma(X_t W_{xo} + H_{t-1} W_{ho} + b_o)}

2.3 记忆更新过程

(1)候选记忆内容

除了上述三个门控之外,LSTM 还引入了一个候选记忆元,用于生成可供更新记忆元的新信息。候选记忆元的计算与门控类似,但采用了 tanh\tanh 激活函数,其值域在 (1,1)(-1,1) 内:

C~t=tanh(XtWxc+Ht1Whc+bc)\boxed{\tilde{C}_t = \text{tanh}(X_t W_{xc} + H_{t-1} W_{hc} + b_c)}

(2)更新记忆本

记忆元的更新由遗忘门和输入门共同控制,其更新公式为:

Ct=FtCt1+ItC~t\boxed{C_t = F_t \odot C_{t-1} + I_t \odot \tilde{C}_t}

其中

  • Ct1C_{t-1} 为上一时刻的记忆元;
  • \odot 表示按元素乘法(Hadamard 乘积)。

这意味着当遗忘门输出接近 11 且输入门输出接近 00 时,过去的记忆会被大部分保留;反之,当输入门输出较高时,新信息会更多地写入记忆元。

(3)生成新隐状态

最终,LSTM 利用经过门控调制后的记忆元来计算当前的隐状态 HtH_t。隐状态不仅作为下一时刻计算的输入,还会参与最终的预测输出,其计算公式为:

Ht=Ottanh(Ct)\boxed{H_t = O_t \odot \text{tanh}(C_t)}

这样设计既保证了隐状态的数值稳定性(值域为 (1,1)(-1,1)),又确保输出层能够根据输出门的控制,灵活地选择传递多少记忆信息。

三、动手实现LSTM

3.1 初始化参数(PyTorch示例)

import torch
from torch import nn

import d2l

batch_size, num_steps = 32, 25
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)


def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.rand(shape, device=device) * 0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数

    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)

    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

3.2 前向传播过程

def init_lstm_state(batch_size, num_hiddens, device):
    """初始化长短期记忆网络的隐状态"""
    return (torch.zeros(size=(batch_size, num_hiddens), device=device),
            torch.zeros(size=(batch_size, num_hiddens), device=device))


def lstm(inputs, state, params):
    W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params
    H, C = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid(X @ W_xi + H @ W_hi + b_i)  # 输入门
        F = torch.sigmoid(X @ W_xf + H @ W_hf + b_f)  # 遗忘门
        O = torch.sigmoid(X @ W_xo + H @ W_ho + b_o)  # 输出门
        C_tilda = torch.tanh(X @ W_xc + H @ W_hc + b_c)  # 候选记忆元
        C = F * C + I * C_tilda  # 记忆元
        H = O * torch.tanh(C)  # 隐状态
        Y = H @ W_hq + b_q  # 输出
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)


vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(vocab_size, num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

屏幕截图 2025-03-07 153604.png

3.3 使用高级API快速搭建

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, vocab_size)
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

屏幕截图 2025-03-07 154035.png

四、总结

本文详细介绍了长短期记忆网络(LSTM)的基本机制及其实现方法,主要内容包括:

  • 门控记忆元:通过输入门、遗忘门和输出门,LSTM 控制了信息的写入、保留和输出。 例如,输入门的计算公式为:It=σ(XtWxi+Ht1Whi+bi)\boxed{I_t = \sigma\bigl(X_t W_{xi} + H_{t-1} W_{hi} + b_i\bigr)}
  • 候选记忆元与记忆元更新:通过生成候选记忆元 C~t\tilde{C}_t 并结合遗忘门和输入门,LSTM 实现了对过去和新信息的平衡更新。更新公式为:Ct=FtCt1+ItC~t\boxed{C_t = F_t \odot C_{t-1} + I_t \odot \tilde{C}_t}
  • 隐状态的计算:依靠输出门对记忆元的调控,隐状态 HtH_t 的计算公式为:Ht=Ottanh(Ct)\boxed{H_t = O_t \odot \tanh(C_t)}
  • 模型实现:不仅展示了从零实现 LSTM 的详细过程,还可以利用 PyTorch 内置的高级 API 轻松构建 LSTM 模型,以便在实际项目中快速部署和训练。

长短期记忆网络作为解决长距离依赖问题的重要模型,其思想在自然语言处理、时间序列分析等领域都有广泛应用。希望本篇博客能够帮助你轻松理解 LSTM 的核心原理,并激发你进一步探索深度学习技术的兴趣!