优化器与指数加权平均(EMA)学习笔记

127 阅读2分钟

基础概念

  • epoch(训练轮次) :完整遍历一次训练集。
  • linear(全连接层) :PyTorch 中 nn.Linear(in_features, out_features),执行线性变换 y=xWT+by=xWT+b
  • Weight(权重参数) :线性层中的可训练参数 WWW,通过梯度下降优化
  • EMA(指数加权平均) :用历史值的加权平均来平滑曲线或参数更新,减少抖动。 image.png
  • 鞍点:一阶导数为 0,但既不是局部极大值,也不是局部极小值。

1. 参数初始化

  • 工业界经验:均匀分布优于正态分布
  • 不要用 全 0 或全 1 初始化,会导致梯度传播失效或对称性问题。

2. 指数加权平均(EMA)使用场景

  • 训练过程可视化:平滑 loss/accuracy 曲线。
  • 大模型训练:对全量参数做 EMA(如 θ_t = βθ_{t-1} + (1-β)θ_current),降低抖动。

例子:

  • 第 1 个 epoch 得到参数 w1w_1
  • 第 2 个 epoch 得到 w2w_2,但保留部分 w1w_1 融合进来
  • 第 3 个 epoch 得到 w3w_3,融合 w2w_2w1w_1 ……
    (相当于滑动平均,更新更平稳)

3. 优化方法

  • Momentum:在梯度更新时引入惯性,能帮助越过鞍点。梯度方向更稳(方向优化)
  • AdaGrad:为每个参数引入自适应学习率(梯度平方累计)。
  • RMSProp:改进 AdaGrad,对梯度平方做指数加权平均,避免学习率过快下降。每个参数步长合适(学习率优化)
  • Adam:结合 Momentum(梯度的一阶动量) + RMSProp(二阶动量)

4. 参数更新公式

一般形式:

θ=θ−η⋅g

  • θ\theta:参数
  • η\eta:学习率
  • gg:梯度

不同优化器的区别在于 如何调整 ggη\eta(学习率)

PyTorch 示例代码

from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt


# 参数初始化示例
def init_weights_demo():
    linear = nn.Linear(5, 3)
    torch.manual_seed(0)
    nn.init.normal_(linear.weight, mean=0, std=1)  # 正态初始化
    print("Initialized Weights:\n", linear.weight.data)


# 生成随机噪声并绘制曲线
def noise_demo():
    torch.manual_seed(0)
    temp = torch.randn(30) * 10  # ~ N(0,10^2)
    days = torch.arange(1, 31)
    plt.plot(days, temp, color='r', label="Noise Curve")
    plt.scatter(days, temp)
    plt.title("Random Noise")
    plt.show()


# 指数加权平均示例
def ema_demo(beta):
    torch.manual_seed(0)
    raw_data = torch.randn(30) * 10
    ema = []

    for idx, value in enumerate(raw_data, start=1):
        if idx == 1:
            ema.append(value)
        else:
            new_val = beta * ema[-1] + (1 - beta) * value
            ema.append(new_val)

    days = torch.arange(1, 31)
    plt.plot(days, ema, color='r', label=f"EMA (β={beta})")
    plt.scatter(days, raw_data, label="Raw Data")
    plt.title(f"EMA with beta={beta}")
    plt.legend()
    plt.show()


if __name__ == "__main__":
    init_weights_demo()
    noise_demo()
    ema_demo(0)
    ema_demo(0.5)
    ema_demo(0.9)