MindSpore 原生权重 EMA(指数移动平均)

6 阅读1分钟

大模型训练时,权重会因批次噪声、梯度波动、动态批不断抖动,直接导致:

  • 验证精度忽高忽低;
  • 模型容易过拟合;
  • 最终推理效果不稳定。

权重 EMA 是所有工业级大模型必开的单点稳定优化,没有之一。

它不修改训练流程、不增加计算量、不影响前向,只在后台对权重做指数滑动平均,用最平滑、最泛化的权重做验证与推理,精度稳、泛化强、收敛曲线干净。

核心原理(一句话)

维护一份影子权重(shadow weights),每一步按衰减系数 shadow=decay×shadow+(1−decay)×current 做指数平滑,推理 / 验证时加载这份平滑权重,彻底抹平训练抖动。

MindSpore 原生极简代码

import mindspore as ms
import mindspore.nn as nn

# 工业标准衰减率(LLaMA/Qwen 通用)
EMA_DECAY = 0.999

class WeightEMA(nn.Cell):
    """大模型权重EMA,纯原生API,静态图完美支持"""
    def __init__(self, network, decay=EMA_DECAY):
        super().__init__()
        self.network = network
        self.decay = decay
        # 初始化影子权重:和模型参数完全同形状
        self.shadow_weights = ms.Parameter(
            [w.clone() for w in network.trainable_params()],
            requires_grad=False
        )

    def update_ema(self):
        """核心:一步完成EMA权重更新"""
        for i, weight in enumerate(self.network.trainable_params()):
            self.shadow_weights[i] = self.decay * self.shadow_weights[i] + (1 - self.decay) * weight

    def load_ema_weights(self):
        """推理/验证时加载EMA平滑权重"""
        for i, weight in enumerate(self.network.trainable_params()):
            weight.set_data(self.shadow_weights[i])

# 集成到训练(只加一行,零侵入)
class EMATrainOneStepCell(nn.TrainOneStepCell):
    def __init__(self, network, optimizer):
        super().__init__(network, optimizer)
        self.ema = WeightEMA(network)

    def construct(self, *inputs):
        loss = super().construct(*inputs)
        # 每步自动更新EMA权重
        self.ema.update_ema()
        return loss