从零构建大语言模型 2:PyTorch 基础与资源核算
基于 Stanford CS336: Language Models From Scratch (Spring 2025) Lecture 2
本讲你将收获什么
- 理解训练 LLM 时的两类核心资源:内存(GB) 和 计算量(FLOPs)
- 掌握 PyTorch 张量的创建、数据类型(float32 / float16 / bfloat16 / fp8)及其内存开销
- 理解张量的存储机制(stride)、视图(view)与拷贝的区别
- 学会使用 einops 进行可读的张量操作
- 掌握前向传播和反向传播的 FLOPs 计算规则
- 走通从数据加载 → 模型定义 → 优化器 → 训练循环的完整流程
- 了解混合精度训练的原理和实践
前置知识
- 第一讲的内容(分词、LLM 全景)
- Python 基础 + 了解矩阵乘法
- 装好 PyTorch(
pip install torch)
第一部分:两道快速估算题
# H100 bf16 峰值性能(不含稀疏加速)
h100_flop_per_sec = 989.5e12 # ~990 TFLOP/s
# MFU (Model FLOPs Utilization): 实际能用到峰值的多少
# 通常 0.3~0.5 算不错
mfu = 0.3
# 总 FLOPs
total_flops = 6 * 70e9 * 15e12
print(f"总 FLOPs: {total_flops:.2e}")
# 每天能算多少
flops_per_day = h100_flop_per_sec * mfu * 1024 * 60 * 60 * 24
print(f"每天 FLOPs (1024×H100): {flops_per_day:.2e}")
# 需要多少天
days = total_flops / flops_per_day
print(f"训练时间: {days:.0f} 天 ≈ {days/30:.1f} 个月")
总 FLOPs: 6.30e+24
每天 FLOPs (1024×H100): 2.63e+22
训练时间: 240 天 ≈ 8.0 个月
问题 2:8 块 H100 最大能训多大的模型?
每块 H100 有 80GB 显存,AdamW 优化器
训练时显存里要放什么?
| 内容 | 每参数字节数 | 说明 |
|---|---|---|
| 参数(weights) | 4 | float32 |
| 梯度(gradients) | 4 | float32 |
| AdamW 一阶矩 (m) | 4 | float32 |
| AdamW 二阶矩 (v) | 4 | float32 |
| 合计 | 16 |
h100_bytes = 80e9 # 80GB per GPU
bytes_per_param = 4 + 4 + (4 + 4) # params + grads + optimizer states
print(f"每个参数占 {bytes_per_param} 字节")
max_params = (h100_bytes * 8) / bytes_per_param
print(f"8×H100 最大参数量: {max_params:.0e} = {max_params/1e9:.0f}B")
每个参数占 16 字节
8×H100 最大参数量: 4e+10 = 40B
注意:这是非常粗略的估算——
- 还没算 激活值(activations) 的显存,这取决于 batch size 和序列长度
- 实际中可以用 bf16 存参数和梯度(各 2 字节),但需要额外保留一份 float32 参数副本,所以总内存差不多
- 可以用 ZeRO 等技术把优化器状态分摊到多块 GPU
第二部分:张量与内存核算
张量基础
张量(Tensor)是 PyTorch 中存储一切数据的基本单元——参数、梯度、激活值、优化器状态都是张量。
import torch
import torch.nn as nn
# 多种创建方式
x = torch.tensor([[1., 2, 3], [4, 5, 6]]) # 从数据创建
print(f"从数据: {x}")
print(f"shape={x.shape}, dtype={x.dtype}")
x = torch.zeros(4, 8) # 全零
x = torch.ones(4, 8) # 全一
x = torch.randn(4, 8) # 标准正态随机
x = torch.empty(4, 8) # 分配内存但不初始化(更快)
print(f"\nempty 的值是随机的垃圾: {x[0, :3]}")
# 截断正态初始化(实际中常用)
nn.init.trunc_normal_(x, mean=0, std=1, a=-2, b=2)
print(f"截断正态初始化后: {x[0, :3]}")
从数据: tensor([[1., 2., 3.],
[4., 5., 6.]])
shape=torch.Size([2, 3]), dtype=torch.float32
empty 的值是随机的垃圾: tensor([0., 0., 0.])
截断正态初始化后: tensor([-0.1682, -1.7984, -1.9893])
浮点数据类型:精度 vs 内存 vs 速度
几乎所有东西(参数、梯度、激活值)都存为浮点数。选择哪种精度,直接决定了内存占用和计算速度。
| 类型 | 位数 | 每个值占内存 | 指数位 | 尾数位 | 动态范围 | 用途 |
|---|---|---|---|---|---|---|
| float32 | 32 | 4 字节 | 8 | 23 | ±3.4×10³⁸ | 默认,安全 |
| float16 | 16 | 2 字节 | 5 | 10 | ±65504 | 省内存,但小数容易下溢 |
| bfloat16 | 16 | 2 字节 | 8 | 7 | ±3.4×10³⁸ | ✅ 深度学习首选——和 float32 一样的范围 |
| fp8 | 8 | 1 字节 | 4/5 | 3/2 | 有限 | 最新,H100 支持 |
关键问题:float16 为什么危险?
# float16 的动态范围问题
x_f16 = torch.tensor([1e-8], dtype=torch.float16)
print(f"float16 存 1e-8: {x_f16.item()}") # 变成 0 了!
# bfloat16 没问题
x_bf16 = torch.tensor([1e-8], dtype=torch.bfloat16)
print(f"bfloat16 存 1e-8: {x_bf16.item():.2e}") # 保留了
# 查看各类型的详细信息
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
info = torch.finfo(dtype)
print(f"\n{str(dtype):20s} | range: [{info.min:.2e}, {info.max:.2e}] | eps: {info.eps:.2e}")
float16 存 1e-8: 0.0
bfloat16 存 1e-8: 1.00e-08
torch.float32 | range: [-3.40e+38, 3.40e+38] | eps: 1.19e-07
torch.float16 | range: [-6.55e+04, 6.55e+04] | eps: 9.77e-04
torch.bfloat16 | range: [-3.39e+38, 3.39e+38] | eps: 7.81e-03
bfloat16 的设计哲学:牺牲精度(尾数只有 7 位 vs float16 的 10 位),保留动态范围(指数 8 位 = 和 float32 一样)。
对深度学习来说,动态范围比精度重要得多——训练中梯度的数量级变化很大,如果下溢到 0 就会导致训练不稳定。
内存计算
张量的内存 = 元素个数 × 每个元素的字节数
def memory_bytes(tensor):
return tensor.nelement() * tensor.element_size()
# 一个 4×8 的 float32 张量
x32 = torch.zeros(4, 8, dtype=torch.float32)
x16 = torch.zeros(4, 8, dtype=torch.bfloat16)
print(f"float32: {x32.nelement()} 个元素 × {x32.element_size()} 字节 = {memory_bytes(x32)} 字节")
print(f"bfloat16: {x16.nelement()} 个元素 × {x16.element_size()} 字节 = {memory_bytes(x16)} 字节")
# GPT-3 的一个 FFN 权重矩阵: 12288 × 49152
gpt3_ffn = torch.empty(12288, 49152, dtype=torch.float32)
print(f"\nGPT-3 一个 FFN 矩阵: {memory_bytes(gpt3_ffn) / 1e9:.2f} GB (float32)")
print(f" {memory_bytes(gpt3_ffn) / 2 / 1e9:.2f} GB (bfloat16)")
float32: 32 个元素 × 4 字节 = 128 字节
bfloat16: 32 个元素 × 2 字节 = 64 字节
GPT-3 一个 FFN 矩阵: 2.42 GB (float32)
1.21 GB (bfloat16)
一个 FFN 矩阵就 2.4GB——GPT-3 有 96 层,每层有多个这样的矩阵。这就是为什么大模型需要那么多显存。
第三部分:张量操作与存储机制
张量的内部存储:stride
PyTorch 张量并不是简单的多维数组——它是一个指向一维连续内存的指针,加上描述如何索引的元数据(shape, stride)。
x = torch.tensor([[1., 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
print(f"shape: {x.shape}") # (3, 4)
print(f"stride: {x.stride()}") # (4, 1)
# stride 的含义:
# - 沿 dim 0(行)移动一步,在内存中跳 4 个元素
# - 沿 dim 1(列)移动一步,在内存中跳 1 个元素
# 手动计算 x[1][2] 在一维内存中的位置
r, c = 1, 2
index = r * x.stride(0) + c * x.stride(1)
print(f"\nx[{r}][{c}] = {x[r][c]:.0f}, 内存偏移 = {index}")
shape: torch.Size([3, 4])
stride: (4, 1)
x[1][2] = 7, 内存偏移 = 6
视图(View)vs 拷贝(Copy)
很多操作只是创建同一块内存的不同视角——不拷贝数据,因此是 的。
x = torch.tensor([[1., 2, 3], [4, 5, 6]])
# 这些操作都是 view(不拷贝,共享内存)
row = x[0] # 取行
col = x[:, 1] # 取列
reshaped = x.view(3, 2) # 改变形状
transposed = x.transpose(0, 1) # 转置
print(f"原始 x:\n{x}")
# 修改 x 会影响所有 view!
x[0][0] = 999
print(f"\n修改 x[0][0]=999 后:")
print(f" row: {row}") # row[0] 也变了
print(f" reshaped:\n{reshaped}") # 也变了
# 转置后是不连续的,某些操作需要先 .contiguous()
x = torch.tensor([[1., 2, 3], [4, 5, 6]])
t = x.transpose(0, 1)
print(f"\n转置后 contiguous? {t.is_contiguous()}")
# t.view(2, 3) # 这会报错!
t_c = t.contiguous().view(2, 3) # 先拷贝成连续的,再 view
print(f"contiguous + view:\n{t_c}")
原始 x:
tensor([[1., 2., 3.],
[4., 5., 6.]])
修改 x[0][0]=999 后:
row: tensor([999., 2., 3.])
reshaped:
tensor([[999., 2.],
[ 3., 4.],
[ 5., 6.]])
转置后 contiguous? False
contiguous + view:
tensor([[1., 4., 2.],
[5., 3., 6.]])
要点:view 是免费的(不消耗额外内存和计算),copy 需要额外内存和计算。知道哪些操作是 view、哪些是 copy,对内存核算非常重要。
逐元素操作
对张量每个元素独立操作,输出形状不变。
# 常见逐元素操作
x = torch.tensor([1., 2., 3., 4.])
print(f"原始: {x}")
print(f" +1: {x + 1}")
print(f" *2: {x * 2}")
print(f" exp: {torch.exp(x)}")
print(f"relu: {torch.relu(x - 2.5)}")
# 因果注意力掩码(上三角)
mask = torch.ones(4, 4).triu()
print(f"\n因果掩码 (triu):\n{mask}")
原始: tensor([1., 2., 3., 4.])
+1: tensor([2., 3., 4., 5.])
*2: tensor([2., 4., 6., 8.])
exp: tensor([ 2.7183, 7.3891, 20.0855, 54.5982])
relu: tensor([0.0000, 0.0000, 0.5000, 1.5000])
因果掩码 (triu):
tensor([[1., 1., 1., 1.],
[0., 1., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 0., 1.]])
矩阵乘法:深度学习的核心运算
# 基本矩阵乘法
# 在实际模型中,x 通常有 batch 和 sequence 维度
# x: (batch, seq_len, d_in) @ w: (d_in, d_out) → (batch, seq_len, d_out)
batch, seq_len, d_in, d_out = 2, 3, 4, 5
x = torch.randn(batch, seq_len, d_in)
w = torch.randn(d_in, d_out)
# PyTorch 的 @ 会自动 broadcast 前面的维度
y = x @ w
print(f"x: {list(x.shape)}")
print(f"w: {list(w.shape)}")
print(f"y = x @ w: {list(y.shape)}")
x: [2, 3, 4]
w: [4, 5]
y = x @ w: [2, 3, 5]
第四部分:用 einops 写可读的张量操作
传统 PyTorch 代码中,张量维度全靠注释和数字索引,非常容易出错:
z = x @ y.transpose(-2, -1) # -2 是哪个维度?-1 呢?
Einops 用命名维度解决这个问题。
from einops import einsum, reduce, rearrange
# ---- einsum: 带名字的矩阵乘法 ----
x = torch.ones(2, 3, 4) # batch=2, seq1=3, hidden=4
y = torch.ones(2, 3, 4) # batch=2, seq2=3, hidden=4
# 旧写法(谁知道 -2 和 -1 是啥)
z_old = x @ y.transpose(-2, -1)
# einops 写法(一目了然)
z_new = einsum(x, y, "batch seq1 hidden, batch seq2 hidden -> batch seq1 seq2")
print(f"注意力分数矩阵 shape: {list(z_new.shape)}")
print(f"两种方式结果一致: {torch.allclose(z_old, z_new)}")
# 用 ... 表示 batch 维度(更通用)
z = einsum(x, y, "... seq1 hidden, ... seq2 hidden -> ... seq1 seq2")
print(f"用 ... 语法: {list(z.shape)}")
注意力分数矩阵 shape: [2, 3, 3]
两种方式结果一致: True
用 ... 语法: [2, 3, 3]
# ---- reduce: 命名的归约操作 ----
x = torch.ones(2, 3, 4) # batch, seq, hidden
# 旧写法
y_old = x.mean(dim=-1) # 对最后一维求均值
# einops 写法
y_sum = reduce(x, "... hidden -> ...", "sum") # 求和
y_mean = reduce(x, "... hidden -> ...", "mean") # 均值
print(f"sum: shape={list(y_sum.shape)}, values={y_sum[0]}")
print(f"mean: shape={list(y_mean.shape)}, values={y_mean[0]}")
sum: shape=[2, 3], values=tensor([4., 4., 4.])
mean: shape=[2, 3], values=tensor([1., 1., 1.])
# ---- rearrange: 维度重组 ----
# 场景:多头注意力中,hidden_dim 实际上是 heads × head_dim
# 需要拆开、操作、再合并
x = torch.ones(2, 3, 8) # batch=2, seq=3, total_hidden=8
w = torch.ones(4, 4) # head_dim_in × head_dim_out
# 拆分 hidden 为 heads × head_dim
x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2)
print(f"拆分后: {list(x.shape)}") # [2, 3, 2, 4]
# 对每个 head 分别做线性变换
x = einsum(x, w, "... hidden1, hidden1 hidden2 -> ... hidden2")
print(f"变换后: {list(x.shape)}") # [2, 3, 2, 4]
# 合并回去
x = rearrange(x, "... heads hidden2 -> ... (heads hidden2)")
print(f"合并后: {list(x.shape)}") # [2, 3, 8]
拆分后: [2, 3, 2, 4]
变换后: [2, 3, 2, 4]
合并后: [2, 3, 8]
建议:einops 写起来稍微长一点,但维度关系一目了然,特别适合多头注意力这种复杂的维度操作。实际项目中强烈推荐使用。
第五部分:计算量核算(FLOPs)
什么是 FLOP?
一个 FLOP = 一次浮点加法或乘法。
两个容易混淆的缩写:
| 缩写 | 含义 | 度量什么 |
|---|---|---|
| FLOPs | Floating-point Operations | 总计算量(一个数字) |
| FLOP/s | FLOPs per Second | 计算速度(硬件性能) |
量级直觉
| 模型 | 训练 FLOPs | 说明 |
|---|---|---|
| GPT-3 (2020) | 3.14 × 10²³ | |
| GPT-4 (2023) | ~2 × 10²⁵ | 估计值 |
| 美国报告门槛 | 1 × 10²⁶ | 超过需向政府报告(2025 已撤销) |
| 硬件 | 峰值 FLOP/s (bf16) |
|---|---|
| A100 | 312 TFLOP/s |
| H100 | ~990 TFLOP/s |
矩阵乘法的 FLOPs
矩阵乘法 (m×n) @ (n×p) → (m×p):每个输出元素需要 n 次乘法 + n 次加法 = 2n,共 m×p 个输出。
import time
B, D, K = 256, 512, 256
x = torch.randn(B, D)
w = torch.randn(D, K)
# 理论 FLOPs
flops = 2 * B * D * K
print(f"矩阵乘法 ({B}×{D}) @ ({D}×{K})")
print(f"理论 FLOPs: {flops:,}")
# 实际计时
start = time.time()
for _ in range(100):
y = x @ w
elapsed = (time.time() - start) / 100
print(f"平均耗时: {elapsed*1000:.2f} ms")
print(f"实际 FLOP/s: {flops / elapsed:.2e}")
矩阵乘法 (256×512) @ (512×256)
理论 FLOPs: 67,108,864
平均耗时: 0.14 ms
实际 FLOP/s: 4.84e+11
从矩阵乘法到 Transformer
对于一个线性层 y = x @ w:
- FLOPs = 2 × (数据点数) × (参数量)
这个结论可以推广到整个 Transformer(一阶近似):
| 阶段 | FLOPs |
|---|---|
| 前向传播 | 2 × B × N |
| 反向传播 | 4 × B × N |
| 总计 | 6 × B × N |
其中 B = 数据点数(batch × seq_len × num_tokens),N = 模型参数量。
这就是第一部分那个
6 × 70e9 × 15e12的来源!
Model FLOPs Utilization (MFU)
- MFU ≥ 0.5 算很好
- 矩阵乘法占比越高,MFU 越高
- 数据类型影响很大:bf16 比 float32 快得多
第六部分:梯度与反向传播的 FLOPs
前向 + 反向 = 6 × 数据 × 参数
我们用一个两层线性模型来推导为什么反向传播的 FLOPs 是前向的两倍。
x --[w1]--> h1 --[w2]--> h2 --> loss
为什么是 6 × 数据量 × 参数量?
我们来拆解这个公式的来源。
先用一个买菜的类比建立直觉:
假设你去菜市场,买了 3 样东西:
品类 单价 (元/kg) 数量 (kg) 番茄 5 3 白菜 2 4 鸡蛋 8 2 Forward(算总价):每样东西做一次 "单价 × 数量",然后加起来得到总价。这是一轮乘加运算 → 2N FLOPs。
Backward 1 — 对数量的敏感度(梯度传回前一层):老板说"总价多了 1 块",你想知道"减少哪样东西的数量影响最大?"答案就是看单价——番茄 5 块/kg 影响最大,白菜 2 块/kg 影响最小。这又是一轮乘加 → 2N FLOPs。对应 ,用来把误差继续往前传。
你可能会问: 是输入数据,又不需要更新,为什么还要算它的梯度?因为这一层的 就是上一层的输出 ——链式法则要求把梯度一层层传回去,否则前面层的权重就没法更新了。只有网络最开头的输入才真正不需要算梯度。
Backward 2 — 对单价的敏感度(更新权重):同样的问题反过来——"调整哪样东西的单价影响最大?"答案是看数量——白菜买了 4kg 影响最大,鸡蛋 2kg 影响最小。又是一轮乘加 → 2N FLOPs。对应 ,用来更新权重。
所以:Forward 1 轮 + Backward 2 轮 = 3 轮,每轮 2N → 总计 6N。
下面用数学语言精确表述:
前向传播:≈ 2N FLOPs
对于一个线性层 ,其中 的形状是 :
- 每个输出元素需要 次乘法 + 次加法 = 次运算
- 总共 个输出元素,所以是 FLOPs
- 而 正好就是这一层的参数量
把所有层加起来,整个模型的前向传播 ≈ FLOPs( 是总参数量)。
反向传播:≈ 4N FLOPs
反向传播对每一层需要计算两个梯度:
- 对输入的梯度(用于继续反向传播到前一层): → 一次矩阵乘法,≈ FLOPs
- 对权重的梯度(用于更新参数): → 又一次矩阵乘法,≈ FLOPs
两个加起来: FLOPs。
合计:前向 + 反向 = FLOPs(per data point)
乘以数据点数 ,总计 。
下面用一个两层线性模型 x → w1 → h1 → w2 → h2 → loss 来验证这个结论:
B, D, K = 2, 3, 2 # 小尺寸便于说明
# 各张量的 shape:
# x: (B, D) = (2, 3) — 输入数据
# w1: (D, D) = (3, 3) — 第一层权重
# h1: (B, D) = (2, 3) — 第一层输出 = x @ w1
# w2: (D, K) = (3, 2) — 第二层权重
# h2: (B, K) = (2, 2) — 第二层输出 = h1 @ w2 → 用来算 loss
print(f"x: ({B}, {D}) — 输入")
print(f"w1: ({D}, {D}) — 第一层权重, 参数量 = {D*D}")
print(f"h1: ({B}, {D}) — x @ w1 的输出")
print(f"w2: ({D}, {K}) — 第二层权重, 参数量 = {D*K}")
print(f"h2: ({B}, {K}) — h1 @ w2 的输出")
print()
# === 前向 FLOPs ===
# x @ w1: (B,D) @ (D,D) → 2*B*D*D
# h1 @ w2: (B,D) @ (D,K) → 2*B*D*K
forward_flops = (2 * B * D * D) + (2 * B * D * K)
print(f"前向 FLOPs:")
print(f" x @ w1: 2×{B}×{D}×{D} = {2*B*D*D}")
print(f" h1 @ w2: 2×{B}×{D}×{K} = {2*B*D*K}")
print(f" 合计: {forward_flops}")
print()
# === 反向 FLOPs ===
backward_flops = 0
# 1. w2.grad = h1.T @ h2.grad → (D,B) @ (B,K) → 2*B*D*K
backward_flops += 2 * B * D * K
print(f"反向 FLOPs:")
print(f" w2.grad = h1.T @ h2.grad: 2×{B}×{D}×{K} = {2*B*D*K}")
# 2. h1.grad = h2.grad @ w2.T → (B,K) @ (K,D) → 2*B*D*K
backward_flops += 2 * B * D * K
print(f" h1.grad = h2.grad @ w2.T: 2×{B}×{D}×{K} = {2*B*D*K}")
# 3. w1.grad = x.T @ h1.grad → (D,B) @ (B,D) → 2*B*D*D
# x.grad = h1.grad @ w1.T → (B,D) @ (D,D) → 2*B*D*D (不需要但链式法则会算)
backward_flops += (2 + 2) * B * D * D
print(f" w1.grad + x.grad: 4×{B}×{D}×{D} = {4*B*D*D}")
print(f" 合计: {backward_flops}")
print()
print(f"前向: {forward_flops}, 反向: {backward_flops}")
print(f"反向 / 前向 = {backward_flops / forward_flops:.1f}x")
print(f"总计 = {forward_flops + backward_flops}")
print(f" ≈ 6 × B × params = {6 * B * (D*D + D*K)}")
x: (2, 3) — 输入
w1: (3, 3) — 第一层权重, 参数量 = 9
h1: (2, 3) — x @ w1 的输出
w2: (3, 2) — 第二层权重, 参数量 = 6
h2: (2, 2) — h1 @ w2 的输出
前向 FLOPs:
x @ w1: 2×2×3×3 = 36
h1 @ w2: 2×2×3×2 = 24
合计: 60
反向 FLOPs:
w2.grad = h1.T @ h2.grad: 2×2×3×2 = 24
h1.grad = h2.grad @ w2.T: 2×2×3×2 = 24
w1.grad + x.grad: 4×2×3×3 = 72
合计: 120
前向: 60, 反向: 120
反向 / 前向 = 2.0x
总计 = 180
≈ 6 × B × params = 180
关键结论:反向传播的 FLOPs 是前向的 2 倍。总计 = 6 × 数据点数 × 参数量。
这也解释了为什么推理(只做前向)比训练便宜得多——训练每步要做 3 倍于推理的计算。
第七部分:模型定义
参数初始化的重要性
随机初始化如果不做缩放,输出的数量级会随输入维度增大而爆炸——这会导致梯度爆炸。
# 不缩放的问题
input_dim = 1000
x = torch.randn(1, input_dim)
# 不缩放:每个输出元素 ~ sum of input_dim 个 N(0,1)*N(0,1)
# 方差 = input_dim,标准差 = sqrt(input_dim) ≈ 31.6
w_bad = torch.randn(input_dim, 4)
out_bad = x @ w_bad
print(f"不缩放 (input_dim={input_dim}):")
print(f" 输出: {out_bad.data[0]}")
print(f" 量级 ≈ sqrt({input_dim}) = {input_dim**0.5:.1f}")
# 缩放(Xavier 初始化的核心思想)
w_good = torch.randn(input_dim, 4) / (input_dim ** 0.5)
out_good = x @ w_good
print(f"\n缩放后 (÷ sqrt(input_dim)):")
print(f" 输出: {out_good.data[0]}")
print(f" 量级 ≈ O(1) ✓")
不缩放 (input_dim=1000):
输出: tensor([-27.6163, 17.9119, 26.9574, 17.6714])
量级 ≈ sqrt(1000) = 31.6
缩放后 (÷ sqrt(input_dim)):
输出: tensor([ 1.3417, -0.2218, -0.9940, 0.2977])
量级 ≈ O(1) ✓
用 nn.Module 定义模型
class DeepLinearModel(nn.Module):
"""简单的多层线性模型(用于演示训练流程)"""
def __init__(self, input_dim, hidden_dim, num_layers):
super().__init__()
layers = []
for i in range(num_layers):
d_in = input_dim if i == 0 else hidden_dim
w = nn.Parameter(torch.randn(d_in, hidden_dim) / d_in ** 0.5)
layers.append(w)
self.layers = nn.ParameterList(layers)
self.bias = nn.Parameter(torch.zeros(hidden_dim))
def forward(self, x):
for w in self.layers:
x = x @ w
return x + self.bias
# 创建模型
D = 3
model = DeepLinearModel(input_dim=D, hidden_dim=D, num_layers=4)
# 数参数量
num_params = sum(p.numel() for p in model.parameters())
print(f"参数量: {num_params}")
print(f"内存 (float32): {num_params * 4} 字节")
# 前向传播
x = torch.randn(2, D)
y = model(x)
print(f"\n输入 shape: {list(x.shape)}")
print(f"输出 shape: {list(y.shape)}")
参数量: 39
内存 (float32): 156 字节
输入 shape: [2, 3]
输出 shape: [2, 3]
第八部分:完整训练流程
随机种子
训练中随机性无处不在(初始化、dropout、数据顺序)。为每个随机源设置独立的种子,这样可以精确复现结果。
import random
import numpy as np
def set_seed(seed):
"""设置所有随机种子"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(42)
print("随机种子已设置为 42")
随机种子已设置为 42
数据加载
语言模型的训练数据 = 分词后的整数序列。数据量通常很大(LLaMA 的训练数据 ~2.8TB),不能一次性加载。
# 模拟:生成一个简单的回归数据集
set_seed(42)
D = 3
true_weights = torch.arange(D, dtype=torch.float32) # [0, 1, 2]
print(f"真实权重: {true_weights}")
def get_batch(batch_size, dim, seed=None):
"""生成一个 batch 的数据"""
if seed is not None:
torch.manual_seed(seed)
x = torch.randn(batch_size, dim)
y = x @ true_weights # 真实标签
return x, y
x, y = get_batch(4, D, seed=0)
print(f"\n输入 x shape: {list(x.shape)}")
print(f"标签 y shape: {list(y.shape)}")
print(f"标签 y: {y}")
真实权重: tensor([0., 1., 2.])
输入 x shape: [4, 3]
标签 y shape: [4]
标签 y: tensor([-4.6510, -3.8817, -0.6005, -0.2326])
优化器
优化器的演进路线:
| 优化器 | 核心思想 |
|---|---|
| SGD | 最基本:w -= lr * grad |
| SGD + Momentum | 加上梯度的指数移动平均 |
| AdaGrad | 用历史梯度平方和做自适应学习率 |
| RMSProp | AdaGrad + 指数衰减(避免学习率单调递减) |
| Adam | RMSProp + Momentum = 一阶矩 + 二阶矩 |
| AdamW | Adam + 解耦的 weight decay |
训练循环的内存核算
D, num_layers, B = 3, 4, 2
num_parameters = D * D * num_layers + D # 权重 + 偏置
num_activations = B * D * num_layers # 每层的中间结果
num_gradients = num_parameters # 和参数一样多
num_optimizer_states = num_parameters # AdaGrad: 一个状态变量
total_memory = 4 * (num_parameters + num_activations + num_gradients + num_optimizer_states)
print(f"参数: {num_parameters} 个 → {num_parameters * 4} 字节")
print(f"激活值: {num_activations} 个 → {num_activations * 4} 字节")
print(f"梯度: {num_gradients} 个 → {num_gradients * 4} 字节")
print(f"优化器状态: {num_optimizer_states} 个 → {num_optimizer_states * 4} 字节")
print(f"总内存: {total_memory} 字节")
flops_per_step = 6 * B * num_parameters
print(f"\n每步 FLOPs: {flops_per_step}")
参数: 39 个 → 156 字节
激活值: 24 个 → 96 字节
梯度: 39 个 → 156 字节
优化器状态: 39 个 → 156 字节
总内存: 564 字节
每步 FLOPs: 468
完整训练循环
set_seed(42)
# 模型 + 优化器
model = DeepLinearModel(input_dim=D, hidden_dim=D, num_layers=2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
# 训练
for step in range(200):
x, y_true = get_batch(batch_size=32, dim=D, seed=step)
# 前向
y_pred = model(x).sum(dim=-1) # 简化:对输出求和
loss = ((y_pred - y_true) ** 2).mean()
# 反向
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 50 == 0:
print(f"step {step:3d} | loss = {loss.item():.4f}")
print(f"\n最终 loss: {loss.item():.6f}")
step 0 | loss = 0.3104
step 50 | loss = 0.0009
step 100 | loss = 0.0000
step 150 | loss = 0.0000
最终 loss: 0.000000
第九部分:Checkpoint 与混合精度训练
Checkpoint(检查点)
训练大模型动辄几天到几个月,中途崩溃是家常便饭。定期保存模型和优化器状态 = 保险。
# 保存
torch.save({
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"step": 200,
}, "/tmp/checkpoint.pt")
print("Checkpoint 已保存")
# 加载
ckpt = torch.load("/tmp/checkpoint.pt", weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
print(f"从 step {ckpt['step']} 恢复训练")
Checkpoint 已保存
从 step 200 恢复训练
混合精度训练
| 精度 | 优点 | 缺点 |
|---|---|---|
| float32 | 稳定 | 慢,内存大 |
| bfloat16 | 快,内存小 | 可能不稳定 |
混合精度的策略:两者结合
- 前向传播(激活值):用 bfloat16 / fp8——这里精度要求低,但数据量大
- 参数和梯度:保留 float32——精度要求高,但相对数据量小
- 额外维护一份 float32 的参数副本(master weights)
PyTorch 提供了 torch.amp(Automatic Mixed Precision)自动管理:
# 伪代码
scaler = torch.amp.GradScaler()
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(input) # 前向用 bf16
loss = criterion(output, target)
scaler.scale(loss).backward() # 反向
scaler.step(optimizer)
scaler.update()
NVIDIA 的 Transformer Engine 还支持 FP8 混合精度(论文),可进一步减少内存和加速计算。
本讲总结
资源核算速查表
| 项目 | 公式 |
|---|---|
| 张量内存 | 元素数 × 每元素字节数 |
| 矩阵乘法 FLOPs | 2 × m × n × p |
| 前向 FLOPs | 2 × B × N |
| 反向 FLOPs | 4 × B × N |
| 训练总 FLOPs | 6 × B × N |
| 训练内存(AdamW, fp32) | 16 × N + 激活值 |
| MFU | 实际 FLOP/s ÷ 峰值 FLOP/s |
数据类型选择
| 场景 | 推荐 |
|---|---|
| 默认/调试 | float32 |
| 训练前向 | bfloat16 |
| 参数/梯度 | float32(master weights) |
| 最激进 | fp8(需要 H100) |
关键概念检查
- 为什么 bfloat16 比 float16 更适合深度学习?(动态范围 = float32,不会下溢)
x.view(3, 2)和x.reshape(3, 2)有什么区别?(view 必须连续,不拷贝;reshape 可能拷贝)- 反向传播的 FLOPs 是前向的几倍?(2 倍)
- 为什么初始化权重要除以
sqrt(input_dim)?(保持输出方差为 O(1),防止梯度爆炸) - 混合精度训练的核心思路是什么?(前向用低精度省内存和加速,参数用高精度保稳定)