聊聊并写写tokenzier(bpe)

425 阅读6分钟

本文首发于机智流(其实这篇内容应该是整个internlm101的最开始)

借用Llama2 论文里的一句话。词元(tokens)是大型语言模型(LLMs)的基本单元,而词元化是将字符串(即文本)翻译并将其转换为词元序列及其相反过程。

在这篇文章里,我们不会实现一个可以完全适配internlm2这种模型的tokenizer。这是一件很复杂的事情,我决定选择karpathy的minbpe作为讲解。在讲代码之前,还是惯例先讲讲基础知识。

0x01 最开始的开始,如何创建一个词表

  • 收集训练数据
  • 初始化tokenizer
  • 为tokenizer选择一个算法(bpe, wordpiece, sentencepiece)
  • 应用这个算法
  • 分配ID号

在创建tokenizer的流程里,最关键的是选择一个算法。和代码库绑定,这篇文章只局限于bpe分词。至于其他的分词法都有异曲同工之妙。

0x02 bpe详解

BPE的四个基本流程

  • 准备词汇和频率统计:在这一步,你需要统计训练数据中每个单词的出现频率。通常,单词后面会加上特殊符号(如)以表示单词的结束。这有助于区分自然组合的子词和单词边界。例如,单词“low”可能被表示为“low”,或“Ġlow”。并记录其出现频率。
  • 初始化词汇表:初始的词汇表包含所有字符的集合,以及可能的子词序列。每个字符或子词序列都被视为一个独立的词汇单元,频率统计也根据第一步得到的数据初始化。
  • 迭代合并最频繁的字符对:BPE算法的核心是迭代合并过程。在每次迭代中,算法会寻找当前词汇中出现频率最高的相邻字符对(如A B),并将其合并为一个新的单元(如AB)。合并过程会更新词汇表和所有单词中的字符对应关系。例如,如果e和s是最常见的一对,那么所有包含es的单词都会被更新,如“forest”变成“forest”。
  • 重复合并直到达到指定的词表:这个过程会重复进行,直到达到用户设定的词表大小或合并次数。每次合并都可能产生新的最频繁对,因此这一步是动态进行的。最终,这个过程生成一个固定大小的词汇表,其中包括单个字符、常见子词以及频繁的子词组合。

0x03 minbpe代码串讲

我们主要来看BasicTokenizer这个类。

训练代码

def train(self, text, vocab_size, verbose=False):
    assert vocab_size >= 256
    num_merges = vocab_size - 256

首先需要确定vocas_size的大小一定要大于256。因为初始词表包含所有单字符值(0-255),这相当于 UTF-8 编码的所有可能的单个字节。之后根据vocab_size和256的差确定总共需要merge的次数。

    text_bytes = text.encode("utf-8") # raw bytes
    ids = list(text_bytes) # list of integers in range 0..255

之后,把text转化成utf-8编码,统一编码格式。并把输入的tex转换成数组形式。

    merges = {} # (int, int) -> int
    vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
  • 我们创建了一个merges字典,字典的键是一个由两个整数组成的元组 (int, int),这两个整数代表一对需要合并的字节。字典的值是一个整数,代表这对字节合并后新生成的token的唯一标识符(ID)。随着bpe算法的执行,每当发现一个新的最常见的字节对时,就会为其分配一个新的 ID,并将这个合并操作添加到 merges 字典中。
  • 之后,使用字典推导式创建了一个 vocab 字典,它的键是整数(从0到255),值是对应的单字节表示形式。这是因为在 UTF-8 编码中,单个字节可以表示256种不同的值(0-255),这正好对应于所有可能的单字节字符。这个词汇表的初始化反映了 BPE 算法的起始点,即最初,每个单字节字符都被视为一个独立的token。
    for i in range(num_merges):
        # count up the number of times every consecutive pair appears
        stats = get_stats(ids)
        # find the pair with the highest count
        pair = max(stats, key=stats.get)
        # mint a new token: assign it the next available id
        idx = 256 + i
        # replace all occurrences of pair in ids with idx
        ids = merge(ids, pair, idx)
        # save the merge
        merges[pair] = idx
        vocab[idx] = vocab[pair[0]] + vocab[pair[1]]

之后开始正式训练流程,我们要做的就是:寻找当前词汇中出现频率最高的相邻字符对(如A B),并将其合并为一个新的单元(如AB)。合并过程会更新词汇表和所有单词中的字符对应关系。

在第二行我们调用了一个get_stats函数:它做的是把一个数组转换成字典。降序排列value给出排列。

def get_stats(ids, counts=None):
    """
    Given a list of integers, return a dictionary of counts of consecutive pairs
    Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
    Optionally allows to update an existing dictionary of counts
    """
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]): # iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

之后我们把这个idx拿出来,并和ids进行合并。这里使用了merge函数,我们来分析一下:

def merge(ids, pair, idx):
    """
    In the list of integers (ids), replace all consecutive occurrences
    of pair with the new integer token idx
    Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
    """
    newids = []
    i = 0
    while i < len(ids):
        # if not at the very last position AND the pair matches, replace it
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

乍一眼看这个example并不是很清晰。其实它的流程是这样的:

ids = [1, 2, 3, 1, 2], pair=(1, 2), idx = 4

  • i = 0 发现此时ids[i] == pair[0] & ids[i+1] == pair[1] -> 把 idx append到新ids里(因为idx代表了老的合并)
  • 重复这个过程

最后我们save merge并更新词表

# save the merge
    merges[pair] = idx
    vocab[idx] = vocab[pair[0]] + vocab[pair[1]]

编码流程

之后让我们来看看编码流程:

def encode(self, text):
    # given a string text, return the token ids
    text_bytes = text.encode("utf-8") # raw bytes
    ids = list(text_bytes) # list of integers in range 0..255
    while len(ids) >= 2:
        # find the pair with the lowest merge index
        stats = get_stats(ids)
        pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
        # subtle: if there are no more merges available, the key will
        # result in an inf for every single pair, and the min will be
        # just the first pair in the list, arbitrarily
        # we can detect this terminating case by a membership check
        if pair not in self.merges:
            break # nothing else can be merged anymore
        # otherwise let's merge the best pair (lowest merge index)
        idx = self.merges[pair]
        ids = merge(ids, pair, idx)
    return ids

主要流程就是:

stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))

通过这两行代码,找到最小索引的对。并进行合并。直到不能合并为止。

0x04 小结

tokenizer是大模型整体流程的第一环,也是重要的一部分。所以掌握最基础的bpe的流程很关键。