0.论文信息
文章概述
这篇论文主要研究了域泛化问题 (Domain Generalization),即如何在训练集和测试集的分布不同的情况下,训练出具有良好泛化性能的模型。为了解决这个问题,作者提出了一种基于惩罚项 (Penalty) 的域泛化方法,并通过理论分析证明了该方法的有效性和收敛性。
具体来说,作者首先对域泛化问题进行数学建模,并将其转化为一个优化问题。然后,作者提出了一种基于惩罚项的域泛化方法,并将其表述为一个联合优化问题。最后,作者通过理论分析证明了该方法可以保证收敛到局部最优解,并且在实验中取得了较好的效果。
总之,这篇论文提出了一种新颖有效的域泛化方法,并通过理论分析和实验验证证明了其有效性。
1.相关工作
相关工作这里我就不使用本文中的相关介绍了,而是使用王晋东老师 Generalizing to Unseen Domains: A Survey on Domain Generalization 中的相关介绍。
王老师将领域泛化中的问题总共分为三大类 : 数据操作、表征学习、学习策略。
数据操作 : 指的是通过对数据的增强和变化使训练数据得到增强。这一类包括数据增强和数据生成两大部分。
表征学习 : 指的是学习领域不变特征 (Domain-invariant representation learning)以使得模型对不同领域都能进行很好地适配。领域不变特征学习方面主要包括四大部分:核方法、显式特征对齐、领域对抗训练、以及不变风险最小化 (Invariant Risk Minimiation, IRM)。特征解耦与领域不变特征学习的目标一致、但学习方法不一致,我们将其单独作为一大类进行介绍。
学习策略 : 指的是将机器学习中成熟的学习模式引入多领域训练中使得模型泛化性更强。这一部分主要包括基于集成学习和元学习的方法。同时,我们还会介绍其他方法,例如自监督方法在领域泛化中的应用。
而本文所想进行辅助解决的是表征学习 对应的领域泛化方法 : 表征学习往往在 ERM loss
之外会增加更多的限定条件以使得学习更具有泛化性,而经典的操作就是进行 Lagrange 乘子法进行操作,即
{ min L E R M ( θ ; { x , y , e } ) min Penalty ( θ ; { x , y , e } ) → min L E R M ( θ ; { x , y , e } ) + Penalty ( θ ; { x , y , e } ) (1) \begin{cases}
\min \mathcal{L}_{ERM}(\theta;\{x,y,e\})\\
\min \text{Penalty}(\theta;\{x,y,e\})
\end{cases}\rightarrow \min \mathcal{L}_{ERM}(\theta;\{x,y,e\})+\text{Penalty}(\theta;\{x,y,e\})\tag{1} { min L ERM ( θ ; { x , y , e }) min Penalty ( θ ; { x , y , e }) → min L ERM ( θ ; { x , y , e }) + Penalty ( θ ; { x , y , e }) ( 1 )
其中 ERM loss
L E R M ( θ ; { x , y , e } ) ≜ 1 M ∑ j = 1 M 1 N j ∑ i = 1 N j l ( h ( x j i ; θ ) , y j i ) \mathcal{L}_{ERM}(\theta;\{x,y,e\})\triangleq \frac{1}{M} \sum_{j=1}^M \frac{1}{N_j} \sum_{i=1}^{N_j} l\left(h\left(x_j^i ; \theta\right), y_j^i\right) L ERM ( θ ; { x , y , e }) ≜ M 1 ∑ j = 1 M N j 1 ∑ i = 1 N j l ( h ( x j i ; θ ) , y j i ) ,其中 M M M 为对应的域的数量。而本文所进行的研究就是对传统 Lagrange 乘子法求解的一种替代。
2.本文的方法
2.1 基本的定义与表示
X 输入空间 Y 标签空间 M 域的数量 π ^ = { π 1 , … , π M } 不同的域的集合 每个域中包含 N j 个样本且 i.i.d. l ( ⋅ , ⋅ ) 损失函数 e 域标签 L E R M ( θ ; { x , y , e } ) ERM 损失函数 Penalty ( θ ; { x , y , e } ) 惩罚项 \begin{aligned}
\mathcal{X}&\quad\text{输入空间}\\
\mathcal{Y}&\quad\text{标签空间}\\
M&\quad\text{域的数量}\\
\hat{\pi}=\{\pi_1,\ldots,\pi_M\}&\quad\text{不同的域的集合}\\
&\quad\text{每个域中包含} N_j \text{个样本且 i.i.d.}\\
l(\cdot,\cdot)&\quad\text{损失函数}\\
e&\quad\text{域标签}\\
\mathcal{L}_{ERM}(\theta;\{x,y,e\})&\quad\text{ERM 损失函数}\\
\text{Penalty}(\theta;\{x,y,e\})&\quad\text{惩罚项}\\
\end{aligned} X Y M π ^ = { π 1 , … , π M } l ( ⋅ , ⋅ ) e L ERM ( θ ; { x , y , e }) Penalty ( θ ; { x , y , e }) 输入空间 标签空间 域的数量 不同的域的集合 每个域中包含 N j 个样本且 i.i.d. 损失函数 域标签 ERM 损失函数 惩罚项
2.2 域泛化优化的目标
考虑一个不可见的分布 π ⋆ \pi_\star π ⋆ ,其相应的总体风险可以使用以下直接分解来限制
E x , y ∼ π ⋆ [ l ( h ( x ; θ ) , y ) ] ⏟ Risk on Unseen Domains ⩽ E x , y ∼ π ^ [ l ( h ( x ; θ ) , y ) ] ⏟ Risk on Seen Domains + ∣ E x , y ∼ π ⋆ [ l ( h ( x ; θ ) , y ) ] − E x , y ∼ π ^ [ l ( h ( x ; θ ) , y ) ] ∣ ⏟ Generalization Error (2) \underbrace{\mathbb{E}_{x, y \sim \pi_{\star}}[l(h(x ; \theta), y)]}_{\text {Risk on Unseen Domains }} \leqslant \underbrace{\mathbb{E}_{x, y \sim \hat{\pi}}[l(h(x ; \theta), y)]}_{\text {Risk on Seen Domains }}+\underbrace{\left|\mathbb{E}_{x, y \sim \pi_{\star}}[l(h(x ; \theta), y)]-\mathbb{E}_{x, y \sim \hat{\pi}}[l(h(x ; \theta), y)]\right|}_{\text {Generalization Error }} \tag{2} Risk on Unseen Domains E x , y ∼ π ⋆ [ l ( h ( x ; θ ) , y )] ⩽ Risk on Seen Domains E x , y ∼ π ^ [ l ( h ( x ; θ ) , y )] + Generalization Error ∣ E x , y ∼ π ⋆ [ l ( h ( x ; θ ) , y )] − E x , y ∼ π ^ [ l ( h ( x ; θ ) , y )] ∣ ( 2 )
为了部分解决这个问题,作者提出了无超额风险约束。作者不是在经验风险最优性的约束下使经验风险和惩罚共同最小化,而是使惩罚最小化。更正式地说,作者考虑以下优化问题 :
min θ Penalty ( θ ; { x i , e i , y i } ) s.t. θ ∈ arg min θ L ( θ ) (3) \begin{aligned}
\min_\theta & \text { Penalty }\left(\theta ;\left\{x^i, e^i, y^i\right\}\right)\\
\text{s.t.} & \quad \theta \in \underset{\theta}{\arg \min } \mathcal{L}(\theta)
\end{aligned}
\tag{3} θ min s.t. Penalty ( θ ; { x i , e i , y i } ) θ ∈ θ arg min L ( θ ) ( 3 )
在后面的部分作者将讨论如何对于 Eq.(3) 进行转化。
2.3 将 ERM 的平稳性作为约束
θ ∈ arg min θ L ( θ ) \theta \in \underset{\theta}{\arg \min } \mathcal{L}(\theta) θ ∈ θ arg min L ( θ ) 的总体目标并不容易直接实现。由于考虑的损失函数族是非凸的,所期望的最好结果是收敛到一个平稳点。因此将 Eq.(3) 转换为
min θ Penalty ( θ ; { x i , e i , y i } ) s.t. ∥ ∇ θ L ( θ ) ∥ ⩽ ϵ (4) \begin{aligned}
\min_\theta & \text { Penalty }\left(\theta ;\left\{x^i, e^i, y^i\right\}\right)\\
\text{s.t.} &\quad \left\|\nabla_\theta \mathcal{L}(\theta)\right\| \leqslant \epsilon
\end{aligned}\tag{4} θ min s.t. Penalty ( θ ; { x i , e i , y i } ) ∥ ∇ θ L ( θ ) ∥ ⩽ ϵ ( 4 )
在解决 Eq.(4) 中的问题时,通过限制使用的迭代方法以与常用于训练深度神经网络的随机梯度下降 (SGD) 兼容。因此,作者希望有一个迭代过程 θ t + 1 = θ t − η G t \theta^{t+1}=\theta^t-\eta G^t θ t + 1 = θ t − η G t 来解决 Eq.(4) 的收敛问题。注意到当 G t G^t G t 是随机梯度时 (例如 E [ G t ] = ∇ θ L ^ ( θ ) \mathbb{E}\left[G^t\right]=\nabla_\theta \hat{\mathcal{L}}(\theta) E [ G t ] = ∇ θ L ^ ( θ ) ),在温和的条件下,经验风险收敛到一个平稳点。然而,也需要使惩罚项最小化。因此,不仅需要允许不同于梯度的更新,同时仍需要保持结果的收敛性。
命题 2.1 使用 T T T 步随机更新 θ t + 1 = θ t − η G t \theta^{t+1}=\theta^t-\eta G^t θ t + 1 = θ t − η G t 。做出如下假设
R ( θ ) \mathcal{R}(\theta) R ( θ ) 是一个以 Δ \Delta Δ 为界的非凸 L-Lipschitz,μ \mu μ -光滑函数
更新期望有界,E [ ∥ G t ∥ 2 2 ] ⩽ V \mathbb{E}\left[\left\|G^t\right\|_2^2\right] \leqslant V E [ ∥ G t ∥ 2 2 ] ⩽ V
更新的偏差逐渐减小,∥ E [ G t ] − ∇ R ( θ t ) ∥ 2 ⩽ D / t \left\|\mathbb{E}\left[G^t\right]-\nabla \mathcal{R}\left(\theta^t\right)\right\|_2 \leqslant D / \sqrt{t} ∥ E [ G t ] − ∇ R ( θ t ) ∥ 2 ⩽ D / t ,然后有
1 T ∑ t = 1 T E [ ∥ ∇ R ( θ t ) ∥ 2 2 ] ⩽ 2 ( Δ μ V + L D ) 1 T \frac{1}{T} \sum_{t=1}^T \mathbb{E}\left[\left\|\nabla \mathcal{R}\left(\theta^t\right)\right\|_2^2\right] \leqslant 2(\sqrt{\Delta \mu V}+L D) \sqrt{\frac{1}{T}} T 1 t = 1 ∑ T E [ ∥ ∥ ∇ R ( θ t ) ∥ ∥ 2 2 ] ⩽ 2 ( Δ μ V + L D ) T 1
在这个命题中,假设风险函数是非凸的、L-Lipschitz 的和 μ \mu μ -光滑的,并且随机更新是有界的。这些都是标准假设,对于常见的深度学习架构选择是有效的。最后,第三个假设是偏差以 1 t \frac{1}{\sqrt{t}} t 1 的速率下降,作者在求解器中明确地处理了它,将在下面进行讨论。
命题 2.1 的证明
考虑将随机更新 θ t + 1 = θ t − η G t \theta^{t+1}=\theta^t-\eta G^t θ t + 1 = θ t − η G t 应用于函数 R ( θ ) \mathcal{R}(\theta) R ( θ ) 。利用函数 R \mathcal{R} R 的 μ \mu μ -平滑度,
R ( θ t + 1 ) = R ( θ t − η G t ) ⩽ R ( θ t ) − η ∇ R ( θ t ) ⊤ G t + μ η 2 2 ∥ G t ∥ 2 2 (7) \begin{aligned} \mathcal{R}\left(\theta^{t+1}\right)
=&\mathcal{R}\left(\theta^t-\eta G^t\right) \leqslant &\\ &\mathcal{R}\left(\theta^t\right)-\eta \nabla \mathcal{R}\left(\theta^t\right)^{\top} G^t+\frac{\mu \eta^2}{2}\left\|G^t\right\|_2^2
\end{aligned}\tag{7} R ( θ t + 1 ) = R ( θ t − η G t ) ⩽ R ( θ t ) − η ∇ R ( θ t ) ⊤ G t + 2 μ η 2 ∥ ∥ G t ∥ ∥ 2 2 ( 7 )
取不等式的期望,并利用 E [ ∥ G t ∥ 2 2 ] ⩽ V \mathbb{E}[\lVert G^t\rVert_2^2]\leqslant V E [∥ G t ∥ 2 2 ] ⩽ V ,
E [ R ( θ t + 1 ) ] ⩽ E [ R ( θ t ) ] − η E [ ∥ ∇ R ( θ t ) ∥ 2 2 ] − η E [ ∇ R ( θ t ) ⊤ ( G t − ∇ R ( θ t ) ) ] + μ η 2 V 2 (8) \begin{gathered} \mathbb{E}\left[\mathcal{R}\left(\theta^{t+1}\right)\right] \leqslant \mathbb{E}\left[\mathcal{R}\left(\theta^t\right)\right]-\eta \mathbb{E}\left[\left\|\nabla \mathcal{R}\left(\theta^t\right)\right\|_2^2\right]-\\ \eta \mathbb{E}\left[\nabla \mathcal{R}\left(\theta^t\right)^{\top}\left(G^t-\nabla \mathcal{R}\left(\theta^t\right)\right)\right]+\frac{\mu \eta^2 V}{2}
\end{gathered}\tag{8} E [ R ( θ t + 1 ) ] ⩽ E [ R ( θ t ) ] − η E [ ∥ ∥ ∇ R ( θ t ) ∥ ∥ 2 2 ] − η E [ ∇ R ( θ t ) ⊤ ( G t − ∇ R ( θ t ) ) ] + 2 μ η 2 V ( 8 )
使用 R \mathcal{R} R 的 Lipschitz-平滑度并对这些项重新排序,
η E [ ∥ ∇ R ( θ t ) ∥ 2 2 ] ⩽ E [ R ( θ t ) ] − E [ R ( θ t + 1 ) ] + μ η 2 V 2 + η L ∥ E [ G t ] − ∇ R ( θ t ) ∥ 2 (9) \begin{gathered}
\eta \mathbb{E}\left[\left\|\nabla \mathcal{R}\left(\theta^t\right)\right\|_2^2\right] \leqslant \mathbb{E}\left[\mathcal{R}\left(\theta^t\right)\right]-\mathbb{E}\left[\mathcal{R}\left(\theta^{t+1}\right)\right]+\\ \frac{\mu \eta^2 V}{2}+\eta L\left\|\mathbb{E}\left[G^t\right]-\nabla \mathcal{R}\left(\theta^t\right)\right\|_2
\end{gathered}\tag{9} η E [ ∥ ∥ ∇ R ( θ t ) ∥ ∥ 2 2 ] ⩽ E [ R ( θ t ) ] − E [ R ( θ t + 1 ) ] + 2 μ η 2 V + η L ∥ ∥ E [ G t ] − ∇ R ( θ t ) ∥ ∥ 2 ( 9 )
从 t = 1 t = 1 t = 1 到 T T T 求和然后除以 η \eta η ,我们得到
∑ t = 1 T E [ ∥ ∇ R ( θ t ) ∥ 2 2 ] ⩽ 2 Δ η + μ η T V 2 + L ∑ t = 1 T ∥ E [ G t ] − ∇ R ( θ t ) ∥ 2 (10) \sum_{t=1}^T \mathbb{E}\left[\left\|\nabla \mathcal{R}\left(\theta^t\right)\right\|_2^2\right] \leqslant \frac{2 \Delta}{\eta}+\frac{\mu \eta T V}{2}+L \sum_{t=1}^T\left\|\mathbb{E}\left[G^t\right]-\nabla \mathcal{R}\left(\theta^t\right)\right\|_2\tag{10} t = 1 ∑ T E [ ∥ ∥ ∇ R ( θ t ) ∥ ∥ 2 2 ] ⩽ η 2Δ + 2 μ η T V + L t = 1 ∑ T ∥ ∥ E [ G t ] − ∇ R ( θ t ) ∥ ∥ 2 ( 10 )
利用扭曲有界为 D t = ∥ E [ G t ] − ∇ R ( θ t ) ∥ 2 ⩽ D t D^t=\lVert\mathbb{E}[G^t]-\nabla\mathcal{R}(\theta^t)\rVert_2\leqslant\frac{D}{\sqrt{t}} D t = ∥ E [ G t ] − ∇ R ( θ t ) ∥ 2 ⩽ t D ,我们可以用 2 D T 2D\sqrt{T} 2 D T 将总和 ∑ t = 1 T D t \sum_{t=1}^TD^t ∑ t = 1 T D t 绑定。通过求解最优 η \eta η 为
η = 2 Δ μ T V (11) \eta=2\sqrt{\frac{\Delta}{\mu TV}}\tag{11} η = 2 μ T V Δ ( 11 )
最终约束转化为
∑ t = 1 T E [ ∥ ∇ R ( θ t ) ∥ 2 2 ] ⩽ 2 Δ η + μ η T V 2 + L ∑ t = 1 T ∥ E [ G t ] − ∇ R ( θ t ) ∥ 2 (12) \sum_{t=1}^T \mathbb{E}\left[\left\|\nabla \mathcal{R}\left(\theta^t\right)\right\|_2^2\right] \leqslant \frac{2 \Delta}{\eta}+\frac{\mu \eta T V}{2}+L \sum_{t=1}^T\left\|\mathbb{E}\left[G^t\right]-\nabla \mathcal{R}\left(\theta^t\right)\right\|_2\tag{12} t = 1 ∑ T E [ ∥ ∥ ∇ R ( θ t ) ∥ ∥ 2 2 ] ⩽ η 2Δ + 2 μ η T V + L t = 1 ∑ T ∥ ∥ E [ G t ] − ∇ R ( θ t ) ∥ ∥ 2 ( 12 )
我们把得到的步长代入边缘分布并除以 T T T ,得到 :
1 T ∑ t = 1 T E [ ∥ ∇ R ( θ t ) ∥ 2 2 ] ≤ 2 ( Δ μ V + L D ) 1 T (13) \frac{1}{T} \sum_{t=1}^T \mathbb{E}\left[\left\|\nabla \mathcal{R}\left(\theta^t\right)\right\|_2^2\right] \leq 2(\sqrt{\Delta \mu V}+L D) \sqrt{\frac{1}{T}}\tag{13} T 1 t = 1 ∑ T E [ ∥ ∥ ∇ R ( θ t ) ∥ ∥ 2 2 ] ≤ 2 ( Δ μ V + L D ) T 1 ( 13 )
即为命题 2.1 所对应的结论。
通过命题2.1,我们实现了满意迭代更新的定义。只要我们能将更新的差值与经验风险的梯度结合起来,就能保证收敛到一个平稳点。因此,我们的域泛化迭代方法从初始化 θ 0 \theta_0 θ 0 开始,应用迭代 θ t + 1 = θ t − η G t \theta^{t+1}=\theta^t-\eta G^t θ t + 1 = θ t − η G t ,其中 G t G^t G t 为以下优化问题的解 :
min p ( G t ) E G t [ Penalty ( θ t + G t ; { x i , e i , y i } ) ] s.t. E G t [ ∥ G t − ∇ θ L ( θ ) ∥ 2 ] ⩽ D t (5) \begin{aligned}
\min_{p\left(G^t\right)} &\quad\mathbb{E}_{G^t}\left[\text { Penalty }\left(\theta^t+G^t ;\left\{x^i, e^i, y^i\right\}\right)\right] \\
\text{s.t.}& \quad \mathbb{E}_{G^t}\left[\left\|G^t-\nabla_\theta \mathcal{L}(\theta)\right\|_2\right] \leqslant \frac{D}{\sqrt{t}}
\end{aligned}\tag{5} p ( G t ) min s.t. E G t [ Penalty ( θ t + G t ; { x i , e i , y i } ) ] E G t [ ∥ ∥ G t − ∇ θ L ( θ ) ∥ ∥ 2 ] ⩽ t D ( 5 )
Eq.(5) 构造的约束保证了命题 2.1 中假设 (3) 的成立。因此,应用作为 Eq.(5) 的解的更新 G t G^t G t ,可以令人满意地最小化经验风险 L ( θ ) \mathcal{L}(\theta) L ( θ ) 。同时,作者在这些令人满意的更新中寻求最小化惩罚的更新。由于提出的方法对经验风险使用满意更新,而不是一阶最优更新,遵循 Herbert Simon 对满意的描述,我们称之为满足域泛化 (satisficing Domain Generalization, SDG)。
2.4 理解满足域泛化 (Satisficing Domain Generalization, SDG)
在本节中,作者利用率失真理论的工具,对提出的公式进行了观察。它将函数 R ( D ) R(D) R ( D ) 视为给定常数因子 D D D 的 Eq.(5) 的解。问题 (5) 只是在更新的有限失真约束下最小化惩罚。这种形式与率失真函数非常相似,其中(压缩)速率通过解码中可接受的退化约束进行优化。如果惩罚是真实梯度和 G t G^t G t 之间的互信息,则它将完全等效。因此,我们可以进一步利用率失真理论的工具来更好地理解 Eq.(5)。
图 1. 惩罚-失真函数 R ( D ) R(D) R ( D ) 是 D D D 的非递增和凸函数,描述了一组可行的更新。我们对具有偏差的 SGD 的分析表明, E [ ∥ G t − ∇ R e ( θ t ) ∥ 2 ] ⩽ D t \mathbb{E}\left[\left\|G^t-\nabla \mathcal{R}_e\left(\theta^t\right)\right\|_2\right] \leqslant D^t E [ ∥ G t − ∇ R e ( θ t ) ∥ 2 ] ⩽ D t 描述了最终最小化经验风险的更新集。最小化更新的代价可能远远超出解决 ERM 的范围。我们的公式选择了在最终解决 ERM 问题的同时最小化惩罚的更新。
极端情况 当 D t = D t = 0 D^t=\frac{D}{\sqrt{t}}=0 D t = t D = 0 时,E [ G t ] = ∇ L ( θ t ) \mathbb{E}[G^t]=\nabla\mathcal{L}(\theta^t) E [ G t ] = ∇ L ( θ t ) 。因此,在结束训练时,我们执行 ERM 更新。当失真无界 (D t = ∞ D^t=\infty D t = ∞ ) 时,更新直接使惩罚最小化。此外,我们提出的公式在训练期间平滑地从最小化对 ERM 更新的惩罚的更新开始。
改进的保证 R ( D ) R(D) R ( D ) 是 D D D 的非递增凸函数,故 R ( 0 ) ⩾ R ( D t ) R(0)\geqslant R(D^t) R ( 0 ) ⩾ R ( D t ) 。换句话说,与 ERM 相比,SDG 优化减少了惩罚。
几何图像 满足条件要求意味着平面上的某个对应结构,其中 x x x 轴是失真,y y y 轴是惩罚项。我们在图 1 中可视化了这种几何图形和解释我们的方法。
2.5 解决满足域泛化 (SDG)
作者提出了一种数值算法来解决 Eq.(5)。我们首先提出了一种类似于 BlhutArimoto (BA) 的迭代方法,其复杂性难以处理。我们稍后会做出简化假设来设计易于处理的方法。
作者使用梯度 (G t G^t G t ) 和域 ID (E E E ) 之间的互信息来规范目标。这种正则化是合理的,因为希望梯度携带关于域 ID 的信息很少 (即提取域无关的深层信息)。我们解决的最终数值优化问题是
min p ( G i t ) E G t [ Penalty ( θ t + G t ; { x i , e i , y i } ) ] + γ I ( G t ; E ) s.t. E G t [ ∥ G t − ∇ θ L ( θ ) ∥ 2 ] ⩽ D t (6) \begin{aligned}
\min _{p\left(G_i^t\right)} \quad&\mathbb{E}_{G^t}\left[\text { Penalty }\left(\theta^t+G^t ;\left\{x^i, e^i, y^i\right\}\right)\right]+\gamma \mathcal{I}\left(G^t ; E\right) \\
& \text {s.t.} \quad \mathbb{E}_{G^t}\left[\left\|G^t-\nabla_\theta \mathcal{L}(\theta)\right\|_2\right] \leqslant \frac{D}{\sqrt{t}}
\end{aligned}\tag{6} p ( G i t ) min E G t [ Penalty ( θ t + G t ; { x i , e i , y i } ) ] + γ I ( G t ; E ) s.t. E G t [ ∥ ∥ G t − ∇ θ L ( θ ) ∥ ∥ 2 ] ⩽ t D ( 6 )
信息正则化不仅是可取的,而且至关重要,因为它使得能应用类似于 Blahut-Arimoto 的技术推导。
为了解决 Eq.(6),我们只考虑 G t ∈ { G 1 t , … , G K t } G^t\in\{G_1^t,\ldots,G_K^t\} G t ∈ { G 1 t , … , G K t } ;因此,我们只需要求解点质量 p G i t = p ( G t = G i t ) p_{G^t_i}=p(G^t = G^t_i) p G i t = p ( G t = G i t ) 。我们首先定义了以域 id 为条件的辅助变量 p G i t ∣ e = p ( G t = G i t ∣ E = e ) p_{G^t_i|e}= p(G^t=G^t_i|E=e) p G i t ∣ e = p ( G t = G i t ∣ E = e ) 。通过这种分解,我们可以分别求解每个域条件概率 p G i t ∣ e p_{G^t_i|e} p G i t ∣ e ,然后通过 p G i t = 1 M ∑ e p G i t ∣ e p_{G^t_i}=\frac{1}{M}\sum_e p_{G^t_i|e} p G i t = M 1 ∑ e p G i t ∣ e 将它们边缘化。我们进一步用对域上均匀边界的约束更强的形式 E G t ∣ E = e [ ∥ G t − ∇ θ L ^ e ( θ ) ∥ 2 ] ⩽ D t \mathbb{E}_{G^{t}|E=e}[\lVert G^{t}-\nabla_{\theta}\hat{\mathcal{L}}^e(\theta)\rVert_2]\leqslant \frac{D}{\sqrt{t}} E G t ∣ E = e [∥ G t − ∇ θ L ^ e ( θ ) ∥ 2 ] ⩽ t D 来替换 E G t [ ∥ G t − ∇ θ L ^ e ( θ ) ∥ 2 ] ⩽ D t \mathbb{E}_{G^{t}}[\lVert G^{t}-\nabla_{\theta}\hat{\mathcal{L}}^e(\theta)\rVert_2]\leqslant \frac{D}{\sqrt{t}} E G t [∥ G t − ∇ θ L ^ e ( θ ) ∥ 2 ] ⩽ t D ,并应用与 Blahut 和 Arimoto 类似的推导。
补充推导 推导 Blahut-Arimoto 风格方法求解 Eq.(6)
我们提出了一种 Blahut-Arimoto 风格方法来解决以下问题
min p ( G t ) E G t [ Penalty ( θ t + G t ; { x i , e i , y i } ) ] + γ I ( G t ; E ) s.t. E G t [ ∥ G t − ∇ θ L ^ ( θ ) ∥ 2 ] ⩽ D t (22) \begin{aligned}
\min_{p\left(G^t\right)} \quad&\mathbb{E}_{G^t}\left[\text { Penalty }\left(\theta^t+G^t ;\left\{x^i, e^i, y^i\right\}\right)\right]+\gamma \mathcal{I}\left(G^t ; E\right) \\
& \text {s.t.} \quad \mathbb{E}_{G^t}\left[\left\|G^t-\nabla_\theta \hat{\mathcal{L}}(\theta)\right\|_2\right] \leqslant \frac{D}{\sqrt{t}}
\end{aligned}\tag{22} p ( G t ) min E G t [ Penalty ( θ t + G t ; { x i , e i , y i } ) ] + γ I ( G t ; E ) s.t. E G t [ ∥ ∥ G t − ∇ θ L ^ ( θ ) ∥ ∥ 2 ] ⩽ t D ( 22 )
在使用互信息的情况下,Eq.(22) 将转化为如下形式,其中 D K L \mathcal{D}_{KL} D K L 表示 KL-散度。
min p ( G t ) E G t [ Penalty ( θ t + G t ; { x i , e i , y i } ) ] + γ E E [ D K L ( p ( G ∣ E ) ∣ ∣ p ( G ) ) ] s.t. E G t [ ∥ G t − ∇ θ L ^ ( θ ) ∥ 2 ] ⩽ D t (23) \begin{aligned}
\min_{p\left(G^t\right)} \quad&\mathbb{E}_{G^t}\left[\text { Penalty }\left(\theta^t+G^t ;\left\{x^i, e^i, y^i\right\}\right)\right]+\\
&\gamma \mathbb{E}_E[\mathcal{D}_{KL}\left(p(G|E)||p(G)\right)] \\
& \text {s.t.} \quad \mathbb{E}_{G^t}\left[\left\|G^t-\nabla_\theta \hat{\mathcal{L}}(\theta)\right\|_2\right] \leqslant \frac{D}{\sqrt{t}}
\end{aligned}\tag{23} p ( G t ) min E G t [ Penalty ( θ t + G t ; { x i , e i , y i } ) ] + γ E E [ D K L ( p ( G ∣ E ) ∣∣ p ( G ) ) ] s.t. E G t [ ∥ ∥ G t − ∇ θ L ^ ( θ ) ∥ ∥ 2 ] ⩽ t D ( 23 )
Blahut-Arimoto 方法中使用的关键技术是将 p ( G t ) p(G^t) p ( G t ) 分解为 p ( G t ) = ∑ e p ( G t ∣ E = e ) p ( E = e ) p(G^t)=\sum_ep(G^t|E=e)p(E=e) p ( G t ) = ∑ e p ( G t ∣ E = e ) p ( E = e ) 。令 d ( G t , e ) = ∥ G t − ∇ L e ( θ ) ∥ 2 d(G^t,e)=\lVert G^t−\nabla \mathcal{L}^e(\theta)\rVert_2 d ( G t , e ) = ∥ G t − ∇ L e ( θ ) ∥ 2 ,其中 L e \mathcal{L}^e L e 为域 e e e 的损失,∣ E ∣ |E| ∣ E ∣ 为域的数量。此外,我们假设所有域的重要性都相同 p ( E = e ) = 1 ∣ E ∣ p(E=e)=\frac{1}{|E|} p ( E = e ) = ∣ E ∣ 1 。基于问题是离散的事实,G t ∈ { G 1 , … , G K } Gt\in\{G_1,\ldots,G^K\} Gt ∈ { G 1 , … , G K } ,Eq.(22) 中的问题可以转化为
min p ( G k ∣ E = e ) ∑ e ∑ k = 1 K p ( G k ∣ E = e ) Penalty ( θ t + G t ; { x i , e i , y i } ) + γ ∑ e ∑ k = 1 K p ( G k ∣ E = e ) log p ( G k ∣ E = e ) p ( G k ) s.t. ∑ e ∑ k = 1 K p ( G k ∣ E = e ) d ( G k , e ) ⩽ D ∣ E ∣ t \begin{aligned}
\min_{p(G_k|E=e)}\quad&\sum_e\sum_{k=1}^Kp(G_k|E=e)\text{Penalty}\left(\theta^t+G^t ;\left\{x^i, e^i, y^i\right\}\right)+\\
\gamma&\sum_e\sum_{k=1}^Kp(G_k|E=e)\log\frac{p(G_k|E=e)}{p(G_k)}\\
\text{s.t.}\quad&\sum_e\sum_{k=1}^Kp(G_k|E=e)d(G_k,e)\leqslant\frac{D|E|}{\sqrt{t}}\\
\end{aligned} p ( G k ∣ E = e ) min γ s.t. e ∑ k = 1 ∑ K p ( G k ∣ E = e ) Penalty ( θ t + G t ; { x i , e i , y i } ) + e ∑ k = 1 ∑ K p ( G k ∣ E = e ) log p ( G k ) p ( G k ∣ E = e ) e ∑ k = 1 ∑ K p ( G k ∣ E = e ) d ( G k , e ) ⩽ t D ∣ E ∣
我们进一步用对域上均匀边界的约束更强的形式 E G t ∣ E = e [ ∥ G t − ∇ θ L ^ e ( θ ) ∥ 2 ] ⩽ D t \mathbb{E}_{G^{t}|E=e}[\lVert G^{t}-\nabla_{\theta}\hat{\mathcal{L}}^e(\theta)\rVert_2]\leqslant \frac{D}{\sqrt{t}} E G t ∣ E = e [∥ G t − ∇ θ L ^ e ( θ ) ∥ 2 ] ⩽ t D 来替换 E G t [ ∥ G t − ∇ θ L ^ e ( θ ) ∥ 2 ] ⩽ D t \mathbb{E}_{G^{t}}[\lVert G^{t}-\nabla_{\theta}\hat{\mathcal{L}}^e(\theta)\rVert_2]\leqslant \frac{D}{\sqrt{t}} E G t [∥ G t − ∇ θ L ^ e ( θ ) ∥ 2 ] ⩽ t D ,然后进一步用拉格朗日乘子 β \beta β 计算这个优化问题的拉格朗日量为
min p ( G k ∣ E = e ) max β ≥ 0 ∑ e ∑ k = 1 n p ( G k ∣ E = e ) [ Penalty ( θ t + G k ; { x i , e i , y i } ) + … + γ log p ( G k ∣ E = e ) p ( G k ) + β d ( G k , e ) ] − β D ∣ E ∣ t (25) \begin{aligned}
\min _{p\left(G_k \mid E=e\right)} \max _{\beta \geq 0}& \sum_e \sum_{k=1}^n p\left(G_k \mid E=e\right) {\left[\text { Penalty }\left(\theta^t+G_k ;\left\{x^i, e^i, y^i\right\}\right)+\right.} \\
& \left.\ldots+\gamma \frac{\log p\left(G_k \mid E=e\right)}{p\left(G_k\right)}+\beta d\left(G_k, e\right)\right]-\beta \frac{D|E|}{\sqrt{t}}
\end{aligned}\tag{25} p ( G k ∣ E = e ) min β ≥ 0 max e ∑ k = 1 ∑ n p ( G k ∣ E = e ) [ Penalty ( θ t + G k ; { x i , e i , y i } ) + … + γ p ( G k ) log p ( G k ∣ E = e ) + β d ( G k , e ) ] − β t D ∣ E ∣ ( 25 )
我们对 p ( G k ∣ E = e ) p\left(G_k \mid E=e\right) p ( G k ∣ E = e ) 取导数,并将其等同于 0。忽略常数因子,可以得到
p ( G k ∣ E = e ) ∼ p ( G k ) exp [ − 1 γ ( Penalty ( θ t + G k ; { x i , e i , y i } ) + β d ( G t , e ) ] (26) \begin{gathered}
p\left(G_k \mid E=e\right)\sim\\
p(G_k)\exp\left[-\frac{1}{\gamma}\left(\text{Penalty}(\theta^t+G_k;\{x^i,e^i,y^i\}\right)+\beta d(G^t,e)\right]
\end{gathered}\tag{26} p ( G k ∣ E = e ) ∼ p ( G k ) exp [ − γ 1 ( Penalty ( θ t + G k ; { x i , e i , y i } ) + β d ( G t , e ) ] ( 26 )
因为概率总和为 1,我们可以分别处理归一化因子。从初始化 p ∘ ( G k ∣ E = e ) p^{\circ}(G_k|E = e) p ∘ ( G k ∣ E = e ) 开始,除非存在先验信息,否则默认每个域是均匀分布的。然后,以下迭代求解 Eq.(22),
p ^ l + 1 ( G k ∣ E = e ) = p l ( G k ) exp [ − 1 γ ( Penalty ( θ t + G k ; { x i , e i , y i } ) + β d ( G t , e ) ) ] p l + 1 ( G k ∣ E = e ) = p ^ l + 1 ( G k ∣ E = e ) ∑ k ^ p ^ l + 1 ( G k ^ ∣ E = e ) p l + 1 ( G k ) = 1 ∣ E ∣ ∑ E = e p l + 1 ( G k ∣ E = e ) \begin{aligned}
\hat{p}^{l+1}(G_k|E=e)&=p^l(G_k)\exp\left[-\frac{1}{\gamma}\left(\text{Penalty}(\theta^t+G_k;\{x^i,e^i,y^i\})+\beta d(G^t,e)\right)\right]\\
p^{l+1}(G_k|E=e)&=\frac{\hat{p}^{l+1}(G_k|E=e)}{\sum_{\hat{k}}\hat{p}^{l+1}(G_{\hat{k}}|E=e)}\\
p^{l+1}(G_k)&=\frac{1}{|E|}\sum_{E=e}p^{l+1}(G_k|E=e)
\end{aligned} p ^ l + 1 ( G k ∣ E = e ) p l + 1 ( G k ∣ E = e ) p l + 1 ( G k ) = p l ( G k ) exp [ − γ 1 ( Penalty ( θ t + G k ; { x i , e i , y i }) + β d ( G t , e ) ) ] = ∑ k ^ p ^ l + 1 ( G k ^ ∣ E = e ) p ^ l + 1 ( G k ∣ E = e ) = ∣ E ∣ 1 E = e ∑ p l + 1 ( G k ∣ E = e )
这里 β t \beta^t β t 是一个与 D t D^t D t 成反比的参数,而其所得到的迭代形式即为 Eq.(6) 的一个迭代解。将这些更新应用于梯度空间是棘手的,因为 G t G^t G t 的空间是高维的。作者提出了两种简化来使 BA 的应用易于处理 :
独立地求解每个参数,以及
我们使用离散集 ( G t ) p ∈ { ( G 1 t ) p = ( ∇ L ) p , ( G 2 t ) p = − ( ∇ L ) p } (G^t)_p\in\{(G_1^t)_p=(\nabla\mathcal{L})_p,(G_2^t)_p=-(\nabla\mathcal{L})_p\} ( G t ) p ∈ {( G 1 t ) p = ( ∇ L ) p , ( G 2 t ) p = − ( ∇ L ) p } 来解决 BA 问题 ( G t ) p (G^t)_p ( G t ) p ,即第 p 个参数的更新。
为了了解所提出的简化,我们考虑了它们对架构其余部分的影响。我们将得到的估计梯度馈送到一阶数值优化器。以下是更多细节。
细节补充 应用 Blahut-Arimto Style Methods 进行基于惩罚项的深度域泛化
在本节中,作者探讨了一种考虑任意惩罚函数的 CORAL+SDG 和 VREX+SDG 方法。为了将其应用于上述的推导,我们需要为任意 G k G^k G k 计算 Penaty ( θ t + G k ; { x i , e i , y i } ) \text{Penaty}(\theta^t+G_k;\{x^i,e^i,y^i\}) Penaty ( θ t + G k ; { x i , e i , y i }) 。我们认为这个惩罚的一阶近似为
Penaty ( θ t + G k ; { x i , e i , y i } ) ≈ Penaty ( θ t ; { x i , e i , y i } ) + ∇ θ Penaty ( θ t ; { x i , e i , y i } ) T G k \begin{gathered}
\text{Penaty}(\theta^t+G_k;\{x^i,e^i,y^i\})\approx\\
\text{Penaty}(\theta^t;\{x^i,e^i,y^i\})+\nabla_{\theta}\text{Penaty}(\theta^t;\{x^i,e^i,y^i\})^TG_k
\end{gathered} Penaty ( θ t + G k ; { x i , e i , y i }) ≈ Penaty ( θ t ; { x i , e i , y i }) + ∇ θ Penaty ( θ t ; { x i , e i , y i } ) T G k
这种近似使得不需要为每个 G k G_k G k 执行额外的惩罚项计算,将其应用于上述的迭代,即为
p ^ l + 1 ( G k ∣ E = e ) = p l ( G k ) exp [ − 1 γ ( β d ( G t , e ) + ∇ θ Penaty ( θ t ) T G k ) ] p l + 1 ( G k ∣ E = e ) = p ^ l + 1 ( G k ∣ E = e ) ∑ k ^ p ^ l + 1 ( G k ^ ∣ E = e ) p l + 1 ( G k ) = 1 ∣ E ∣ ∑ E = e p l + 1 ( G k ∣ E = e ) \begin{aligned}
\hat{p}^{l+1}(G_k|E=e)&=p^l(G_k)\exp\left[-\frac{1}{\gamma}\left(\beta d(G^t,e)+\nabla_{\theta}\text{Penaty}(\theta^t)^TG_k\right)\right]\\
p^{l+1}(G_k|E=e)&=\frac{\hat{p}^{l+1}(G_k|E=e)}{\sum_{\hat{k}}\hat{p}^{l+1}(G_{\hat{k}}|E=e)}\\
p^{l+1}(G_k)&=\frac{1}{|E|}\sum_{E=e}p^{l+1}(G_k|E=e)
\end{aligned} p ^ l + 1 ( G k ∣ E = e ) p l + 1 ( G k ∣ E = e ) p l + 1 ( G k ) = p l ( G k ) exp [ − γ 1 ( β d ( G t , e ) + ∇ θ Penaty ( θ t ) T G k ) ] = ∑ k ^ p ^ l + 1 ( G k ^ ∣ E = e ) p ^ l + 1 ( G k ∣ E = e ) = ∣ E ∣ 1 E = e ∑ p l + 1 ( G k ∣ E = e )
从而得到算法 1
细节补充 将 Blahut-Arimoto Style 方法应用于 Fish
Fish 不使用直接惩罚函数;因此,应用 SDG 需要进行一些修改。为了将 SDG 应用于 Fish,我们简单地使用原始工作中提供的近似值。相关的工作表明,当 SGD 的内部更新应用于几个域的 θ t \theta^t θ t 时,得 近似于惩罚的梯度。因此,我们利用这一事实来执行迭代为
p ^ l + 1 ( G k ∣ E = e ) = p l ( G k ) exp [ − 1 γ ( β d ( G t , e ) + ( θ ~ t − θ ) T G k ) ] p l + 1 ( G k ∣ E = e ) = p ^ l + 1 ( G k ∣ E = e ) ∑ k ^ p ^ l + 1 ( G k ^ ∣ E = e ) p l + 1 ( G k ) = 1 ∣ E ∣ ∑ E = e p l + 1 ( G k ∣ E = e ) \begin{aligned}
\hat{p}^{l+1}(G_k|E=e)&=p^l(G_k)\exp\left[-\frac{1}{\gamma}\left(\beta d(G^t,e)+(\tilde{\theta}^t-\theta)^TG_k\right)\right]\\
p^{l+1}(G_k|E=e)&=\frac{\hat{p}^{l+1}(G_k|E=e)}{\sum_{\hat{k}}\hat{p}^{l+1}(G_{\hat{k}}|E=e)}\\
p^{l+1}(G_k)&=\frac{1}{|E|}\sum_{E=e}p^{l+1}(G_k|E=e)
\end{aligned} p ^ l + 1 ( G k ∣ E = e ) p l + 1 ( G k ∣ E = e ) p l + 1 ( G k ) = p l ( G k ) exp [ − γ 1 ( β d ( G t , e ) + ( θ ~ t − θ ) T G k ) ] = ∑ k ^ p ^ l + 1 ( G k ^ ∣ E = e ) p ^ l + 1 ( G k ∣ E = e ) = ∣ E ∣ 1 E = e ∑ p l + 1 ( G k ∣ E = e )
3.实验
暂时略。
4.总结
总的来说,作者的这种想法比起传统的 Lagrange 乘子法还是让我感到耳目一新的,但是可能受制于较为繁琐的迭代过程,并没有太多的引用和使用。
参考资料 (References)
Sener, O. and Koltun, V., 2022. Domain Generalization without Excess Empirical Risk. Advances in Neural Information Processing Systems, 35, pp.13380-13391.
Wang, J., Lan, C., Liu, C., Ouyang, Y., Qin, T., Lu, W., Chen, Y., Zeng, W. and Yu, P., 2022. Generalizing to unseen domains: A survey on domain generalization. IEEE Transactions on Knowledge and Data Engineering.