基于unsloth训练与部署实践DeepSeek-R1 法律推理模型

99 阅读9分钟

一、背景介绍

DeepSeek-R1 是一款高效的开源语言模型,适用于多种任务。通过 LoRA(Low-Rank Adaptation)技术,我们可以在不改变原始参数的情况下针对法律任务进行优化。本文使用 unsloth/DeepSeek-R1-Distill-Qwen-14B 模型和 kienhoang123/QR-legal 数据集,目标是训练一个能够生成法律推理思路链的模型。

二、训练过程

2.1.环境准备

首先安装依赖:

pip install unsloth transformers torch trl datasets

硬件建议:至少 16GB 显存的 GPU,若资源有限,可启用 4-bit 量化。

2.2.训练代码

以下是完整的训练脚本,每行代码都附有注释:

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# 导入必要的库
import torch  # 用于张量计算和模型训练
from unsloth import FastLanguageModel, is_bfloat16_supported  # unsloth 提供高效模型加载和微调支持
from transformers import TrainingArguments  # 配置训练参数
from trl import SFTTrainer  # 监督式微调训练器
from datasets import load_dataset  # 加载数据集

# 1. 模型加载参数设置
max_seq_length = 2048  # 定义模型支持的最大序列长度
model_name = "unsloth/DeepSeek-R1-Distill-Qwen-14B"  # 指定预训练模型名称
load_in_4bit = True  # 启用 4-bit 量化以减少内存占用

# 加载预训练模型和分词器
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,  # 使用指定的模型名称
    max_seq_length=max_seq_length,  # 设置最大序列长度
    dtype=None,  # 自动选择数据类型(如 float16 或 bfloat16)
    load_in_4bit=load_in_4bit,  # 使用 4-bit 量化加载模型
    token=None  # 公开模型无需认证 token
)
device = "cuda" if torch.cuda.is_available() else "cpu"  # 动态检测设备,若有 GPU 则用 cuda,否则用 CPU
print(f"使用设备:{device}")  # 输出当前使用的设备

# 2. 数据准备与格式化
train_prompt_style = (  # 定义训练用的提示模板,包含问题、思路链和回答
    "你是一位法律专家,具备高级法律推理、案例分析和法律解释能力。\n"
    "请根据以下问题,生成逐步的思路链并回答。\n\n"
    "问题:{}\n"
    "回答:\n"
    "<思路>\n"
    "{}\n"
    "</思路>\n"
    "{}"
)
EOS_TOKEN = tokenizer.eos_token  # 获取分词器的结束标记,用于标记文本结尾

def formatting_prompts_func(examples):  # 定义数据格式化函数,将数据集转为训练格式
    qs = examples.get("Question", examples.get("question", [""] * len(next(iter(examples.values())))))  # 获取问题字段,支持大小写
    cots = examples.get("Complex_CoT", examples.get("cot", [""] * len(qs)))  # 获取思路链字段,若无则用空字符串
    resps = examples.get("Response", examples.get("answer", [""] * len(qs)))  # 获取回答字段,若无则用空字符串
    texts = [train_prompt_style.format(q, cot, resp) + EOS_TOKEN for q, cot, resp in zip(qs, cots, resps)]  # 格式化每个样本
    return {"text": texts}  # 返回格式化后的文本列表

dataset = load_dataset("kienhoang123/QR-legal", "default", split="train[0:500]")  # 加载数据集的前 500 条数据
dataset = dataset.map(formatting_prompts_func, batched=True)  # 对数据集应用格式化函数,使用批处理提高效率
print(f"数据集大小:{len(dataset)}")  # 输出数据集大小,便于确认
print("示例数据:", dataset["text"][0])  # 输出第一个样本以检查格式

# 3. 使用 LoRA 微调模型
model = FastLanguageModel.get_peft_model(
    model,  # 传入原始模型
    r=32,  # LoRA 的秩,控制适配器容量
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],  # 指定微调的 Transformer 层
    lora_alpha=64,  # LoRA 的缩放因子,保持数值稳定性
    lora_dropout=0,  # 不使用 dropout,简化训练
    bias="none",  # 不微调偏置参数
    use_gradient_checkpointing="unsloth",  # 使用 unsloth 的梯度检查点节省内存
    random_state=3407  # 设置随机种子,确保结果可重复
)

# 4. 配置训练参数
batch_size = 2  # 每个设备的批次大小
gradient_accumulation_steps = 4  # 梯度累积步数,累计 4 次更新一次参数
effective_batch_size = batch_size * gradient_accumulation_steps  # 计算实际批次大小
max_steps = int(len(dataset) * 2 / effective_batch_size)  # 动态计算训练步数,覆盖 2 个 epoch

training_args = TrainingArguments(
    per_device_train_batch_size=batch_size,  # 设置每设备批次大小
    gradient_accumulation_steps=gradient_accumulation_steps,  # 设置梯度累积步数
    warmup_steps=5,  # 学习率预热步数
    max_steps=max_steps,  # 总训练步数
    learning_rate=5e-5,  # 设置学习率,较小值避免过拟合
    fp16=not is_bfloat16_supported(),  # 若不支持 bfloat16,则用 fp16
    bf16=is_bfloat16_supported(),  # 若支持 bfloat16,则使用
    logging_steps=10,  # 每 10 步记录一次日志
    optim="adamw_8bit",  # 使用 8-bit AdamW 优化器,节省内存
    weight_decay=0.01,  # 设置权重衰减,防止过拟合
    lr_scheduler_type="linear",  # 使用线性学习率调度器
    seed=3407,  # 设置随机种子
    output_dir="outputs"  # 训练输出目录
)

trainer = SFTTrainer(  # 初始化监督式微调训练器
    model=model,  # 传入微调模型
    tokenizer=tokenizer,  # 传入分词器
    train_dataset=dataset,  # 传入训练数据集
    dataset_text_field="text",  # 指定数据集中的文本字段
    max_seq_length=max_seq_length,  # 设置最大序列长度
    args=training_args  # 传入训练参数
)

# 5. 开始训练并保存模型
print(f"开始训练,步数:{max_steps}")  # 输出训练开始信息
trainer.train()  # 执行训练
new_model_local = "DeepSeek-R1-Legal-COT-merged"  # 定义保存目录
model.save_pretrained_merged(new_model_local, tokenizer, save_method="merged_16bit")  # 保存合并后的模型(16-bit)
tokenizer.save_pretrained(new_model_local)  # 保存分词器
print(f"模型已保存至:{new_model_local}")  # 输出保存路径

2.3.训练要点

  • 数据格式化:通过 formatting_prompts_func 将数据集转为模型可理解的结构,支持大小写字段兼容。
  • LoRA 参数r=32lora_alpha=64 提供足够的容量,同时保持稳定性。
  • 内存优化:4-bit 量化和梯度检查点大幅降低显存需求,适合消费级 GPU。
  • 动态步数:根据数据集大小自动计算 max_steps,确保充分训练。

三、使用模型

3.1. 推理代码

训练完成后,我们可以使用以下脚本加载模型并进行交互式推理:

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# 导入必要的库
import os  # 用于文件和目录操作
import torch  # 用于张量计算和设备管理
from unsloth import FastLanguageModel  # 加载 unsloth 模型
from transformers import AutoTokenizer  # 加载分词器

# 1. 配置与加载模型
merged_model_dir = "DeepSeek-R1-Legal-COT-merged"  # 指定训练后模型的本地目录
max_seq_length = 2048  # 设置最大序列长度,与训练保持一致

if not os.path.exists(merged_model_dir):  # 检查模型目录是否存在
    raise FileNotFoundError(f"错误:目录 {merged_model_dir} 不存在")

model, tokenizer = FastLanguageModel.from_pretrained(  # 加载合并后的模型和分词器
    merged_model_dir,  # 使用本地模型路径
    max_seq_length=max_seq_length,  # 设置最大序列长度
    dtype=None,  # 自动选择数据类型
    load_in_4bit=True  # 启用 4-bit 量化
)
print("模型和分词器加载成功!")  # 输出加载成功的提示

device = "cuda" if torch.cuda.is_available() else "cpu"  # 动态选择设备
print(f"使用设备:{device}")  # 输出当前设备
FastLanguageModel.for_inference(model)  # 切换模型至推理模式

# 2. 定义推理提示模板
prompt_template = (  # 定义推理用的提示模板
    "你是一位法律专家,具备高级法律推理、案例分析和法律解释能力。\n"
    "请根据以下问题,生成逐步的思路链并回答。\n\n"
    "问题:{}\n"
    "回答:\n"
    "<思路>{}</思路>\n"
)

def print_response(response):  # 定义输出解析函数
    print("\n【模型回答】")  # 输出标题
    response_text = response.strip()  # 去除首尾空白
    if "<思路>" in response_text:  # 检查是否包含思路链
        print(response_text)  # 正常输出
    else:
        print("警告:未生成完整思路链,直接输出:")  # 若无思路链,给出警告
        print(response_text)  # 输出完整内容

# 3. 推理函数
def perform_inference(question):  # 定义推理函数
    input_text = prompt_template.format(question, "")  # 格式化输入提示
    inputs = tokenizer(input_text, return_tensors="pt").to(device)  # 编码输入并移动到设备
    outputs = model.generate(  # 调用模型生成回答
        input_ids=inputs.input_ids,  # 输入 token ID
        attention_mask=inputs.attention_mask,  # 注意力掩码
        max_new_tokens=1000,  # 设置最大生成长度
        temperature=0.7,  # 控制生成随机性
        top_p=0.9,  # 核采样参数
        use_cache=True  # 启用缓存加速
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)  # 解码生成结果
    response = response[len(input_text):].strip()  # 去除输入部分,仅保留生成内容
    print_response(response)  # 输出回答

# 4. 主程序:交互式推理
def main():  # 定义主函数
    print("欢迎使用 DeepSeek-R1-Legal-COT 法律推理模型!")  # 输出欢迎信息
    print("请输入中文法律问题(输入 '退出' 或 'exit' 结束):")  # 提示用户输入
    
    while True:  # 进入交互循环
        question = input("> ").strip()  # 获取用户输入并去除首尾空白
        if question.lower() in ["退出", "exit"]:  # 检查退出条件
            print("程序已退出。")  # 输出退出提示
            break  # 退出循环
        if not question:  # 检查输入是否为空
            print("请输入有效问题!")  # 提示用户输入有效问题
            continue  # 继续下一次循环
        
        print(f"\n处理问题:{question}")  # 输出当前处理的问题
        perform_inference(question)  # 执行推理

# 5. 运行程序
if __name__ == "__main__":  # 检查是否为主程序运行
    try:
        main()  # 调用主函数
    except Exception as e:  # 捕获异常
        print(f"程序出错:{e}")  # 输出错误信息

3.2.使用示例

运行脚本后,用户可以输入法律问题并获取回答:

欢迎使用 DeepSeek-R1-Legal-COT 法律推理模型!
请输入中文法律问题(输入 '退出''exit' 结束):
> 合同一方未履行义务,另一方能否解除合同?

处理问题:合同一方未履行义务,另一方能否解除合同?
【模型回答】
<思路>首先,依据《民法典》第563条,合同解除需满足特定条件;其次,判断未履行是否构成根本违约;最后,确认解除程序。</思路>
若一方未履行主要义务且构成根本违约,另一方可依法解除合同,需通知对方。
> 退出
程序已退出。

四、实践经验

  1. 训练优化

    • 数据集较小时(如 500 条),建议增加 epoch 次数或扩充数据。
    • 若显存不足,可减小 batch_size 或增大 gradient_accumulation_steps
  2. 推理调整

    • 若回答过于冗长,调整 max_new_tokens
    • 若内容过于死板,调高 temperature(如 0.9)。
  3. 部署建议

    • 合并后的模型可结合 Ollama 等工具部署为服务。
    • 保存路径需妥善管理,避免覆盖。

五、总结

通过上述步骤,我们成功微调了 DeepSeek-R1 模型,使其能够生成带有思路链的法律回答。unsloth 和 LoRA 的结合极大降低了训练成本,而本地推理脚本提供了灵活的使用方式。希望本文能为你的 LLM 实践提供参。