参考文章
核心概念
通用概念
RL
惩罚信号
当前模型生产样本训练自身
Ray
分布式训练框架,管理复杂的Roles。
Ray Actor
- 有状态的
远程计算任务
,进程。 ray.remote
装饰的python class
Ray Task
- 无状态的远程计算任务,局部变量仅当前可见,对任务提交者不可见,
无状态
- ray.remote 装饰的python class,
资源管理
- Ray可
自动管理CPU/GPU/Mem的分配
,比如指定actor所需资源;设计资源组等。 - 通过Ray实现各种角色、并行策略的
资源分配
,实现Hybrid Engine
等colocate策略
。
异步执行
- ray 异步计算
- 执行一个ray计算任务后,立刻返回任务的执行句柄,
用户代码不会阻塞
- 通过
ray.get/ray.wait
进行阻塞式/轮训式的结果获取
- 执行一个ray计算任务后,立刻返回任务的执行句柄,
- RL 异步
- 方便actor/critic/generator/rm之间,
overlap一些时间
,- 如:actor更新上一个batch时,generator已经可生成下一个batch了
- 方便actor/critic/generator/rm之间,
并行策略
3D 并行
把模型参数(权重/优化器状态等) 在GPU之间
分片存储
- 仅当某个GPU需要其他GPU参数时才进行通信。
FPSDP,简单、逻辑清晰,research友好;Megatron:大模型更好
Verl同时支持megatron和FSDP
Verl 概念
Hybrid Flow
核心思想
- RL涉及多个模型交互协作,Verl设计
两层HybridFlow
,对训练dataflow
进行解耦
HyBrid Flow
- 控制流:high-level,描述
多个角色之间的交互逻辑
- 如:
actor采样
,Critic/RM/Refernce开始计算分数
,完成后计算GAE和loss
等
- 如:
- 计算流:low-level,描述
单个角色内部的计算流程
,管理模型训练和推理具体过程- 如:前向、反向传播、优化器更新、自回归生成等。
单控制器vs多控制器
Single Controller
- 核心:使用1个中央控制器来
统一管理所有子模块
- 优点:架构清晰、容易理解
- VeRL:
控制流
,Single Controller
- 易于新算法开发
Multiple Controller
- 核心:把控制逻辑
分散到多个控制器
,每个控制器负责特定模块; - 优点:
缓解单一控制器通信开销过大的问题
;集合通信实现各角色的同步控制 - VeRL:
计算流
,Multi Controller
- 通过多层级Worker来实现计算流MultiController
RayWorkerGroup
->WorkerDict
->MOdelWorker
->ParallelWorker
- 通过多层级Worker来实现计算流MultiController
- 业界主流训练推理引擎,多是基于多控制器。
- 训练引擎:FSDP、Megatron等
- 推理引擎:也有计划逐渐适配该模式,如vLLM、sglang等。
SingleController 来实现RL算法的控制流

训练推理/模型放置策略
训练引擎
- 训练:
Actor、Critic训练
推理引擎
推理:
Generator生成样本
;仅涉及
前向过程
,也会使用训练引擎Critic/RM/Reference
打分计算ligits和score
等因为训练引擎一般比推理引擎精度更高(因为kernal fusion底层因素)
分开放置
- 思想:
所有角色
放在不同设备
- 优点:异步overlap执行时间
- 缺点:GPU在训练过程中会空闲
分组放置
- 思想:
角色分组
,按组分配
在相同设备上 - 优点:可以overlap时间,减少GPU Idel时间
- 常见分组:
- 分法1:Actor+Ref,Critc+RM, Generator 单独一组
- 分法2(Verl主流):
Actor+Generator一组
,其余单独放置
- 因为Actor和Generator参数需要实时同步
一起放置
- 思想:
所有角色
放在相同设备上
- 优点:GPU时钟被占用
- 缺点:只能串行执行
Hybrid Engine/WorkerDict/WorkerGroup
灵活支持 各种模型放置策略
- 通过ResourcePool支持
- 以colocate为主:把actor的
训练和推理引擎
放置在一起
,动态切换角色- colocate共同放置:把不同功能模块计算单元,放在一个设备上,提高效率
WorkerDict
Worker被封装进
WorkerDict
,实现Worker角色灵活切换
- 不同角色可放置在相同设备上,通过rebind进行转换,通过reload/offload来切换参数
每个GPU
调度1个WorkerDict
,主要方便ray管理和角色切换
WorkerGroup
- 当前colocate的RL角色所占据设备的
所有WorkerDict
- WG管理一组远程运行的workers,colocate的RL角色依托WorkerGroup管理
- 统一管理数据resharding、任务执行等分布式逻辑
- WG 作为 SingleController和workers之间的中介,把worker方法绑定到WG上。
高效切换策略:Zero Redundancy Model Resharding
- 背景:worker可动态切换角色(
actor->generator
),需不同的参数切分逻辑
。 - Verl设计了高效切换策略
数据传输协议
- 背景:worker可动态切换角色(
actor->generator
),为适配不同角色和方法所需的数据划分细节- 如dp维度切分数据、3d维度切分数据等
- 设计了一套数据传输协议,主要包括
数据分发(Dispatch)
和收集(Collect)
HybridFlow支持的训推并行策略、权重转换策略、模型放置和执行策略

Verl 训练数据执行流程
概览
verl 把数据传输+方法执行协议,设计为python的装饰器
- 通过定义decorator绑定给各worker类的具体方法
这样每个workergroup调用workerdict时,便可知如何分发和收集数据
具体流程
- RayPPOTrainer 向 RayWorkerGroup 发起方法调用
- RayWorkerGroup内部
- 先执行数据分发逻辑
- 执行逻辑判断哪些worker需要运行任务
- 带有数据的任务被分发给指定的WorkerDicts
- 任务执行
- 每个WorkDict通过远程执行接受其任务
- 完成任务后,结果返回给RayWorkerGroup
- 结果处理
- 结果通过收集逻辑进行处理 collect protocol
- 最终,处理后的结果返回给RayPPOTrainer

代码架构
Trainer 组件
main_ppo.py
选择奖励函数 (model or rule)
- model-based or rule-based
- RewardManager + 用户自定义的reward_score
选择训练后端 (FSDP or Megatron)
- FSDP:学术界
- Megatron:工业界,大规模训练
调用RayPPOTrainer
- 调用trainer的init_workers初始化rl各角色的workergroup
- 调用fit进行训练
ray_trainer.py
初始化RL中的各个Role
- 多个角色:Actor、Critic、RM、Ref等
- 定义好各模型的角色、resource_pool的定义分配、workerdict和workergroup的初始化和分配
WorkerGroup 机制实现 (每类colocate model group)
actor_rollout_wg
:actor/generator互相切换的hybrid engine- reload/offload params,reshard等
critic_wg
:支持critic角色ref_policy_wg
:reference角色,KL 需要rm_wg
:RM,model based reward 需要- 由
init_workers
初始化各worker group
ResourcePoolManager
- 资源池管理,封装ray的placement_group,指定角色合理分配到设备上
PPO loss 依赖的函数
apply_kl_penalty
:token-level kl rewardkl loss
, 在core_algos.py
compute_advantage
:计算优势函数,核心在core_algos.py
adv_esitmator
:支持PPO/GRPO/Reinforce++/Remax等算法,区别主要在advantages,核心依然在core_algos.py
Timer/Metrics函数等
- metric计算函数 compute_data_metrics, compute_timing_metrics
- save/load 断点续训、ckpt保存等
- validate 逻辑、DP负载均衡逻辑等
Fit/Train loop
fit
:实现RL的完整训练流程,调用各worker进行实际计算
core_algos.py
各种loss计算逻辑
- policy loss, value loss, entropy loss, kl loss
各种advantages计算逻辑
- 各rl算法区分在advantages estimator 如何实现,
class AdvantageEstimator(str, Enum):
GAE = "gae"
GRPO = "grpo"
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
REMAX = "remax"
RLOO = "rloo"
OPO = "opo"
GRPO_PASSK = "grpo_passk"
GPG = "gpg"
main_generator.py
main_eval.py
Workers 组件
workers 文件夹 定义了 RL中
- 各角色的worker,high level, 负责描述逻辑
- 各角色计算时实际依赖的worker,low-level,负责描述运算
worker 各角色

fsdp_workers.py
定义RL训练过程中用到的相关Worker,基于实际运行的workers封装的。
功能
- 可以单独作为RL中的Actor、Rollout、Reference (负责提供ref_log_prob计算KL)
- 可以基于hybrid engine,同时扮演多个角色,通过参数进行灵活切换
关键方法
- update_actor
- generate_sequences
- compute_log_prob
- compute_ref_log_prob
init_model
- 根据config指定model类型,初始化当前worker
if self._is_actor:
actor_cfg = omega_conf_to_dataclass(self.config.actor)
self.actor = DataParallelPPOActor(
config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer
)
update_actor
- 基于Actor的update_policy,计算loss并更新Policy模型
- 基于ulysses_sharding_manager支持序列并行的数据前后处理,实现序列并行
data = self.ulysses_sharding_manager.preprocess_data(data=data)
metrics = self.actor.update_policy(data=data)
output = DataProto(meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
generate_sequences
- rollout引擎,inference生成数据
output = self.rollout.generate_sequences(prompts=prompts)
compute_log_prob
- 基于actor训练引擎,同步计算old_logprobs,方便重要性采样计算
data = self.ulysses_sharding_manager.preprocess_data(data)
# actor 计算 log probs
output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
output = DataProto.from_dict(
tensors={"old_log_probs": output, "entropys": entropys},
meta_info={"temperature": self.config.rollout.temperature},
)
output = self.ulysses_sharding_manager.postprocess_data(output)
# DataParrallelPPOActor.compute_log_prob
log_probs_lst = []
entropy_lst = []
for micro_batch in micro_batches:
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
with torch.no_grad():
# [bs, response_len], [bs, response_len]
entropy, log_probs = self._forward_micro_batch(
model_inputs, temperature=temperature, calculate_entropy=calculate_entropy
)
log_probs_lst.append(log_probs)
if calculate_entropy:
entropy_lst.append(entropy)
log_probs = torch.concat(log_probs_lst, dim=0)
# DataParrallelPPOActor._forward_micor_batch
output = self.actor_module(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
**multi_modal_inputs,
use_cache=False,
**extra_args,
) # prevent model thinks we are generating
if self.use_fused_kernels:
log_probs = output.log_probs.squeeze(0) # (total_nnz,)
entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,)
else:
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
logits_rmpad.div_(temperature)
# only return response part:
if calculate_entropy:
entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
compute_ref_log_prob
data = self.ulysses_sharding_manager.preprocess_data(data)
# ref policy 计算 log probs
output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)
output = DataProto.from_dict(tensors={"ref_log_prob": output})
output = self.ulysses_sharding_manager.postprocess_data(output)
整体思想
- 整理逻辑和ActorRolloutRefWorker基本一致,只是后端是DataParallelPPOCritic
- 无需Rollout,且额外多了compute_values操作,通过Value Head 计算 token-level value
关键方法
- compute_values
- update_critic
compute_values
- 计算values
# fsdp_workers.py . CriticWorker
data = self.ulysses_sharding_manager.preprocess_data(data=data)
values = self.critic.compute_values(data=data)
output = DataProto.from_dict(tensors={"values": values})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
# DataParallelPPOCritic.compute_values
values_lst = []
for micro_batch in micro_batches:
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
with torch.no_grad():
values = self._forward_micro_batch(model_inputs)
values_lst.append(values)
values = torch.concat(values_lst, dim=0)
# DataParallelPPOCritic._forward_micro_batch
output = self.critic_module(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**multi_modal_inputs,
use_cache=False,
) # prevent model thinks we are generating
if hasattr(self.critic_module, "v_head"):
# For trl.AutoModelForCausalLMWithValueHead
values = output[2]
else:
values = output.logits
values = values[:, -response_length - 1 : -1].squeeze(-1)
update_critic
- 计算critic loss,更新critic
# CriticWorker.update_critic
metrics = self.critic.update_critic(data=data)
# 数据切分为多个mini-batch
mini_batches = data.split(self.config.ppo_mini_batch_size)
for _ in range(self.config.ppo_epochs):
# 数据可使用多轮ppo更新
for batch_idx, mini_batch in enumerate(mini_batches):
# 计算梯度累计步数
self.gradient_accumulation = (
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
)
# 划分成固定大小的micro_batchs,循环多个micro_batches才进行1次梯度更新
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
self.critic_optimizer.zero_grad()
for micro_batch in micro_batches:
micro_batch_metrics = {}
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
# 标记出哪些是回复内容,只对回复内容做critic loss
response_mask = model_inputs["response_mask"]
# 旧价值,裁剪参考
values = model_inputs["values"]
# 目标回报
returns = model_inputs["returns"]
# 当前critic的价值预测
vpreds = self._forward_micro_batch(model_inputs)
# value funtion loss,MSE loss,以及有多少比例的样本被clip
vf_loss, vf_clipfrac = core_algos.compute_value_loss(
vpreds=vpreds,
values=values,
returns=returns,
response_mask=response_mask,
cliprange_value=self.config.cliprange_value,
loss_agg_mode=self.config.loss_agg_mode,
)
loss = vf_loss / self.gradient_accumulation
loss.backward()
micro_batch_metrics.update(
{
"critic/vf_loss": vf_loss.detach().item(),
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(),
}
)
append_to_dict(metrics, micro_batch_metrics)
grad_norm = self._optimizer_step()
mini_batch_metrics = {"critic/grad_norm": grad_norm.detach().item()}
append_to_dict(metrics, mini_batch_metrics)
self.critic_optimizer.zero_grad()
def compute_value_loss(
vpreds: torch.Tensor,
returns: torch.Tensor,
values: torch.Tensor,
response_mask: torch.Tensor,
cliprange_value: float,
loss_agg_mode: str = "token-mean",
):
# 基于旧values对预测价值做clip,防止更新幅度过大,和PPO clip 类似
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
# 未裁剪的loss
vf_losses1 = (vpreds - returns) ** 2
# 裁剪后的loss
vf_losses2 = (vpredclipped - returns) ** 2
# 取大值,选择更大的惩罚,避免因为裁切而意外降低loss
clipped_vf_losses = torch.max(vf_losses1, vf_losses2)
# 0.5 MSE常见1/2因子,与梯度形式对应;只统计response_mask=1的位置,仅对回复token训练,忽略指令部分
vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
# 统计裁剪占比
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
return vf_loss, vf_clipfrac
核心思想
- 基于模型的RM打分
关键函数
- compute_rm_score
compute_rm_score
# RewardModelWorker.compute_rm_score
output = []
for micro_batch in micro_batches:
rm_score = self._forward_micro_batch(micro_batch)
output.append(rm_score)
scores = torch.cat(output, dim=0) # (batch_size)
# RewardModelWorker._forward_micro_batch
output = self.reward_module(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False)
rm_score = output.logits # (batch_size, seq_len, 1)
rm_score = rm_score.squeeze(-1)
核心功能
- 训练时rollout,
- 核心函数:generate_sequences
支持不同backend
- 原生rollout逻辑:logits->softmax->sampling
- HF TGI rollout逻辑
- VLLM rollout逻辑
- 基于third_party修改的vllm engine进行推理
- repeat 没有采用n_samples参数,而是直接repeat_interleave输入,多次生成
- 为了保证重要性采样和KL散度的准确性,old_log_probs没有使用vllm引擎结果,而是用训练引擎一起计算
- SGLang rollout 逻辑
megatron_workers.py
- 基于megatron支持4D并行,DP、TP、SP、PP;
- 核心逻辑基本和FSDP版本一致,但是底层逻辑需要适配megatron框架
Single Controller 组件
Worker
- 方便管理worker进程在workergroup进程组内部信息和资源分配等
Resource Pool
- 管理资源池,包括池内节点和进程信息
WorkerGroup
- 管理多个worker所组成的workergroup
Decortor
- 定义各种worker数据分发和函数执行的装饰器
Ray
- 去管理worker(WorkerDict)和workergroup(RayWorkerGroup)
Models 组件
包含常见模型结构
- transformers
- llama
- Qwen2
Utils 组件
Dataset
- rl, sft, rm 数据集
- 常见功能
- 处理各数据集中的key
- 取出parquet里的prompt序列
- apply_chat_ml + tokenize后设为input_ids
- verl的dataset和dataloader没有和训练过程强绑定,可在训练过程中轻松修改dataloader
reward_score
- Math/code/search r1等不同的rule-grader