从零构建大语言模型 2:PyTorch 基础与资源核算

1 阅读13分钟

从零构建大语言模型 2:PyTorch 基础与资源核算

基于 Stanford CS336: Language Models From Scratch (Spring 2025) Lecture 2

课程主页:stanford-cs336.github.io/spring2025


本讲你将收获什么

  1. 理解训练 LLM 时的两类核心资源:内存(GB)计算量(FLOPs)
  2. 掌握 PyTorch 张量的创建、数据类型(float32 / float16 / bfloat16 / fp8)及其内存开销
  3. 理解张量的存储机制(stride)、视图(view)与拷贝的区别
  4. 学会使用 einops 进行可读的张量操作
  5. 掌握前向传播和反向传播的 FLOPs 计算规则
  6. 走通从数据加载 → 模型定义 → 优化器 → 训练循环的完整流程
  7. 了解混合精度训练的原理和实践

前置知识

  • 第一讲的内容(分词、LLM 全景)
  • Python 基础 + 了解矩阵乘法
  • 装好 PyTorch(pip install torch

第一部分:两道快速估算题

# H100 bf16 峰值性能(不含稀疏加速)
h100_flop_per_sec = 989.5e12   # ~990 TFLOP/s

# MFU (Model FLOPs Utilization): 实际能用到峰值的多少
# 通常 0.3~0.5 算不错
mfu = 0.3

# 总 FLOPs
total_flops = 6 * 70e9 * 15e12
print(f"总 FLOPs: {total_flops:.2e}")

# 每天能算多少
flops_per_day = h100_flop_per_sec * mfu * 1024 * 60 * 60 * 24
print(f"每天 FLOPs (1024×H100): {flops_per_day:.2e}")

# 需要多少天
days = total_flops / flops_per_day
print(f"训练时间: {days:.0f} 天 ≈ {days/30:.1f} 个月")
总 FLOPs: 6.30e+24
每天 FLOPs (1024×H100): 2.63e+22
训练时间: 240 天 ≈ 8.0 个月

问题 2:8 块 H100 最大能训多大的模型?

每块 H100 有 80GB 显存,AdamW 优化器

训练时显存里要放什么?

内容每参数字节数说明
参数(weights)4float32
梯度(gradients)4float32
AdamW 一阶矩 (m)4float32
AdamW 二阶矩 (v)4float32
合计16
h100_bytes = 80e9  # 80GB per GPU
bytes_per_param = 4 + 4 + (4 + 4)  # params + grads + optimizer states
print(f"每个参数占 {bytes_per_param} 字节")

max_params = (h100_bytes * 8) / bytes_per_param
print(f"8×H100 最大参数量: {max_params:.0e} = {max_params/1e9:.0f}B")
每个参数占 16 字节
8×H100 最大参数量: 4e+10 = 40B

注意:这是非常粗略的估算——

  • 还没算 激活值(activations) 的显存,这取决于 batch size 和序列长度
  • 实际中可以用 bf16 存参数和梯度(各 2 字节),但需要额外保留一份 float32 参数副本,所以总内存差不多
  • 可以用 ZeRO 等技术把优化器状态分摊到多块 GPU

第二部分:张量与内存核算

张量基础

张量(Tensor)是 PyTorch 中存储一切数据的基本单元——参数、梯度、激活值、优化器状态都是张量。

import torch
import torch.nn as nn

# 多种创建方式
x = torch.tensor([[1., 2, 3], [4, 5, 6]])   # 从数据创建
print(f"从数据: {x}")
print(f"shape={x.shape}, dtype={x.dtype}")

x = torch.zeros(4, 8)    # 全零
x = torch.ones(4, 8)     # 全一
x = torch.randn(4, 8)    # 标准正态随机
x = torch.empty(4, 8)    # 分配内存但不初始化(更快)
print(f"\nempty 的值是随机的垃圾: {x[0, :3]}")

# 截断正态初始化(实际中常用)
nn.init.trunc_normal_(x, mean=0, std=1, a=-2, b=2)
print(f"截断正态初始化后: {x[0, :3]}")
从数据: tensor([[1., 2., 3.],
        [4., 5., 6.]])
shape=torch.Size([2, 3]), dtype=torch.float32

empty 的值是随机的垃圾: tensor([0., 0., 0.])
截断正态初始化后: tensor([-0.1682, -1.7984, -1.9893])

浮点数据类型:精度 vs 内存 vs 速度

几乎所有东西(参数、梯度、激活值)都存为浮点数。选择哪种精度,直接决定了内存占用和计算速度。

类型位数每个值占内存指数位尾数位动态范围用途
float32324 字节823±3.4×10³⁸默认,安全
float16162 字节510±65504省内存,但小数容易下溢
bfloat16162 字节87±3.4×10³⁸✅ 深度学习首选——和 float32 一样的范围
fp881 字节4/53/2有限最新,H100 支持

关键问题:float16 为什么危险?

# float16 的动态范围问题
x_f16 = torch.tensor([1e-8], dtype=torch.float16)
print(f"float16 存 1e-8: {x_f16.item()}")  # 变成 0 了!

# bfloat16 没问题
x_bf16 = torch.tensor([1e-8], dtype=torch.bfloat16)
print(f"bfloat16 存 1e-8: {x_bf16.item():.2e}")  # 保留了

# 查看各类型的详细信息
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
    info = torch.finfo(dtype)
    print(f"\n{str(dtype):20s} | range: [{info.min:.2e}, {info.max:.2e}] | eps: {info.eps:.2e}")
float16  1e-8: 0.0
bfloat16  1e-8: 1.00e-08

torch.float32        | range: [-3.40e+38, 3.40e+38] | eps: 1.19e-07

torch.float16        | range: [-6.55e+04, 6.55e+04] | eps: 9.77e-04

torch.bfloat16       | range: [-3.39e+38, 3.39e+38] | eps: 7.81e-03

bfloat16 的设计哲学:牺牲精度(尾数只有 7 位 vs float16 的 10 位),保留动态范围(指数 8 位 = 和 float32 一样)。

对深度学习来说,动态范围比精度重要得多——训练中梯度的数量级变化很大,如果下溢到 0 就会导致训练不稳定。

内存计算

张量的内存 = 元素个数 × 每个元素的字节数

def memory_bytes(tensor):
    return tensor.nelement() * tensor.element_size()

# 一个 4×8 的 float32 张量
x32 = torch.zeros(4, 8, dtype=torch.float32)
x16 = torch.zeros(4, 8, dtype=torch.bfloat16)
print(f"float32: {x32.nelement()} 个元素 × {x32.element_size()} 字节 = {memory_bytes(x32)} 字节")
print(f"bfloat16: {x16.nelement()} 个元素 × {x16.element_size()} 字节 = {memory_bytes(x16)} 字节")

# GPT-3 的一个 FFN 权重矩阵: 12288 × 49152
gpt3_ffn = torch.empty(12288, 49152, dtype=torch.float32)
print(f"\nGPT-3 一个 FFN 矩阵: {memory_bytes(gpt3_ffn) / 1e9:.2f} GB (float32)")
print(f"                     {memory_bytes(gpt3_ffn) / 2 / 1e9:.2f} GB (bfloat16)")
float32: 32 个元素 × 4 字节 = 128 字节
bfloat16: 32 个元素 × 2 字节 = 64 字节

GPT-3 一个 FFN 矩阵: 2.42 GB (float32)
                     1.21 GB (bfloat16)

一个 FFN 矩阵就 2.4GB——GPT-3 有 96 层,每层有多个这样的矩阵。这就是为什么大模型需要那么多显存。


第三部分:张量操作与存储机制

张量的内部存储:stride

PyTorch 张量并不是简单的多维数组——它是一个指向一维连续内存的指针,加上描述如何索引的元数据(shape, stride)。

x = torch.tensor([[1., 2, 3, 4],
                   [5, 6, 7, 8],
                   [9, 10, 11, 12]])

print(f"shape: {x.shape}")        # (3, 4)
print(f"stride: {x.stride()}")    # (4, 1)
# stride 的含义:
# - 沿 dim 0(行)移动一步,在内存中跳 4 个元素
# - 沿 dim 1(列)移动一步,在内存中跳 1 个元素

# 手动计算 x[1][2] 在一维内存中的位置
r, c = 1, 2
index = r * x.stride(0) + c * x.stride(1)
print(f"\nx[{r}][{c}] = {x[r][c]:.0f}, 内存偏移 = {index}")
shape: torch.Size([3, 4])
stride: (4, 1)

x[1][2] = 7, 内存偏移 = 6

视图(View)vs 拷贝(Copy)

很多操作只是创建同一块内存的不同视角——不拷贝数据,因此是 O(1)O(1) 的。

x = torch.tensor([[1., 2, 3], [4, 5, 6]])

# 这些操作都是 view(不拷贝,共享内存)
row = x[0]           # 取行
col = x[:, 1]        # 取列
reshaped = x.view(3, 2)  # 改变形状
transposed = x.transpose(0, 1)  # 转置

print(f"原始 x:\n{x}")

# 修改 x 会影响所有 view!
x[0][0] = 999
print(f"\n修改 x[0][0]=999 后:")
print(f"  row: {row}")           # row[0] 也变了
print(f"  reshaped:\n{reshaped}")  # 也变了

# 转置后是不连续的,某些操作需要先 .contiguous()
x = torch.tensor([[1., 2, 3], [4, 5, 6]])
t = x.transpose(0, 1)
print(f"\n转置后 contiguous? {t.is_contiguous()}")
# t.view(2, 3)  # 这会报错!
t_c = t.contiguous().view(2, 3)  # 先拷贝成连续的,再 view
print(f"contiguous + view:\n{t_c}")
原始 x:
tensor([[1., 2., 3.],
        [4., 5., 6.]])

修改 x[0][0]=999 后:
  row: tensor([999.,   2.,   3.])
  reshaped:
tensor([[999.,   2.],
        [  3.,   4.],
        [  5.,   6.]])

转置后 contiguous? False
contiguous + view:
tensor([[1., 4., 2.],
        [5., 3., 6.]])

要点:view 是免费的(不消耗额外内存和计算),copy 需要额外内存和计算。知道哪些操作是 view、哪些是 copy,对内存核算非常重要。

逐元素操作

对张量每个元素独立操作,输出形状不变。

# 常见逐元素操作
x = torch.tensor([1., 2., 3., 4.])
print(f"原始: {x}")
print(f"  +1: {x + 1}")
print(f"  *2: {x * 2}")
print(f" exp: {torch.exp(x)}")
print(f"relu: {torch.relu(x - 2.5)}")

# 因果注意力掩码(上三角)
mask = torch.ones(4, 4).triu()
print(f"\n因果掩码 (triu):\n{mask}")
原始: tensor([1., 2., 3., 4.])
  +1: tensor([2., 3., 4., 5.])
  *2: tensor([2., 4., 6., 8.])
 exp: tensor([ 2.7183,  7.3891, 20.0855, 54.5982])
relu: tensor([0.0000, 0.0000, 0.5000, 1.5000])

因果掩码 (triu):
tensor([[1., 1., 1., 1.],
        [0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.]])

矩阵乘法:深度学习的核心运算

# 基本矩阵乘法
# 在实际模型中,x 通常有 batch 和 sequence 维度
# x: (batch, seq_len, d_in) @ w: (d_in, d_out) → (batch, seq_len, d_out)

batch, seq_len, d_in, d_out = 2, 3, 4, 5
x = torch.randn(batch, seq_len, d_in)
w = torch.randn(d_in, d_out)

# PyTorch 的 @ 会自动 broadcast 前面的维度
y = x @ w
print(f"x: {list(x.shape)}")
print(f"w: {list(w.shape)}")
print(f"y = x @ w: {list(y.shape)}")
x: [2, 3, 4]
w: [4, 5]
y = x @ w: [2, 3, 5]

第四部分:用 einops 写可读的张量操作

传统 PyTorch 代码中,张量维度全靠注释和数字索引,非常容易出错:

z = x @ y.transpose(-2, -1)  # -2 是哪个维度?-1 呢?

Einops命名维度解决这个问题。

from einops import einsum, reduce, rearrange

# ---- einsum: 带名字的矩阵乘法 ----
x = torch.ones(2, 3, 4)  # batch=2, seq1=3, hidden=4
y = torch.ones(2, 3, 4)  # batch=2, seq2=3, hidden=4

# 旧写法(谁知道 -2 和 -1 是啥)
z_old = x @ y.transpose(-2, -1)

# einops 写法(一目了然)
z_new = einsum(x, y, "batch seq1 hidden, batch seq2 hidden -> batch seq1 seq2")

print(f"注意力分数矩阵 shape: {list(z_new.shape)}")
print(f"两种方式结果一致: {torch.allclose(z_old, z_new)}")

# 用 ... 表示 batch 维度(更通用)
z = einsum(x, y, "... seq1 hidden, ... seq2 hidden -> ... seq1 seq2")
print(f"用 ... 语法: {list(z.shape)}")
注意力分数矩阵 shape: [2, 3, 3]
两种方式结果一致: True
用 ... 语法: [2, 3, 3]
# ---- reduce: 命名的归约操作 ----
x = torch.ones(2, 3, 4)  # batch, seq, hidden

# 旧写法
y_old = x.mean(dim=-1)   # 对最后一维求均值

# einops 写法
y_sum = reduce(x, "... hidden -> ...", "sum")     # 求和
y_mean = reduce(x, "... hidden -> ...", "mean")    # 均值

print(f"sum:  shape={list(y_sum.shape)}, values={y_sum[0]}")
print(f"mean: shape={list(y_mean.shape)}, values={y_mean[0]}")
sum:  shape=[2, 3], values=tensor([4., 4., 4.])
mean: shape=[2, 3], values=tensor([1., 1., 1.])
# ---- rearrange: 维度重组 ----
# 场景:多头注意力中,hidden_dim 实际上是 heads × head_dim
# 需要拆开、操作、再合并

x = torch.ones(2, 3, 8)  # batch=2, seq=3, total_hidden=8
w = torch.ones(4, 4)      # head_dim_in × head_dim_out

# 拆分 hidden 为 heads × head_dim
x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2)
print(f"拆分后: {list(x.shape)}")  # [2, 3, 2, 4]

# 对每个 head 分别做线性变换
x = einsum(x, w, "... hidden1, hidden1 hidden2 -> ... hidden2")
print(f"变换后: {list(x.shape)}")  # [2, 3, 2, 4]

# 合并回去
x = rearrange(x, "... heads hidden2 -> ... (heads hidden2)")
print(f"合并后: {list(x.shape)}")  # [2, 3, 8]
拆分后: [2, 3, 2, 4]
变换后: [2, 3, 2, 4]
合并后: [2, 3, 8]

建议:einops 写起来稍微长一点,但维度关系一目了然,特别适合多头注意力这种复杂的维度操作。实际项目中强烈推荐使用。


第五部分:计算量核算(FLOPs)

什么是 FLOP?

一个 FLOP = 一次浮点加法或乘法。

两个容易混淆的缩写:

缩写含义度量什么
FLOPsFloating-point Operations总计算量(一个数字)
FLOP/sFLOPs per Second计算速度(硬件性能)

量级直觉

模型训练 FLOPs说明
GPT-3 (2020)3.14 × 10²³
GPT-4 (2023)~2 × 10²⁵估计值
美国报告门槛1 × 10²⁶超过需向政府报告(2025 已撤销)
硬件峰值 FLOP/s (bf16)
A100312 TFLOP/s
H100~990 TFLOP/s

矩阵乘法的 FLOPs

矩阵乘法 (m×n) @ (n×p) → (m×p):每个输出元素需要 n 次乘法 + n 次加法 = 2n,共 m×p 个输出。

FLOPs=2×m×n×p\text{FLOPs} = 2 \times m \times n \times p

import time

B, D, K = 256, 512, 256

x = torch.randn(B, D)
w = torch.randn(D, K)

# 理论 FLOPs
flops = 2 * B * D * K
print(f"矩阵乘法 ({B}×{D}) @ ({D}×{K})")
print(f"理论 FLOPs: {flops:,}")

# 实际计时
start = time.time()
for _ in range(100):
    y = x @ w
elapsed = (time.time() - start) / 100

print(f"平均耗时: {elapsed*1000:.2f} ms")
print(f"实际 FLOP/s: {flops / elapsed:.2e}")
矩阵乘法 (256×512) @ (512×256)
理论 FLOPs: 67,108,864
平均耗时: 0.14 ms
实际 FLOP/s: 4.84e+11

从矩阵乘法到 Transformer

对于一个线性层 y = x @ w

  • FLOPs = 2 × (数据点数) × (参数量)

这个结论可以推广到整个 Transformer(一阶近似):

阶段FLOPs
前向传播2 × B × N
反向传播4 × B × N
总计6 × B × N

其中 B = 数据点数(batch × seq_len × num_tokens),N = 模型参数量。

这就是第一部分那个 6 × 70e9 × 15e12 的来源!

Model FLOPs Utilization (MFU)

MFU=实际 FLOP/s硬件峰值 FLOP/s\text{MFU} = \frac{\text{实际 FLOP/s}}{\text{硬件峰值 FLOP/s}}

  • MFU ≥ 0.5 算很好
  • 矩阵乘法占比越高,MFU 越高
  • 数据类型影响很大:bf16 比 float32 快得多

第六部分:梯度与反向传播的 FLOPs

前向 + 反向 = 6 × 数据 × 参数

我们用一个两层线性模型来推导为什么反向传播的 FLOPs 是前向的两倍。

x --[w1]--> h1 --[w2]--> h2 --> loss

为什么是 6 × 数据量 × 参数量?

我们来拆解这个公式的来源。

先用一个买菜的类比建立直觉:

假设你去菜市场,买了 3 样东西:

品类单价 ww(元/kg)数量 xx(kg)
番茄53
白菜24
鸡蛋82

Forward(算总价):每样东西做一次 "单价 × 数量",然后加起来得到总价。这是一轮乘加运算 → 2N FLOPs

Backward 1 — 对数量的敏感度(梯度传回前一层):老板说"总价多了 1 块",你想知道"减少哪样东西的数量影响最大?"答案就是看单价——番茄 5 块/kg 影响最大,白菜 2 块/kg 影响最小。这又是一轮乘加 → 2N FLOPs。对应 Lx=WLy\frac{\partial L}{\partial x} = W^\top \cdot \frac{\partial L}{\partial y},用来把误差继续往前传。

你可能会问:xx 是输入数据,又不需要更新,为什么还要算它的梯度?因为这一层的 xx 就是上一层的输出 hh——链式法则要求把梯度一层层传回去,否则前面层的权重就没法更新了。只有网络最开头的输入才真正不需要算梯度。

Backward 2 — 对单价的敏感度(更新权重):同样的问题反过来——"调整哪样东西的单价影响最大?"答案是看数量——白菜买了 4kg 影响最大,鸡蛋 2kg 影响最小。又是一轮乘加 → 2N FLOPs。对应 LW=Lyx\frac{\partial L}{\partial W} = \frac{\partial L}{\partial y} \cdot x^\top,用来更新权重。

所以:Forward 1 轮 + Backward 2 轮 = 3 轮,每轮 2N → 总计 6N

下面用数学语言精确表述:

前向传播:≈ 2N FLOPs

对于一个线性层 y=Wxy = Wx,其中 WW 的形状是 (m,n)(m, n)

  • 每个输出元素需要 nn 次乘法 + nn 次加法 = 2n2n 次运算
  • 总共 mm 个输出元素,所以是 2mn2mn FLOPs
  • mnmn 正好就是这一层的参数量

把所有层加起来,整个模型的前向传播 ≈ 2N2N FLOPs(NN 是总参数量)。

反向传播:≈ 4N FLOPs

反向传播对每一层需要计算两个梯度

  1. 对输入的梯度(用于继续反向传播到前一层):Lx=WLy\frac{\partial L}{\partial x} = W^\top \cdot \frac{\partial L}{\partial y} → 一次矩阵乘法,≈ 2mn2mn FLOPs
  2. 对权重的梯度(用于更新参数):LW=Lyx\frac{\partial L}{\partial W} = \frac{\partial L}{\partial y} \cdot x^\top → 又一次矩阵乘法,≈ 2mn2mn FLOPs

两个加起来:4mn4N4mn \approx 4N FLOPs。

合计:前向 2N2N + 反向 4N4N = 6N6N FLOPs(per data point)

乘以数据点数 BB,总计 6BN6BN

下面用一个两层线性模型 x → w1 → h1 → w2 → h2 → loss 来验证这个结论:

B, D, K = 2, 3, 2  # 小尺寸便于说明

# 各张量的 shape:
# x:  (B, D)   = (2, 3)   — 输入数据
# w1: (D, D)   = (3, 3)   — 第一层权重
# h1: (B, D)   = (2, 3)   — 第一层输出 = x @ w1
# w2: (D, K)   = (3, 2)   — 第二层权重
# h2: (B, K)   = (2, 2)   — 第二层输出 = h1 @ w2 → 用来算 loss

print(f"x:  ({B}, {D})  — 输入")
print(f"w1: ({D}, {D})  — 第一层权重, 参数量 = {D*D}")
print(f"h1: ({B}, {D})  — x @ w1 的输出")
print(f"w2: ({D}, {K})  — 第二层权重, 参数量 = {D*K}")
print(f"h2: ({B}, {K})  — h1 @ w2 的输出")
print()

# === 前向 FLOPs ===
# x @ w1: (B,D) @ (D,D) → 2*B*D*D
# h1 @ w2: (B,D) @ (D,K) → 2*B*D*K
forward_flops = (2 * B * D * D) + (2 * B * D * K)
print(f"前向 FLOPs:")
print(f"  x @ w1:  2×{B}×{D}×{D} = {2*B*D*D}")
print(f"  h1 @ w2: 2×{B}×{D}×{K} = {2*B*D*K}")
print(f"  合计: {forward_flops}")
print()

# === 反向 FLOPs ===
backward_flops = 0

# 1. w2.grad = h1.T @ h2.grad  →  (D,B) @ (B,K) → 2*B*D*K
backward_flops += 2 * B * D * K
print(f"反向 FLOPs:")
print(f"  w2.grad = h1.T @ h2.grad:  2×{B}×{D}×{K} = {2*B*D*K}")

# 2. h1.grad = h2.grad @ w2.T  →  (B,K) @ (K,D) → 2*B*D*K
backward_flops += 2 * B * D * K
print(f"  h1.grad = h2.grad @ w2.T: 2×{B}×{D}×{K} = {2*B*D*K}")

# 3. w1.grad = x.T @ h1.grad  →  (D,B) @ (B,D) → 2*B*D*D
#    x.grad  = h1.grad @ w1.T →  (B,D) @ (D,D) → 2*B*D*D (不需要但链式法则会算)
backward_flops += (2 + 2) * B * D * D
print(f"  w1.grad + x.grad:          4×{B}×{D}×{D} = {4*B*D*D}")

print(f"  合计: {backward_flops}")
print()
print(f"前向: {forward_flops},  反向: {backward_flops}")
print(f"反向 / 前向 = {backward_flops / forward_flops:.1f}x")
print(f"总计 = {forward_flops + backward_flops}")
print(f"     ≈ 6 × B × params = {6 * B * (D*D + D*K)}")
x:  (2, 3)  — 输入
w1: (3, 3)  — 第一层权重, 参数量 = 9
h1: (2, 3)  — x @ w1 的输出
w2: (3, 2)  — 第二层权重, 参数量 = 6
h2: (2, 2)  — h1 @ w2 的输出

前向 FLOPs:
  x @ w1:  2×2×3×3 = 36
  h1 @ w2: 2×2×3×2 = 24
  合计: 60

反向 FLOPs:
  w2.grad = h1.T @ h2.grad:  2×2×3×2 = 24
  h1.grad = h2.grad @ w2.T: 2×2×3×2 = 24
  w1.grad + x.grad:          4×2×3×3 = 72
  合计: 120

前向: 60,  反向: 120
反向 / 前向 = 2.0x
总计 = 180
     ≈ 6 × B × params = 180

关键结论:反向传播的 FLOPs 是前向的 2 倍。总计 = 6 × 数据点数 × 参数量

这也解释了为什么推理(只做前向)比训练便宜得多——训练每步要做 3 倍于推理的计算。


第七部分:模型定义

参数初始化的重要性

随机初始化如果不做缩放,输出的数量级会随输入维度增大而爆炸——这会导致梯度爆炸。

# 不缩放的问题
input_dim = 1000
x = torch.randn(1, input_dim)

# 不缩放:每个输出元素 ~ sum of input_dim 个 N(0,1)*N(0,1)
# 方差 = input_dim,标准差 = sqrt(input_dim) ≈ 31.6
w_bad = torch.randn(input_dim, 4)
out_bad = x @ w_bad
print(f"不缩放 (input_dim={input_dim}):")
print(f"  输出: {out_bad.data[0]}")
print(f"  量级 ≈ sqrt({input_dim}) = {input_dim**0.5:.1f}")

# 缩放(Xavier 初始化的核心思想)
w_good = torch.randn(input_dim, 4) / (input_dim ** 0.5)
out_good = x @ w_good
print(f"\n缩放后 (÷ sqrt(input_dim)):")
print(f"  输出: {out_good.data[0]}")
print(f"  量级 ≈ O(1) ✓")
不缩放 (input_dim=1000):
  输出: tensor([-27.6163,  17.9119,  26.9574,  17.6714])
  量级 ≈ sqrt(1000) = 31.6

缩放后 (÷ sqrt(input_dim)):
  输出: tensor([ 1.3417, -0.2218, -0.9940,  0.2977])
  量级 ≈ O(1) ✓

用 nn.Module 定义模型

class DeepLinearModel(nn.Module):
    """简单的多层线性模型(用于演示训练流程)"""
    def __init__(self, input_dim, hidden_dim, num_layers):
        super().__init__()
        layers = []
        for i in range(num_layers):
            d_in = input_dim if i == 0 else hidden_dim
            w = nn.Parameter(torch.randn(d_in, hidden_dim) / d_in ** 0.5)
            layers.append(w)
        self.layers = nn.ParameterList(layers)
        self.bias = nn.Parameter(torch.zeros(hidden_dim))

    def forward(self, x):
        for w in self.layers:
            x = x @ w
        return x + self.bias

# 创建模型
D = 3
model = DeepLinearModel(input_dim=D, hidden_dim=D, num_layers=4)

# 数参数量
num_params = sum(p.numel() for p in model.parameters())
print(f"参数量: {num_params}")
print(f"内存 (float32): {num_params * 4} 字节")

# 前向传播
x = torch.randn(2, D)
y = model(x)
print(f"\n输入 shape: {list(x.shape)}")
print(f"输出 shape: {list(y.shape)}")
参数量: 39
内存 (float32): 156 字节

输入 shape: [2, 3]
输出 shape: [2, 3]

第八部分:完整训练流程

随机种子

训练中随机性无处不在(初始化、dropout、数据顺序)。为每个随机源设置独立的种子,这样可以精确复现结果。

import random
import numpy as np

def set_seed(seed):
    """设置所有随机种子"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
print("随机种子已设置为 42")
随机种子已设置为 42

数据加载

语言模型的训练数据 = 分词后的整数序列。数据量通常很大(LLaMA 的训练数据 ~2.8TB),不能一次性加载。

# 模拟:生成一个简单的回归数据集
set_seed(42)
D = 3
true_weights = torch.arange(D, dtype=torch.float32)  # [0, 1, 2]
print(f"真实权重: {true_weights}")

def get_batch(batch_size, dim, seed=None):
    """生成一个 batch 的数据"""
    if seed is not None:
        torch.manual_seed(seed)
    x = torch.randn(batch_size, dim)
    y = x @ true_weights  # 真实标签
    return x, y

x, y = get_batch(4, D, seed=0)
print(f"\n输入 x shape: {list(x.shape)}")
print(f"标签 y shape: {list(y.shape)}")
print(f"标签 y: {y}")
真实权重: tensor([0., 1., 2.])

输入 x shape: [4, 3]
标签 y shape: [4]
标签 y: tensor([-4.6510, -3.8817, -0.6005, -0.2326])

优化器

优化器的演进路线:

优化器核心思想
SGD最基本:w -= lr * grad
SGD + Momentum加上梯度的指数移动平均
AdaGrad用历史梯度平方和做自适应学习率
RMSPropAdaGrad + 指数衰减(避免学习率单调递减)
AdamRMSProp + Momentum = 一阶矩 + 二阶矩
AdamWAdam + 解耦的 weight decay

训练循环的内存核算

D, num_layers, B = 3, 4, 2

num_parameters = D * D * num_layers + D  # 权重 + 偏置
num_activations = B * D * num_layers      # 每层的中间结果
num_gradients = num_parameters            # 和参数一样多
num_optimizer_states = num_parameters     # AdaGrad: 一个状态变量

total_memory = 4 * (num_parameters + num_activations + num_gradients + num_optimizer_states)
print(f"参数: {num_parameters} 个 → {num_parameters * 4} 字节")
print(f"激活值: {num_activations} 个 → {num_activations * 4} 字节")
print(f"梯度: {num_gradients} 个 → {num_gradients * 4} 字节")
print(f"优化器状态: {num_optimizer_states} 个 → {num_optimizer_states * 4} 字节")
print(f"总内存: {total_memory} 字节")

flops_per_step = 6 * B * num_parameters
print(f"\n每步 FLOPs: {flops_per_step}")
参数: 39 个 → 156 字节
激活值: 24 个 → 96 字节
梯度: 39 个 → 156 字节
优化器状态: 39 个 → 156 字节
总内存: 564 字节

每步 FLOPs: 468

完整训练循环

set_seed(42)

# 模型 + 优化器
model = DeepLinearModel(input_dim=D, hidden_dim=D, num_layers=2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)

# 训练
for step in range(200):
    x, y_true = get_batch(batch_size=32, dim=D, seed=step)

    # 前向
    y_pred = model(x).sum(dim=-1)  # 简化:对输出求和
    loss = ((y_pred - y_true) ** 2).mean()

    # 反向
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 50 == 0:
        print(f"step {step:3d} | loss = {loss.item():.4f}")

print(f"\n最终 loss: {loss.item():.6f}")
step   0 | loss = 0.3104
step  50 | loss = 0.0009
step 100 | loss = 0.0000
step 150 | loss = 0.0000

最终 loss: 0.000000

第九部分:Checkpoint 与混合精度训练

Checkpoint(检查点)

训练大模型动辄几天到几个月,中途崩溃是家常便饭。定期保存模型和优化器状态 = 保险。

# 保存
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "step": 200,
}, "/tmp/checkpoint.pt")
print("Checkpoint 已保存")

# 加载
ckpt = torch.load("/tmp/checkpoint.pt", weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
print(f"从 step {ckpt['step']} 恢复训练")
Checkpoint 已保存
从 step 200 恢复训练

混合精度训练

精度优点缺点
float32稳定慢,内存大
bfloat16快,内存小可能不稳定

混合精度的策略:两者结合

  • 前向传播(激活值):用 bfloat16 / fp8——这里精度要求低,但数据量大
  • 参数和梯度:保留 float32——精度要求高,但相对数据量小
  • 额外维护一份 float32 的参数副本(master weights)

PyTorch 提供了 torch.amp(Automatic Mixed Precision)自动管理:

# 伪代码
scaler = torch.amp.GradScaler()
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    output = model(input)    # 前向用 bf16
    loss = criterion(output, target)

scaler.scale(loss).backward()  # 反向
scaler.step(optimizer)
scaler.update()

NVIDIA 的 Transformer Engine 还支持 FP8 混合精度(论文),可进一步减少内存和加速计算。


本讲总结

资源核算速查表

项目公式
张量内存元素数 × 每元素字节数
矩阵乘法 FLOPs2 × m × n × p
前向 FLOPs2 × B × N
反向 FLOPs4 × B × N
训练总 FLOPs6 × B × N
训练内存(AdamW, fp32)16 × N + 激活值
MFU实际 FLOP/s ÷ 峰值 FLOP/s

数据类型选择

场景推荐
默认/调试float32
训练前向bfloat16
参数/梯度float32(master weights)
最激进fp8(需要 H100)

关键概念检查

  1. 为什么 bfloat16 比 float16 更适合深度学习?(动态范围 = float32,不会下溢)
  2. x.view(3, 2)x.reshape(3, 2) 有什么区别?(view 必须连续,不拷贝;reshape 可能拷贝)
  3. 反向传播的 FLOPs 是前向的几倍?(2 倍)
  4. 为什么初始化权重要除以 sqrt(input_dim)?(保持输出方差为 O(1),防止梯度爆炸)
  5. 混合精度训练的核心思路是什么?(前向用低精度省内存和加速,参数用高精度保稳定)