从 Adam 到 Adam W|关于 Weight Decay 的数学直觉

99 阅读2分钟

在深度学习关于 optimizer 的学习中,Adam 和 AdamW 常常被并列出现,很多时候我们会直接看到结论:AdamW 更好。但为什么更好?数学上如何理解?网络上有很多资源尝试解释这个问题。以我自己的学习经历为例,我的困惑集中在这样一句话上:在传统 Adam(L2 正则并入梯度)中,weight decay 的强度会随着坐标和历史梯度的大小而改变。在这篇文章里,我将尝试把这句话完整地拆解清楚。


我们先从最标准的 Adam 更新公式开始。对参数 ww,在第 tt 步,Adam 会维护一阶和二阶动量 mtvtm_t、v_t,并在 bias correction 之后做如下更新:

wt+1=wtαm^tv^t+ϵw_{t+1} = w_t - \alpha \frac{\hat m_t}{\sqrt{\hat v_t} + \epsilon}

其中
gt=wL(wt)g_t = \nabla_w \mathcal L(w_t) 是当前的梯度。到这里,一切都很容易理解。

问题出现在我们引入 L2 正则(也就是 weight decay;当 optimizer 为 SGD 时,二者在实现和效果上是同一件事)的时候。传统做法不是修改更新式,而是修改梯度本身。如果目标函数写成

Ltotal(w)=L(w)+λ2w2\mathcal L_{\text{total}}(w) = \mathcal L(w) + \frac{\lambda}{2}\lVert w\rVert^2

那么梯度就变成

gtotal=g+λwg^{\text{total}} = g + \lambda w

而 Adam 并不知道这是「任务梯度 + 衰减项」,它只会把整个 (g+λw)(g + \lambda w) 当作普通梯度送进动量和方差的计算中。于是完整更新变成

Δw=αm^(g+λw)v^(g+λw)\Delta w = -\alpha \frac{\hat m(g + \lambda w)}{\sqrt{\hat v(g + \lambda w)}}

到这里,一切依然是「严格正确」的数学推导。但问题是:如果我们只关心 weight decay 这一项 λw\lambda w,它在 Adam 中到底是如何作用到参数上的?


为了回答这个问题,我们需要暂时放弃「精确等式」的执念,而改用一种在优化分析中非常常见的视角:贡献分析。也就是说,不去问「整个更新是多少」,而是问「λw\lambda w 这一部分对更新产生了多大的影响」。

在不影响结论的前提下,我们可以先做一个简化:忽略动量的时间累积,把 Adam 近似看成「单步自适应缩放」。在这种近似下,可以把

m^g+λw,v^(g+λw)2\hat m \approx g + \lambda w, \qquad \hat v \approx (g + \lambda w)^2

代回更新式中。注意这一步并不是在说 Adam 等于 这个形式,而是在说:局部地看,Adam 的更新方向和尺度由当前输入的梯度决定

我们并不是要比较

g+λw(g+λw)2\frac{g + \lambda w}{\sqrt{(g + \lambda w)^2}}

而是要问:如果在同一个 v^\hat v 下,λw\lambda w 这一项单独存在,它会带来怎样的更新?

在这种一阶近似的视角下,由衰减项引起的有效更新可以写成

ΔwdecayAdamαλwv^\Delta w_{\text{decay}}^{\text{Adam}} \approx -\alpha \frac{\lambda w}{\sqrt{\hat v}}

我们可以在这里清晰地看到:weight decay 的作用强度,被 Adam 的自适应分母 v^\sqrt{\hat v} 再缩放了一次


为了让这个式子不再停留在符号层面,我们代入具体数字来看。假设某个参数当前满足 λw=1\lambda w = 1

  • 如果这个参数所在的坐标在历史上梯度一直很大,那么对应的 v^100\hat v \approx 100,此时

    Δwdecayα1100=0.1α\Delta w_{\text{decay}} \approx -\alpha \frac{1}{\sqrt{100}} = -0.1\alpha

    衰减被压缩到了原来的十分之一。

  • 反过来,如果另一个坐标梯度很小,v^0.01\hat v \approx 0.01,那么

    Δwdecayα10.01=10α\Delta w_{\text{decay}} \approx -\alpha \frac{1}{\sqrt{0.01}} = -10\alpha

    同样的 λw\lambda w,衰减强度却被放大了十倍。

这就是「衰减强度会随坐标、随历史梯度大小而改变」这句话的完整数学含义。在 Adam + L2 中,weight decay 不再是「每一步把参数按固定比例往 0 拉」,而是变成了一种被自适应机制调制过的、坐标相关的力


为什么这种被缩放的 weight decay 是「不好的」?

问题并不在于这种行为是否数学上错误,而在于它违背了我们引入 weight decay 的初衷

在 SGD 的语境下,weight decay 的角色非常清晰:它是一种与梯度大小、梯度历史无关的结构性约束,其作用是持续、稳定地将参数向 0 收缩,从而限制模型的有效容量。这种收缩是各向同性的:每一个坐标、每一个时间步,衰减强度都只由 λ\lambda 和学习率决定。

但在 Adam + L2 中,这种角色被悄然改变了。由于 λw\lambda w 被送入了 Adam 的自适应机制,weight decay 的强度开始依赖于 v^\hat v,也就依赖于任务梯度在这个坐标上的历史统计特性。结果是:

  • 梯度长期较大的参数(往往是「重要的」、频繁被任务信号更新的参数)反而被衰减得更弱
  • 梯度较小、甚至几乎不参与任务优化的参数,却可能被过度衰减

换句话说,weight decay 从一种「全局、稳定的正则化」,变成了一种与任务梯度耦合的、非均匀的隐式调节机制。这种耦合并不是我们在设计正则项时有意为之的,而是 Adam 更新结构的副产物。它让 weight decay 的实际效果变得难以直觉控制,也让 λ\lambda 这个超参数失去了在 SGD 语境下清晰、可解释的含义。

这正是人们说 Adam + L2 「不是真正的 weight decay」的核心原因。


理解了这一点,再来看 AdamW,就会异常清晰。AdamW 做的事情只有一件:把 weight decay 从梯度中拆出来。Adam 仍然只作用在任务梯度 gg 上,而衰减直接写成

w(1αλ)ww \leftarrow (1 - \alpha \lambda) w

或者等价地

ΔwdecayAdamW=αλw\Delta w_{\text{decay}}^{\text{AdamW}} = -\alpha \lambda w

它不进入 mmvv,也不受 v^\sqrt{\hat v} 的影响。每一个坐标、每一步,衰减强度完全一致。这种行为,才与 SGD 时代人们对 weight decay 的直觉完全一致。

从这个角度看,AdamW 并不是一个「更聪明的 Adam」,而是一个更诚实的 Adam:它让自适应学习率只负责调节任务梯度的步长,而让 weight decay 回到它本来的位置:一个简单、稳定、与梯度统计无关的参数收缩。

一句话总结全文:在 Adam + L2 中,weight decay 会被 Adam 的自适应分母重新加权;而在 AdamW 中,weight decay 才真正是 weight decay。