循环神经网络7-通过时间反向传播理解序列模型训练细节

57 阅读2分钟

1. 引言

在之前的学习中,我们提到了循环神经网络(RNN)训练时可能遇到的 梯度爆炸梯度消失 问题,并且在计算梯度时需要 分离梯度。但这些概念背后的数学原理是什么?本节将深入探讨 通过时间反向传播(Backpropagation Through Time,BPTT)算法,解释循环神经网络梯度的计算方式。

2. 循环神经网络的梯度计算

2.1 计算图展开

BPTT 是反向传播在 RNN 中的特定应用。它的核心思想是 将计算图在时间维度上展开,然后应用链式法则计算梯度。

假设时间步 tt 的隐藏状态为 hth_t,输入为 xtx_t,输出为 oto_t,隐藏层和输出层的权重分别为 WhW_hWoW_o,则它们的计算方式为:

ht=f(Whht1+Wxxt)h_t = f(W_h h_{t-1} + W_x x_t)
ot=g(Woht)o_t = g(W_o h_t)

其中 f()f(\cdot) 是激活函数,g()g(\cdot) 是输出层的映射函数。

损失函数 LL 对参数 WW 的梯度计算如下:

LW=t=1TLotoththtW\frac{\partial L}{\partial W} = \sum_{t=1}^{T} \frac{\partial L}{\partial o_t} \cdot \frac{\partial o_t}{\partial h_t} \cdot \frac{\partial h_t}{\partial W}

在反向传播时,我们需要递归计算 htht1\frac{\partial h_t}{\partial h_{t-1}},这导致 梯度可能随着时间步增长而指数级衰减或增长,从而引发梯度消失或爆炸。

2.2 梯度爆炸与梯度消失

假设我们展开梯度的递归表达式(推导过程省略):

htht1=Whht1ht2\frac{\partial h_t}{\partial h_{t-1}} = W_h \cdot \frac{\partial h_{t-1}}{\partial h_{t-2}}

展开 TT 步后,梯度变为:

hTh0=WhT\frac{\partial h_T}{\partial h_0} = W_h^{T}

如果 WhW_h 的特征值大于 1,则 WhTW_h^T 会指数级增长,导致梯度爆炸;如果特征值小于 1,则梯度会逐渐趋近于 0,导致梯度消失。

3. 解决梯度问题的方法

3.1 完全计算 BPTT

一种直接的方法是计算整个序列的梯度,但这在长序列上计算量过大,且容易发生梯度爆炸或消失。因此,通常不使用这种方式。

3.2 截断 BPTT

为了减少计算量,我们可以 在一定时间步后截断梯度计算,即只反向传播最近 kk 个时间步的梯度,这种方法称为 截断的 BPTT(Truncated BPTT)。

3.3 随机截断

另一种方法是 随机决定何时截断梯度,以减少偏差。但实践表明,它并不比固定长度的截断方法效果更好,因此使用较少。

4. BPTT 计算过程示例

假设我们有一个 3 层的 RNN,在时间步 tt,输入 xtx_t,输出 oto_t,损失为 LL。我们想计算权重 WhW_h 的梯度。

  1. 前向传播
ht=f(Whht1+Wxxt)h_t = f(W_h h_{t-1} + W_x x_t)
ot=g(Woht)o_t = g(W_o h_t)
  1. 计算损失的梯度
Lot=损失函数对输出的梯度\frac{\partial L}{\partial o_t} = \text{损失函数对输出的梯度}
  1. 反向传播
Lht=Lototht\frac{\partial L}{\partial h_t} = \frac{\partial L}{\partial o_t} \cdot \frac{\partial o_t}{\partial h_t}
  1. 梯度递归传播
Lht1=LhtWh\frac{\partial L}{\partial h_{t-1}} = \frac{\partial L}{\partial h_t} \cdot W_h

5. 结论

  • BPTT 是反向传播在 RNN 中的应用,需要沿时间维度展开计算梯度。
  • 梯度消失和梯度爆炸是 RNN 训练的主要挑战,可通过梯度截断或特殊结构(如 LSTM)缓解。
  • 截断 BPTT 是常用的方法,可以在保证模型效果的同时减少计算量。

通过这些方法,我们可以更稳定地训练 RNN,使其能够更好地学习序列数据中的长期依赖关系。