本文首发于机智流(其实这篇内容应该是整个internlm101的最开始)
借用Llama2 论文里的一句话。词元(tokens)是大型语言模型(LLMs)的基本单元,而词元化是将字符串(即文本)翻译并将其转换为词元序列及其相反过程。
在这篇文章里,我们不会实现一个可以完全适配internlm2这种模型的tokenizer。这是一件很复杂的事情,我决定选择karpathy的minbpe作为讲解。在讲代码之前,还是惯例先讲讲基础知识。
0x01 最开始的开始,如何创建一个词表
- 收集训练数据
- 初始化tokenizer
- 为tokenizer选择一个算法(bpe, wordpiece, sentencepiece)
- 应用这个算法
- 分配ID号
在创建tokenizer的流程里,最关键的是选择一个算法。和代码库绑定,这篇文章只局限于bpe分词。至于其他的分词法都有异曲同工之妙。
0x02 bpe详解
BPE的四个基本流程
- 准备词汇和频率统计:在这一步,你需要统计训练数据中每个单词的出现频率。通常,单词后面会加上特殊符号(如)以表示单词的结束。这有助于区分自然组合的子词和单词边界。例如,单词“low”可能被表示为“low”,或“Ġlow”。并记录其出现频率。
- 初始化词汇表:初始的词汇表包含所有字符的集合,以及可能的子词序列。每个字符或子词序列都被视为一个独立的词汇单元,频率统计也根据第一步得到的数据初始化。
- 迭代合并最频繁的字符对:BPE算法的核心是迭代合并过程。在每次迭代中,算法会寻找当前词汇中出现频率最高的相邻字符对(如A B),并将其合并为一个新的单元(如AB)。合并过程会更新词汇表和所有单词中的字符对应关系。例如,如果e和s是最常见的一对,那么所有包含es的单词都会被更新,如“forest”变成“forest”。
- 重复合并直到达到指定的词表:这个过程会重复进行,直到达到用户设定的词表大小或合并次数。每次合并都可能产生新的最频繁对,因此这一步是动态进行的。最终,这个过程生成一个固定大小的词汇表,其中包括单个字符、常见子词以及频繁的子词组合。
0x03 minbpe代码串讲
我们主要来看BasicTokenizer这个类。
训练代码
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
num_merges = vocab_size - 256
首先需要确定vocas_size的大小一定要大于256。因为初始词表包含所有单字符值(0-255),这相当于 UTF-8 编码的所有可能的单个字节。之后根据vocab_size和256的差确定总共需要merge的次数。
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
之后,把text转化成utf-8编码,统一编码格式。并把输入的tex转换成数组形式。
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
- 我们创建了一个merges字典,字典的键是一个由两个整数组成的元组 (int, int),这两个整数代表一对需要合并的字节。字典的值是一个整数,代表这对字节合并后新生成的token的唯一标识符(ID)。随着bpe算法的执行,每当发现一个新的最常见的字节对时,就会为其分配一个新的 ID,并将这个合并操作添加到 merges 字典中。
- 之后,使用字典推导式创建了一个 vocab 字典,它的键是整数(从0到255),值是对应的单字节表示形式。这是因为在 UTF-8 编码中,单个字节可以表示256种不同的值(0-255),这正好对应于所有可能的单字节字符。这个词汇表的初始化反映了 BPE 算法的起始点,即最初,每个单字节字符都被视为一个独立的token。
for i in range(num_merges):
# count up the number of times every consecutive pair appears
stats = get_stats(ids)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = 256 + i
# replace all occurrences of pair in ids with idx
ids = merge(ids, pair, idx)
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
之后开始正式训练流程,我们要做的就是:寻找当前词汇中出现频率最高的相邻字符对(如A B),并将其合并为一个新的单元(如AB)。合并过程会更新词汇表和所有单词中的字符对应关系。
在第二行我们调用了一个get_stats函数:它做的是把一个数组转换成字典。降序排列value给出排列。
def get_stats(ids, counts=None):
"""
Given a list of integers, return a dictionary of counts of consecutive pairs
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
Optionally allows to update an existing dictionary of counts
"""
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]): # iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts
之后我们把这个idx拿出来,并和ids进行合并。这里使用了merge函数,我们来分析一下:
def merge(ids, pair, idx):
"""
In the list of integers (ids), replace all consecutive occurrences
of pair with the new integer token idx
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
newids = []
i = 0
while i < len(ids):
# if not at the very last position AND the pair matches, replace it
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
乍一眼看这个example并不是很清晰。其实它的流程是这样的:
ids = [1, 2, 3, 1, 2], pair=(1, 2), idx = 4
- i = 0 发现此时ids[i] == pair[0] & ids[i+1] == pair[1] -> 把 idx append到新ids里(因为idx代表了老的合并)
- 重复这个过程
最后我们save merge并更新词表
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
编码流程
之后让我们来看看编码流程:
def encode(self, text):
# given a string text, return the token ids
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
主要流程就是:
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
通过这两行代码,找到最小索引的对。并进行合并。直到不能合并为止。
0x04 小结
tokenizer是大模型整体流程的第一环,也是重要的一部分。所以掌握最基础的bpe的流程很关键。