一文了解模型蒸馏、多种微调方式以及DeepSpeed

301 阅读13分钟

前言: 本文将针对模型蒸馏,四种常用微调方式,以及模型加速框架DeepSpeed做原理讲解。并且各个微调方法附完整代码。

一、模型蒸馏

1、背景

image.png

  • 大型语言模型凭借其理解和生成类人文本的能力,彻底改变了NLP的格局。然而,它们的大小和复杂性往往在部署、速度和成本方面带来挑战。通常对于专门的小众任务,我们最终部署了最佳可用模型,即使我们没有利用其所有功能。这就是蒸馏发挥作用的地方,它提供了一种方法来创建(微调)更小、定制化、更高效的模型,同时保留了一个显著更大型最先进模型的很大一部分性能

  • 蒸馏是一种将大型预训练模型(“教师”)的知识迁移到较小模型(“学生”)的技术,使学生模型能够达到与教师模型相当的性能。这种技术允许用户利用大型 LLMs 的高质量,同时由于较小的学生模型,降低了生产环境中的推理成本

2、步骤

模型蒸馏示意图

  • 首先,学生模型利用教师模型的输出概率(即软标签,Soft Targets)以及真实标签(硬标签,Hard Targets)进行联合训练

  • 训练过程:学生模型通过最小化软标签与自身预测之间的差异(通常使用 Kullback-Leibler 散度或交叉熵)进行优化,同时结合硬标签监督,确保与真实数据的契合度。这种方法使学生模型更好地理解教师的决策逻辑,提升准确性和可靠性,尤其适用于多分类任务

二、多种微调方式

在实际使用场景中有一个非常明显的现象是小公司喜欢用lora,大公司喜欢用全量。而我们在业务需求中全量微调通常受时间限制,上一个模型还未调完,下一个需求就来了,因此我个人通常用LoRA来微调模型

1、LoRA

  • LoRA 的动机是更高效地获得高质量的微调结果,因为全参数微调需要大量内存,而至今为止的其他方法要么不切实际,要么在质量上做出了妥协。高性能 GPU 是一种宝贵的资源,因此更高效的方法可以使微调更加普及,并允许进行更大程度的实验

  • 神经网络中的所有权重都分组到多个层或模块中。你可以将每一层想象成一个可以表示为矩阵的数字集合。这些矩阵非常庞大。对于微调一个 13B 参数的模型,总共有 130 亿个权重需要调整,并且你需要重复这个过程。每当模型处理完一批样本(提示/完成对)后,它就会使用所谓的“损失函数”来计算调整量。这基本上是模型试图对输出应该是什么做出最佳猜测,然后计算它与正确答案的偏差,最后计算如何调整每个权重以便下次更接近。这被称为学习、训练或适应

1.1 LoRA原理

LoRA原理示意图

  • LoRA 与全参数微调有两个根本不同之处:

    • 跟踪权重的变化而不是直接更新权重
    • 将权重变化的大矩阵分解为两个包含“可训练参数”的小矩阵
  • 将一个 1x5 矩阵与一个 5x1 矩阵相乘,得到一个 5x5 矩阵。你也可以反过来进行这个过程,从一个矩阵开始,尝试找到两个矩阵,当它们相乘时,其结果接近原始矩阵的值

  • 5x5 矩阵共有 25 个值,而如果我们计算分解后的矩阵中的值,只有 10 个(5 + 5)

How LoRA Works?

LoRA权重加载至相应模型

  • 在 LoRA 的上下文中,我们将这两个较小的分解矩阵称为“变化矩阵”,因为它们跟踪我们想要对模型权重进行的改变

  • 将我们在处理过程中将模型的权重存储在内存中,但不会直接更新它们。LoRA 论文将这种方式称为“冻结”模型的权重

  • 然后,我们只需将乘以变化矩阵后的结果加到我们要微调的层的原始矩阵上,从而得到一个崭新的微调模型

但是如何选择一个合适的矩阵秩r?

1.2 LoRA正确的秩是多少

  • 如果我们希望我们的 LoRA 方法更加精确,我们可以增加变化矩阵的秩。我们仍然会得到与模型中正在训练的层相同大小的输出矩阵,但我们让每个变化矩阵编码更多信息,这样当它们相乘时,得到的结果数字会更加精确

  • 在微软的 LoRA 仓库中,他们于 2021 年发布的论文中实现了的例子使用了 8 或 16 的秩

模型矩阵拆分成r为2的矩阵相乘

  • 但是我在查阅QLoRA论文后,发现:

QLoRA论文截取

  • 所以,如果你的秩是 8 或更高,它可能根本就不重要

  • 我通常会修改一个变量:LoRA Alpha,当权重变化被加回到原始模型权重中时,它们会被乘以一个按 alpha 除以秩计算出的缩放因子。在 LoRA 代码库中,微软在所有示例中都设置 alpha 为秩的两倍,这意味着权重变化在添加时会加倍。遵循他们的做法,如果你的秩是 8,就从 16 开始设置 alpha。如果秩是 16,就从 32 开始

我的LoRA微调示例代码如下:

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import time
import torch
import json
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq
)
from peft import LoraConfig, get_peft_model
import wandb
unique_id = str(time.strftime('%m%d%H%M%S'))
model_path = "Qwen3-14B"
json_file_path = "gynecology_train.json"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    torch_dtype=torch.float16
)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
data = []
with open(json_file_path, "r", encoding="utf-8") as f:
    data = json.load(f)
if isinstance(data, dict):
    data = [data]
dataset = Dataset.from_list(data)

def preprocess_function(examples):
    inputs = [f"问题: {q}\n答案: " for q in examples["question"]]
    model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=512)
    labels = tokenizer(examples["answer"], truncation=True, padding="max_length", max_length=512)
    # Mask掉padding的label,设置为-100,避免计算loss时参与
    labels_input_ids = [
        [(token if token != tokenizer.pad_token_id else -100) for token in label]
        for label in labels["input_ids"]
    ]
    model_inputs["labels"] = labels_input_ids
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset.column_names)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

training_args = TrainingArguments(
    output_dir="./qwen3_14b_gynecology_checkpoints_" + unique_id,
    num_train_epochs=25,
    per_device_train_batch_size=8,   
    gradient_accumulation_steps=2,   # 显存优化
    save_steps=200,
    save_total_limit=5,
    learning_rate=1e-5,
    logging_steps=10,
    warmup_steps=200,
    weight_decay=0.01,
    fp16=True,
    report_to="wandb",
    deepspeed="deepspeed_config.json",
    dataloader_num_workers=4,
    ddp_find_unused_parameters=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator
)

wandb.init(project='qwen3_14b_gynecology_20250710_2', name=time.strftime('%m%d%H%M%S'))
trainer.train()
model.save_pretrained("./qwen3_14b_gynecology_lora_model_0710_2")
model = model.merge_and_unload()
model.save_pretrained("./qwen3_14b_gynecology_merged_model_0710_2")
tokenizer.save_pretrained("./qwen3_14b_gynecology_merged_model_0710_2")

2、Prefix Tunning

  • Prefix Tunning在输入token之前构造一段任务相关的virtual tokens作为Prefix,然后训练的时候只更新Prefix部分的参数,而PLM中的其他部分参数固定

  • 针对自回归架构模型:在句子前面添加前缀,得到 z = [PREFIX; x; y],合适的上文能够在固定 LM 的情况下去引导生成下文

  • 针对编码器-解码器架构模型:Encoder和Decoder都增加了前缀,得到 z = [PREFIX; x; PREFIX0; y]。Encoder端增加前缀是为了引导输入部分的编码,Decoder 端增加前缀是为了引导后续token的生成

  • 该方法其实和构造Prompt类似,只是Prompt是人为构造的“显式”的提示,并且无法更新参数,而Prefix则是可以学习的“隐式”的提示

  • 为了防止直接更新Prefix的参数导致训练不稳定和性能下降的情况,在Prefix层前面加了MLP结构,训练完成后,只保留Prefix的参数

Prefix Tunning结构图

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import time
import torch
import json
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import PrefixTuningConfig, get_peft_model, TaskType
import wandb
unique_id = str(time.strftime('%m%d%H%M%S'))
model_path = "Qwen3-1.7B"
# 数据集文件路径
json_file_path = "merged_xinnei_dataset_shuffled.jsonl"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.float16
)
from peft import PrefixTuningConfig, get_peft_model, TaskType

prefix_config = PrefixTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    num_virtual_tokens=20,
    encoder_hidden_size=128  # 前缀编码器隐层维度
)
# 应用 Prefix Tuning
model = get_peft_model(model, prefix_config)
model.print_trainable_parameters()

data = []
with open(json_file_path, 'r', encoding='utf-8') as f:
    for line in f:
        data.append(json.loads(line))
        
dataset = Dataset.from_list(data)

def preprocess_function(examples):
    inputs = [f"问题: {q}\n答案: " for q in examples["input"]]
    model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=2048)
    labels = tokenizer(examples["output"], truncation=True, padding="max_length", max_length=2048)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
    
tokenized_dataset = dataset.map(preprocess_function, batched=True)
training_args = TrainingArguments(
    output_dir="./qwen3_1_7b_prefix_checkpoints_" + unique_id,
    num_train_epochs=10,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    save_steps=200,
    save_total_limit=10,
    learning_rate=5e-5,  # Prefix Tuning 通常需要较高的学习率
    prediction_loss_only=True,
    fp16=True,
    logging_steps=10,
    warmup_steps=100,
    weight_decay=0.01,
    dataloader_num_workers=4,
    ddp_find_unused_parameters=False,
    report_to="wandb"
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator
)
wandb.init(project='qwen3_1_7b_xinnei_prefix_20250526', name=unique_id)
trainer.train()
# 保存 Prefix Tuning 模型(仅保存前缀)

model.save_pretrained("./qwen3_14b_xinnei_prefix_model")

tokenizer.save_pretrained("./qwen3_14b_xinnei_prefix_model")

3、Prompt Tunning

  • Prompt Tunning可以看作是Prefix Tuning的简化版本,它给每个任务定义了自己的Prompt,然后拼接到数据上作为输入,但只在输入层加入prompt tokens,并且不需要加入 MLP 进行调整来解决难训练的问题

Prompt Tunning结构图

  • Prompt Tuning 论文中还探讨了 Prompt token 的初始化方法和长度对于模型性能的影响。Prompt token 的长度在20左右时的表现已经不错(超过20之后对模型的性能提升不明显了),这个gap也会随着模型参数规模的提升而减小

Prompt Tunning效果图

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import time
import torch
import json
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import PromptTuningConfig, get_peft_model, TaskType
import wandb
unique_id = str(time.strftime('%m%d%H%M%S'))
model_path = "Qwen3-1.7B"
json_file_path = "merged_xinnei_dataset_shuffled.jsonl"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.float16
)
# 配置 Prompt Tuning
prompt_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    num_virtual_tokens=20,  # 通常为 10~100,需根据模型大小调整
    tokenizer_name_or_path=model_path,
    prompt_tuning_init="TEXT",
    prompt_tuning_init_text="以下是一个问答任务的示例:"
)
# 应用 Prompt Tuning
model = get_peft_model(model, prompt_config)
model.print_trainable_parameters()

data = []
with open(json_file_path, 'r', encoding='utf-8') as f:
    for line in f:
        data.append(json.loads(line))

dataset = Dataset.from_list(data)

def preprocess_function(examples):
    inputs = [f"问题: {q}\n答案: " for q in examples["input"]]
    model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=2048)
    labels = tokenizer(examples["output"], truncation=True, padding="max_length", max_length=2048)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
    
tokenized_dataset = dataset.map(preprocess_function, batched=True)
training_args = TrainingArguments(
    output_dir="./qwen3_14b_prompt_checkpoints_" + unique_id,
    num_train_epochs=10,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    save_steps=200,
    save_total_limit=10,
    learning_rate=5e-5,  # Prompt Tuning 通常需要稍高的学习率
    prediction_loss_only=True,
    fp16=True,
    logging_steps=10,
    warmup_steps=100,
    weight_decay=0.01,
    dataloader_num_workers=4,
    ddp_find_unused_parameters=False,
    report_to="wandb"
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator
)
trainer.train()
model.save_pretrained("./qwen3_14b_xinnei_prompt_model")
tokenizer.save_pretrained("./qwen3_14b_xinnei_prompt_model")

4、P Tunning

  • 在Prompt-Tuning的基础上,对Prompt部分进行进一步的编码计算,加速收敛。具体来说,PEFT中支持两种编码方式,一种是LSTM,一种是MLP。与Prompt-Tuning不同的是,Prompt的形式只有Soft Prompt

  • 经过预训练的LM的词嵌入已经变得高度离散,如果随机初始化virtual token,容易优化到局部最优值,而这些virtual token理论是应该有相关关联的。因此,作者通过实验发现用一个prompt encoder来编码会收敛更快,效果更好。即用一个LSTM+MLP去编码这些virtual token以后,再输入到模型

P Tunning结构图

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import time
import torch
import json
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import PromptEncoderConfig, get_peft_model, TaskType
import wandb
unique_id = str(time.strftime('%m%d%H%M%S'))
model_path = "/data1/zhouwenzhong/.cache/modelscope/hub/models/Qwen/Qwen3-1.7B"
json_file_path = "merged_xinnei_dataset_shuffled.jsonl"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.float16
)

from peft import PromptEncoderConfig, get_peft_model, TaskType
# P-Tuning 配置
p_Tunning_config = PromptEncoderConfig(
    peft_type="P_TUNING",
    task_type=TaskType.CAUSAL_LM,
    num_virtual_tokens=20,
    token_dim=model.config.hidden_size,
    num_transformer_submodules=1,
    encoder_reparameterization_type="MLP",
    encoder_hidden_size=128
)
model = get_peft_model(model, p_Tunning_config)
model.print_trainable_parameters()
data = []
with open(json_file_path, 'r', encoding='utf-8') as f:
    for line in f:
        data.append(json.loads(line))
        
dataset = Dataset.from_list(data)

def preprocess_function(examples):
    inputs = [f"问题: {q}\n答案: " for q in examples["input"]]
    model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=2048)
    labels = tokenizer(examples["output"], truncation=True, padding="max_length", max_length=2048)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
    
tokenized_dataset = dataset.map(preprocess_function, batched=True)
training_args = TrainingArguments(
    output_dir="./qwen3_1_7b_p_Tunning_checkpoints_" + unique_id,
    num_train_epochs=10,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    save_steps=200,
    save_total_limit=10,
    learning_rate=5e-5,
    prediction_loss_only=True,
    fp16=True,
    logging_steps=10,
    warmup_steps=100,
    weight_decay=0.01,
    dataloader_num_workers=4,
    ddp_find_unused_parameters=False,
    report_to="wandb"
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator
)
wandb.init(project='qwen3_1_7b_xinnei_p_Tunning_20250526', name=unique_id)
trainer.train()

model.save_pretrained("./qwen3_14b_xinnei_p_Tunning_model")

上述我总共介绍了4种微调方法,经我个人实验,最有用的仍然是LoRA,当然这仅是使用了相同参数的情况。而不同微调方法当然需要不同参数,Prompt Tunning等方法需要的学习率都需要增加,我由于时间原因并没有再次实验,仅做初步结论

三、微调加速方法:DeepSpeed

DeepSpeed框架由微软团队研发。我们都知道模型微调技术难点并不大,无非是炼丹,通常需要考虑的是手里的显存资源。我在LoRA微调qwen3-14b模型时,batchsize设为4大致单卡消耗显存110G,而我使用DeepSpeed后,单卡batchsize设置为8,单卡显存消耗为60G左右。并且时间从原来的10h优化到2h左右,优化效果非常明显。

我们都知道DeepSpeed使用了数据并行技术,数据并行将模型复制多份至各个 GPU 设备上,但显然这个复制模型的过程将产生较大的显存冗余,为了解决这个问题,有效地降低冗余,可以采用 ZeRO-DP 来取代 DP

截图:每张卡保存的内容

我们看到每张卡要保存:模型参数、梯度以及优化器状态。而优化器状态例如梯度、一阶动量等占用了大部分的空间,这也就是ZeRO优化的目标,而从下往上即:优化器状态,梯度,模型参数则构成了我们的3个stage。ZeRO-DP 通过以下方式解决这种冗余问题:

  • Partitioning optimizer state (分割优化器状态)
  • Partitioning gradients (划分梯度)
  • Partitioning model parameters (分割模型参数) 以上对应了ZeRO的3个阶段,优化目标:

截图:每张卡保存的内容-stage1优化版

声明:以上很多文本以及图片来自于网络截取,本文仅对网络内容加个人实操内容做完整汇总,如有侵权可以联系