从零构建大语言模型 1:全景概览与分词

0 阅读27分钟

从零构建大语言模型 1:全景概览与分词

基于 Stanford CS336: Language Models From Scratch (Spring 2025) Lecture 1

课程主页:stanford-cs336.github.io/spring2025


本讲你将收获什么

读完这篇讲义,你将理解:

  1. 大语言模型(LLM)到底在做什么——它和你以为的"人工智能"有什么不同
  2. 为什么 2024 年了还要从零写一个语言模型,而不是直接调 API
  3. 从 1950 年代到今天,语言模型经历了怎样的技术演变
  4. 分词(Tokenization) 是怎么回事——为什么 GPT 看到的不是"文字"而是"数字"
  5. 动手实现一个 BPE 分词器——这是 GPT-2/3/4 使用的核心分词算法

前置知识

  • 基本的 Python 编程能力
  • 了解什么是神经网络(知道"参数""训练""损失函数"即可)
  • 不需要读过任何论文

第一部分:语言模型到底在做什么?

在深入技术之前,先建立一个准确的直觉。

一句话定义

语言模型 = 一个给"下一个词"打分的概率函数

给定前面的文字(上下文),语言模型会输出一个概率分布,告诉你接下来每个可能的词/token 出现的概率有多大。

输入: "今天天气真"
输出: {"好": 0.35, "不错": 0.20, "差": 0.08, "热": 0.15, ...}

训练的目标就是让模型学会:在大量文本上,尽可能准确地预测下一个 token。

这看起来很简单,但神奇之处在于——当你在海量文本上做好了"预测下一个词"这件事,模型会顺带学会语法、常识、推理、甚至写代码。

从"预测下一个词"到"聊天助手"

你可能好奇:ChatGPT 明明是在"对话",怎么变成了"预测下一个词"?

答案是生成过程

用户输入: "请解释什么是量子计算"
模型看到: "请解释什么是量子计算

"
模型预测: "量"     → 把"量"加到序列末尾
模型预测: "子"     → 把"子"加到序列末尾
模型预测: "计"     → ...
模型预测: "算"
模型预测: "是"
...(逐字生成,直到结束)

每一步都是在"预测下一个 token",拼接起来就形成了流畅的回答。后续我们会学到,要让模型从"预测下一个词"变成"有用的助手",还需要对齐(Alignment)——但那是第五讲的内容。


第二部分:为什么要从零构建?

研究者正在与底层技术脱节

这个趋势你可能已经感受到了:

时间研究者的典型工作方式类比
~2017自己实现并训练模型自己造发动机
~2019下载 BERT 等预训练模型进行微调买来发动机自己改装
今天调用 GPT-4 / Claude API直接打车

"打车"确实方便,但如果你想改进交通工具本身(做基础研究),你需要理解发动机是怎么工作的。

更关键的是:LLM 的抽象层是有漏洞的。什么意思?

  • 编程语言的抽象很好:你写 Python,不用管 CPU 指令集,代码照样跑
  • 但 LLM 的抽象不行:你不理解 tokenization,就不知道为什么 GPT 数不清 "strawberry" 里有几个 r;你不理解 attention 窗口,就不知道为什么超长文档会"失忆"

这门课的理念:通过亲手构建来获得真正的理解。

但有一个现实问题:前沿模型太大了

事实数据
GPT-4 参数量据传约 1.8 万亿(1.8T)参数
GPT-4 训练成本据传 ~1 亿美元
xAI Grok 集群200,000 块 H100 GPU
Stargate 投资4 年 5000 亿美元

参考:一块 H100 约 3 万美元。200,000 块 = 60 亿美元纯硬件。

而且前沿模型几乎不公开细节。GPT-4 技术报告里关于模型架构的描述基本为零:

GPT-4 Technical Report, Section 2:

"Given both the competitive landscape and the safety implications of large-scale models like GPT-4, this report contains no further details about the architecture (including model size), hardware, training compute, dataset construction, training method, or similar."

(鉴于竞争格局和大规模模型的安全影响,本报告不提供任何关于架构(包括模型大小)、硬件、训练算力、数据集构建、训练方法等细节。)

在课程中我们只能训练 < 10 亿参数的小模型——这和万亿参数的 GPT-4 差了 1000 倍以上。那我们学到的东西还有用吗?

哪些知识可以迁移,哪些不行?

类型含义举例迁移性
机制零件怎么工作Transformer 架构、注意力机制、分布式训练✅ 完全通用
思维方式怎么思考效率问题算 FLOPs、分析内存瓶颈、Scaling Laws✅ 完全通用
直觉什么选择效果好SwiGLU 比 ReLU 好?学习率多大合适?⚠️ 可能随规模变化

关于"直觉"——很多架构选择目前没有理论解释。举个真实的例子:

Noam Shazeer(Transformer 论文合著者)提出 SwiGLU 激活函数时,论文里写了这么一句话:

"We offer no explanation as to why these architectures seem to work; we attribute their success to divine benevolence."

(我们无法解释为什么这些架构有效;我们将其成功归功于上天的恩赐。)

论文原文 (GLU Variants Improve Transformer, Section 4 Conclusions):

"We offer no explanation as to why these architectures seem to work; we attribute their success to divine benevolence."

这不是开玩笑——深度学习中很多"best practice"就是试出来的,我们还不完全理解为什么。

效率才是核心竞争力

你可能听说过 Rich Sutton 的 "Bitter Lesson"(苦涩的教训)——一篇极具影响力的博文。很多人把它误读为"算力就是一切",但正确理解是:

能够利用规模(scale)的通用方法,最终会胜出。

换成公式:

模型质量=算法效率×计算资源\boxed{\text{模型质量} = \text{算法效率} \times \text{计算资源}}

这意味着什么?

  • 如果你计算资源少(穷),那效率就是你唯一的武器
  • 如果你计算资源多(富),效率更加重要——因为浪费 1% 可能就是几百万美元
  • 从 2012 到 2019 年,ImageNet 上的算法效率提升了 44 倍——这等于白送了 44 倍算力

这门课的核心问题:给定固定的计算预算和数据,如何训练出最好的模型?


第三部分:从 Shannon 到 ChatGPT——语言模型简史

理解历史有助于理解"为什么现在的技术是这样的"。我们按时间线快速梳理关键节点。

第一幕:统计时代(1950s–2010s)

1950 年,Claude Shannon 提出用数学模型衡量英语的信息量(熵)。这是"语言模型"这个概念的起源。

在之后的几十年里,n-gram 模型主导了语言建模:简单地统计"前 n-1 个词之后,下一个词的出现频率"。这种方法在机器翻译和语音识别中广泛使用,但它有天花板——无法捕捉长距离依赖。

第二幕:神经网络入场(2003–2017)

关键里程碑:

年份事件为什么重要
2003Bengio 提出神经语言模型第一次用神经网络做语言建模——词变成了向量
2014Seq2Seq + 注意力机制让模型能"关注"输入的不同部分,机器翻译大突破
2014Adam 优化器至今仍是最常用的优化器之一
2017Transformer 横空出世抛弃循环结构,全靠注意力——速度快、效果好、可并行

Transformer 的论文标题 "Attention Is All You Need" 已经成为深度学习史上最著名的一句话。它奠定了之后所有 LLM 的架构基础。

第三幕:预训练革命(2018–2019)

模型核心思路
ELMo用大量文本预训练 LSTM,然后迁移到下游任务
BERT用大量文本预训练 Transformer(双向),刷榜无数
GPT-2用大量文本训练单向 Transformer,发现模型能"零样本"完成任务

这个时期的核心发现:不需要针对每个任务单独训练模型——先在大量文本上预训练,模型就能学到通用的语言能力。

第四幕:大力出奇迹 → 闭源化(2020–2023)

模型参数量关键发现
GPT-3175B"In-context learning"——不用微调,只靠 prompt 就能做任务
PaLM540B规模上 540B,但训练数据量不够(undertrained)
Chinchilla70B重要发现:模型大小和数据量应该同步扩大

Chinchilla 的核心结论(后面会细讲):

D20×ND^* \approx 20 \times N^*

即:如果你有一个 1B 参数的模型,应该用 20B token 来训练它。很多模型之前都"太大但训练不够"。

第五幕:开源反击战(2023–至今)

闭源模型越来越强,但开源社区也没闲着:

模型来自亮点
Llama 1/2/3Meta推动了整个开源 LLM 生态系统
Mistral 7BMistral AI7B 参数,效果媲美 13B
Qwen 2.5阿里中文能力突出
DeepSeek V2/V3DeepSeekMoE 架构,极致性价比
OLMo 2AI2完全开源——包括训练数据和全部细节

开放程度有三个层次,记住这个区分很有用:

级别你能拿到什么例子
闭源只有 APIGPT-4、Claude
开放权重模型权重 + 论文,但没有训练数据Llama、DeepSeek
完全开源权重 + 数据 + 训练代码 + 详细文档OLMo

第四部分:训练 LLM 的完整流程

在开始写代码之前,先看一下全景图。训练一个语言模型需要哪些步骤?

BasicsSystemsScaling LawsDataAlignment
TokenizationKernelsScaling sequenceEvaluationSupervised fine-tuning
ArchitectureParallelismModel complexityCurationReinforcement learning
Loss functionQuantizationLoss metricTransformationPreference data
OptimizerActivation checkpointingParametric formFilteringSynthetic data
Learning rateCPU offloadingDeduplicationVerifiers
InferenceMixing

五大模块一览

序号模块一句话类比
1基础分词 + Transformer + 训练循环造一辆能跑的车
2系统GPU kernel + 分布式训练 + 推理优化改装发动机和变速箱
3Scaling Laws小实验预测大模型的最优配置在沙盘上演练再上战场
4数据采集、清洗、去重、评估找到优质燃料
5对齐SFT + RLHF/DPO请教练教它规矩

模块 1:基础——搭建完整 pipeline

这是第一步,也是本讲的重点之一。

分词(Tokenization):把人类读的文字变成模型能处理的数字

"Stanford was founded in 1885."
        ↓ encode()
[93447, 9201, 673, 24303, 306, 220, 13096, 20, 13]
 Stan   ford  was  founded  in  " " 1885    .   .
        ↓ decode()
"Stanford was founded in 1885."

分词器必须保证 decode(encode(text)) == text,即编码-解码完全可逆

模型架构:现代 LLM 几乎都基于 Transformer

不过原始 Transformer(2017)之后有很多改进,主要变体包括:

组件原始版本现代常用版本为什么改
激活函数ReLUSwiGLU实验效果更好(虽然不知道为什么)
位置编码正弦函数RoPE支持变长外推
归一化LayerNormRMSNorm计算更快,效果相当
MLP单一全连接MoE(混合专家)同样计算量激活更多参数
注意力Full attentionGQA / MLA省显存、推理更快

训练:优化器(AdamW → Muon → SOAP)、学习率调度(余弦退火 / WSD)、Batch size 策略

模块 2:系统——榨干每一块 GPU

这是很多人忽视但极其关键的部分。

GPU 不是一块铁板——它内部有分层的存储结构:

组件角色类比
DRAM (HBM)主显存,容量大 (~80GB),带宽"慢"仓库:东西多但取货远
SRAM计算单元旁的缓存,容量小 (~20MB),极快工作台:手边但放不了多少

训练优化的核心原则:最小化数据在 DRAM 和 SRAM 之间的搬运次数。

当模型太大,一块 GPU 装不下怎么办?分布式并行

  • 数据并行:每块 GPU 拿一部分 batch
  • 张量并行:一个矩阵乘法切给多块 GPU
  • 流水线并行:模型的不同层放在不同 GPU 上

推理优化——生成文本分两个阶段:

┌─────────────────────┐  ┌──────────────────────────────────────────┐
│    Prefill Phase     │  │           Decoding Phase                 │
│                      │  │                                          │
│  "Computer           │  │  ┌───────────┐ ┌───────────┐ ┌────────┐ │
│   science is"        │  │  │Iteration 2│→│Iteration 3│→│Iter. 4 │ │
│       ↓              │  │  └─────┬─────┘ └─────┬─────┘ └───┬────┘ │
│  ┌───────────┐       │  │        ↓              ↓           ↓      │
│  │Iteration 1│→ "a"  │  │   "discipline"      "."       <EOS>     │
│  └───────────┘       │  │                                          │
└─────────────────────┘  └──────────────────────────────────────────┘
                    ←——————— KV-Cache ————————→
  • Prefill:一次性处理所有输入 token("Computer science is"),计算并缓存 KV
  • Decode:逐个生成新 token,每步复用之前缓存的 KV(避免重复计算)
阶段做什么瓶颈类比
Prefill处理用户输入的所有 token计算量大(compute-bound)读完一整本书
Decode一个接一个生成新 token内存带宽(memory-bound)一个字一个字写回答

Decode 阶段很慢,因为每次只生成 1 个 token 却要读取整个模型。加速方法:KV Cache、投机解码、量化等。

模块 3:Scaling Laws——用小实验预测大结果

这是最"省钱"的模块。核心问题:

我有 CC 个 FLOPs(浮点运算次数)的预算,应该训练一个多大的模型?用多少数据?

Chinchilla 定律给出了答案:模型参数量 NN 和训练数据量 DD 应该同步增长。

参数量建议训练数据说明
1B20B tokensD20ND \approx 20N
7B140B tokens
70B1.4T tokens

但要注意:这个定律只考虑了训练成本,没考虑推理成本。如果模型要被调用几十亿次,用更小的模型 + 更多数据训练(over-train)可能总成本更低。Llama 系列就采用了这种策略。

模块 4:数据——决定模型上限的隐形力量

"Garbage in, garbage out"——数据的质量直接决定模型的质量。

常见数据来源:

  • Common Crawl:全互联网爬取的网页数据(几十 TB)
  • 书籍:高质量长文本
  • arXiv 论文:学术知识
  • GitHub 代码:编程能力
  • Wikipedia:结构化知识

但原始数据是一片荒野——大量广告、垃圾页面、重复内容、有害信息。需要经过:

  1. 格式转换:HTML/PDF → 纯文本
  2. 质量过滤:训练分类器判断内容质量
  3. 去重:用 MinHash 等算法去除重复,避免模型"记住"而非"理解"
  4. 安全过滤:移除有害、违法内容

模块 5:对齐——从"会说话"到"有用"

预训练完成后,模型是一个很强的"补全引擎"——你给它开头,它能继续写。但它:

  • 不会遵循指令
  • 可能输出有害内容
  • 风格不可控

对齐(Alignment) 就是把这个"原始能力"塑造成"有用的助手"。分两步:

第一步:监督微调(SFT)

用 (指令, 回答) 对来微调模型。一个惊人的发现:

你不需要很多数据!LIMA 论文表明,仅用 1000 条高质量 (指令, 回答) 对,就能让模型表现出不错的指令跟随能力。基座模型已经有能力了,SFT 只是把它"唤醒"。

第二步:从人类反馈中学习(RLHF / DPO)

让模型生成多个回答,人类标注哪个更好,然后用这些偏好数据进一步优化模型。

算法核心思路复杂度
PPO (RLHF)训练一个奖励模型,再用强化学习优化高(需要 4 个模型)
DPO直接从偏好数据优化,不需要单独的奖励模型
GRPO去掉 value function,用组内对比

第五部分:深入理解分词(Tokenization)

推荐视频:Andrej Karpathy — Let's build the GPT Tokenizer(YouTube, 2h13m)

为什么分词是第一课?

因为它是整个 pipeline 的入口。模型看不到文字——它只能处理数字。分词器就是翻译官:

"Hello world" → [15496, 995] → 送入 Transformer → 输出概率 → [下一个 token] → 解码回文字

分词器的设计直接影响:

  • 模型能力:分词方式决定了模型"看到"什么粒度的信息
  • 计算成本:token 越多,序列越长,注意力计算量越大(O(n2)O(n^2)
  • 多语言表现:中文分词不好的话,同样内容要花更多 token,效率低下

两个关键指标

指标含义越大越好?
词表大小(vocabulary size)有多少种不同的 token太大浪费内存,太小表达力差
压缩比(compression ratio)字节数 / token 数✅ 越大越好——同样的文本用更少的 token 表示

词表大小为什么重要?因为模型有一个嵌入矩阵,大小是 词表大小 × 嵌入维度。词表 50K、嵌入维度 4096 → 这个矩阵就有 2 亿参数。词表翻倍,这 2 亿也翻倍。

下面我们依次看四种分词方案。

先体验一下:GPT-2 的分词器

在动手实现之前,先感受一下成熟的分词器是什么样的。

在线体验:tiktokenizer.vercel.app

import tiktoken

# 加载 GPT-2 分词器(OpenAI 提供)
tokenizer = tiktoken.get_encoding("gpt2")

string = "Hello, 🌍! 你好!"
print(f"原始字符串: {string}")
print(f"UTF-8 字节数: {len(string.encode('utf-8'))}")

# 编码:字符串 → token ID 列表
indices = tokenizer.encode(string)
print(f"\nToken IDs: {indices}")
print(f"Token 数量: {len(indices)}")

# 看看每个 token 对应什么
for idx in indices:
    token_bytes = tokenizer.decode_single_token_bytes(idx)
    print(f"  {idx:>6d}{token_bytes}")

# 解码:token ID 列表 → 字符串(验证可逆)
reconstructed = tokenizer.decode(indices)
print(f"\n解码还原: {reconstructed}")
print(f"压缩比: {len(string.encode('utf-8')) / len(indices):.2f}")
原始字符串: Hello, 🌍! 你好!
UTF-8 字节数: 20

Token IDs: [15496, 11, 12520, 234, 235, 0, 220, 19526, 254, 25001, 121, 0]
Token 数量: 12
   15496b'Hello'
      11b','
   12520b' \xf0\x9f'
     234b'\x8c'
     235b'\x8d'
       0b'!'
     220b' '
   19526b'\xe4\xbd'
     254b'\xa0'
   25001b'\xe5\xa5'
     121b'\xbd'
       0b'!'

解码还原: Hello, 🌍! 你好!
压缩比: 1.67

注意观察几个有趣的现象:

  1. "Hello" 是一个常见的英文词,被编码成一个 token(15496)——非常高效
  2. 🌍 这个 emoji 被拆成了 4 个 token——因为它在训练数据中不常见
  3. "你好" 两个汉字被拆成了 5 个 token——GPT-2 的训练数据以英文为主,中文分词效率很差

这就是为什么用 GPT 系列处理中文时 token 消耗更多、成本更高。新一代模型(如 Qwen、DeepSeek)专门优化了中文分词。

方案 1:基于字符的分词

思路:每个 Unicode 字符就是一个 token。

Unicode 是全球通用的字符编码标准,每个字符有一个唯一的"码点"(code point)——一个整数。Python 提供了两个内置函数来转换:

# ord(): 字符 → 码点(整数)
# chr(): 码点 → 字符

print(f"'A' 的码点: {ord('A')}")          # 65
print(f"'你' 的码点: {ord('你')}")        # 20320
print(f"'🌍' 的码点: {ord('🌍')}")       # 127757

# 反向转换
print(f"码点 65 对应: {chr(65)}")          # A
print(f"码点 20320 对应: {chr(20320)}")    # 你
'A' 的码点: 65
'你' 的码点: 20320
'🌍' 的码点: 127757
码点 65 对应: A
码点 20320 对应: 你
# 用字符级分词
string = "Hello, 🌍! 你好!"
indices = [ord(c) for c in string]

print(f"原始: {string}")
print(f"Token IDs: {indices}")
print(f"词表大小至少: {max(indices) + 1:,}")   # 127,758
print(f"Token 数: {len(indices)}")
print(f"压缩比: {len(string.encode('utf-8')) / len(indices):.2f}")
原始: Hello, 🌍! 你好!
Token IDs: [72, 101, 108, 108, 111, 44, 32, 127757, 33, 32, 20320, 22909, 33]
词表大小至少: 127,758
Token 数: 13
压缩比: 1.54

这个方案的问题

问题说明
词表太大Unicode 有约 15 万个字符。嵌入矩阵 = 150K × 维度,参数量爆炸
极度稀疏99% 的字符在训练数据中极少出现,模型学不好它们的表示
没有学到"词"的概念H, e, l, l, o 是 5 个独立 token,模型要自己学习"它们经常连在一起"

类比:这就像让你一个字一个字地读书——信息密度太低,阅读速度太慢。

方案 2:基于字节的分词

思路:用 UTF-8 编码把字符串变成字节序列,每个字节就是一个 token。

什么是 UTF-8?

UTF-8 是最常用的 Unicode 编码方式。它用变长字节表示字符:

字符类型字节数例子
ASCII(英文、数字、标点)1 字节'A'[65]
欧洲语言带重音符号2 字节'é'[195, 169]
中文、日文、韩文3 字节'你'[228, 189, 160]
Emoji4 字节'🌍'[240, 159, 140, 141]
string = "Hello, 🌍! 你好!"

# UTF-8 编码
raw_bytes = string.encode("utf-8")
indices = list(raw_bytes)

print(f"原始: {string}")
print(f"UTF-8 字节: {raw_bytes}")
print(f"Token IDs: {indices}")
print(f"词表大小: 256(固定——一个字节只有 0~255)")
print(f"Token 数: {len(indices)}")
print(f"压缩比: {len(raw_bytes) / len(indices):.2f}")

# 还原
print(f"还原: {bytes(indices).decode('utf-8')}")
原始: Hello, 🌍! 你好!
UTF-8 字节: b'Hello, \xf0\x9f\x8c\x8d! \xe4\xbd\xa0\xe5\xa5\xbd!'
Token IDs: [72, 101, 108, 108, 111, 44, 32, 240, 159, 140, 141, 33, 32, 228, 189, 160, 229, 165, 189, 33]
词表大小: 256(固定——一个字节只有 0~255)
Token 数: 20
压缩比: 1.00
还原: Hello, 🌍! 你好!

优点

  • 词表极小且固定(256),嵌入矩阵很小
  • 天然覆盖所有语言和特殊字符,永远不会遇到"未知字符"

致命缺点

  • 压缩比 = 1.0——一个字节一个 token,完全没有压缩
  • 中文每个字变成 3 个 token,emoji 变成 4 个 token
  • 序列太长 → 注意力计算是 O(n2)O(n^2),成本爆炸

类比:这就像把文章翻译成摩尔斯电码再阅读——虽然"字母表"只有点和划两个符号,但信息表达极其低效。

不过,有一些前沿研究(ByT5、MegaByte、BLT)正在尝试让字节级分词在实际中可用,只是还没有扩展到前沿规模。

方案 3:基于词的分词

思路:回到传统 NLP 的方式——用正则表达式把文本切成"词"。

import regex

string = "I'll say supercalifragilisticexpialidocious!"

# GPT-2 使用的正则模式(处理缩写、数字、标点等)
GPT2_PATTERN = r"'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
segments = regex.findall(GPT2_PATTERN, string)

print(f"原始: {string}")
print(f"切分结果: {segments}")
print(f"Token 数: {len(segments)}")
print(f"压缩比: {len(string.encode('utf-8')) / len(segments):.1f}")
原始: I'll say supercalifragilisticexpialidocious!
切分结果: ['I', "'ll", ' say', ' supercalifragilisticexpialidocious', '!']
Token 数: 5
压缩比: 8.8

压缩比 8.8,非常高!但问题也很严重:

问题说明
词表爆炸英语有几十万个词,加上中文、日文等更是天文数字
稀疏性"supercalifragilisticexpialidocious" 可能只在训练数据中出现过一次,模型学不到有意义的表示
OOV 问题遇到新词(如新的品牌名、人名)只能输出 <UNK>(unknown),丢失信息

类比:词级分词就像只允许你用字典里已有的词说话——遇到新词就只能说"那个东西"。

四种方案对比:为什么需要 BPE?

方案词表大小压缩比OOV?问题
字符级~150K~1.5词表大、序列长、稀疏
字节级2561.0序列太长,计算成本爆炸
词级~500K+~8.8词表爆炸、无法处理新词
BPE可控✅ 目前主流方案

我们需要一个方案:

  1. 词表大小可控(通常 32K~100K)
  2. 常见模式用少量 token 表示(高压缩比)
  3. 再罕见的文本也能分解成已知 token(无 OOV

这就是 Byte Pair Encoding (BPE) 的设计目标。


BPE:当前最主流的分词算法

来历

  • 1994 年 Philip Gage 发明,原本用于数据压缩
  • 2016 年 Sennrich et al. 将其引入 NLP(论文
  • 2019 年被 GPT-2 采用,从此成为 LLM 标配

核心思想用一句话概括

从最小的单位(字节)开始,不断合并最常见的相邻组合,直到词表达到目标大小。

这很像你学习一门新语言的过程:

  1. 先学会认字母(字节)
  2. 然后你发现 t-h 总是一起出现,于是你把它记成一个整体 th
  3. 接着你发现 th-e 也很常见,记成 the
  4. 然后 the (带空格) 也很常见,记成一个整体
  5. ...如此反复,你就建立了一个高效的"词汇表"

算法正式描述:

输入: 训练文本, 目标合并次数 num_merges
1. 将文本转为字节序列(词表 = {0, 1, ..., 255})
2. 重复 num_merges 次:
   a. 统计所有相邻 token 对的出现次数
   b. 找到出现最多的那对 (a, b)
   c. 创建新 token = a+b, 加入词表
   d. 在序列中将所有 (a, b) 替换为新 token
3. 输出: 合并规则表 + 词表

动手实现:一步步训练 BPE

我们用一个简短的字符串来演示完整过程。

string = "the cat in the hat"

# 第 0 步:转为字节(= 初始 token 序列)
indices = list(string.encode("utf-8"))

print(f"训练文本: '{string}'")
print(f"初始 token 序列 (每个字节一个 token):")
print(f"  IDs:  {indices}")
print(f"  字符: {[chr(i) for i in indices]}")
print(f"  长度: {len(indices)}")
训练文本: 'the cat in the hat'
初始 token 序列 (每个字节一个 token):
  IDs:  [116, 104, 101, 32, 99, 97, 116, 32, 105, 110, 32, 116, 104, 101, 32, 104, 97, 116]
  字符: ['t', 'h', 'e', ' ', 'c', 'a', 't', ' ', 'i', 'n', ' ', 't', 'h', 'e', ' ', 'h', 'a', 't']
  长度: 18
from collections import defaultdict

def count_pairs(indices):
    """统计相邻 token 对的出现次数"""
    counts = defaultdict(int)
    for i in range(len(indices) - 1):
        counts[(indices[i], indices[i+1])] += 1
    return dict(counts)

def merge(indices, pair, new_index):
    """将序列中所有的 pair 替换为 new_index"""
    result = []
    i = 0
    while i < len(indices):
        if i < len(indices) - 1 and (indices[i], indices[i+1]) == pair:
            result.append(new_index)
            i += 2  # 跳过两个,因为已经合并
        else:
            result.append(indices[i])
            i += 1
    return result

# 初始化词表:256 个字节
vocab = {i: bytes([i]) for i in range(256)}
merges = {}  # 记录合并规则

# =============================================
# 第 1 轮:统计、找最高频、合并
# =============================================
counts = count_pairs(indices)

# 只显示出现 >= 2 次的 pair
print("=== 第 1 轮 ===")
print("相邻 pair 频率 (>= 2 次):")
for pair, count in sorted(counts.items(), key=lambda x: -x[1]):
    if count >= 2:
        print(f"  ('{chr(pair[0])}', '{chr(pair[1])}') → {count} 次")

best_pair = max(counts, key=counts.get)
new_id = 256
merges[best_pair] = new_id
vocab[new_id] = vocab[best_pair[0]] + vocab[best_pair[1]]

print(f"\n→ 合并 ('{chr(best_pair[0])}', '{chr(best_pair[1])}') = token {new_id} ('{vocab[new_id].decode()}')")
indices = merge(indices, best_pair, new_id)
print(f"  新长度: {len(indices)}  (18 → {len(indices)})")
=== 第 1 轮 ===
相邻 pair 频率 (>= 2 次):
  ('t', 'h') → 2 次
  ('h', 'e') → 2 次
  ('e', ' ') → 2 次
  ('a', 't') → 2 次

→ 合并 ('t', 'h') = token 256 ('th')
  新长度: 16  (1816)
# =============================================
# 第 2 轮
# =============================================
counts = count_pairs(indices)
best_pair = max(counts, key=counts.get)
new_id = 257
merges[best_pair] = new_id
vocab[new_id] = vocab[best_pair[0]] + vocab[best_pair[1]]

print("=== 第 2 轮 ===")
print(f"最高频 pair: token {best_pair[0]} ('{vocab[best_pair[0]].decode()}') + token {best_pair[1]} ('{vocab[best_pair[1]].decode()}')")
print(f"→ 合并为 token {new_id} ('{vocab[new_id].decode()}')")
indices = merge(indices, best_pair, new_id)
print(f"  新长度: {len(indices)}")

# =============================================
# 第 3 轮
# =============================================
counts = count_pairs(indices)
best_pair = max(counts, key=counts.get)
new_id = 258
merges[best_pair] = new_id
vocab[new_id] = vocab[best_pair[0]] + vocab[best_pair[1]]

print(f"\n=== 第 3 轮 ===")
print(f"最高频 pair: token {best_pair[0]} ('{vocab[best_pair[0]].decode()}') + token {best_pair[1]} ('{vocab[best_pair[1]].decode()}')")
print(f"→ 合并为 token {new_id} ('{vocab[new_id].decode()}')")
indices = merge(indices, best_pair, new_id)
print(f"  新长度: {len(indices)}")

# 总结
print(f"\n{'='*50}")
print(f"训练完成! 经过 3 轮合并:")
print(f"  序列长度: 18 → {len(indices)}")
print(f"  词表大小: 256 → {len(vocab)}")
print(f"  学到的合并规则:")
for (a, b), idx in merges.items():
    print(f"    '{vocab[a].decode()}' + '{vocab[b].decode()}' → '{vocab[idx].decode()}'  (token {idx})")
=== 第 2 轮 ===
最高频 pair: token 256 ('th') + token 101 ('e')
→ 合并为 token 257 ('the')
  新长度: 14

=== 第 3 轮 ===
最高频 pair: token 257 ('the') + token 32 (' ')
→ 合并为 token 258 ('the ')
  新长度: 12

==================================================
训练完成! 经过 3 轮合并:
  序列长度: 1812
  词表大小: 256259
  学到的合并规则:
    't' + 'h''th'  (token 256)
    'th' + 'e''the'  (token 257)
    'the' + ' ''the '  (token 258)

注意合并过程中发生了什么:

  1. t + hth:两个最常一起出现的字母合并了
  2. th + ethe:新 token 继续和高频邻居合并
  3. the + the (带空格):注意空格也被合并进去了!

这解释了为什么在 GPT 分词器中," world" 前面带空格——空格是 token 的一部分。

实际应用中,GPT-2 做了约 50,000 次合并,最终词表大小 ≈ 50,257。GPT-4 的词表更大,约 100K。

用训练好的 BPE 分词器编码新文本

训练阶段产出的是一组有序的合并规则。编码新文本时,只需按顺序尝试每条规则:

def bpe_encode(text, merges, vocab):
    """BPE 编码:字符串 → token ID 列表"""
    # 先转成字节级 token
    ids = list(text.encode("utf-8"))
    # 按训练时的顺序逐条应用合并规则
    for pair, new_id in merges.items():
        ids = merge(ids, pair, new_id)
    return ids

def bpe_decode(ids, vocab):
    """BPE 解码:token ID 列表 → 字符串"""
    return b"".join(vocab[i] for i in ids).decode("utf-8")

# 在新文本上测试
test = "the quick brown fox"
encoded = bpe_encode(test, merges, vocab)
decoded = bpe_decode(encoded, vocab)

print(f"原文: '{test}'")
print(f"编码: {encoded}")
print(f"每个 token 的含义:")
for i in encoded:
    print(f"  token {i:>3d} → '{vocab[i].decode()}'")
print(f"\n解码: '{decoded}'")
print(f"压缩比: {len(test.encode('utf-8')) / len(encoded):.2f}")
原文1: 'the quick brown fox'
编码: [258, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120]
每个 token 的含义:
  token 258'the '
  token 113'q'
  token 117'u'
  token 105'i'
  token  99'c'
  token 107'k'
  token  32' '
  token  98'b'
  token 114'r'
  token 111'o'
  token 119'w'
  token 110'n'
  token  32' '
  token 102'f'
  token 111'o'
  token 120'x'

解码: 'the quick brown fox'
压缩比: 1.19

可以观察到:

  • "the " 被识别为一个 token——因为训练时学到了这个高频模式
  • 其他词(quick, brown, fox)仍然是逐字节的——因为我们只做了 3 次合并
  • 如果做 50,000 次合并,这些常见英文词也会逐渐被合并成单个 token

直觉:BPE 就像一个"自适应压缩字典"——常见的东西用短编码,罕见的东西用长编码。这和霍夫曼编码、摩尔斯电码的思想一致:高频信号用短表示。

实际工程中的关键细节

我们实现的是最简化版本。真正的 BPE 分词器还需要处理:

细节说明
预分词先用正则表达式把文本切成"词"级别的段落,再对每段分别做 BPE。这防止了跨词合并(比如不希望 "dog"g 和下一个词的 a 合并成 ga
特殊 token保留 <|endoftext|><|pad|> 等控制标记不被拆分
编码效率不需要每次都遍历所有合并规则——可以用优先队列只处理当前序列中存在的 pair
训练效率在大规模语料上训练时,需要高效的数据结构来统计 pair 频率

本讲总结

大图景

                    原始文本
                       ↓
              ┌─── 分词器 (BPE) ───┐
              │  字符串 → token IDs │  ← 今天学的
              └────────────────────┘
                       ↓
              ┌─── Transformer ───┐
              │  token → 概率分布  │  ← 下一讲
              └───────────────────┘
                       ↓
              ┌─── 训练循环 ──────┐
              │  反向传播更新参数   │  ← 下一讲
              └───────────────────┘
                       ↓
              ┌─── 对齐 ─────────┐
              │  SFT + RLHF/DPO  │  ← 第五讲
              └───────────────────┘
                       ↓
                   ChatGPT 🎉

分词方案对比(今天的核心内容)

方案词表压缩比能处理新词?评价
字符级~150K~1.5词表太大,稀疏
字节级2561.0序列太长
词级~500K+~8.8OOV 问题
BPE可控当前标准方案

关键概念检查

读完本讲,你应该能回答这些问题:

  1. 语言模型的训练目标是什么?(预测下一个 token)
  2. 为什么不能直接把文字送入模型?(模型只接受数字)
  3. 词表大小为什么不能太大?(嵌入矩阵太大,稀疏 token 学不好)
  4. 为什么字节级分词不实用?(序列太长,注意力计算 O(n2)O(n^2)
  5. BPE 的核心操作是什么?(反复合并最高频的相邻 token 对)