RNN 梯度计算详细推导 (BPTT)

274 阅读2分钟

RNN 梯度计算详细推导 (BPTT)

为了详细推导循环神经网络(RNN)中的梯度计算方法——沿时间反向传播(Backpropagation Through Time, BPTT),我们将使用一个最基础的RNN模型结构。

第一步:定义RNN模型和符号

在一个时间步 tt,RNN的计算过程如下:

  1. 隐藏状态 (Hidden State)

    ht=f(Uxt+Wht1+b)h_t = f(U x_t + W h_{t-1} + b)
    • xtx_t:在时间步 t 的输入向量。
    • hth_t:在时间步 t 的隐藏状态向量。h0h_0 通常初始化为零向量。
    • ht1h_{t-1}:前一个时间步的隐藏状态。
    • U,W,bU, W, b:循环层的参数。UU 是输入到隐藏层的权重矩阵,WW 是隐藏层到隐藏层的权重矩阵(循环权重),bb 是偏置向量。
    • ff:激活函数,通常是 tanhReLU。这里我们假设为 tanh
  2. 输出 (Output)

    ot=Vht+co_t = V h_t + c
    • oto_t:在时间步 t 的输出(或称为 logits)。
    • V,cV, c:输出层的参数。VV 是隐藏层到输出层的权重矩阵,cc 是偏置向量。
  3. 预测概率 (Predicted Probability)

    y^t=g(ot)\hat{y}_t = g(o_t)
    • gg:输出激活函数,对于分类任务通常是 Softmax
  4. 损失函数 (Loss Function)

    Lt=Loss(y^t,yt)L_t = \text{Loss}(\hat{y}_t, y_t)
    • LtL_t:在时间步 t 的损失,例如交叉熵损失。
    • yty_t:在时间步 t 的真实标签。

总体目标

我们的目标是计算总损失 L=t=1TLtL = \sum_{t=1}^{T} L_t 对所有模型参数 θ={U,W,V,b,c}\theta = \{U, W, V, b, c\} 的梯度。即求解:LV,Lc,LW,LU,Lb\frac{\partial L}{\partial V}, \frac{\partial L}{\partial c}, \frac{\partial L}{\partial W}, \frac{\partial L}{\partial U}, \frac{\partial L}{\partial b}


第二步:前向传播

模型按照 t=1,2,,Tt=1, 2, \dots, T 的顺序,依次计算出每个时间步的 ht,ot,y^t,Lth_t, o_t, \hat{y}_t, L_t,并最终得到总损失 LL。这个过程比较直接,就是将输入序列喂给模型,得到输出和损失。


第三步:反向传播 (BPTT)

梯度是反向计算的,从最后一个时间步 TT 开始,一直传播到第一个时间步 11

A. 输出层参数的梯度 (LV,Lc\frac{\partial L}{\partial V}, \frac{\partial L}{\partial c})

这部分比较简单,因为 VVcc 的计算不涉及时间上的循环依赖。总损失对它们的梯度是每个时间步梯度贡献的总和。

LV=t=1TLtV\frac{\partial L}{\partial V} = \sum_{t=1}^{T} \frac{\partial L_t}{\partial V}

我们来看单个时间步 tt 的梯度 LtV\frac{\partial L_t}{\partial V}。根据链式法则:

LtV=LtototV\frac{\partial L_t}{\partial V} = \frac{\partial L_t}{\partial o_t} \frac{\partial o_t}{\partial V}

其中:

  • Ltot\frac{\partial L_t}{\partial o_t} 是损失对输出 logits 的梯度。我们将其记为 δo,t\delta_{o,t}。例如,对于Softmax和交叉熵损失,δo,t=y^tyt\delta_{o,t} = \hat{y}_t - y_t
  • otV\frac{\partial o_t}{\partial V}:因为 ot=Vht+co_t = V h_t + c,所以 otV=htT\frac{\partial o_t}{\partial V} = h_t^T(转置是为了维度匹配)。

所以,

LtV=δo,thtT\frac{\partial L_t}{\partial V} = \delta_{o,t} \cdot h_t^T

最终,

LV=t=1Tδo,thtT\boxed{\frac{\partial L}{\partial V} = \sum_{t=1}^{T} \delta_{o,t} \cdot h_t^T}

同理,对于偏置 cc

Lc=t=1Tδo,t\boxed{\frac{\partial L}{\partial c} = \sum_{t=1}^{T} \delta_{o,t}}

B. 循环层参数的梯度 (LW,LU,Lb\frac{\partial L}{\partial W}, \frac{\partial L}{\partial U}, \frac{\partial L}{\partial b})

这是BPTT的核心和难点。我们以 LW\frac{\partial L}{\partial W} 为例进行推导。 为了解决 WW 在时间上的复杂依赖,我们引入一个关键的中间量:总损失 LL 对隐藏状态 hth_t 的梯度,记为 δh,t=Lht\delta_{h,t} = \frac{\partial L}{\partial h_t}

根据链式法则,LL 通过两条路径影响 hth_t

  1. 通过当前时间步的输出 oto_t
  2. 通过下一个时间步的隐藏状态 ht+1h_{t+1}(因为 ht+1h_{t+1} 的计算用到了 hth_t)。

因此,δh,t\delta_{h,t} 的计算是一个从后向前的递归过程:

δh,t=Lht=Lototht+Lht+1ht+1ht\delta_{h,t} = \frac{\partial L}{\partial h_t} = \frac{\partial L}{\partial o_t}\frac{\partial o_t}{\partial h_t} + \frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_t}

将各部分代入,我们得到 δh,t\delta_{h,t} 的递归公式:

δh,t=δo,tVT+δh,t+1WTdiag(f(ht+1))\delta_{h,t} = \delta_{o,t} V^T + \delta_{h,t+1} W^T \text{diag}(f'(h_{t+1}))
  • 递归的起点(Base Case):在最后一个时间步 TT,没有未来的隐藏状态,所以递归项为0。
    δh,T=LhT=LToToThT=δo,TVT\delta_{h,T} = \frac{\partial L}{\partial h_T} = \frac{\partial L_T}{\partial o_T}\frac{\partial o_T}{\partial h_T} = \delta_{o,T} V^T

我们可以从 δh,T\delta_{h,T} 开始,反向计算出 δh,T1,,δh,1\delta_{h,T-1}, \dots, \delta_{h,1}

现在,我们用 δh,t\delta_{h,t} 来计算最终的梯度 LW\frac{\partial L}{\partial W}

LW=t=1TLhthtW=t=1Tδh,thtW\frac{\partial L}{\partial W} = \sum_{t=1}^{T} \frac{\partial L}{\partial h_t} \frac{\partial h_t}{\partial W} = \sum_{t=1}^{T} \delta_{h,t} \frac{\partial h_t}{\partial W}

根据 ht=f(Uxt+Wht1+b)h_t = f(U x_t + W h_{t-1} + b),我们有 htW=diag(f(ht))ht1T\frac{\partial h_t}{\partial W} = \text{diag}(f'(h_t)) \cdot h_{t-1}^T。 所以,

LW=t=1Tdiag(f(ht))δh,tht1T\boxed{\frac{\partial L}{\partial W} = \sum_{t=1}^{T} \text{diag}(f'(h_t)) \cdot \delta_{h,t} \cdot h_{t-1}^T}

同理可得 LU\frac{\partial L}{\partial U}Lb\frac{\partial L}{\partial b}

LU=t=1Tdiag(f(ht))δh,txtT\boxed{\frac{\partial L}{\partial U} = \sum_{t=1}^{T} \text{diag}(f'(h_t)) \cdot \delta_{h,t} \cdot x_t^T}
Lb=t=1Tdiag(f(ht))δh,t\boxed{\frac{\partial L}{\partial b} = \sum_{t=1}^{T} \text{diag}(f'(h_t)) \cdot \delta_{h,t}}

第四步:梯度消失与爆炸的根源

回顾 δh,t\delta_{h,t} 的递归公式:

δh,t=+δh,t+1(WTdiag(f(ht+1)))\delta_{h,t} = \dots + \delta_{h,t+1} \cdot (W^T \text{diag}(f'(h_{t+1})))

我们可以看到,梯度在时间上传播时,会反复乘以循环权重矩阵 WW

  • 梯度爆炸 (Gradient Exploding):如果 WW 的某些特征值(或范数)大于1,经过多次连乘后,梯度会呈指数级增长,导致数值溢出,训练发散。
  • 梯度消失 (Gradient Vanishing):如果 WW 的某些特征值(或范数)小于1,经过多次连乘后,梯度会呈指数级衰减,趋近于0。这使得模型难以学习到长距离的依赖关系。

总结

BPTT算法的完整流程如下:

  1. 前向传播:对于 t=1,,Tt=1, \dots, T,计算 ht,ot,Lth_t, o_t, L_t,得到总损失 LL。同时保存所有的 xt,htx_t, h_t
  2. 反向传播: a. 计算最后一个时间步的隐藏层梯度 δh,T\delta_{h,T}。 b. 对于 t=T1,,1t = T-1, \dots, 1,使用递归公式反向计算 δh,t\delta_{h,t}
  3. 计算最终梯度:根据所有时间步的中间值,使用求和公式计算出所有参数的梯度。
  4. 参数更新:使用计算出的梯度,通过梯度下降等优化算法更新模型参数。