语言模型困惑度评估实战指南

3 阅读8分钟

语言模型困惑度评估实战指南

语言模型是对一系列词元序列的概率分布。训练语言模型时,需要衡量其预测人类语言使用的准确程度,这是一个困难的任务,需要一个评估模型的度量标准。在本文中,你将了解困惑度这一度量标准。具体来说,你将学习:

  • 什么是困惑度,以及如何计算它
  • 如何使用样本数据评估语言模型的困惑度

让我们开始吧。

概述

本文将分为两部分:

  1. 什么是困惑度以及如何计算它
  2. 使用HellaSwag数据集评估语言模型的困惑度

什么是困惑度以及如何计算它

困惑度衡量的是语言模型预测一段文本样本的能力。它的定义是样本中词元概率的几何平均数的倒数。数学上,困惑度定义为: PPL(x1:L)=i=1Lp(xi)1/L=exp(1Li=1Llogp(xi))PPL(x_{1:L}) = \prod_{i=1}^{L} p(x_i)^{-1/L} = \exp \left( -\frac{1}{L} \sum_{i=1}^{L} \log p(x_i) \right) 困惑度是特定词元序列的函数。在实践中,按照上述公式计算对数概率的平均值来求困惑度更为方便。

困惑度是一个量化语言模型对下一个词元平均犹豫程度的度量。如果语言模型完全确定,困惑度为1。如果语言模型完全不确定,那么词汇表中的每个词元都同样可能;困惑度等于词汇表的大小。不应期望困惑度超出此范围。

使用HellaSwag数据集评估语言模型的困惑度

困惑度是一个依赖于数据集的度量。可以使用的数据集之一是HellaSwag。这是一个包含训练集、测试集和验证集的数据集。它可在某平台的Hugging Face hub上获取,可以使用以下代码加载:

import datasets

dataset = datasets.load_dataset("HuggingFaceFW/hellaswag")
print(dataset)

for sample in dataset["validation"]:
    print(sample)
    break

运行此代码将打印以下内容:

DatasetDict({
    train: Dataset({
        features: ['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings',
                   'source_id', 'split', 'split_type', 'label'],
        num_rows: 39905
    })
    test: Dataset({
        features: ['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings',
                   'source_id', 'split', 'split_type', 'label'],
        num_rows: 10003
    })
    validation: Dataset({
        features: ['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings',
                   'source_id', 'split', 'split_type', 'label'],
        num_rows: 10042
    })
})
{'ind': 24, 'activity_label': 'Roof shingle removal',
'ctx_a': 'A man is sitting on a roof.', 'ctx_b': 'he',
'ctx': 'A man is sitting on a roof. he', 'endings': [
    'is using wrap to wrap a pair of skis.', 'is ripping level tiles off.',
    "is holding a rubik's cube.", 'starts pulling up roofing on a roof.'
], 'source_id': 'activitynet~v_-JhWjGDPHMY', 'split': 'val', 'split_type': 'indomain',
'label': '3'}

可以看到验证集有10042个样本。这将是本文使用的数据集。每个样本是一个字典。键“activity_label”指定活动类别,键“ctx”提供待完成的上下文。模型需要通过从四个结尾中选择一个来完成序列。键“label”的值从0到3,指示哪个结尾是正确的。

有了这些,你可以编写一个简短的代码来评估你自己的语言模型。让我们以某平台上的一个小型模型为例:

import datasets
import torch
import torch.nn.functional as F
import tqdm
import transformers

model = "openai-community/gpt2"

# 加载模型
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = transformers.AutoTokenizer.from_pretrained(model)
model = transformers.AutoModelForCausalLM.from_pretrained(model)

# 加载数据集:HellaSwag有训练、测试和验证集
dataset = datasets.load_dataset("hellaswag", split="validation")

# 评估模型:计算每个结尾的困惑度
num_correct = 0
for sample in tqdm.tqdm(dataset):
    # 对样本中的文本进行词元化
    text = tokenizer.encode(" " + sample["activity_label"] + ". " + sample["ctx"])
    endings = [tokenizer.encode(" " + x) for x in sample["endings"]]  # 4个结尾
    groundtruth = int(sample["label"])  # 整数,0到3
    # 为每个结尾生成logits
    perplexities = [0.0] * 4
    for i, ending in enumerate(endings):
        # 将整个输入和结尾送入模型
        input_ids = torch.tensor(text + ending).unsqueeze(0)
        output = model(input_ids).logits
        # 提取结尾中每个词元的logits
        logits = output[0, len(text)-1:, :]
        token_probs = F.log_softmax(logits, dim=-1)
        # 累积生成结尾的概率
        log_prob = 0.0
        for j, token in enumerate(ending):
            log_prob += token_probs[j, token]
        # 将对数概率总和转换为困惑度
        perplexities[i] = torch.exp(-log_prob / len(ending))
    # 打印每个结尾的困惑度
    print(sample["activity_label"] + ". " + sample["ctx"])
    correct = perplexities[groundtruth] == min(perplexities)
    for i, p in enumerate(perplexities):
        if i == groundtruth:
            symbol = '(O)' if correct else '(!)'
        elif p == min(perplexities):
            symbol = '(X)'
        else:
            symbol = '   '
        print(f"Ending {i}: {p:.4g} {symbol} - {sample['endings'][i]}")
    if correct:
        num_correct += 1

print(f"Accuracy: {num_correct}/{len(dataset)} = {num_correct / len(dataset):.4f}")

这段代码从Hugging Face Hub加载最小的GPT-2模型。这是一个124M参数的模型,可以在低配置计算机上轻松运行。使用Hugging Face transformers库加载模型和分词器。同时还加载了HellaSwag验证数据集。

在for循环中,对活动标签和上下文进行词元化。还对四个结尾中的每一个进行词元化。请注意,tokenizer.encode()是使用transformers库中分词器的方法,与前面文章中使用的分词器对象不同。

接下来,对于每个结尾,将拼接后的输入和结尾送入模型运行。input_ids张量是一个形状为(1, L)的整数词元ID二维张量(批处理维度为1)。模型返回一个对象,从中提取输出logits张量。这与前面文章中构建的模型不同,因为这是transformers库中的一个模型对象。只需稍作修改,就可以用它来替换你自己的训练好的模型对象。

GPT-2是一个仅解码器的transformer模型。它使用因果掩码处理输入。对于一个形状为(1,L)的输入张量,输出logits张量的形状为(1,L,V),其中V是词汇表大小。位置p的输出对应于模型对位置p+1词元的估计,这取决于位置1到p的输入。因此,提取从偏移n-1开始的logits,其中n是活动标签和上下文组合的长度。然后将logits转换为对数概率,并计算每个结尾长度上的平均值。

值token_probs[j, token]是位置j处ID为token的词元的对数概率。使用结尾中每个词元的平均对数概率来计算困惑度。期望一个好的模型能够以最低的困惑度识别正确的结尾。可以通过计算在整个HellaSwag验证数据集上的正确预测数量来评估模型。运行此代码时,你将看到:

...
Finance and Business. [header] How to buy a peridot [title] Look at a variety of stones...
Ending 0: 13.02 (X) - You will want to watch several of the gemstones, particularly eme...
Ending 1: 30.19 - Not only are they among the delicates among them, but they can be...
Ending 2: 34.96 (!) - Familiarize yourself with the different shades that it comes in, ...
Ending 3: 28.85 - Neither peridot nor many other jade or allekite stones are necess...
Family Life. [header] How to tell if your teen is being abused [title] Pay attention to...
Ending 0: 16.58 - Try to figure out why they are dressing something that is frowned...
Ending 1: 22.01 - Read the following as a rule for determining your teen's behaviou...
Ending 2: 15.21 (O) - [substeps] For instance, your teen may try to hide the signs of a...
Ending 3: 23.91 - [substeps] Ask your teen if they have black tights (with stripper...
Accuracy: 3041/10042 = 0.3028

代码打印了每个结尾的困惑度,并用(O)或(!)标记正确答案,用(X)标记模型的错误预测。可以看到,即使是对于正确答案,GPT-2的困惑度也在10到20之间。先进的大型语言模型即使词汇量比GPT-2大得多,也能实现困惑度低于10。更重要的是模型能否识别出正确结尾:即能自然完成句子的那个。它应该是困惑度最低的那个;否则,模型就无法生成正确的结尾。GPT-2在这个数据集上仅达到30%的准确率。

你也可以用不同的模型重复这段代码。结果如下:

  • 模型 openai-community/gpt2: 这是最小的GPT-2模型,有124M个参数,在上面的代码中使用。准确率为3041/10042或30.28%。
  • 模型 openai-community/gpt2-medium: 这是较大的GPT-2模型,有355M个参数。准确率为3901/10042或38.85%。
  • 模型 meta-llama/Llama-3.2-1B: 这是Llama家族中最小的模型,有1B个参数。准确率为5731/10042或57.07%。

因此,模型越大,准确率越高是很自然的。

需要注意的是,不应该在架构差异很大的模型之间比较困惑度。由于困惑度是一个范围在1到词汇表大小之间的度量,它高度依赖于分词器。原因在你将上面的代码中的GPT-2替换为Llama 3.2 1B后就变得很明显了:对于Llama 3,困惑度要高一个数量级,但准确率确实更高。这是因为GPT-2的词汇表大小只有50,257,而Llama 3.2 1B的词汇表大小为128,256。

扩展阅读

以下是一些可能有用的资源:

  • Zellers等人(2019),《HellaSwag:机器真的能完成你的句子吗?》
  • 来自Hugging Face transformers库文档的“固定长度模型的困惑度”
  • Guo等人(2023)《评估大型语言模型:一项全面综述》

总结

在本文中,你了解了困惑度度量以及如何使用HellaSwag数据集评估语言模型的困惑度。具体来说,你学到了:

  • 困惑度衡量的是模型对下一个词元的平均犹豫程度。
  • 困惑度是对词汇量大小敏感的度量。
  • 计算困惑度意味着计算样本中词元概率的几何平均数。