CS336 Assignment 1

0 阅读8分钟

从零实现语言模型:CS336 Assignment 1 基础组件学习总结

最近开始学习 Stanford CS336: Language Modeling from Scratch。这门课的核心不是调用现成大模型 API,而是从底层实现一个语言模型训练系统。课程官网说明,Assignment 1 的目标是实现训练标准 Transformer Language Model 所需的 tokenizer、模型结构和 optimizer,并训练一个最小语言模型。(Stanford CS336) 官方仓库也提供了 assignment1-basics 的学生版代码,里面大量函数初始状态都是 NotImplementedError,需要自己补实现。(GitHub)

这篇文章记录我目前已经学习和实现的内容,包括:

1. 基础张量函数
2. 语言模型训练数据 batch 构造
3. softmax 与 cross entropy
4. RMSNorm
5. SwiGLU
6. BPE tokenizer 的 encode/decode 基础逻辑

1. 项目启动:先跑测试

官方代码的结构大致是:

assignment1-basics/
├── cs336_basics/
│   ├── __init__.py
│   └── pretokenization_example.py
├── tests/
│   ├── adapters.py
│   ├── test_model.py
│   ├── test_nn_utils.py
│   ├── test_data.py
│   ├── test_tokenizer.py
│   └── test_train_bpe.py

一开始很多函数都在 tests/adapters.py 里:

def run_silu(...):
    raise NotImplementedError

也就是说,测试会通过 adapters.py 调用我们自己的实现。学习过程不是先写一个完整模型,而是逐个补齐基础组件,然后通过测试验证。

运行测试:

uv run pytest tests/test_nn_utils.py -q
uv run pytest tests/test_data.py -q
uv run pytest tests/test_model.py -q

我已经通过的基础测试包括:

softmax
cross_entropy
gradient_clipping
linear
embedding
silu
get_batch
rmsnorm
swiglu

前面测试日志里,softmax / cross_entropy / gradient_clipping / linear / embedding / silu 已经通过,run_get_batch 当时因为仍然是 NotImplementedError 而失败。 后续补完 run_get_batch 后,这条基础数据构造线也完成了。


2. SiLU:最简单的激活函数

SiLU 的公式是:

SiLU(x) = x * sigmoid(x)

它是逐元素操作,不改变 tensor 的 shape。

def run_silu(in_features):
    return in_features * torch.sigmoid(in_features)

例子:

x = torch.tensor([-1.0, 0.0, 1.0])
y = run_silu(x)

print(y)
# tensor([-0.2689, 0.0000, 0.7311])

直观理解:

输入是什么 shape,输出还是什么 shape;
只是每个数都经过 x * sigmoid(x) 变换。

3. Linear:线性层

run_linear 做的是:

y = x @ W.T

其中:

输入 x:       [..., d_in]
权重 weights: [d_out, d_in]
输出 y:       [..., d_out]

实现:

def run_linear(d_in, d_out, weights, in_features):
    return in_features @ weights.T

例子:

in_features.shape = [2, 5, 3]
weights.shape = [4, 3]

output = in_features @ weights.T
output.shape = [2, 5, 4]

也就是说,最后一维从 d_in=3 变成 d_out=4


4. Embedding:token id 查表

语言模型不能直接处理整数 token id,所以要先把 token id 转成向量。

def run_embedding(vocab_size, d_model, weights, token_ids):
    return weights[token_ids]

其中:

weights.shape = [vocab_size, d_model]
token_ids.shape = [batch, seq]
output.shape = [batch, seq, d_model]

例子:

token_ids = torch.tensor([
    [2, 5, 7],
    [1, 3, 4],
])

weights.shape = [10, 8]

output = weights[token_ids]
output.shape
# [2, 3, 8]

含义:

2 个样本
每个样本 3 个 token
每个 token 被查成 8 维向量

5. Softmax:把 logits 变成概率

模型输出的 logits 只是原始分数,不是概率。softmax 的作用是把一组分数变成概率分布。

公式:

softmax(x_i) = exp(x_i) / sum(exp(x_j))

为了防止数值溢出,需要先减去最大值:

def run_softmax(in_features, dim):
    shifted = in_features - torch.max(in_features, dim=dim, keepdim=True).values
    exp = torch.exp(shifted)
    return exp / torch.sum(exp, dim=dim, keepdim=True)

例子:

x = torch.tensor([2.0, 1.0, 0.0])

softmax(x)
# 大约 [0.665, 0.245, 0.090]

这表示模型认为:

第 0 类概率最大
第 1 类次之
第 2 类最小

6. Cross Entropy:只惩罚正确答案概率不够高

Cross entropy 可以理解为:

模型有没有把正确答案的概率打高?

输入:

inputs:  [batch_size, vocab_size]  # 模型 logits
targets: [batch_size]              # 正确类别编号

实现:

def run_cross_entropy(inputs, targets):
    shifted = inputs - torch.max(inputs, dim=-1, keepdim=True).values
    log_probs = shifted - torch.log(torch.sum(torch.exp(shifted), dim=-1, keepdim=True))

    batch_indices = torch.arange(targets.shape[0], device=targets.device)
    correct_log_probs = log_probs[batch_indices, targets]

    return -torch.mean(correct_log_probs)

关键逻辑:

correct_log_probs = log_probs[batch_indices, targets]

这句的意思是:

0 行,取 target[0] 指定的那一列
第 1 行,取 target[1] 指定的那一列
第 2 行,取 target[2] 指定的那一列
……

举例:

log_probs = torch.tensor([
    [-0.4076, -1.4076, -2.4076],
    [-3.1698, -0.1698, -2.1698],
])

targets = torch.tensor([0, 1])
batch_indices = torch.tensor([0, 1])

correct_log_probs = log_probs[batch_indices, targets]
# tensor([-0.4076, -0.1698])

最后:

loss = -torch.mean(correct_log_probs)

也就是:

正确答案概率越高,loss 越小;
正确答案概率越低,loss 越大。

7. get_batch:构造语言模型训练样本

语言模型训练的目标是 next-token prediction。也就是:

输入前面的 token,预测下一个 token。

实现:

def run_get_batch(dataset, batch_size, context_length, device):
    max_start = len(dataset) - context_length

    start_indices = torch.randint(
        low=0,
        high=max_start,
        size=(batch_size,),
    )

    x = torch.stack([
        torch.as_tensor(dataset[i : i + context_length], dtype=torch.long)
        for i in start_indices.tolist()
    ])

    y = torch.stack([
        torch.as_tensor(dataset[i + 1 : i + 1 + context_length], dtype=torch.long)
        for i in start_indices.tolist()
    ])

    return x.to(device), y.to(device)

例子:

dataset = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
context_length = 4
batch_size = 3

start_indices = [2, 0, 5]

那么:

x = [
    [2, 3, 4, 5],
    [0, 1, 2, 3],
    [5, 6, 7, 8],
]

y = [
    [3, 4, 5, 6],
    [1, 2, 3, 4],
    [6, 7, 8, 9],
]

核心规律:

x 从 i 开始取 context_length 个 token
y 从 i+1 开始取 context_length 个 token

也就是:

x = dataset[i     : i     + context_length]
y = dataset[i + 1 : i + 1 + context_length]

8. RMSNorm:稳定每个 token 向量的数值规模

RMSNorm 的作用是归一化每个 token 的向量,让数值规模更稳定。

公式:

RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight

实现:

def run_rmsnorm(d_model, eps, weights, in_features):
    rms = torch.sqrt(torch.mean(in_features ** 2, dim=-1, keepdim=True) + eps)
    return in_features / rms * weights

输入输出:

输入: [..., d_model]
输出: [..., d_model]

例子:

in_features.shape = [2, 5, 8]
weights.shape = [8]

output.shape = [2, 5, 8]

RMSNorm 不改变 shape,只改变数值分布。


9. SwiGLU:Transformer 里的前馈网络

SwiGLU 是 Transformer 中 FFN 的一种形式。它不是 attention。

attention 负责 token 之间互相看;
SwiGLU 负责每个 token 自己内部的向量变换。

公式:

SwiGLU(x) = W2( SiLU(W1x) * W3x )

实现:

def run_swiglu(
    d_model,
    d_ff,
    w1_weight,
    w2_weight,
    w3_weight,
    in_features,
):
    w1_x = in_features @ w1_weight.T
    w3_x = in_features @ w3_weight.T

    silu_w1_x = w1_x * torch.sigmoid(w1_x)

    hidden = silu_w1_x * w3_x

    return hidden @ w2_weight.T

维度例子:

d_model = 4
d_ff = 12
in_features.shape = [2, 5, 4]

w1_weight.shape = [12, 4]
w3_weight.shape = [12, 4]
w2_weight.shape = [4, 12]

计算过程:

x:              [2, 5, 4]
W1x:            [2, 5, 12]
W3x:            [2, 5, 12]
SiLU(W1x):      [2, 5, 12]
相乘:            [2, 5, 12]
W2 降维:          [2, 5, 4]

输出仍然是:

[2, 5, 4]

原因是 Transformer block 里有残差连接:

x = x + FFN(x)

所以 FFN(x) 必须和 x 形状一致。


10. Tokenizer:从字符串到 token id

模型不能直接吃字符串,需要先经过 tokenizer:

text -> bytes -> BPE merge -> token ids

我手动创建了:

cs336_basics/tokenizer.py

基础结构:

import regex as re


class BPETokenizer:
    def __init__(self, vocab, merges, special_tokens=None):
        self.vocab = vocab
        self.merges = merges
        self.special_tokens = special_tokens or []

        # bytes -> token id
        self.vocab_inv = {v: k for k, v in vocab.items()}

        # merge pair -> rank
        self.merge_ranks = {pair: i for i, pair in enumerate(merges)}

        self.pat = re.compile(
            r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
        )

其中:

self.merge_ranks = {pair: i for i, pair in enumerate(merges)}

作用是把 BPE 合并规则变成一个“优先级查询表”。

比如:

merges = [
    (b"l", b"o"),
    (b"lo", b"w"),
    (b"e", b"s"),
]

会变成:

merge_ranks = {
    (b"l", b"o"): 0,
    (b"lo", b"w"): 1,
    (b"e", b"s"): 2,
}

rank 越小,合并优先级越高。


11. BPE encode 的核心逻辑

BPE encode 的核心函数是:

def _bpe_encode_bytes(self, token_bytes: bytes) -> list[int]:
    parts = [bytes([b]) for b in token_bytes]

    if len(parts) == 0:
        return []

    while True:
        pairs = [(parts[i], parts[i + 1]) for i in range(len(parts) - 1)]

        candidate_pairs = [
            pair for pair in pairs
            if pair in self.merge_ranks
        ]

        if not candidate_pairs:
            break

        best_pair = min(candidate_pairs, key=lambda pair: self.merge_ranks[pair])

        new_parts = []
        i = 0
        while i < len(parts):
            if (
                i < len(parts) - 1
                and parts[i] == best_pair[0]
                and parts[i + 1] == best_pair[1]
            ):
                new_parts.append(parts[i] + parts[i + 1])
                i += 2
            else:
                new_parts.append(parts[i])
                i += 1

        parts = new_parts

    return [self.vocab_inv[p] for p in parts]

核心原理:

1. 先把输入 bytes 拆成单个 byte
2. 找所有相邻 pair
3. 判断哪些 pair 在 merge 规则里
4. 选 rank 最小的 pair
5. 合并这个 pair
6. 重复,直到不能再合并
7. 查 vocab_inv,把 bytes token 变成 token id

例子:

token_bytes = b"lowest"

初始:

[b"l", b"o", b"w", b"e", b"s", b"t"]

假设 merges 是:

[    (b"l", b"o"),    (b"lo", b"w"),    (b"e", b"s"),    (b"es", b"t"),    (b"low", b"est"),]

合并过程:

1 轮:l + o       -> lo
[b"lo", b"w", b"e", b"s", b"t"]

第 2 轮:lo + w      -> low
[b"low", b"e", b"s", b"t"]

第 3 轮:e + s       -> es
[b"low", b"es", b"t"]

第 4 轮:es + t      -> est
[b"low", b"est"]

第 5 轮:low + est   -> lowest
[b"lowest"]

最后:

[b"lowest"] -> [token_id]

12. encode:把完整字符串转成 token ids

完整字符串不能直接送进 BPE,需要先做预分词:

def encode(self, text: str) -> list[int]:
    ids = []

    for match in self.pat.finditer(text):
        piece = match.group(0)
        piece_bytes = piece.encode("utf-8")
        ids.extend(self._bpe_encode_bytes(piece_bytes))

    return ids

这段代码的人话版:

准备一个空列表 ids。

用正则把 text 切成一小段一小段。

每一小段:
    取出字符串 piece
    转成 UTF-8 bytes
    用 BPE 规则编码成 token ids
    把这些 ids 加到总列表里

最后返回总 ids。

例子:

text = "Hello world!"

预分词可能得到:

"Hello"
" world"
"!"

然后:

"Hello"  -> b"Hello"  -> BPE -> [15496]
" world" -> b" world" -> BPE -> [995]
"!"      -> b"!"      -> BPE -> [0]

最终:

[15496, 995, 0]

13. decode:从 token ids 还原文本

decode 做的是反向过程:

token ids -> bytes -> string

实现:

def decode(self, ids: list[int]) -> str:
    token_bytes = b"".join(self.vocab[i] for i in ids)
    return token_bytes.decode("utf-8", errors="replace")

例子:

vocab = {
    0: b"h",
    1: b"e",
    2: b"l",
    3: b"o",
    4: b"he",
    5: b"ll",
}

ids = [4, 5, 3]

查表:

4 -> b"he"
5 -> b"ll"
3 -> b"o"

拼起来:

b"he" + b"ll" + b"o" = b"hello"

最后 decode:

"hello"

14. 当前学习进度总结

目前我已经理解并实现了语言模型训练中的一批基础组件:

1. SiLU:逐元素激活函数
2. Linear:线性变换
3. Embedding:token id 查表成向量
4. Softmax:把 logits 变成概率
5. Cross Entropy:计算正确答案概率对应的损失
6. Gradient Clipping:限制梯度范数
7. get_batch:构造 next-token prediction 训练样本
8. RMSNorm:稳定 token 向量数值规模
9. SwiGLU:Transformer FFN 模块
10. BPE tokenizer encode/decode 基础逻辑

现在我对语言模型训练主线的理解是:

文本
-> tokenizer.encode
-> token ids
-> get_batch
-> x, y
-> embedding
-> Transformer blocks
-> logits
-> cross_entropy
-> optimizer update

15. 下一步计划

接下来继续推进 A1:

1. 完善 tokenizer 对 special_tokens 的处理
2. 实现 encode_iterable
3. 实现 train_bpe
4. 进入 scaled dot-product attention
5. 实现 multi-head attention
6. 实现 RoPE
7. 实现 TransformerBlock
8. 实现完整 TransformerLM