一、背景介绍
DeepSeek-R1 是一款高效的开源语言模型,适用于多种任务。通过 LoRA(Low-Rank Adaptation)技术,我们可以在不改变原始参数的情况下针对法律任务进行优化。本文使用 unsloth/DeepSeek-R1-Distill-Qwen-14B 模型和 kienhoang123/QR-legal 数据集,目标是训练一个能够生成法律推理思路链的模型。
二、训练过程
2.1.环境准备
首先安装依赖:
pip install unsloth transformers torch trl datasets
硬件建议:至少 16GB 显存的 GPU,若资源有限,可启用 4-bit 量化。
2.2.训练代码
以下是完整的训练脚本,每行代码都附有注释:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# 导入必要的库
import torch # 用于张量计算和模型训练
from unsloth import FastLanguageModel, is_bfloat16_supported # unsloth 提供高效模型加载和微调支持
from transformers import TrainingArguments # 配置训练参数
from trl import SFTTrainer # 监督式微调训练器
from datasets import load_dataset # 加载数据集
# 1. 模型加载参数设置
max_seq_length = 2048 # 定义模型支持的最大序列长度
model_name = "unsloth/DeepSeek-R1-Distill-Qwen-14B" # 指定预训练模型名称
load_in_4bit = True # 启用 4-bit 量化以减少内存占用
# 加载预训练模型和分词器
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name, # 使用指定的模型名称
max_seq_length=max_seq_length, # 设置最大序列长度
dtype=None, # 自动选择数据类型(如 float16 或 bfloat16)
load_in_4bit=load_in_4bit, # 使用 4-bit 量化加载模型
token=None # 公开模型无需认证 token
)
device = "cuda" if torch.cuda.is_available() else "cpu" # 动态检测设备,若有 GPU 则用 cuda,否则用 CPU
print(f"使用设备:{device}") # 输出当前使用的设备
# 2. 数据准备与格式化
train_prompt_style = ( # 定义训练用的提示模板,包含问题、思路链和回答
"你是一位法律专家,具备高级法律推理、案例分析和法律解释能力。\n"
"请根据以下问题,生成逐步的思路链并回答。\n\n"
"问题:{}\n"
"回答:\n"
"<思路>\n"
"{}\n"
"</思路>\n"
"{}"
)
EOS_TOKEN = tokenizer.eos_token # 获取分词器的结束标记,用于标记文本结尾
def formatting_prompts_func(examples): # 定义数据格式化函数,将数据集转为训练格式
qs = examples.get("Question", examples.get("question", [""] * len(next(iter(examples.values()))))) # 获取问题字段,支持大小写
cots = examples.get("Complex_CoT", examples.get("cot", [""] * len(qs))) # 获取思路链字段,若无则用空字符串
resps = examples.get("Response", examples.get("answer", [""] * len(qs))) # 获取回答字段,若无则用空字符串
texts = [train_prompt_style.format(q, cot, resp) + EOS_TOKEN for q, cot, resp in zip(qs, cots, resps)] # 格式化每个样本
return {"text": texts} # 返回格式化后的文本列表
dataset = load_dataset("kienhoang123/QR-legal", "default", split="train[0:500]") # 加载数据集的前 500 条数据
dataset = dataset.map(formatting_prompts_func, batched=True) # 对数据集应用格式化函数,使用批处理提高效率
print(f"数据集大小:{len(dataset)}") # 输出数据集大小,便于确认
print("示例数据:", dataset["text"][0]) # 输出第一个样本以检查格式
# 3. 使用 LoRA 微调模型
model = FastLanguageModel.get_peft_model(
model, # 传入原始模型
r=32, # LoRA 的秩,控制适配器容量
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], # 指定微调的 Transformer 层
lora_alpha=64, # LoRA 的缩放因子,保持数值稳定性
lora_dropout=0, # 不使用 dropout,简化训练
bias="none", # 不微调偏置参数
use_gradient_checkpointing="unsloth", # 使用 unsloth 的梯度检查点节省内存
random_state=3407 # 设置随机种子,确保结果可重复
)
# 4. 配置训练参数
batch_size = 2 # 每个设备的批次大小
gradient_accumulation_steps = 4 # 梯度累积步数,累计 4 次更新一次参数
effective_batch_size = batch_size * gradient_accumulation_steps # 计算实际批次大小
max_steps = int(len(dataset) * 2 / effective_batch_size) # 动态计算训练步数,覆盖 2 个 epoch
training_args = TrainingArguments(
per_device_train_batch_size=batch_size, # 设置每设备批次大小
gradient_accumulation_steps=gradient_accumulation_steps, # 设置梯度累积步数
warmup_steps=5, # 学习率预热步数
max_steps=max_steps, # 总训练步数
learning_rate=5e-5, # 设置学习率,较小值避免过拟合
fp16=not is_bfloat16_supported(), # 若不支持 bfloat16,则用 fp16
bf16=is_bfloat16_supported(), # 若支持 bfloat16,则使用
logging_steps=10, # 每 10 步记录一次日志
optim="adamw_8bit", # 使用 8-bit AdamW 优化器,节省内存
weight_decay=0.01, # 设置权重衰减,防止过拟合
lr_scheduler_type="linear", # 使用线性学习率调度器
seed=3407, # 设置随机种子
output_dir="outputs" # 训练输出目录
)
trainer = SFTTrainer( # 初始化监督式微调训练器
model=model, # 传入微调模型
tokenizer=tokenizer, # 传入分词器
train_dataset=dataset, # 传入训练数据集
dataset_text_field="text", # 指定数据集中的文本字段
max_seq_length=max_seq_length, # 设置最大序列长度
args=training_args # 传入训练参数
)
# 5. 开始训练并保存模型
print(f"开始训练,步数:{max_steps}") # 输出训练开始信息
trainer.train() # 执行训练
new_model_local = "DeepSeek-R1-Legal-COT-merged" # 定义保存目录
model.save_pretrained_merged(new_model_local, tokenizer, save_method="merged_16bit") # 保存合并后的模型(16-bit)
tokenizer.save_pretrained(new_model_local) # 保存分词器
print(f"模型已保存至:{new_model_local}") # 输出保存路径
2.3.训练要点
- 数据格式化:通过
formatting_prompts_func将数据集转为模型可理解的结构,支持大小写字段兼容。 - LoRA 参数:
r=32和lora_alpha=64提供足够的容量,同时保持稳定性。 - 内存优化:4-bit 量化和梯度检查点大幅降低显存需求,适合消费级 GPU。
- 动态步数:根据数据集大小自动计算
max_steps,确保充分训练。
三、使用模型
3.1. 推理代码
训练完成后,我们可以使用以下脚本加载模型并进行交互式推理:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# 导入必要的库
import os # 用于文件和目录操作
import torch # 用于张量计算和设备管理
from unsloth import FastLanguageModel # 加载 unsloth 模型
from transformers import AutoTokenizer # 加载分词器
# 1. 配置与加载模型
merged_model_dir = "DeepSeek-R1-Legal-COT-merged" # 指定训练后模型的本地目录
max_seq_length = 2048 # 设置最大序列长度,与训练保持一致
if not os.path.exists(merged_model_dir): # 检查模型目录是否存在
raise FileNotFoundError(f"错误:目录 {merged_model_dir} 不存在")
model, tokenizer = FastLanguageModel.from_pretrained( # 加载合并后的模型和分词器
merged_model_dir, # 使用本地模型路径
max_seq_length=max_seq_length, # 设置最大序列长度
dtype=None, # 自动选择数据类型
load_in_4bit=True # 启用 4-bit 量化
)
print("模型和分词器加载成功!") # 输出加载成功的提示
device = "cuda" if torch.cuda.is_available() else "cpu" # 动态选择设备
print(f"使用设备:{device}") # 输出当前设备
FastLanguageModel.for_inference(model) # 切换模型至推理模式
# 2. 定义推理提示模板
prompt_template = ( # 定义推理用的提示模板
"你是一位法律专家,具备高级法律推理、案例分析和法律解释能力。\n"
"请根据以下问题,生成逐步的思路链并回答。\n\n"
"问题:{}\n"
"回答:\n"
"<思路>{}</思路>\n"
)
def print_response(response): # 定义输出解析函数
print("\n【模型回答】") # 输出标题
response_text = response.strip() # 去除首尾空白
if "<思路>" in response_text: # 检查是否包含思路链
print(response_text) # 正常输出
else:
print("警告:未生成完整思路链,直接输出:") # 若无思路链,给出警告
print(response_text) # 输出完整内容
# 3. 推理函数
def perform_inference(question): # 定义推理函数
input_text = prompt_template.format(question, "") # 格式化输入提示
inputs = tokenizer(input_text, return_tensors="pt").to(device) # 编码输入并移动到设备
outputs = model.generate( # 调用模型生成回答
input_ids=inputs.input_ids, # 输入 token ID
attention_mask=inputs.attention_mask, # 注意力掩码
max_new_tokens=1000, # 设置最大生成长度
temperature=0.7, # 控制生成随机性
top_p=0.9, # 核采样参数
use_cache=True # 启用缓存加速
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True) # 解码生成结果
response = response[len(input_text):].strip() # 去除输入部分,仅保留生成内容
print_response(response) # 输出回答
# 4. 主程序:交互式推理
def main(): # 定义主函数
print("欢迎使用 DeepSeek-R1-Legal-COT 法律推理模型!") # 输出欢迎信息
print("请输入中文法律问题(输入 '退出' 或 'exit' 结束):") # 提示用户输入
while True: # 进入交互循环
question = input("> ").strip() # 获取用户输入并去除首尾空白
if question.lower() in ["退出", "exit"]: # 检查退出条件
print("程序已退出。") # 输出退出提示
break # 退出循环
if not question: # 检查输入是否为空
print("请输入有效问题!") # 提示用户输入有效问题
continue # 继续下一次循环
print(f"\n处理问题:{question}") # 输出当前处理的问题
perform_inference(question) # 执行推理
# 5. 运行程序
if __name__ == "__main__": # 检查是否为主程序运行
try:
main() # 调用主函数
except Exception as e: # 捕获异常
print(f"程序出错:{e}") # 输出错误信息
3.2.使用示例
运行脚本后,用户可以输入法律问题并获取回答:
欢迎使用 DeepSeek-R1-Legal-COT 法律推理模型!
请输入中文法律问题(输入 '退出' 或 'exit' 结束):
> 合同一方未履行义务,另一方能否解除合同?
处理问题:合同一方未履行义务,另一方能否解除合同?
【模型回答】
<思路>首先,依据《民法典》第563条,合同解除需满足特定条件;其次,判断未履行是否构成根本违约;最后,确认解除程序。</思路>
若一方未履行主要义务且构成根本违约,另一方可依法解除合同,需通知对方。
> 退出
程序已退出。
四、实践经验
-
训练优化:
- 数据集较小时(如 500 条),建议增加 epoch 次数或扩充数据。
- 若显存不足,可减小
batch_size或增大gradient_accumulation_steps。
-
推理调整:
- 若回答过于冗长,调整
max_new_tokens。 - 若内容过于死板,调高
temperature(如 0.9)。
- 若回答过于冗长,调整
-
部署建议:
- 合并后的模型可结合 Ollama 等工具部署为服务。
- 保存路径需妥善管理,避免覆盖。
五、总结
通过上述步骤,我们成功微调了 DeepSeek-R1 模型,使其能够生成带有思路链的法律回答。unsloth 和 LoRA 的结合极大降低了训练成本,而本地推理脚本提供了灵活的使用方式。希望本文能为你的 LLM 实践提供参。