专栏
简介
- 蒸馏大模型(Model Distillation)是一种将复杂、庞大的预训练模型(通常称为“教师模型”)的知识压缩到更小、更高效的模型(“学生模型”)中的技术。其核心思想是通过模仿教师模型的输出或中间特征,使学生模型在保持较高性能的同时,显著减少参数量和计算开销。蒸馏过程通常利用软标签(soft labels),即教师模型输出的概率分布,而非硬标签(hard labels),因为软标签包含了类别间的相对关系,能传递更多信息。
- 蒸馏大模型的优势在于,它能够在资源受限的设备(如移动设备或嵌入式系统)上部署高性能的深度学习模型,同时降低推理时间和能耗。例如,BERT 等大型语言模型可以通过蒸馏生成 TinyBERT 或 DistilBERT 等轻量级版本,在保持大部分性能的前提下大幅压缩模型规模。蒸馏技术还可与其他优化方法(如量化、剪枝)结合,进一步提升效率。
- 标注处理方式中会有一个学生模型通常比较小,会学习老师模型通常比较大,数据集如果在老师模型提前微调过通过效果比较好,学生模型通过学习老师模型的logits

标准知识蒸馏实践
- 这边使用过torchtune介绍的蒸馏方式,实战可能是参数没调整好,导致loss比较好,所以转战llm-dojo蒸馏项目,效果不错,不过还是简单提下torchtune的蒸馏KD的使用方式可能别的场景比较合适
torchtune蒸馏
conda create --name torchtune python=3.10
conda activate torchtune
pip install torch torchvision torchao
pip install torchtune
tune

- 建议git clone源码下来很来配置文件还是很不错的
- 执行下面命令, 我的显存比较低,批次啥都按照最低配置来的,这边的config是github项目上的配置文件我的是8B蒸馏到1B的模型配置,理论上72B蒸馏到1B也可以,只是显存要求更高,配置改动更大,下面代码2张A10 24G可以跑。我的老师模型8B因为之前微调过checkpoint_files不一样,dataset.data_files按照你自己的需要微调数据,我的是大模型私有化部署实践(四):打造符合自身业务的垂类模型这里面的数据,model.lora_attn_modules我配置了多个模块都要微调
tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/8B_to_1B_KD_lora_distributed \
teacher_checkpointer.checkpoint_files='[
"model-00001-of-00009.safetensors",
"model-00002-of-00009.safetensors",
"model-00003-of-00009.safetensors",
"model-00004-of-00009.safetensors",
"model-00005-of-00009.safetensors",
"model-00006-of-00009.safetensors",
"model-00007-of-00009.safetensors",
"model-00008-of-00009.safetensors",
"model-00009-of-00009.safetensors"
]' \
optimizer.lr=10e-5 \
kd_ratio=0.1 \
dataset=torchtune.datasets.alpaca_dataset \
dataset.source=json \
dataset.data_files=/usr/local/alpaca_zh_demo.json \
batch_size=2 \
output_dir=/usr/local/torchtune/llama3_2_8B_to_1B/KD_lora_distributed \
tokenizer.path=/usr/local/Llama-3.2-1B-Instruct/Llama-3.2-1B-Instruct/original/tokenizer.model \
checkpointer.checkpoint_dir=/usr/local/Llama-3.2-1B-Instruct/Llama-3.2-1B-Instruct/ \
teacher_checkpointer.checkpoint_dir=/usr/local/Meta-Llama-3.1-8B-Instruct/ \
teacher_checkpointer.output_dir=/usr/local/Meta-Llama-3.1-8B-Instruct \
model.lora_rank=8 \
model.lora_alpha=16 \
model.lora_attn_modules="['q_proj', 'v_proj', 'output_proj', 'k_proj']" \
mode.apply_lora_to_output=True \
epochs=3 \
warmup_steps=0
LLM-DOJO微调
- 本质上是trl项目,这边用的GKD(使用student模型生成output,这就属于on policy的方式了,模型每次更新权重后都会进行output输出,再反馈到训练中。)蒸馏方式
- 安装
cd /usr/local/src
git clone https://github.com/mst272/LLM-Dojo.git
conda create --name llmdojo python=3.10
conda activate llmdojo
pip install -r requirements.txt
pip install -U git+https://github.com/huggingface/trl.git
- 使用, CUDA_VISIBLE_DEVICES=0,1是因为我有两张显卡accelerate使用deepspeed分布式多机多卡大模型私有化部署实践(三):使用DeepSpeed多机多卡训练,当然你也可以直接执行train_gkd.py。注意lmbda,beta参数可调整,lmbda:0时为Supervised KD,1时为GKD。可在[0,1]范围内选择,这样就会混合比例,beta: 0时loss为KLD, 1时为JSD。可在[0,1]范围内选择,这样就会混合比例
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file /usr/local/src/LLM-Dojo/rlhf/ds_config/ds_zero2_v2.yaml /usr/local/src/LLM-Dojo/rlhf/train_gkd.py \
--model_name_or_path /usr/local/Llama-3.2-1B-Instruct/Llama-3.2-1B-Instruct \
--teacher_model_name_or_path /usr/local/llama-3-7B-lora-trans-export \
--dataset_name /usr/local/src/LLM-Dojo/converted_data.json \
--learning_rate 2e-5 \
--per_device_train_batch_size 3 \
--gradient_accumulation_steps 2 \
--output_dir gkd-model2 \
--logging_steps 2 \
--dataset_batch_size 3 \
--num_train_epochs 1 \
--gradient_checkpointing \
--lmbda 0.5 \
--beta 0.5 \
--use_peft \
--lora_r 8 \
--lora_alpha 16 \
--trust_remote_code \
--bf16 \
--save_strategy "steps" \
--save_steps 180 \
--save_total_limit 5 \
--warmup_steps 0 \
--lr_scheduler_type "cosine" \
--seq_kd True \
--torch_dtype auto
- 训练的数据结构长这样,训练数据集是从网上下载的,让gpt帮忙写了python脚本简单处理了下,老师模型也是经过这份数据训练的
# /usr/local/src/LLM-Dojo/converted_data.json
[ { "messages": [ { "role": "system", "content": "帮我把下面中文翻译成英文" }, { "role": "user", "content": "一种做蚕茧的结构, 另一种充当粘合剂或是基质, 然后可以将这些纤维聚到一起," }, { "role": "assistant", "content": "" } ],
"prompt": [
{
"role": "system",
"content": "帮我把下面中文翻译成英文"
},
{
"role": "user",
"content": "一种做蚕茧的结构, 另一种充当粘合剂或是基质, 然后可以将这些纤维聚到一起,"
}
]
},
{
"messages": [
{
"role": "system",
"content": "帮我把下面中文翻译成英文"
},
{
"role": "user",
"content": "很好,那是一幅很酷的画作。"
},
{
"role": "assistant",
"content": ""
}
],
"prompt": [
{
"role": "system",
"content": "帮我把下面中文翻译成英文"
},
{
"role": "user",
"content": "很好,那是一幅很酷的画作。"
}
]
}
]
参考文章