后合成指数移动平均(post-hoc synthesized EMA)

1,193 阅读2分钟

总览

指数移动平均 EMA(Exponential Moving Average)是一种网络训练技巧。通过取一段时间的权重的平均值,能提高模型的泛化能力。

一言蔽之:模型训练完后还能获得 “免费” 的提升。

后合成指数移动平均(post-hoc synthesized EMA)是在论文 Analyzing and Improving the Training Dynamics of Diffusion Models 中提出的 EMA 改进方法。主要改进在于,比起原方法是在训练中完成平均,post-hoc EMA 则会先保存一系列关键节点的 EMA 权重,然后在训练结束后找一个最佳的的超参进行权重平均,进一步提升模型能力。

θ^β\hat{\theta}_\beta 的更新

使用 EMA 技巧训练模型时,会维护一份模型权重 θ\theta 的副本 θ^β\hat{\theta}_\beta。传统方法下,每训练一个 step,会用以下公式更新 θ^β\hat{\theta}_\beta

θ^β(t)=βθ^β(t1)+(1β)θ(t)\hat{\theta}_\beta(t)=\beta\hat{\theta}_\beta (t-1)+(1-\beta)\theta(t)

tt 代表当前训练 step。β\beta 是常量超参,通常非常接近 1。

Post-hoc EMA 没有使用这样的指数衰减策略,而是修改为了幂函数:

θ^γ(t)=0tτγθ(τ)dτ0tτγdτ=γ+1tγ+10tτγθ(τ)dτ\hat{\theta}_\gamma(t)=\frac{\int^t_0 \tau^\gamma\theta(\tau)\mathrm{d}\tau}{\int^t_0 \tau^\gamma\mathrm{d}\tau}=\frac{\gamma+1}{t^{\gamma+1}}\int^t_0\tau^\gamma\theta(\tau)\mathrm{d}\tau

γ\gamma 是控制尖锐程度的超参。通常 θt=0\theta_{t=0} 不采取随机初始化而是直接赋值为 0。

可见,γ\gamma 为 0 时,公式的含义就是直接取一段时间权重的均值。γ\gamma 越大,θ^γ(t)\hat{\theta}_\gamma(t) 就越受临近 ttθ(t)\theta(t) 的影响。

实际计算时使用递归形式,节省内存:

θ^β(t)=βγ(t)θ^γ(t1)+(1βγ(t))θ(t)where βγ(t)=(11/t)γ+1\begin{aligned} &\hat{\theta}_\beta(t)=\beta_\gamma(t)\hat{\theta}_\gamma(t-1)+(1-\beta_\gamma(t))\theta(t)\\ \mathrm{where\ }& \beta_\gamma(t)=(1-1/t)^{\gamma+1} \end{aligned}

与本节的第一个公式很像,唯一区别是 βγ(t)\beta_\gamma(t) 会随着 tt 增加而减小,而不是一个恒定值。

γ\gamma 可能不是很值观不方便配置,于是论文提出了 σrel\sigma_{\text{rel}} 超参,含义是 “相邻峰值的宽度相对于整个训练时长的占比”。实际训练时,可以通过 σrel\sigma_{\text{rel}} 确定 γ\gamma

σrel=(γ+1)1/2(γ+2)1(γ+3)1/2\sigma_{\text{rel}}=(\gamma+1)^{1/2}(\gamma+2)^{-1}(\gamma+3)^{-1/2}

即可以用 σrel\sigma_{\text{rel}} 定义以训练时长为标尺的 EMA 长度,而不是像原始 EMA 公式中的 β\beta 那样会随着 step 数量变化而剧烈影响训练效果。这是论文对原始 EMA 的另一个改进贡献。

顺带一提,论文在实际实验中发现,σrel\sigma_{\text{rel}} 超参控制下的最佳 EMA 长度其实仍会随着 step 数量增加而缓慢地变长。

从快照重建 EMA

Post-hoc EMA 的核心是滞后平均。如何合理地从权重快照重建出各种超参的 EMA 结果尤为重要。换句话说,需要有一种方法,能在训练完成后处理出任意 γ\gamma 值(或者是等效的 σrel\sigma_{\text{rel}} 值)对应的平滑权重。

用什么策略创建快照?可以选取两个 σrel\sigma_{\text{rel}} 值,隔一段 step 步数保存一次快照。

如何合并快照?论文在公式推导后给出了一个代码示例,用于获得快照的权重:

def p_dot_p(t_a, gamma_a, t_b, gamma_b):
    t_ratio = t_a / t_b
    t_exp = torch.where(t_a < t_b , gamma_b , -gamma_a)
    t_max = torch.maximum(t_a , t_b)
    num = (gamma_a + 1) * (gamma_b + 1) * t_ratio ** t_exp
    den = (gamma_a + gamma_b + 1) * t_max
    return num / den

def solve_weights(t_i, gamma_i, t_r, gamma_r):
    rv = lambda x: x.double().reshape(-1, 1)
    cv = lambda x: x.double().reshape(1, -1)
    A = p_dot_p(rv(t_i), rv(gamma_i), cv(t_i), cv(gamma_i))
    b = p_dot_p(rv(t_i), rv(gamma_i), cv(t_r), cv(gamma_r))
    return np.linalg.solve(A, b)

我自己实测下来重建效果可以说是意外的优秀。即使总共只使用 40 个快照(σrel\sigma_{\text{rel}} 取 0.05 和 0.2),大致 σrel\sigma_{\text{rel}} 在 0.04 到 0.22 范围内重建出的 EMA 权重几近完美。

该怎么写代码?

有一个开箱即用的 EMA 包,通过 pip install ema-pytorch 即可使用。

import torch
from ema_pytorch import PostHocEMA

# your neural network as a pytorch module

model = ...

emas = PostHocEMA(
    model,
    sigma_rels = (0.05, 0.3),
    update_every = 1,   # 每调用 1 次 emas.update() 就更新一次 ema 权重
    checkpoint_every_num_steps = 50,
    checkpoint_folder = './post-hoc-ema-checkpoints'  # 保存快照的路径
)

model.train()

for epoch in range(300):
    for input, target in loader:
        ...
        optimizer.zero_grad()
        loss_fn(model(input), target).backward()
        optimizer.step()
    
        emas.update()  # 在此调用 emas.update()

# 重建
synthesized_ema = emas.synthesize_ema_model(sigma_rel = 0.15)
# 直接调用 synthesized_ema 进行推理
synthesized_ema_output = synthesized_ema(data)

有些细节要注意,

  • PostHocEMA 不会主动清空 checkpoint_folder 里的内容。目录里原有的 pt 文件会对程序造成干扰
  • 使用 emas.model 获取原模型,使用 emas.ema_model 获取权重重建后的模型
  • 建议保存模型时保存 PostHocEMA。这样能保留 ema step 等信息

实验

为了探究重建效果,使用以下代码进行实验。

σrel\sigma_{\text{rel}} 取 0.05 和 0.2,总步数 1000,两种 σrel\sigma_{\text{rel}} 下各取 20 个快照。重建 σrel\sigma_{\text{rel}} 为 0.15 时的 ema 权重。

import torch
from ema_pytorch import PostHocEMA
import plotly.graph_objects as go

net = torch.nn.Linear(1, 1000, bias=False)

emas = PostHocEMA(
    net,
    sigma_rels=(
        0.05,
        0.2,
    ),
    update_every=1,
    checkpoint_every_num_steps=50,
    checkpoint_folder=r'Z:\post-hoc-ema-checkpoints',
)

net.train()

for i in range(1000):
    with torch.no_grad():
        net.weight.zero_()
        channel_index = i % 1000
        net.weight[channel_index, 0] = 1.

    emas.update()

synthesized_ema = emas.synthesize_ema_model(sigma_rel=0.15)
ema_weights = synthesized_ema.ema_model.weight.detach().numpy().flatten()

fig = go.Figure(data=go.Scatter(x=list(range(1000)), y=ema_weights, mode='lines+markers'))
fig.update_layout(
    title='Synthesized EMA Model Weights',
    xaxis_title='Channel Index',
    yaxis_title='Weight Value',
    template='plotly_white',
)
fig.show()

ema 重建效果

这个曲线相当够用了。

思考

EMA 能在训练完成后提升模型效果。post-hoc EMA 更进一步,允许在训练完成后找到最合适 σrel\sigma_{\text{rel}},进一步提升模型能力。可以说是相当值得一用的 trick。

不过显而易见的,post-hoc EMA 要求维护不止一个模型的参数,对显存要求陡然上升。即使单独让 EMA 权重放在 CPU 侧,设备通信过程会大大拖慢训练进程。即使可以设置 update_every 参数减少通信次数,但势必会影响重建效果。另一个缺点就是快照占用磁盘空间。40 可不是小数目,更别说论文中提及的至少 160 个快照。

论文提出 post-hoc EMA 是针对扩散模型的,而扩散模型众所周知大块头一个……普通人很难很好地用上 post-hoc EMA 训练大模型吧。不过训练小玩具时,好歹有了个似乎很有提分希望的可选技巧。

参考来源