从零实现语言模型:CS336 Assignment 1 基础组件学习总结
最近开始学习 Stanford CS336: Language Modeling from Scratch。这门课的核心不是调用现成大模型 API,而是从底层实现一个语言模型训练系统。课程官网说明,Assignment 1 的目标是实现训练标准 Transformer Language Model 所需的 tokenizer、模型结构和 optimizer,并训练一个最小语言模型。(Stanford CS336) 官方仓库也提供了 assignment1-basics 的学生版代码,里面大量函数初始状态都是 NotImplementedError,需要自己补实现。(GitHub)
这篇文章记录我目前已经学习和实现的内容,包括:
1. 基础张量函数
2. 语言模型训练数据 batch 构造
3. softmax 与 cross entropy
4. RMSNorm
5. SwiGLU
6. BPE tokenizer 的 encode/decode 基础逻辑
1. 项目启动:先跑测试
官方代码的结构大致是:
assignment1-basics/
├── cs336_basics/
│ ├── __init__.py
│ └── pretokenization_example.py
├── tests/
│ ├── adapters.py
│ ├── test_model.py
│ ├── test_nn_utils.py
│ ├── test_data.py
│ ├── test_tokenizer.py
│ └── test_train_bpe.py
一开始很多函数都在 tests/adapters.py 里:
def run_silu(...):
raise NotImplementedError
也就是说,测试会通过 adapters.py 调用我们自己的实现。学习过程不是先写一个完整模型,而是逐个补齐基础组件,然后通过测试验证。
运行测试:
uv run pytest tests/test_nn_utils.py -q
uv run pytest tests/test_data.py -q
uv run pytest tests/test_model.py -q
我已经通过的基础测试包括:
softmax
cross_entropy
gradient_clipping
linear
embedding
silu
get_batch
rmsnorm
swiglu
前面测试日志里,softmax / cross_entropy / gradient_clipping / linear / embedding / silu 已经通过,run_get_batch 当时因为仍然是 NotImplementedError 而失败。 后续补完 run_get_batch 后,这条基础数据构造线也完成了。
2. SiLU:最简单的激活函数
SiLU 的公式是:
SiLU(x) = x * sigmoid(x)
它是逐元素操作,不改变 tensor 的 shape。
def run_silu(in_features):
return in_features * torch.sigmoid(in_features)
例子:
x = torch.tensor([-1.0, 0.0, 1.0])
y = run_silu(x)
print(y)
# tensor([-0.2689, 0.0000, 0.7311])
直观理解:
输入是什么 shape,输出还是什么 shape;
只是每个数都经过 x * sigmoid(x) 变换。
3. Linear:线性层
run_linear 做的是:
y = x @ W.T
其中:
输入 x: [..., d_in]
权重 weights: [d_out, d_in]
输出 y: [..., d_out]
实现:
def run_linear(d_in, d_out, weights, in_features):
return in_features @ weights.T
例子:
in_features.shape = [2, 5, 3]
weights.shape = [4, 3]
output = in_features @ weights.T
output.shape = [2, 5, 4]
也就是说,最后一维从 d_in=3 变成 d_out=4。
4. Embedding:token id 查表
语言模型不能直接处理整数 token id,所以要先把 token id 转成向量。
def run_embedding(vocab_size, d_model, weights, token_ids):
return weights[token_ids]
其中:
weights.shape = [vocab_size, d_model]
token_ids.shape = [batch, seq]
output.shape = [batch, seq, d_model]
例子:
token_ids = torch.tensor([
[2, 5, 7],
[1, 3, 4],
])
weights.shape = [10, 8]
output = weights[token_ids]
output.shape
# [2, 3, 8]
含义:
2 个样本
每个样本 3 个 token
每个 token 被查成 8 维向量
5. Softmax:把 logits 变成概率
模型输出的 logits 只是原始分数,不是概率。softmax 的作用是把一组分数变成概率分布。
公式:
softmax(x_i) = exp(x_i) / sum(exp(x_j))
为了防止数值溢出,需要先减去最大值:
def run_softmax(in_features, dim):
shifted = in_features - torch.max(in_features, dim=dim, keepdim=True).values
exp = torch.exp(shifted)
return exp / torch.sum(exp, dim=dim, keepdim=True)
例子:
x = torch.tensor([2.0, 1.0, 0.0])
softmax(x)
# 大约 [0.665, 0.245, 0.090]
这表示模型认为:
第 0 类概率最大
第 1 类次之
第 2 类最小
6. Cross Entropy:只惩罚正确答案概率不够高
Cross entropy 可以理解为:
模型有没有把正确答案的概率打高?
输入:
inputs: [batch_size, vocab_size] # 模型 logits
targets: [batch_size] # 正确类别编号
实现:
def run_cross_entropy(inputs, targets):
shifted = inputs - torch.max(inputs, dim=-1, keepdim=True).values
log_probs = shifted - torch.log(torch.sum(torch.exp(shifted), dim=-1, keepdim=True))
batch_indices = torch.arange(targets.shape[0], device=targets.device)
correct_log_probs = log_probs[batch_indices, targets]
return -torch.mean(correct_log_probs)
关键逻辑:
correct_log_probs = log_probs[batch_indices, targets]
这句的意思是:
第 0 行,取 target[0] 指定的那一列
第 1 行,取 target[1] 指定的那一列
第 2 行,取 target[2] 指定的那一列
……
举例:
log_probs = torch.tensor([
[-0.4076, -1.4076, -2.4076],
[-3.1698, -0.1698, -2.1698],
])
targets = torch.tensor([0, 1])
batch_indices = torch.tensor([0, 1])
correct_log_probs = log_probs[batch_indices, targets]
# tensor([-0.4076, -0.1698])
最后:
loss = -torch.mean(correct_log_probs)
也就是:
正确答案概率越高,loss 越小;
正确答案概率越低,loss 越大。
7. get_batch:构造语言模型训练样本
语言模型训练的目标是 next-token prediction。也就是:
输入前面的 token,预测下一个 token。
实现:
def run_get_batch(dataset, batch_size, context_length, device):
max_start = len(dataset) - context_length
start_indices = torch.randint(
low=0,
high=max_start,
size=(batch_size,),
)
x = torch.stack([
torch.as_tensor(dataset[i : i + context_length], dtype=torch.long)
for i in start_indices.tolist()
])
y = torch.stack([
torch.as_tensor(dataset[i + 1 : i + 1 + context_length], dtype=torch.long)
for i in start_indices.tolist()
])
return x.to(device), y.to(device)
例子:
dataset = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
context_length = 4
batch_size = 3
start_indices = [2, 0, 5]
那么:
x = [
[2, 3, 4, 5],
[0, 1, 2, 3],
[5, 6, 7, 8],
]
y = [
[3, 4, 5, 6],
[1, 2, 3, 4],
[6, 7, 8, 9],
]
核心规律:
x 从 i 开始取 context_length 个 token
y 从 i+1 开始取 context_length 个 token
也就是:
x = dataset[i : i + context_length]
y = dataset[i + 1 : i + 1 + context_length]
8. RMSNorm:稳定每个 token 向量的数值规模
RMSNorm 的作用是归一化每个 token 的向量,让数值规模更稳定。
公式:
RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight
实现:
def run_rmsnorm(d_model, eps, weights, in_features):
rms = torch.sqrt(torch.mean(in_features ** 2, dim=-1, keepdim=True) + eps)
return in_features / rms * weights
输入输出:
输入: [..., d_model]
输出: [..., d_model]
例子:
in_features.shape = [2, 5, 8]
weights.shape = [8]
output.shape = [2, 5, 8]
RMSNorm 不改变 shape,只改变数值分布。
9. SwiGLU:Transformer 里的前馈网络
SwiGLU 是 Transformer 中 FFN 的一种形式。它不是 attention。
attention 负责 token 之间互相看;
SwiGLU 负责每个 token 自己内部的向量变换。
公式:
SwiGLU(x) = W2( SiLU(W1x) * W3x )
实现:
def run_swiglu(
d_model,
d_ff,
w1_weight,
w2_weight,
w3_weight,
in_features,
):
w1_x = in_features @ w1_weight.T
w3_x = in_features @ w3_weight.T
silu_w1_x = w1_x * torch.sigmoid(w1_x)
hidden = silu_w1_x * w3_x
return hidden @ w2_weight.T
维度例子:
d_model = 4
d_ff = 12
in_features.shape = [2, 5, 4]
w1_weight.shape = [12, 4]
w3_weight.shape = [12, 4]
w2_weight.shape = [4, 12]
计算过程:
x: [2, 5, 4]
W1x: [2, 5, 12]
W3x: [2, 5, 12]
SiLU(W1x): [2, 5, 12]
相乘: [2, 5, 12]
W2 降维: [2, 5, 4]
输出仍然是:
[2, 5, 4]
原因是 Transformer block 里有残差连接:
x = x + FFN(x)
所以 FFN(x) 必须和 x 形状一致。
10. Tokenizer:从字符串到 token id
模型不能直接吃字符串,需要先经过 tokenizer:
text -> bytes -> BPE merge -> token ids
我手动创建了:
cs336_basics/tokenizer.py
基础结构:
import regex as re
class BPETokenizer:
def __init__(self, vocab, merges, special_tokens=None):
self.vocab = vocab
self.merges = merges
self.special_tokens = special_tokens or []
# bytes -> token id
self.vocab_inv = {v: k for k, v in vocab.items()}
# merge pair -> rank
self.merge_ranks = {pair: i for i, pair in enumerate(merges)}
self.pat = re.compile(
r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
其中:
self.merge_ranks = {pair: i for i, pair in enumerate(merges)}
作用是把 BPE 合并规则变成一个“优先级查询表”。
比如:
merges = [
(b"l", b"o"),
(b"lo", b"w"),
(b"e", b"s"),
]
会变成:
merge_ranks = {
(b"l", b"o"): 0,
(b"lo", b"w"): 1,
(b"e", b"s"): 2,
}
rank 越小,合并优先级越高。
11. BPE encode 的核心逻辑
BPE encode 的核心函数是:
def _bpe_encode_bytes(self, token_bytes: bytes) -> list[int]:
parts = [bytes([b]) for b in token_bytes]
if len(parts) == 0:
return []
while True:
pairs = [(parts[i], parts[i + 1]) for i in range(len(parts) - 1)]
candidate_pairs = [
pair for pair in pairs
if pair in self.merge_ranks
]
if not candidate_pairs:
break
best_pair = min(candidate_pairs, key=lambda pair: self.merge_ranks[pair])
new_parts = []
i = 0
while i < len(parts):
if (
i < len(parts) - 1
and parts[i] == best_pair[0]
and parts[i + 1] == best_pair[1]
):
new_parts.append(parts[i] + parts[i + 1])
i += 2
else:
new_parts.append(parts[i])
i += 1
parts = new_parts
return [self.vocab_inv[p] for p in parts]
核心原理:
1. 先把输入 bytes 拆成单个 byte
2. 找所有相邻 pair
3. 判断哪些 pair 在 merge 规则里
4. 选 rank 最小的 pair
5. 合并这个 pair
6. 重复,直到不能再合并
7. 查 vocab_inv,把 bytes token 变成 token id
例子:
token_bytes = b"lowest"
初始:
[b"l", b"o", b"w", b"e", b"s", b"t"]
假设 merges 是:
[ (b"l", b"o"), (b"lo", b"w"), (b"e", b"s"), (b"es", b"t"), (b"low", b"est"),]
合并过程:
第 1 轮:l + o -> lo
[b"lo", b"w", b"e", b"s", b"t"]
第 2 轮:lo + w -> low
[b"low", b"e", b"s", b"t"]
第 3 轮:e + s -> es
[b"low", b"es", b"t"]
第 4 轮:es + t -> est
[b"low", b"est"]
第 5 轮:low + est -> lowest
[b"lowest"]
最后:
[b"lowest"] -> [token_id]
12. encode:把完整字符串转成 token ids
完整字符串不能直接送进 BPE,需要先做预分词:
def encode(self, text: str) -> list[int]:
ids = []
for match in self.pat.finditer(text):
piece = match.group(0)
piece_bytes = piece.encode("utf-8")
ids.extend(self._bpe_encode_bytes(piece_bytes))
return ids
这段代码的人话版:
准备一个空列表 ids。
用正则把 text 切成一小段一小段。
每一小段:
取出字符串 piece
转成 UTF-8 bytes
用 BPE 规则编码成 token ids
把这些 ids 加到总列表里
最后返回总 ids。
例子:
text = "Hello world!"
预分词可能得到:
"Hello"
" world"
"!"
然后:
"Hello" -> b"Hello" -> BPE -> [15496]
" world" -> b" world" -> BPE -> [995]
"!" -> b"!" -> BPE -> [0]
最终:
[15496, 995, 0]
13. decode:从 token ids 还原文本
decode 做的是反向过程:
token ids -> bytes -> string
实现:
def decode(self, ids: list[int]) -> str:
token_bytes = b"".join(self.vocab[i] for i in ids)
return token_bytes.decode("utf-8", errors="replace")
例子:
vocab = {
0: b"h",
1: b"e",
2: b"l",
3: b"o",
4: b"he",
5: b"ll",
}
ids = [4, 5, 3]
查表:
4 -> b"he"
5 -> b"ll"
3 -> b"o"
拼起来:
b"he" + b"ll" + b"o" = b"hello"
最后 decode:
"hello"
14. 当前学习进度总结
目前我已经理解并实现了语言模型训练中的一批基础组件:
1. SiLU:逐元素激活函数
2. Linear:线性变换
3. Embedding:token id 查表成向量
4. Softmax:把 logits 变成概率
5. Cross Entropy:计算正确答案概率对应的损失
6. Gradient Clipping:限制梯度范数
7. get_batch:构造 next-token prediction 训练样本
8. RMSNorm:稳定 token 向量数值规模
9. SwiGLU:Transformer FFN 模块
10. BPE tokenizer encode/decode 基础逻辑
现在我对语言模型训练主线的理解是:
文本
-> tokenizer.encode
-> token ids
-> get_batch
-> x, y
-> embedding
-> Transformer blocks
-> logits
-> cross_entropy
-> optimizer update
15. 下一步计划
接下来继续推进 A1:
1. 完善 tokenizer 对 special_tokens 的处理
2. 实现 encode_iterable
3. 实现 train_bpe
4. 进入 scaled dot-product attention
5. 实现 multi-head attention
6. 实现 RoPE
7. 实现 TransformerBlock
8. 实现完整 TransformerLM