Sharpness-Aware Minimization 相关两篇工作的简要概述

1,168 阅读6分钟

本文正在参加 人工智能创作者扶持计划

一些缘由和动机

  机缘巧合之下,最近读了关于 Sharpness-Aware Minimization 的两篇文章。第一篇奠基的文章是在 ICLR 2021 上发表的,虽然这篇工作里还有一些坑没有填并不能说是非常完备,但是也确实产生了很大的影响。借着刚读完的热劲,在此简单记录一下 (之后的论文笔记我会朝着更精简的道路一路狂奔的)。

0.SAM 的论文信息

1.锐度与泛化能力的感性认知

  如果损失函数的锐度很大,那么它会更加敏感地捕捉模型在训练集上的表现,使得模型需要更好地适应训练数据才能获得较低的损失值。而泛化能力则要求模型在训练集外也有很好的表现,锐度过大的损失函数很可能导致模型在训练集外表现不佳。因此模型要兼顾锐度与平滑。

图 1.(左子图) 表示通过切换到 SAM 获得的错误率降低。每个点都是不同的数据集/模型/数据扩充。(中间子图) 表示用 SGD 训练的 ResNet 收敛到的急剧最小值。(右子图) 表示用SAM训练的相同ResNet收敛到的宽最小值。

2.Sharpness-Aware Minimization (SAM)

  受损失函数图像的锐度和泛化之间的联系的启发,我们提出了一种不同的方法:不是寻找仅具有低训练损失值 LS(w)L_{\boldsymbol{S}}(w) 的参数值 ww,而是寻找整个邻域具有一致低的参数值训练损失值(等效于具有低损失和低曲率的邻域)。以下定理通过在邻域训练损失方面限制泛化能力来说明这种方法的动机

定理 1 (非正式地表述). 对于任何 ρ>0\rho>0 且在从分布 D\mathscr{D} 生成的训练集 S\mathcal{S} 上有大概率满足

LD(w)maxϵ2ρLS(w+ϵ)+h(w22/ρ2)L_{\mathscr{D}}(\boldsymbol{w}) \leq \max _{\|\epsilon\|_2 \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})+h\left(\|\boldsymbol{w}\|_2^2 / \rho^2\right)

  其中 h:R+R+h: \mathbb{R}_{+} \rightarrow \mathbb{R}_{+} 是严格递增的函数 ( 在某些 LD(w)L_{\mathscr{D}}(w) 的技术条件下 )。为了明确锐度项,我们可以将上述不等式的右侧重写为

[maxϵ2ρLS(w+ϵ)LS(w)]+LS(w)+h(w22/ρ2)\left[\max _{\|\epsilon\|_2 \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})-L_{\mathcal{S}}(\boldsymbol{w})\right]+L_{\mathcal{S}}(\boldsymbol{w})+h\left(\|\boldsymbol{w}\|_2^2 / \rho^2\right)

  方括号中的通过测量通过将 ww 移动到附近的参数值来提高训练损失的速度来捕获 wwL_\boldsymbol{S} 的锐度;然后将这个锐度项与训练损失值本身和 ww 大小的正则化器相加。鉴于特定函数 hh 受到证明细节的严重影响,我们将第二项用超参数 λ\lambda 替换为 λw22\lambda\lVert w\rVert^2_2 ,产生标准 L2-正则化项。因此,受边界项的启发,我们建议通过解决以下锐化感知最小化 (SAM) 问题来选择参数值 :

minwLSSAM(w)+λw22 where LSSAM(w)maxϵpρLS(w+ϵ)(1)\min _{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w})+\lambda\|\boldsymbol{w}\|_2^2 \quad \text { where } \quad L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) \triangleq \max _{\|\boldsymbol{\epsilon}\|_p \leq \rho} L_S(\boldsymbol{w}+\boldsymbol{\epsilon}) \tag{1}

  其中 ρ>0\rho>0 是超参数且范数维度 p[1,]p\in[1,\infty],为了最小化 LSSAM(w)L_{\mathcal{S}}^{S A M}(\boldsymbol{w}),通过对 wLSSAM(w)\nabla_wL_{\mathcal{S}}^{S A M}(\boldsymbol{w}) 进行内极值微分,我们得到了一个高效的近似,这反过来使我们能够将随机梯度下降直接应用于 SAM 目标。以此继续,我们首先通过 LS(w+ϵ)L_\mathcal{S}(w+\epsilon) 在 0 附近的一阶泰勒展开来近似内部最大化问题,得到 :

ϵ(w)argmaxϵpρLS(w+ϵ)argmaxϵpρLS(w)+ϵTwLS(w)=argmaxϵpρϵTwLS(w)\epsilon^*(\boldsymbol{w}) \triangleq \underset{\|\epsilon\|_p \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w}+\epsilon) \approx \underset{\|\epsilon\|_p \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w})+\epsilon^T \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})=\underset{\|\epsilon\|_p \leq \rho}{\arg \max } \epsilon^T \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})

  反过来,求解该近似值 ϵ^(w)\hat{\epsilon}(w) 由经典对偶范数问题的解给出 ( q1\lvert\cdot\rvert^{q-1}表示元素绝对值和幂 ) :

ϵ^(w)=ρsign(wLS(w))wLS(w)q1/(wLS(w)qq)1/p(2)\hat{\boldsymbol{\epsilon}}(\boldsymbol{w})=\rho \operatorname{sign}\left(\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right)\left|\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right|^{q-1} /\left(\left\|\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right\|_q^q\right)^{1 / p}\tag{2}

  其中 1p+1q=1\frac{1}{p}+\frac{1}{q}=1,代回式 (1) 并求导有 :

wLSSAM(w)wLS(w+ϵ^(w))=d(w+ϵ^(w))dwwLS(w)w+ϵ^(w)=wLS(w)w+ϵ^(w)+dϵ^(w)dwwLS(w)w+ϵ^(w)\begin{aligned} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) & \approx \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}+\hat{\epsilon}(\boldsymbol{w}))=\left.\frac{d(\boldsymbol{w}+\hat{\epsilon}(\boldsymbol{w}))}{d \boldsymbol{w}} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}(w)} \\ & =\left.\nabla_w L_{\mathcal{S}}(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}(\boldsymbol{w})}+\left.\frac{d \hat{\epsilon}(\boldsymbol{w})}{d \boldsymbol{w}} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}(\boldsymbol{w})} \end{aligned}

  对于 wLSSAM(w)\nabla_{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) 的近似可以通过自动微分直接计算。虽然这个计算隐含地依赖于 LSL_\mathcal{S} 的 Hessian 矩阵,因为 ϵ^(w)\hat{\epsilon}(w) 本身是 wLS(w)\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}) 的函数,Hessian 只能通过 Hessian 向量积计算,这使得可以在不具体化 Hessian 矩阵的情况下计算。尽管如此,为了进一步加速计算,我们去掉了二阶项。得到我们的最终梯度近似 :

wLSSAM(w)wLS(w)w+ϵ^(w)(3)\left.\nabla_{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) \approx \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(w)\right|_{\boldsymbol{w}+\hat{\epsilon}(\boldsymbol{w})}\tag{3}

  以下给出对应算法和示意图

算法 1. (左子图) Sharpness-Aware Minimization 算法对应流程

图 2. (右子图) SAM参数更新示意图

0.LookSAM 的论文信息

1.这篇工作的核心点 : 降低计算量

  SAM 的更新规则需要在每一步进行两次连续的(不可并行化的)梯度计算,这会使计算开销增加一倍。在本文中,我们提出了一种新颖的算法 LookSAM,它只定期计算内部梯度上升,以显着降低 SAM 的额外训练成本。

2.LookSAM 的具体技术细节

  在 SAM 中梯度为 gs=wLS(w)w+ϵ^g_{\boldsymbol{s}}=\left.\nabla_{\boldsymbol{w}} \mathcal{L}_S(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}},在下图中用蓝色箭头进行表示。几种梯度分别如下图所示。

图 3. LookSAM 更新的可视化。蓝色箭头 gsg_{\boldsymbol{s}} 是 SAM 针对较平坦区域的梯度。黄色箭头 ηwLS(w)-\eta\nabla_w\mathcal{L}_{\boldsymbol{S}}(w) 表示 SGD 梯度。ghg_{\boldsymbol{h}} (棕色箭头) 和 gvg_{\boldsymbol{v}} (红色箭头) 是 gsg_{\boldsymbol{s}} 的正交梯度分量,分别与 SGD 梯度平行和垂直。

  文章中提出了一种新颖的 LookSAM 算法来应对这一挑战。主要思想是研究如何重用信息以防止每次都计算 SAM 的梯度。如图 3 所示,与 SGD 梯度 (黄色箭头) 相比,SAM 的梯度 gs=wLS(w)w+ϵ^g_{\boldsymbol{s}}=\left.\nabla_{\boldsymbol{w}} \mathcal{L}_S(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}} 提升到一个更平坦的区域(蓝色箭头)。为了更直观地了解这个平坦区域,我们基于泰勒展开重写了 SAM 的更新:

wLS(w)w+ϵ^=wLS(w+ϵ^)w[LS(w)+ϵ^wLS(w)]=w[LS(w)+ρwLS(w)wLS(w)TwLS(w)]=w[LS(w)+ρwLS(w)](4)\begin{aligned} \nabla_{\boldsymbol{w}} & \left.\mathcal{L}_S(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}}=\nabla_{\boldsymbol{w}} \mathcal{L}_S(\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}) \\ & \approx \nabla_{\boldsymbol{w}}\left[\mathcal{L}_S(\boldsymbol{w})+\hat{\epsilon} \cdot \nabla_{\boldsymbol{w}} \mathcal{L}_S(\boldsymbol{w})\right] \\ & =\nabla_{\boldsymbol{w}}\left[\mathcal{L}_S(\boldsymbol{w})+\frac{\rho}{\left\|\nabla_{\boldsymbol{w}} \mathcal{L}_S(\boldsymbol{w})\right\|} \nabla_{\boldsymbol{w}} \mathcal{L}_S(\boldsymbol{w})^T \nabla_{\boldsymbol{w}} \mathcal{L}_S(\boldsymbol{w})\right] \\ & =\nabla_{\boldsymbol{w}}\left[\mathcal{L}_S(\boldsymbol{w})+\rho\left\|\nabla_{\boldsymbol{w}} \mathcal{L}_S(\boldsymbol{w})\right\|\right] \end{aligned}\tag{4}

  我们发现 SAM 的梯度由两部分组成:原始梯度 wLS(w)\nabla_{\boldsymbol{w}} \mathcal{L}_S(\boldsymbol{w}) 和原始梯度 wLS(w)\lVert\nabla_{\boldsymbol{w}} \mathcal{L}_S(\boldsymbol{w})\rVert 的 L2-范数的梯度。我们认为优化梯度的 L2-范数可以提示模型收敛到平坦区域,因为平坦区域通常意味着低梯度范数值。因此,SAM 的更新可分为两部分 : 第一部分 (记为 ghg_{\boldsymbol{h}}) 是降低损失值,第二部分 (记为 gvg_{\boldsymbol{v}}) 是将更新偏向平坦区域。更具体地说,ghg_{\boldsymbol{h}}是普通的 SGD 的梯度方向,即使没有 SAM,也需要在每一步计算。因此,SAM 的额外计算代价主要是由第二部分 gvg_{\boldsymbol{v}} 引起的。已知 SAM 的梯度(蓝色箭头) 和 SGD 的梯度方向 (ghg_{\boldsymbol{h}}),我们可以进行投影得到 gvg_{\boldsymbol{v}} :

gv=wLS(w)w+ϵ^sin(θ)(5)g_{\boldsymbol{v}}=\left.\nabla_{\boldsymbol{w}} \mathcal{L}_S(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}} \cdot \sin (\theta)\tag{5}

图 4. gsg_{\boldsymbol{s}}, ghg_{\boldsymbol{h}}gvg_{\boldsymbol{v}} 的每 5 步之间的梯度差异 (即 gstgst+k\left\|g_s^t-g_s^{t+k}\right\|), gvg_{\boldsymbol{v}} 的变化比 ghg_{\boldsymbol{h}}gsg_{\boldsymbol{s}} 平滑得多

  一个关键的观察结果是 gvg_{\boldsymbol{v}} 的变化比 ghg_{\boldsymbol{h}}gsg_{\boldsymbol{s}} 慢得多。在图 4 中,我们通过 SAM 的整个训练过程绘制了迭代 tt 和迭代 t+5t+5 之间这三个组件的变化,结果表明 gvg_{\boldsymbol{v}} (绿线) 的差异显示出比 ghg_{\boldsymbol{h}} (橙色线) 和 gsg_{\boldsymbol{s}} (蓝线) 更稳定的模式。直观地说,这意味着指向平坦区域的方向在几次迭代中不会显着变化。所以优化呼之欲出 : 减少 gvg_{\boldsymbol{v}} 的采样和计算次数,在小批量中只计算一次,即可大幅加快速度。

算法 2. LookSAM 算法对应流程

  Layer-Wise 的优化策略在此不做过多赘述。