Skip to content

Verl

📅 发表于 2025/09/04
🔄 更新于 2025/09/04
👁️ -- 次访问
📝 0 字
0 分钟
infra
#verl
#ray
#ray actor
#ray task
#hybrid flow
#控制流
#计算流
#Single Controller
#Multi Controller
#训练引擎
#推理引擎
#fsdp
#模型放置策略
#分组放置
#Hybrid Engine
#Colocate
#Worker
#WorkerDict
#WorkerGroup
#main_ppo.py
#ray_trainer.py
#core_algos.py
#fsdp_workers.py
#ActorRolloutRefWorker
#update_actor
#generate_sequences
#compute_log_prob
#compute_ref_log_prob
#CriticWorker
#compute_values
#update_critic
#RewardModelWorker
#compute_rm_score

参考文章

核心概念

通用概念

RL

RL

惩罚信号

当前模型生产样本训练自身

Ray

Ray

分布式训练框架,管理复杂的Roles。

Ray Actor

  • 有状态的远程计算任务,进程。
  • ray.remote 装饰的python class

Ray Task

  • 无状态的远程计算任务,局部变量仅当前可见,对任务提交者不可见无状态
  • ray.remote 装饰的python class,

资源管理

  • Ray可自动管理CPU/GPU/Mem的分配,比如指定actor所需资源;设计资源组等。
  • 通过Ray实现各种角色并行策略资源分配,实现Hybrid Enginecolocate策略

异步执行

  • ray 异步计算
    • 执行一个ray计算任务后,立刻返回任务的执行句柄,用户代码不会阻塞
    • 通过ray.get/ray.wait进行阻塞式/轮训式的结果获取
  • RL 异步
    • 方便actor/critic/generator/rm之间overlap一些时间
      • 如:actor更新上一个batch时,generator已经可生成下一个batch了

并行策略

并行策略

3D 并行

FSDP

  • 模型参数(权重/优化器状态等) 在GPU之间分片存储

    • 仅当某个GPU需要其他GPU参数时才进行通信。
  • FPSDP,简单、逻辑清晰,research友好;Megatron:大模型更好

  • Verl同时支持megatron和FSDP

Verl 概念

Hybrid Flow

Hybrid Flow

核心思想

  • RL涉及多个模型交互协作,Verl设计两层HybridFlow,对训练dataflow进行解耦

HyBrid Flow

  • 控制流:high-level,描述多个角色之间的交互逻辑
    • 如:actor采样Critic/RM/Refernce开始计算分数,完成后计算GAE和loss
  • 计算流:low-level,描述单个角色内部的计算流程,管理模型训练和推理具体过程
    • 如:前向、反向传播、优化器更新、自回归生成等。

单控制器vs多控制器

Single Controller vs Multiple Controller 两种设计模式

Single Controller

  • 核心:使用1个中央控制器统一管理所有子模块
  • 优点:架构清晰、容易理解
  • VeRL:控制流Single Controller
    • 易于新算法开发

Multiple Controller

  • 核心:把控制逻辑分散到多个控制器每个控制器负责特定模块
  • 优点:缓解单一控制器通信开销过大的问题集合通信实现各角色的同步控制
  • VeRL:计算流, Multi Controller
    • 通过多层级Worker来实现计算流MultiController
      • RayWorkerGroup -> WorkerDict -> MOdelWorker -> ParallelWorker
  • 业界主流训练推理引擎,多是基于多控制器
    • 训练引擎:FSDP、Megatron等
    • 推理引擎:也有计划逐渐适配该模式,如vLLM、sglang等。

SingleController 来实现RL算法的控制流

训练推理/模型放置策略

训练引擎 vs 推理引擎

训练引擎

  • 训练:Actor、Critic训练

推理引擎

  • 推理:Generator生成样本

  • 仅涉及前向过程,也会使用训练引擎

    • Critic/RM/Reference 打分计算ligits和score

    • 因为训练引擎一般比推理引擎精度更高(因为kernal fusion底层因素)

模型放置策略

分开放置

  • 思想:所有角色放在不同设备
  • 优点:异步overlap执行时间
  • 缺点:GPU在训练过程中会空闲

分组放置

  • 思想:角色分组按组分配在相同设备上
  • 优点:可以overlap时间减少GPU Idel时间
  • 常见分组:
    • 分法1:Actor+Ref,Critc+RM, Generator 单独一组
    • 分法2(Verl主流)Actor+Generator一组其余单独放置
      • 因为Actor和Generator参数需要实时同步

一起放置

  • 思想:所有角色放在相同设备上
  • 优点:GPU时钟被占用
  • 缺点:只能串行执行

Hybrid Engine/WorkerDict/WorkerGroup

VeRL Hybrid Engine

灵活支持 各种模型放置策略

  • 通过ResourcePool支持
  • colocate为主:把actor训练和推理引擎 放置在一起动态切换角色
    • colocate共同放置:把不同功能模块计算单元,放在一个设备上,提高效率

WorkerDict

  • Worker被封装进WorkerDict,实现Worker角色灵活切换

    • 不同角色可放置在相同设备上,通过rebind进行转换,通过reload/offload来切换参数
  • 每个GPU调度1个WorkerDict,主要方便ray管理和角色切换

WorkerGroup

  • 当前colocate的RL角色所占据设备的所有WorkerDict
  • WG管理一组远程运行的workers,colocate的RL角色依托WorkerGroup管理
    • 统一管理数据resharding、任务执行等分布式逻辑
  • WG 作为 SingleController和workers之间的中介,把worker方法绑定到WG上。

高效切换策略:Zero Redundancy Model Resharding

  • 背景:worker可动态切换角色(actor->generator),需不同的参数切分逻辑
  • Verl设计了高效切换策略

数据传输协议

  • 背景:worker可动态切换角色(actor->generator),为适配不同角色和方法所需的数据划分细节
    • 如dp维度切分数据、3d维度切分数据等
  • 设计了一套数据传输协议,主要包括数据分发(Dispatch)收集(Collect)

HybridFlow支持的训推并行策略、权重转换策略、模型放置和执行策略

Verl 训练数据执行流程

训练数据流程

概览

  • verl 把数据传输+方法执行协议,设计为python的装饰器

    • 通过定义decorator绑定给各worker类的具体方法
  • 这样每个workergroup调用workerdict时,便可知如何分发和收集数据

具体流程

  • RayPPOTrainer 向 RayWorkerGroup 发起方法调用
  • RayWorkerGroup内部
    • 先执行数据分发逻辑
    • 执行逻辑判断哪些worker需要运行任务
    • 带有数据的任务被分发给指定的WorkerDicts
  • 任务执行
    • 每个WorkDict通过远程执行接受其任务
    • 完成任务后,结果返回RayWorkerGroup
  • 结果处理
    • 结果通过收集逻辑进行处理 collect protocol
    • 最终,处理后的结果返回给RayPPOTrainer

代码架构

Trainer 组件

main_ppo.py

main.py RL 算法主入口

选择奖励函数 (model or rule)

  • model-based or rule-based
  • RewardManager + 用户自定义的reward_score

选择训练后端 (FSDP or Megatron)

  • FSDP:学术界
  • Megatron:工业界,大规模训练

调用RayPPOTrainer

  • 调用trainer的init_workers初始化rl各角色的workergroup
  • 调用fit进行训练

ray_trainer.py

ray_trainer.py

初始化RL中的各个Role

  • 多个角色:Actor、Critic、RM、Ref等
  • 定义好各模型的角色、resource_pool的定义分配、workerdict和workergroup的初始化和分配

WorkerGroup 机制实现 (每类colocate model group)

  • actor_rollout_wg:actor/generator互相切换的hybrid engine
    • reload/offload params,reshard等
  • critic_wg:支持critic角色
  • ref_policy_wg:reference角色,KL 需要
  • rm_wg:RM,model based reward 需要
  • init_workers 初始化各worker group

ResourcePoolManager

  • 资源池管理,封装ray的placement_group,指定角色合理分配到设备上

PPO loss 依赖的函数

  • apply_kl_penalty:token-level kl reward
    • kl loss, 在core_algos.py
  • compute_advantage:计算优势函数,核心在core_algos.py
  • adv_esitmator:支持PPO/GRPO/Reinforce++/Remax等算法,区别主要在advantages,核心依然在core_algos.py

Timer/Metrics函数等

  • metric计算函数 compute_data_metrics, compute_timing_metrics
  • save/load 断点续训、ckpt保存等
  • validate 逻辑、DP负载均衡逻辑等

Fit/Train loop

  • fit:实现RL的完整训练流程,调用各worker进行实际计算

core_algos.py

core_algos.py

各种loss计算逻辑

  • policy loss, value loss, entropy loss, kl loss

各种advantages计算逻辑

  • 各rl算法区分在advantages estimator 如何实现,
python
class AdvantageEstimator(str, Enum):		
 	GAE = "gae"
    GRPO = "grpo"
    REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
    REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
    REMAX = "remax"
    RLOO = "rloo"
    OPO = "opo"
    GRPO_PASSK = "grpo_passk"
    GPG = "gpg"

main_generator.py

main_eval.py

Workers 组件

workers 文件夹 定义了 RL中

  • 各角色的worker,high level, 负责描述逻辑
  • 各角色计算时实际依赖的worker,low-level,负责描述运算

worker 各角色

fsdp_workers.py

定义RL训练过程中用到的相关Worker,基于实际运行的workers封装的。

ActorRolloutRefWorker

功能

  • 可以单独作为RL中的ActorRolloutReference (负责提供ref_log_prob计算KL)
  • 可以基于hybrid engine,同时扮演多个角色,通过参数进行灵活切换

关键方法

  • update_actor
  • generate_sequences
  • compute_log_prob
  • compute_ref_log_prob

init_model

  • 根据config指定model类型,初始化当前worker
python
if self._is_actor:
    actor_cfg = omega_conf_to_dataclass(self.config.actor)
    self.actor = DataParallelPPOActor(
        config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer
    )

update_actor

  • 基于Actor的update_policy,计算loss并更新Policy模型
  • 基于ulysses_sharding_manager支持序列并行的数据前后处理,实现序列并行
python
data = self.ulysses_sharding_manager.preprocess_data(data=data)
metrics = self.actor.update_policy(data=data)
output = DataProto(meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)

generate_sequences

  • rollout引擎,inference生成数据
python
output = self.rollout.generate_sequences(prompts=prompts)

compute_log_prob

  • 基于actor训练引擎,同步计算old_logprobs,方便重要性采样计算
python
data = self.ulysses_sharding_manager.preprocess_data(data)
# actor 计算 log probs
output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
output = DataProto.from_dict(
  tensors={"old_log_probs": output, "entropys": entropys},
  meta_info={"temperature": self.config.rollout.temperature},
)
output = self.ulysses_sharding_manager.postprocess_data(output)
python
# DataParrallelPPOActor.compute_log_prob
log_probs_lst = []
entropy_lst = []
for micro_batch in micro_batches:
    model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
    with torch.no_grad():
      	# [bs, response_len], [bs, response_len]
        entropy, log_probs = self._forward_micro_batch(
            model_inputs, temperature=temperature, calculate_entropy=calculate_entropy
        )
    log_probs_lst.append(log_probs)
    if calculate_entropy:
        entropy_lst.append(entropy)
log_probs = torch.concat(log_probs_lst, dim=0)
python
# DataParrallelPPOActor._forward_micor_batch

output = self.actor_module(
    input_ids=input_ids_rmpad,
    attention_mask=None,
    position_ids=position_ids_rmpad,
    **multi_modal_inputs,
    use_cache=False,
    **extra_args,
)  # prevent model thinks we are generating

if self.use_fused_kernels:
    log_probs = output.log_probs.squeeze(0)  # (total_nnz,)
    entropy_rmpad = output.entropy.squeeze(0)  # (total_nnz,)

else:
    logits_rmpad = output.logits.squeeze(0)  # (total_nnz, vocab_size)
    logits_rmpad.div_(temperature)

# only return response part:
if calculate_entropy:
    entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1]  # (bsz, response_length)
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1]  # (bsz, response_length)

compute_ref_log_prob

python
data = self.ulysses_sharding_manager.preprocess_data(data)
# ref policy 计算 log probs
output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)
output = DataProto.from_dict(tensors={"ref_log_prob": output})
output = self.ulysses_sharding_manager.postprocess_data(output)
CriticWorker

整体思想

  • 整理逻辑和ActorRolloutRefWorker基本一致,只是后端是DataParallelPPOCritic
  • 无需Rollout,且额外多了compute_values操作,通过Value Head 计算 token-level value

关键方法

  • compute_values
  • update_critic

compute_values

  • 计算values
python
# fsdp_workers.py . CriticWorker
data = self.ulysses_sharding_manager.preprocess_data(data=data)
values = self.critic.compute_values(data=data)
output = DataProto.from_dict(tensors={"values": values})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
python
# DataParallelPPOCritic.compute_values
values_lst = []
for micro_batch in micro_batches:
    model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
    with torch.no_grad():
        values = self._forward_micro_batch(model_inputs)
    values_lst.append(values)
values = torch.concat(values_lst, dim=0)
python
# DataParallelPPOCritic._forward_micro_batch
output = self.critic_module(
    input_ids=input_ids,
    attention_mask=attention_mask,
    position_ids=position_ids,
    **multi_modal_inputs,
    use_cache=False,
)  # prevent model thinks we are generating
if hasattr(self.critic_module, "v_head"):
    # For trl.AutoModelForCausalLMWithValueHead
    values = output[2]
else:
    values = output.logits
values = values[:, -response_length - 1 : -1].squeeze(-1)

update_critic

  • 计算critic loss,更新critic
python
# CriticWorker.update_critic
metrics = self.critic.update_critic(data=data)
python
# 数据切分为多个mini-batch
mini_batches = data.split(self.config.ppo_mini_batch_size)
for _ in range(self.config.ppo_epochs):
  	# 数据可使用多轮ppo更新
    for batch_idx, mini_batch in enumerate(mini_batches):
        # 计算梯度累计步数
        self.gradient_accumulation = (
            self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
        )
        # 划分成固定大小的micro_batchs,循环多个micro_batches才进行1次梯度更新
        micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)

        self.critic_optimizer.zero_grad()

        for micro_batch in micro_batches:
            micro_batch_metrics = {}
            model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
            # 标记出哪些是回复内容,只对回复内容做critic loss
            response_mask = model_inputs["response_mask"]
            # 旧价值,裁剪参考
            values = model_inputs["values"]
            # 目标回报
            returns = model_inputs["returns"]
            # 当前critic的价值预测
            vpreds = self._forward_micro_batch(model_inputs)
            # value funtion loss,MSE loss,以及有多少比例的样本被clip
            vf_loss, vf_clipfrac = core_algos.compute_value_loss(
                vpreds=vpreds,
                values=values,
                returns=returns,
                response_mask=response_mask,
                cliprange_value=self.config.cliprange_value,
                loss_agg_mode=self.config.loss_agg_mode,
            )
            loss = vf_loss / self.gradient_accumulation
            loss.backward()
            micro_batch_metrics.update(
                {
                    "critic/vf_loss": vf_loss.detach().item(),
                    "critic/vf_clipfrac": vf_clipfrac.detach().item(),
                    "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(),
                }
            )

            append_to_dict(metrics, micro_batch_metrics)

        grad_norm = self._optimizer_step()
        mini_batch_metrics = {"critic/grad_norm": grad_norm.detach().item()}
        append_to_dict(metrics, mini_batch_metrics)
self.critic_optimizer.zero_grad()
python
def compute_value_loss(
    vpreds: torch.Tensor,
    returns: torch.Tensor,
    values: torch.Tensor,
    response_mask: torch.Tensor,
    cliprange_value: float,
    loss_agg_mode: str = "token-mean",
):
  	# 基于旧values对预测价值做clip,防止更新幅度过大,和PPO clip 类似
    vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
    # 未裁剪的loss
    vf_losses1 = (vpreds - returns) ** 2
    # 裁剪后的loss
    vf_losses2 = (vpredclipped - returns) ** 2
    # 取大值,选择更大的惩罚,避免因为裁切而意外降低loss
    clipped_vf_losses = torch.max(vf_losses1, vf_losses2)
    # 0.5 MSE常见1/2因子,与梯度形式对应;只统计response_mask=1的位置,仅对回复token训练,忽略指令部分
    vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
    # 统计裁剪占比
    vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
    return vf_loss, vf_clipfrac
RewardModelWorker

核心思想

  • 基于模型的RM打分

关键函数

  • compute_rm_score

compute_rm_score

python
# RewardModelWorker.compute_rm_score
output = []
for micro_batch in micro_batches:
    rm_score = self._forward_micro_batch(micro_batch)
    output.append(rm_score)
scores = torch.cat(output, dim=0)  # (batch_size)
python
# RewardModelWorker._forward_micro_batch
output = self.reward_module(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False)
rm_score = output.logits  # (batch_size, seq_len, 1)
rm_score = rm_score.squeeze(-1)
AsyncActorRolloutRefWorker

核心功能

  • 训练时rollout,
  • 核心函数:generate_sequences

支持不同backend

  • 原生rollout逻辑:logits->softmax->sampling
  • HF TGI rollout逻辑
  • VLLM rollout逻辑
    • 基于third_party修改的vllm engine进行推理
    • repeat 没有采用n_samples参数,而是直接repeat_interleave输入,多次生成
    • 为了保证重要性采样和KL散度的准确性,old_log_probs没有使用vllm引擎结果,而是用训练引擎一起计算
  • SGLang rollout 逻辑

megatron_workers.py

基于megatron后端实现的RL workers
  1. 基于megatron支持4D并行,DP、TP、SP、PP;
  2. 核心逻辑基本和FSDP版本一致,但是底层逻辑需要适配megatron框架

Single Controller 组件

Single Controller

Worker

  • 方便管理worker进程在workergroup进程组内部信息和资源分配等

Resource Pool

  • 管理资源池,包括池内节点和进程信息

WorkerGroup

  • 管理多个worker所组成的workergroup

Decortor

  • 定义各种worker数据分发和函数执行的装饰器

Ray

  • 去管理worker(WorkerDict)和workergroup(RayWorkerGroup)

Models 组件

Models组件

包含常见模型结构

  • transformers
  • llama
  • Qwen2

Utils 组件

Utils组件

Dataset

  • rl, sft, rm 数据集
  • 常见功能
    • 处理各数据集中的key
    • 取出parquet里的prompt序列
    • apply_chat_ml + tokenize后设为input_ids
  • verl的dataset和dataloader没有和训练过程强绑定,可在训练过程中轻松修改dataloader

reward_score

  • Math/code/search r1等不同的rule-grader
总访客数:   ·   总访问量:
PLM's Blog @ 2016 - 2025