总览
指数移动平均 EMA(Exponential Moving Average)是一种网络训练技巧。通过取一段时间的权重的平均值,能提高模型的泛化能力。
一言蔽之:模型训练完后还能获得 “免费” 的提升。
后合成指数移动平均(post-hoc synthesized EMA)是在论文 Analyzing and Improving the Training Dynamics of Diffusion Models 中提出的 EMA 改进方法。主要改进在于,比起原方法是在训练中完成平均,post-hoc EMA 则会先保存一系列关键节点的 EMA 权重,然后在训练结束后找一个最佳的的超参进行权重平均,进一步提升模型能力。
的更新
使用 EMA 技巧训练模型时,会维护一份模型权重 的副本 。传统方法下,每训练一个 step,会用以下公式更新 :
代表当前训练 step。 是常量超参,通常非常接近 1。
Post-hoc EMA 没有使用这样的指数衰减策略,而是修改为了幂函数:
是控制尖锐程度的超参。通常 不采取随机初始化而是直接赋值为 0。
可见, 为 0 时,公式的含义就是直接取一段时间权重的均值。 越大, 就越受临近 的 的影响。
实际计算时使用递归形式,节省内存:
与本节的第一个公式很像,唯一区别是 会随着 增加而减小,而不是一个恒定值。
可能不是很值观不方便配置,于是论文提出了 超参,含义是 “相邻峰值的宽度相对于整个训练时长的占比”。实际训练时,可以通过 确定 :
即可以用 定义以训练时长为标尺的 EMA 长度,而不是像原始 EMA 公式中的 那样会随着 step 数量变化而剧烈影响训练效果。这是论文对原始 EMA 的另一个改进贡献。
顺带一提,论文在实际实验中发现, 超参控制下的最佳 EMA 长度其实仍会随着 step 数量增加而缓慢地变长。
从快照重建 EMA
Post-hoc EMA 的核心是滞后平均。如何合理地从权重快照重建出各种超参的 EMA 结果尤为重要。换句话说,需要有一种方法,能在训练完成后处理出任意 值(或者是等效的 值)对应的平滑权重。
用什么策略创建快照?可以选取两个 值,隔一段 step 步数保存一次快照。
如何合并快照?论文在公式推导后给出了一个代码示例,用于获得快照的权重:
def p_dot_p(t_a, gamma_a, t_b, gamma_b):
t_ratio = t_a / t_b
t_exp = torch.where(t_a < t_b , gamma_b , -gamma_a)
t_max = torch.maximum(t_a , t_b)
num = (gamma_a + 1) * (gamma_b + 1) * t_ratio ** t_exp
den = (gamma_a + gamma_b + 1) * t_max
return num / den
def solve_weights(t_i, gamma_i, t_r, gamma_r):
rv = lambda x: x.double().reshape(-1, 1)
cv = lambda x: x.double().reshape(1, -1)
A = p_dot_p(rv(t_i), rv(gamma_i), cv(t_i), cv(gamma_i))
b = p_dot_p(rv(t_i), rv(gamma_i), cv(t_r), cv(gamma_r))
return np.linalg.solve(A, b)
我自己实测下来重建效果可以说是意外的优秀。即使总共只使用 40 个快照( 取 0.05 和 0.2),大致 在 0.04 到 0.22 范围内重建出的 EMA 权重几近完美。
该怎么写代码?
有一个开箱即用的 EMA 包,通过 pip install ema-pytorch 即可使用。
import torch
from ema_pytorch import PostHocEMA
# your neural network as a pytorch module
model = ...
emas = PostHocEMA(
model,
sigma_rels = (0.05, 0.3),
update_every = 1, # 每调用 1 次 emas.update() 就更新一次 ema 权重
checkpoint_every_num_steps = 50,
checkpoint_folder = './post-hoc-ema-checkpoints' # 保存快照的路径
)
model.train()
for epoch in range(300):
for input, target in loader:
...
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
emas.update() # 在此调用 emas.update()
# 重建
synthesized_ema = emas.synthesize_ema_model(sigma_rel = 0.15)
# 直接调用 synthesized_ema 进行推理
synthesized_ema_output = synthesized_ema(data)
有些细节要注意,
PostHocEMA不会主动清空checkpoint_folder里的内容。目录里原有的 pt 文件会对程序造成干扰- 使用
emas.model获取原模型,使用emas.ema_model获取权重重建后的模型 - 建议保存模型时保存
PostHocEMA。这样能保留 ema step 等信息
实验
为了探究重建效果,使用以下代码进行实验。
取 0.05 和 0.2,总步数 1000,两种 下各取 20 个快照。重建 为 0.15 时的 ema 权重。
import torch
from ema_pytorch import PostHocEMA
import plotly.graph_objects as go
net = torch.nn.Linear(1, 1000, bias=False)
emas = PostHocEMA(
net,
sigma_rels=(
0.05,
0.2,
),
update_every=1,
checkpoint_every_num_steps=50,
checkpoint_folder=r'Z:\post-hoc-ema-checkpoints',
)
net.train()
for i in range(1000):
with torch.no_grad():
net.weight.zero_()
channel_index = i % 1000
net.weight[channel_index, 0] = 1.
emas.update()
synthesized_ema = emas.synthesize_ema_model(sigma_rel=0.15)
ema_weights = synthesized_ema.ema_model.weight.detach().numpy().flatten()
fig = go.Figure(data=go.Scatter(x=list(range(1000)), y=ema_weights, mode='lines+markers'))
fig.update_layout(
title='Synthesized EMA Model Weights',
xaxis_title='Channel Index',
yaxis_title='Weight Value',
template='plotly_white',
)
fig.show()
这个曲线相当够用了。
思考
EMA 能在训练完成后提升模型效果。post-hoc EMA 更进一步,允许在训练完成后找到最合适 ,进一步提升模型能力。可以说是相当值得一用的 trick。
不过显而易见的,post-hoc EMA 要求维护不止一个模型的参数,对显存要求陡然上升。即使单独让 EMA 权重放在 CPU 侧,设备通信过程会大大拖慢训练进程。即使可以设置 update_every 参数减少通信次数,但势必会影响重建效果。另一个缺点就是快照占用磁盘空间。40 可不是小数目,更别说论文中提及的至少 160 个快照。
论文提出 post-hoc EMA 是针对扩散模型的,而扩散模型众所周知大块头一个……普通人很难很好地用上 post-hoc EMA 训练大模型吧。不过训练小玩具时,好歹有了个似乎很有提分希望的可选技巧。
参考来源
- github.com/lucidrains/…
- arxiv.org/abs/2312.02…
- Miika Aittala,“Rethinking How to Train Diffusion Models”,developer.nvidia.com/blog/rethin…