1. 引言
在之前的学习中,我们提到了循环神经网络(RNN)训练时可能遇到的 梯度爆炸 和 梯度消失 问题,并且在计算梯度时需要 分离梯度。但这些概念背后的数学原理是什么?本节将深入探讨 通过时间反向传播(Backpropagation Through Time,BPTT)算法,解释循环神经网络梯度的计算方式。
2. 循环神经网络的梯度计算
2.1 计算图展开
BPTT 是反向传播在 RNN 中的特定应用。它的核心思想是 将计算图在时间维度上展开,然后应用链式法则计算梯度。
假设时间步 t 的隐藏状态为 ht,输入为 xt,输出为 ot,隐藏层和输出层的权重分别为 Wh 和 Wo,则它们的计算方式为:
ht=f(Whht−1+Wxxt)
ot=g(Woht)
其中 f(⋅) 是激活函数,g(⋅) 是输出层的映射函数。
损失函数 L 对参数 W 的梯度计算如下:
∂W∂L=t=1∑T∂ot∂L⋅∂ht∂ot⋅∂W∂ht
在反向传播时,我们需要递归计算 ∂ht−1∂ht,这导致 梯度可能随着时间步增长而指数级衰减或增长,从而引发梯度消失或爆炸。
2.2 梯度爆炸与梯度消失
假设我们展开梯度的递归表达式(推导过程省略):
∂ht−1∂ht=Wh⋅∂ht−2∂ht−1
展开 T 步后,梯度变为:
∂h0∂hT=WhT
如果 Wh 的特征值大于 1,则 WhT 会指数级增长,导致梯度爆炸;如果特征值小于 1,则梯度会逐渐趋近于 0,导致梯度消失。
3. 解决梯度问题的方法
3.1 完全计算 BPTT
一种直接的方法是计算整个序列的梯度,但这在长序列上计算量过大,且容易发生梯度爆炸或消失。因此,通常不使用这种方式。
3.2 截断 BPTT
为了减少计算量,我们可以 在一定时间步后截断梯度计算,即只反向传播最近 k 个时间步的梯度,这种方法称为 截断的 BPTT(Truncated BPTT)。
3.3 随机截断
另一种方法是 随机决定何时截断梯度,以减少偏差。但实践表明,它并不比固定长度的截断方法效果更好,因此使用较少。
4. BPTT 计算过程示例
假设我们有一个 3 层的 RNN,在时间步 t,输入 xt,输出 ot,损失为 L。我们想计算权重 Wh 的梯度。
- 前向传播:
ht=f(Whht−1+Wxxt)
ot=g(Woht)
- 计算损失的梯度:
∂ot∂L=损失函数对输出的梯度
- 反向传播:
∂ht∂L=∂ot∂L⋅∂ht∂ot
- 梯度递归传播:
∂ht−1∂L=∂ht∂L⋅Wh
5. 结论
- BPTT 是反向传播在 RNN 中的应用,需要沿时间维度展开计算梯度。
- 梯度消失和梯度爆炸是 RNN 训练的主要挑战,可通过梯度截断或特殊结构(如 LSTM)缓解。
- 截断 BPTT 是常用的方法,可以在保证模型效果的同时减少计算量。
通过这些方法,我们可以更稳定地训练 RNN,使其能够更好地学习序列数据中的长期依赖关系。