RNN 梯度计算详细推导 (BPTT)
为了详细推导循环神经网络(RNN)中的梯度计算方法——沿时间反向传播(Backpropagation Through Time, BPTT),我们将使用一个最基础的RNN模型结构。
第一步:定义RNN模型和符号
在一个时间步 t,RNN的计算过程如下:
-
隐藏状态 (Hidden State):
ht=f(Uxt+Wht−1+b)
- xt:在时间步
t 的输入向量。
- ht:在时间步
t 的隐藏状态向量。h0 通常初始化为零向量。
- ht−1:前一个时间步的隐藏状态。
- U,W,b:循环层的参数。U 是输入到隐藏层的权重矩阵,W 是隐藏层到隐藏层的权重矩阵(循环权重),b 是偏置向量。
- f:激活函数,通常是
tanh 或 ReLU。这里我们假设为 tanh。
-
输出 (Output):
ot=Vht+c
- ot:在时间步
t 的输出(或称为 logits)。
- V,c:输出层的参数。V 是隐藏层到输出层的权重矩阵,c 是偏置向量。
-
预测概率 (Predicted Probability):
y^t=g(ot)
- g:输出激活函数,对于分类任务通常是
Softmax。
-
损失函数 (Loss Function):
Lt=Loss(y^t,yt)
- Lt:在时间步
t 的损失,例如交叉熵损失。
- yt:在时间步
t 的真实标签。
总体目标
我们的目标是计算总损失 L=∑t=1TLt 对所有模型参数 θ={U,W,V,b,c} 的梯度。即求解:∂V∂L,∂c∂L,∂W∂L,∂U∂L,∂b∂L。
第二步:前向传播
模型按照 t=1,2,…,T 的顺序,依次计算出每个时间步的 ht,ot,y^t,Lt,并最终得到总损失 L。这个过程比较直接,就是将输入序列喂给模型,得到输出和损失。
第三步:反向传播 (BPTT)
梯度是反向计算的,从最后一个时间步 T 开始,一直传播到第一个时间步 1。
A. 输出层参数的梯度 (∂V∂L,∂c∂L)
这部分比较简单,因为 V 和 c 的计算不涉及时间上的循环依赖。总损失对它们的梯度是每个时间步梯度贡献的总和。
∂V∂L=t=1∑T∂V∂Lt
我们来看单个时间步 t 的梯度 ∂V∂Lt。根据链式法则:
∂V∂Lt=∂ot∂Lt∂V∂ot
其中:
- ∂ot∂Lt 是损失对输出 logits 的梯度。我们将其记为 δo,t。例如,对于Softmax和交叉熵损失,δo,t=y^t−yt。
- ∂V∂ot:因为 ot=Vht+c,所以 ∂V∂ot=htT(转置是为了维度匹配)。
所以,
∂V∂Lt=δo,t⋅htT
最终,
∂V∂L=t=1∑Tδo,t⋅htT
同理,对于偏置 c:
∂c∂L=t=1∑Tδo,t
B. 循环层参数的梯度 (∂W∂L,∂U∂L,∂b∂L)
这是BPTT的核心和难点。我们以 ∂W∂L 为例进行推导。
为了解决 W 在时间上的复杂依赖,我们引入一个关键的中间量:总损失 L 对隐藏状态 ht 的梯度,记为 δh,t=∂ht∂L。
根据链式法则,L 通过两条路径影响 ht:
- 通过当前时间步的输出 ot。
- 通过下一个时间步的隐藏状态 ht+1(因为 ht+1 的计算用到了 ht)。
因此,δh,t 的计算是一个从后向前的递归过程:
δh,t=∂ht∂L=∂ot∂L∂ht∂ot+∂ht+1∂L∂ht∂ht+1
将各部分代入,我们得到 δh,t 的递归公式:
δh,t=δo,tVT+δh,t+1WTdiag(f′(ht+1))
- 递归的起点(Base Case):在最后一个时间步 T,没有未来的隐藏状态,所以递归项为0。
δh,T=∂hT∂L=∂oT∂LT∂hT∂oT=δo,TVT
我们可以从 δh,T 开始,反向计算出 δh,T−1,…,δh,1。
现在,我们用 δh,t 来计算最终的梯度 ∂W∂L。
∂W∂L=t=1∑T∂ht∂L∂W∂ht=t=1∑Tδh,t∂W∂ht
根据 ht=f(Uxt+Wht−1+b),我们有 ∂W∂ht=diag(f′(ht))⋅ht−1T。
所以,
∂W∂L=t=1∑Tdiag(f′(ht))⋅δh,t⋅ht−1T
同理可得 ∂U∂L 和 ∂b∂L:
∂U∂L=t=1∑Tdiag(f′(ht))⋅δh,t⋅xtT
∂b∂L=t=1∑Tdiag(f′(ht))⋅δh,t
第四步:梯度消失与爆炸的根源
回顾 δh,t 的递归公式:
δh,t=⋯+δh,t+1⋅(WTdiag(f′(ht+1)))
我们可以看到,梯度在时间上传播时,会反复乘以循环权重矩阵 W。
- 梯度爆炸 (Gradient Exploding):如果 W 的某些特征值(或范数)大于1,经过多次连乘后,梯度会呈指数级增长,导致数值溢出,训练发散。
- 梯度消失 (Gradient Vanishing):如果 W 的某些特征值(或范数)小于1,经过多次连乘后,梯度会呈指数级衰减,趋近于0。这使得模型难以学习到长距离的依赖关系。
总结
BPTT算法的完整流程如下:
- 前向传播:对于 t=1,…,T,计算 ht,ot,Lt,得到总损失 L。同时保存所有的 xt,ht。
- 反向传播:
a. 计算最后一个时间步的隐藏层梯度 δh,T。
b. 对于 t=T−1,…,1,使用递归公式反向计算 δh,t。
- 计算最终梯度:根据所有时间步的中间值,使用求和公式计算出所有参数的梯度。
- 参数更新:使用计算出的梯度,通过梯度下降等优化算法更新模型参数。