大模型训练时,权重会因批次噪声、梯度波动、动态批不断抖动,直接导致:
- 验证精度忽高忽低;
- 模型容易过拟合;
- 最终推理效果不稳定。
权重 EMA 是所有工业级大模型必开的单点稳定优化,没有之一。
它不修改训练流程、不增加计算量、不影响前向,只在后台对权重做指数滑动平均,用最平滑、最泛化的权重做验证与推理,精度稳、泛化强、收敛曲线干净。
核心原理(一句话)
维护一份影子权重(shadow weights),每一步按衰减系数 shadow=decay×shadow+(1−decay)×current 做指数平滑,推理 / 验证时加载这份平滑权重,彻底抹平训练抖动。
MindSpore 原生极简代码
import mindspore as ms
import mindspore.nn as nn
# 工业标准衰减率(LLaMA/Qwen 通用)
EMA_DECAY = 0.999
class WeightEMA(nn.Cell):
"""大模型权重EMA,纯原生API,静态图完美支持"""
def __init__(self, network, decay=EMA_DECAY):
super().__init__()
self.network = network
self.decay = decay
# 初始化影子权重:和模型参数完全同形状
self.shadow_weights = ms.Parameter(
[w.clone() for w in network.trainable_params()],
requires_grad=False
)
def update_ema(self):
"""核心:一步完成EMA权重更新"""
for i, weight in enumerate(self.network.trainable_params()):
self.shadow_weights[i] = self.decay * self.shadow_weights[i] + (1 - self.decay) * weight
def load_ema_weights(self):
"""推理/验证时加载EMA平滑权重"""
for i, weight in enumerate(self.network.trainable_params()):
weight.set_data(self.shadow_weights[i])
# 集成到训练(只加一行,零侵入)
class EMATrainOneStepCell(nn.TrainOneStepCell):
def __init__(self, network, optimizer):
super().__init__(network, optimizer)
self.ema = WeightEMA(network)
def construct(self, *inputs):
loss = super().construct(*inputs)
# 每步自动更新EMA权重
self.ema.update_ema()
return loss