AReaL:大规模异步强化学习系统
AReaL(A Large-Scale Asynchronous Reinforcement Learning System)是一个开源的、完全异步的强化学习训练系统,旨在解决大型语言模型与智能体模型的对齐问题。它由清华大学交叉信息研究院和蚂蚁集团AReaL团队的成员共同开发。AReaL致力于通过提供训练细节、数据和基础设施,帮助每个人都能轻松且经济地构建自己的AI智能体。它不仅速度快、扩展性强,而且提供了无与伦比的灵活性,支持从数学推理到客户服务代理的多种应用场景。
功能特性
- 全异步架构:实现稳定、高性能的强化学习训练,通过异步的rollout和分布式训练,显著提升吞吐量,达到行业领先的训练速度。
- 多算法支持:原生支持多种PPO系列算法,包括PPO、GRPO、DAPO、RLOO、GSPO等,并允许通过配置轻松切换,无需修改核心代码。
- 灵活的模块化设计:核心组件如训练引擎(
Engine)、Rollout工作流(Workflow)、奖励函数(Reward)、数据集(Dataset)均通过清晰API解耦,方便开发者独立扩展和定制。 - 强大的分布式训练引擎:
- FSDPEngine:基于PyTorch FSDP2,针对稠密模型优化,支持TP、DP、CP并行。
- MegatronEngine:专为超大模型设计,支持复杂的张量、流水线、数据并行。
- ArchonEngine:专为MoE模型优化,原生支持专家并行(EP)、专家张量并行(ETP),是训练稀疏模型的利器。
- 丰富的推理后端集成:无缝集成vLLM和SGLang等高性能推理引擎,为rollout提供快速高效的样本生成。
- 智能体强化学习(Agentic RL):支持通过简单的
base_url替换,将任何与OpenAI SDK兼容的模型或服务集成到RL训练流程中,实现灵活的智能体训练。
安装指南
系统要求
- Python 3.12 或更高版本
- CUDA 支持的环境(可选,但强烈推荐用于训练)
- PyTorch (版本需与CUDA版本匹配)
- uv 包管理工具
安装步骤
-
克隆仓库
git clone https://github.com/inclusionAI/AReaL.git cd AReaL -
使用
uv同步依赖环境 推荐使用uv来创建和管理虚拟环境。# 安装依赖(包含CUDA支持) uv sync --extra cuda # 或者,如果没有CUDA环境 # uv sync # 激活虚拟环境 source .venv/bin/activate -
安装预提交钩子(可选,用于开发)
pre-commit install -
验证安装
uv run python3 areal/tools/validate_installation.py
使用说明
基础训练示例
以下示例展示了如何使用AReaL进行GRPO算法的训练。
import asyncio
from areal.api.cli_args import GRPOConfig, ClusterSpecConfig, ExperimentConfig
from areal.infra import TrainController
async def main():
# 1. 配置训练参数
config = GRPOConfig(
experiment_name="my_first_grpo",
model_name="Qwen/Qwen2.5-1.5B-Instruct",
# ... 配置数据集、超参、集群等
cluster=ClusterSpecConfig(device="cuda"),
)
# 2. 创建并启动训练控制器
controller = TrainController(config)
await controller.run()
if __name__ == "__main__":
asyncio.run(main())
使用不同算法
AReaL的算法行为主要通过配置文件控制。例如,切换到DAPO算法只需修改配置文件或命令行参数。
# config/dapo_example.yaml
algorithm: "dapo" # 指定算法
actor:
eps_clip: 0.2
kl_ctl: 0.0 # DAPO通常不使用KL惩罚
norm:
mean_level: "batch"
std_level: "batch"
...
然后通过Hydra加载配置:
python run.py --config-name dapo_example
核心代码
1. 分布式训练引擎接口 (TrainEngine)
TrainEngine 是所有训练引擎的抽象基类,定义了核心生命周期和分布式通信能力。无论是FSDPEngine还是MegatronEngine,都遵循此接口,保证了上层工作流的统一性。
# 文件: areal/api/engine_api.py
from __future__ import annotations
import abc
from typing import TYPE_CHECKING
import torch.distributed as dist
from areal.api.alloc_mode import ParallelStrategy
class TrainEngine(abc.ABC):
"""训练引擎的抽象基类"""
@abc.abstractmethod
def create_process_group(self, parallel_strategy: ParallelStrategy | None = None):
"""初始化PyTorch分布式通信组。"""
raise NotImplementedError()
@abc.abstractmethod
def initialize(self, *args, **kwargs):
"""初始化分布式训练环境并加载模型。"""
raise NotImplementedError()
@property
@abc.abstractmethod
def data_parallel_group(self) -> dist.ProcessGroup:
"""获取数据并行通信组。"""
raise NotImplementedError()
@property
@abc.abstractmethod
def data_parallel_rank(self) -> int:
"""获取当前进程在数据并行组中的rank。"""
raise NotImplementedError()
# ... 其他抽象方法
2. Rollout 工作流 (RolloutWorkflow)
RolloutWorkflow 是定义数据生成(即交互)逻辑的核心抽象。用户通过实现 arun_episode 方法来定义一次交互:如何使用推理引擎从提示生成回复,并计算奖励。这种设计使得RL的数据收集阶段高度可定制。
# 文件: areal/api/workflow_api.py
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
class RolloutWorkflow(ABC):
"""Rollout工作流的抽象基类"""
@abstractmethod
async def arun_episode(
self, engine: InferenceEngine, data: dict[str, Any]
) -> dict[str, Any] | None:
"""运行单次交互(episode)。
Args:
engine : 用于生成回复的推理引擎。
data : 包含prompt等信息的输入数据。
Returns:
包含生成的回复、奖励、logprobs等信息的字典。
返回 `None` 表示此轨迹被拒绝,不会用于训练。
"""
raise NotImplementedError()
3. 并行策略配置 (ParallelStrategy)
AReaL通过一个统一的数据类来描述复杂的5D并行策略(张量、流水线、数据、上下文、专家并行)。这个配置会被传递到不同的训练引擎中,用于创建相应的分布式通信组,实现了并行策略的声明式配置。
# 文件: areal/api/alloc_mode.py
from dataclasses import dataclass
@dataclass
class ParallelStrategy:
"""5D并行策略,支持张量、流水线、数据、上下文和专家并行。"""
tensor_parallel_size: int = 1 # TP大小
pipeline_parallel_size: int = 1 # PP大小
data_parallel_size: int = 1 # DP大小
context_parallel_size: int = 1 # CP大小
expert_parallel_size: int = 1 # EP大小
def __post_init__(self):
# ... 验证逻辑
pass
@property
def world_size(self) -> int:
"""计算所需的总GPU数量。"""
return (self.tensor_parallel_size *
self.pipeline_parallel_size *
self.data_parallel_size *
self.context_parallel_size *
self.expert_parallel_size)
VLptrkeiNx7vkcnRlAYW2Tg1PZT+AMc6h27M+7N5m0g=