手把手教你在预训练权重中嵌入多 Token 预测加速

5 阅读7分钟

在自然语言处理(NLP)领域,预训练语言模型已经成为主流,它们在各种任务中都取得了显著的成就。然而,这些模型在生成长序列时,通常采用自回归的方式,即一个接一个地预测 Token。这种方式效率较低,尤其是在需要快速生成大量文本的场景中。

本文将深入探讨如何在预训练权重中嵌入多 Token 预测(Multi-Token Prediction, MTP)加速技术。MTP 允许模型在一次前向传播中预测多个 Token,从而显著提高生成速度。我们将从理论基础开始,逐步引导你实现这一优化。

1. 理解自回归生成与多 Token 预测的差异

在深入 MTP 之前,我们首先回顾一下传统的自回归生成。

自回归生成

在自回归生成中,模型在时间步 t 预测 Token y_t 时,会依赖于之前所有已生成的 Token y_1, ..., y_{t-1}

y_t = Model(y_1, ..., y_{t-1})

这种方法简单直观,但其固有的串行性限制了生成速度。

多 Token 预测 (MTP)

MTP 的核心思想是打破这种串行性,让模型在一次前向传播中同时预测 k 个 Token。

y_t, y_{t+1}, ..., y_{t+k-1} = MultiTokenModel(y_1, ..., y_{t-1})

为了实现这一点,我们需要对模型的架构和训练策略进行一些调整。

2. 核心思想:修改模型输出层与损失函数

要实现 MTP,最直接的方法是修改模型的输出层,使其能够同时输出多个 Token 的预测。

2.1 扩展输出层

假设我们有一个预训练的 Transformer 解码器模型。通常,它的输出层是一个线性层,将隐藏状态映射到词汇表大小的 logits。

# 原始输出层
self.lm_head = nn.Linear(hidden_size, vocab_size)

为了支持预测 k 个 Token,我们可以将输出层扩展为 k 个独立的线性层,或者一个更大的线性层,其输出维度为 k * vocab_size

方案一:多个独立的线性层

# MTP 输出层 (k个独立的线性层)
self.lm_heads = nn.ModuleList([
    nn.Linear(hidden_size, vocab_size) for _ in range(k)
])

在这种方案中,每个线性层负责预测一个位置的 Token。

方案二:一个扩展的线性层

# MTP 输出层 (一个扩展的线性层)
self.lm_head_mtp = nn.Linear(hidden_size, k * vocab_size)

然后,我们需要将输出的 logits reshape 为 (batch_size, k, vocab_size)

2.2 调整损失函数

在训练阶段,我们需要调整损失函数来同时考虑这 k 个 Token 的预测。通常,我们可以使用交叉熵损失,并对这 k 个预测位置的损失进行平均或求和。

# MTP 损失计算示例
total_loss = 0
for i in range(k):
    # 获取第 i 个 Token 的预测 logits 和真实标签
    logits_i = predicted_logits[:, i, :]
    labels_i = target_labels[:, i]
    
    loss_i = F.cross_entropy(logits_i, labels_i)
    total_loss += loss_i

# 可以选择平均或求和
# final_loss = total_loss / k
final_loss = total_loss

3. 在预训练权重中嵌入 MTP:迁移学习策略

将 MTP 能力嵌入到已有的预训练模型中,需要采用迁移学习的策略。我们不希望从头开始训练,而是希望利用预训练模型的强大能力。

3.1 模型加载与权重初始化

首先,加载你选择的预训练模型(例如 BERT、GPT-2 等)。然后,你需要修改其输出层以适应 MTP。

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn as nn
import torch.nn.functional as F
import torch

# 假设使用 GPT-2
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
original_model = AutoModelForCausalLM.from_pretrained(model_name)

# 定义 MTP 的 k 值,例如每次预测 3 个 Token
k = 3
hidden_size = original_model.config.hidden_size
vocab_size = original_model.config.vocab_size

# 创建 MTP 版本的模型
class GPT2ForMultiTokenPrediction(nn.Module):
    def __init__(self, original_model, k):
        super().__init__()
        self.transformer = original_model.transformer
        self.lm_head_mtp = nn.Linear(hidden_size, k * vocab_size)
        
        # 复制原始 lm_head 的权重到第一个 MTP 头
        with torch.no_grad():
            self.lm_head_mtp.weight[:vocab_size, :] = original_model.lm_head.weight
            self.lm_head_mtp.bias[:vocab_size] = original_model.lm_head.bias
            
            # 为后续的 k-1 个 MTP 头初始化权重 (例如,可以随机初始化或复制第一个头)
            for i in range(1, k):
                self.lm_head_mtp.weight[i*vocab_size:(i+1)*vocab_size, :].normal_(mean=0.0, std=0.02)
                self.lm_head_mtp.bias[i*vocab_size:(i+1)*vocab_size].zero_()

    def forward(self, input_ids, labels=None):
        outputs = self.transformer(input_ids)
        hidden_states = outputs.last_hidden_state
        
        # 取最后一个 Token 的隐藏状态进行预测
        last_hidden_state = hidden_states[:, -1, :] 
        
        # 通过 MTP 输出层
        logits_mtp = self.lm_head_mtp(last_hidden_state)
        logits_mtp = logits_mtp.view(-1, k, vocab_size) # (batch_size, k, vocab_size)
        
        loss = None
        if labels isnotNone:
            # 损失计算与 MTP 目标对齐
            # labels 形状应为 (batch_size, k)
            loss = F.cross_entropy(logits_mtp.view(-1, vocab_size), labels.view(-1), ignore_index=tokenizer.pad_token_id)
        
        return logits_mtp, loss

mtp_model = GPT2ForMultiTokenPrediction(original_model, k)

这是一个关于模型结构改变和权重的可视化,展示了如何从原始的单 Token 预测头扩展到多个 Token 预测头,并如何初始化新添加的权重部分。

4. 训练策略:微调 MTP 模型

由于我们修改了模型结构,需要对模型进行微调,使其适应 MTP 任务。

4.1 数据准备

你需要准备包含长文本序列的数据集。对于 MTP 训练,每个输入序列的标签将不再是一个 Token,而是 k 个连续的 Token。

# 示例数据准备 (简化版)
text = "The quick brown fox jumps over the lazy dog."
tokenized_text = tokenizer.encode(text, return_tensors="pt")

# 创建输入和目标标签
input_ids = tokenized_text[:, :-k] # 假设输入到倒数第k个
labels = tokenized_text[:, -k:] # 目标是最后 k 个 Token

print("Input IDs:", input_ids)
print("Labels:", labels)

4.2 微调过程

微调过程与标准的 Transformer 模型训练类似,但需要确保损失函数正确计算了多 Token 预测的损失。

from torch.optim import AdamW

# 示例训练循环 (简化版)
optimizer = AdamW(mtp_model.parameters(), lr=1e-5)
device = torch.device("cuda"if torch.cuda.is_available() else"cpu")
mtp_model.to(device)

# 假设我们有一个 DataLoader
# for batch in dataloader:
#     input_ids, labels = batch['input_ids'].to(device), batch['labels'].to(device)

# 模拟一个训练步骤
input_ids_example = torch.tensor([[100, 200, 300, 400, 500, 600, 700]]).to(device) # 示例输入
labels_example = torch.tensor([[800, 900, 1000]]).to(device) # 示例标签 (k=3)

mtp_model.train()
optimizer.zero_grad()

logits, loss = mtp_model(input_ids_example, labels_example)

loss.backward()
optimizer.step()

print(f"Loss: {loss.item()}")

在实际训练中,你需要构建一个适当的数据集和 DataLoader,并进行多个 epoch 的训练。训练过程中,建议使用较小的学习率,并监控验证集上的表现,以避免过拟合。

5. 加速推理:MTP 在生成中的应用

微调完成后,MTP 模型就可以用于加速文本生成了。

5.1 贪婪解码与采样

在生成过程中,每次预测 k 个 Token,然后将这 k 个 Token 添加到已生成的序列中,作为下一次预测的输入。

# 示例生成函数
def generate_mtp(model, tokenizer, prompt, max_length, k, device):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    generated_ids = input_ids
    
    with torch.no_grad():
        for _ in range(max_length // k): # 假设生成 max_length 个 Token
            
            # 使用最后一个 Token 的上下文进行预测
            # 注意:这里需要根据具体模型结构调整,可能需要整个序列的隐藏状态
            # 目前我们的示例 MTP 模型只接受最后一个 Token 的隐藏状态
            logits, _ = model(generated_ids) 
            
            # 从 logits 中选择 k 个 Token (贪婪解码)
            predicted_token_ids = torch.argmax(logits[0], dim=-1) # (k,)
            
            generated_ids = torch.cat([generated_ids, predicted_token_ids.unsqueeze(0)], dim=-1)
            
            if generated_ids.shape[1] >= max_length:
                break
    
    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

# 运行生成示例
prompt = "Once upon a time,"
generated_text = generate_mtp(mtp_model, tokenizer, prompt, max_length=50, k=k, device=device)
print(f"Generated text: {generated_text}")

这是一个可视化的生成过程,对比了自回归生成和 MTP 生成的步骤,突出了 MTP 如何通过并行预测多个 Token 来减少迭代次数。

6. 进阶考量与优化

6.1 Beam Search 与 MTP

将 MTP 与 Beam Search 结合可以进一步提升生成质量。在 Beam Search 的每一步中,不再只考虑下一个 Token 的 top-k 可能性,而是考虑 k 个 Token 组合的 top-k 可能性。这会增加计算复杂度,但可能带来更好的生成效果。

6.2 MTP 训练中的数据对齐

在 MTP 训练中,输入序列的长度和目标序列的长度需要仔细对齐。确保模型在预测 k 个 Token 时,其输入上下文是正确的。

6.3 动态 MTP

固定 k 值可能不是最优的。在某些情况下,模型可能可以预测更多的 Token,而在另一些情况下,则只能预测少数 Token。动态 MTP 允许模型根据当前上下文自适应地调整 k 值。这通常需要更复杂的模型架构和训练策略。

6.4 预训练与 MTP

如果条件允许,从头开始预训练一个支持 MTP 的模型将获得更好的效果。这涉及到修改原始模型的预训练任务,使其在预训练阶段就学习同时预测多个 Token。例如,在掩码语言模型(Masked Language Model, MLM)中,可以同时掩码 k 个连续的 Token,并让模型预测它们。