DeepSeek R1本地训练全流程实操指南

43 阅读5分钟

一、 环境搭建不求人

1. 显卡驱动与CUDA适配要点

open-r1明确要求cuda12.4,得先瞅瞅自己机器的显卡驱动版本(如下图),要是版本太老,那可就得升级才能适配适配cuda12.4,我亲测,显卡驱动版本为470以上就能正常运行,我的版本是535。

# 查看自己的显卡版本与cuda是否适配
import torch
print(torch.cuda.is_available()) # True就可以

2. 快速搞定环境安装

与readme里的uv相比,我还是习惯使用conda管理虚拟环境:

1. conda create -n openr1 python=3.11
2. pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
3. pip install vllm==0.7.2
4. pip install flash-attn
5. 切换到open-r1目录执行pip install -e ".[dev]"

二、训练踩坑大避雷

1. 导致OOM的原因有这么多

以grpo训练为例,使用Qwen-14B在A100上训练很容易报错OOM,原因有多种,让我来为大家一一分析:grpo任务可以分为两部分:一部分是模型训练(7卡),一部分是模型推理(1卡),OOM报错的原因就来自这两部分。

  • 训练报错oom:7张A100卡无法实现14B模型的训练。解决方法:修改recipes/accelerate_configs/zero3.yaml,开启offload
  • 推理报错oom:如果vllm版本在0.7.3以下,很容易发生oom,需要修改recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml,调低vllm_gpu_memory_utilization参数值,14B模型可以改为0.2,7B模型可以改为0.5。
  • 推理报错oom:指定vllm推理的max_model_len太长,导致kv caceh需要占用的显存太多。解决方法:修改recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml,调低vllm_max_model_len,注意这个参数是指prompt+模型输出长度,不宜过短,可以调整为4k-8k。默认值是读取基座模型config,比如Qwen-14B默认是32768。

那么如何识别自己的OOM报错是出自训练还是推理呢?直接看报错的GPU卡号,因为默认是最后一张卡用于推理,如下图既然是GPU 7 内存不足,那就推理出了问题,只需要调整上述提到的两个参数即可。

针对Qwen-14B在8卡A100(40G)训练对应的配置文件,我已经调教好了放在本文最后,供大家参考。

2. reward函数的形参命名有讲究

在设计reward函数,有个注意:reward函数声明的形参很重要,不是随便起的,要求与dataset的列名是一致的。比如下面这个reawrd函数的两个形参,completions表示模型生成的内容,ground_truth表示dataset中”ground_truth“列的值,这里的形参ground_truth就是要求与dataset列名字对齐。

import re

def reward_func(completions, ground_truth, **kwargs):
    # Regular expression to capture content inside \boxed{}
    matches = [re.search(r"\boxed{(.*?)}", completion) for completion in completions]
    contents = [match.group(1) if match else "" for match in matches]
    # Reward 1 if the content is the same as the ground truth, 0 otherwise
    return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]

三、DeepSeek R1训练快速开启不迷路

1. 数据先行!准备业务数据要点

离线构造业务数据集data.json,注意字段名为problem与solution,与官方给的示例数据字段名一致,这样可以少去很多改代码的麻烦:

{"problem": "Classify the text into neutral, negative, or positive\nText: I think the food was okay.\nSentiment:\n", "solution": "positive"}
{"problem": "Classify the text into neutral, negative, or positive\nText: I think the food was shit.\nSentiment:\n", "solution": "negative"}

2. 巧妙变身!轻松更改数据读取方式

修改grpo.py中数据读取方式,由读取hub数据改为读取离线数据:

dataset = load_dataset("json", data_files=XXX/data.json)
dataset = dataset["train"].train_test_split(test_size=0.02)

个性定制!手把手自定义reward函数 注意这里函数声明中solution形参要与dataset的字段保持一致:

def accuracy_reward_ours(completions, solution, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol in zip(contents, solution):
        gold_parsed = sol  # 从数据集中读取ground-truth
        if len(gold_parsed) != 0:
            # We require the answer to be provided in correct latex (no malformed operators)
            answer_parsed = re.findall("<answer>(.*?)</answer>",content)  # 从模型输出文本中提取预测答案
            if len(answer_parsed)>0:
                answer_parsed = answer_parsed[0]
                reward = float(1 if answer_parsed==gold_parsed else 0)  # 判断预测结果与真实结果是否一致
            else:
                reward = float(0)
        else:
            # If the gold solution is not parseable, we reward 1 to skip this example
            reward = 1.0
            print("Failed to parse gold solution: ", sol)
        rewards.append(reward)

    return rewards

3. 一键启动!畅爽开启DeepSeek R1训练

ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
    --num_processes=7 src/open_r1/grpo.py \
    --config recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml \
    &> /workspace/user_code/Qwen2.5-14B-Instruct.log

四、能让14B模型在A100上丝滑跑通R1的配置参数大公开

recipes/accelerate_configs/zero3.yaml

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  offload_optimizer_device: "cpu"
  offload_param_device: "cpu"
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml

# Model arguments
model_name_or_path: XXX/models/Qwen2.5-14B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
dataset_name: XXX/dataset/data.json
# Num processes is less by 1 as vLLM is using 1 GPU
num_processes: 7

# GRPO trainer config
reward_funcs:
- accuracy_ours
- format
bf16: true
use_vllm: true
vllm_device: cuda:7
vllm_gpu_memory_utilization: 0.2  # vllm版本在0.7.3以下
vllm_max_model_len: 8000
do_eval: true
eval_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
 use_reentrant: false
hub_model_id: Qwen-2.5-7B-Simple-RL
hub_strategy: every_save
learning_rate: 3.0e-06
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: -1
num_generations: 7
num_train_epochs: 1
output_dir: XXX/Qwen-2.5-7B-Instruct-RL
overwrite_output_dir: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 8
push_to_hub: false
report_to: "none"
save_strategy: "steps"
save_steps: 100
save_total_limit: 2
seed: 42
warmup_ratio: 0.1