Skip to content

Verl 训练流程源代码阅读

📅 发表于 2025/12/01
🔄 更新于 2025/12/01
👁️ -- 次访问
📝 0 字
0 分钟
verl
#main_ppo
#RayPPOTrainer.fit
#Rollout
#Reward 计算
#Old LogProb计算
#Ref LogProb计算
#CriticValues计算
#优势计算
#create_rl_dataset
#_build_messages
#_create_dataloader
#Validate
#generate_sequences
#val_reward_fn
#process_validation_metrics
#compute_advantage
#compute_gae_advantage_return
#compute_grpo_outcome_advantage
#MegatronPPOActor.compute_log_prob
#MegatronPPOActor.forward_backward_batch
#forward_step
#计算熵
#计算label logprobs
#Policy Loss 计算
#PG Loss计算
#compute_policy_loss_vanilla
#agg_loss
#token-mean
#seq-mean-token-mean
#MegatronCriticWorker.update_critic
#MegatronPPOCriitc.update_critic、forward_backward_batch、forward_step、loss_func
#compute_value_loss
#core_alogs.py
#value clip
#update_actor
#update_policy
#load_reward_manager
#NaiveRewardManager
#RayPPOTrainer.init_workers
#MegatronActorRolloutRefWorker
#MegatronActorRolloutRefWorker.init_model

训练主流程

训练主入口 (main_ppo)

main_ppo.py 主入口:

python
def run(self, config):
    from pprint import pprint
    from omegaconf import OmegaConf
    from verl.utils.fs import copy_to_local

    print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
    pprint(OmegaConf.to_container(config, resolve=True))
    OmegaConf.resolve(config)

    actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config)
    self.add_critic_worker(config)

    # We should adopt a multi-source reward function here:
    # - for rule-based rm, we directly call a reward score
    # - for model-based rm, we call a model
    # - for code related prompt, we send to a sandbox if there are test cases
    # finally, we combine all the rewards together
    # The reward type depends on the tag of the data
    self.add_reward_model_worker(config)

    # Add a reference policy worker if KL loss or KL reward is used.
    self.add_ref_policy_worker(config, actor_rollout_cls)

    # validate config
    validate_config( 
        config=config, 
        use_reference_policy=need_reference_policy(self.role_worker_mapping), 
        use_critic=need_critic(config),
    )

    # Download the checkpoint from HDFS to the local machine.
    # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
    local_path = copy_to_local(
        config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
    )

    # Instantiate the tokenizer and processor.
    from verl.utils import hf_processor, hf_tokenizer

    trust_remote_code = config.data.get("trust_remote_code", False)
    tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
    # Used for multimodal LLM, could be None
    processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)

    # Load the reward manager for training and validation.
    # 如果训练和测试reward不一致,需要自己修改load_reward_manager读取不同的fn
    reward_fn = load_reward_manager(
        config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
    )
    val_reward_fn = load_reward_manager(
        config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
    )

    resource_pool_manager = self.init_resource_pool_mgr(config)
    from verl.utils.dataset.rl_dataset import collate_fn
    # Create training and validation datasets.
    train_dataset = create_rl_dataset( 
        config.data.train_files, 
        config.data, 
        tokenizer, 
        processor, 
        is_train=True, 
        max_samples=config.data.get("train_max_samples", -1), 
    )
    val_dataset = create_rl_dataset(
        config.data.val_files,
        config.data,
        tokenizer,
        processor,
        is_train=False,
        max_samples=config.data.get("val_max_samples", -1),
    )
    train_sampler = create_rl_sampler(config.data, train_dataset)

    # Initialize the PPO trainer.
    trainer = RayPPOTrainer( 
        config=config,
        tokenizer=tokenizer,
        processor=processor,
        role_worker_mapping=self.role_worker_mapping,
        resource_pool_manager=resource_pool_manager,
        ray_worker_group_cls=ray_worker_group_cls,
        reward_fn=reward_fn,
        val_reward_fn=val_reward_fn,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        collate_fn=collate_fn,
        train_sampler=train_sampler,
    )
    # Initialize the workers of the trainer.
    trainer.init_workers() 

    # Start the training process.
    trainer.fit()

主流程 (RayPPOTrainer.fit)

核心流程

RayPPOTrainer.fit()

0. 核心思想

  • RL思想,采样、评估、学习。

1. Rollout (生成/采样)

  • 使用当前策略Actor πθ,对batch prompts做采样生成responses。(环境交互)
  • gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)

2. Evaluate (评估)

3. Learn/(学习/更新)

代码

python
def fit(self):
  """
  The training loop of PPO.
  The driver process only need to call the compute functions of the worker group through RPC
  to construct the PPO dataflow.
  The light-weight advantage computation is done on the driver process.
  """

  self.global_steps = 0
  self._load_checkpoint()

  current_epoch = self.global_steps // len(self.train_dataloader)

  # perform validation before training
  # currently, we only support validation using the reward_function.
  if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
      val_metrics = self._validate()
      logger.log(data=val_metrics, step=self.global_steps)
      if self.config.trainer.get("val_only", False):
          return

  if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
      rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
      rollout_skip.wrap_generate_sequences()

  # add tqdm
  progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")

  # we start from step 1
  self.global_steps += 1
  last_val_metrics = None
  self.max_steps_duration = 0

  prev_step_profile = False
  curr_step_profile = (
      self.global_steps in self.config.global_profiler.steps
      if self.config.global_profiler.steps is not None
      else False
  )
  next_step_profile = False

  for epoch in range(current_epoch, self.config.trainer.total_epochs): 
      for batch_dict in self.train_dataloader: 
          metrics = {}
          timing_raw = {}
          batch: DataProto = DataProto.from_single_dict(batch_dict)

          # add uid to batch
          batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)

          gen_batch = self._get_gen_batch(batch)

          # pass global_steps to trace
          gen_batch.meta_info["global_steps"] = self.global_steps
          # 每个prompt复制n次,对应rollout n次
          gen_batch_output = gen_batch.repeat( 
              repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
          ) 

          is_last_step = self.global_steps >= self.total_training_steps
          with marked_timer("step", timing_raw):
            	# generate a batch
              with marked_timer("gen", timing_raw, color="red"):
                  # rollout,环境交互/模型生成
                  if not self.async_rollout_mode:
                      gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
                  else: 
                      gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) 

              # repeat to align with repeated responses in rollout
              batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
              # 拼接生成结果
              batch = batch.union(gen_batch_output) 

              if "response_mask" not in batch.batch.keys():
                  batch.batch["response_mask"] = compute_response_mask(batch)
              # Balance the number of valid tokens across DP ranks.
              # NOTE: This usually changes the order of data in the `batch`,
              # which won't affect the advantage calculation (since it's based on uid),
              # but might affect the loss calculation (due to the change of mini-batching).
              if self.config.trainer.balance_batch:
                  self._balance_batch(batch, metrics=metrics)

              # compute global_valid tokens
              batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

              with marked_timer("reward", timing_raw, color="yellow"):
                  # compute reward model score
                  if self.use_rm and "rm_scores" not in batch.batch.keys():
                      reward_tensor = self.rm_wg.compute_rm_score(batch)
                      batch = batch.union(reward_tensor)

                  if self.config.reward_model.launch_reward_fn_async:
                      future_reward = compute_reward_async.remote(
                          data=batch, config=self.config, tokenizer=self.tokenizer
                      )
                  else:
                      reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) 

              # Operating Mode Selection:
              # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ)
              # - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ)
              #   Note: π_old computed once per data batch, serves as stable reference during mini-batch updates
              rollout_corr_config = self.config.algorithm.get("rollout_correction", None)
              bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False)
              if bypass_recomputing_logprobs:  # Use `rollout_log_probs`
                  from verl.trainer.ppo.rollout_corr_helper import apply_rollout_correction

                  apply_rollout_correction(
                      batch=batch,
                      rollout_corr_config=rollout_corr_config,
                      policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss,
                  )
              else:
                  # Recompute old_log_probs, 计算\pi_{\theta_{old}}
                  with marked_timer("old_log_prob", timing_raw, color="blue"):
                      old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) 
                      entropys = old_log_prob.batch["entropys"] 
                      response_masks = batch.batch["response_mask"] 
                      loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
                      entropy_agg = agg_loss(
                          loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode
                      )
                      old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
                      metrics.update(old_log_prob_metrics)
                      old_log_prob.batch.pop("entropys")
                      batch = batch.union(old_log_prob)
                      if "rollout_log_probs" in batch.batch.keys():
                          # TODO: we may want to add diff of probs too.
                          from verl.utils.debug.metrics import calculate_debug_metrics

                          metrics.update(calculate_debug_metrics(batch))

              assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}'

              if self.use_reference_policy:
                  # compute reference log_prob, \pi_{ref}
                  with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
                      if not self.ref_in_actor:
                          ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                      else:
                          ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) 
                      batch = batch.union(ref_log_prob)

              # compute values
              if self.use_critic:
                  with marked_timer("values", timing_raw, color="cyan"):
                     # ppo critic
                      values = self.critic_wg.compute_values(batch) 
                      batch = batch.union(values)

              with marked_timer("adv", timing_raw, color="brown"):
                  # we combine with rule-based rm
                  reward_extra_infos_dict: dict[str, list]
                  if self.config.reward_model.launch_reward_fn_async:
                      reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
                  # reward,环境最后一步给的奖励
                  batch.batch["token_level_scores"] = reward_tensor 

                  if reward_extra_infos_dict:
                      batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})

                  # compute rewards. apply_kl_penalty if available
                  if self.config.algorithm.use_kl_in_reward:
                      # 在reward里增加kl
                      batch, kl_metrics = apply_kl_penalty(
                          batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
                      )
                      metrics.update(kl_metrics)
                  else:
                      # 不使用的话,就直接使用环境给的奖励
                      batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

                  # Compute rollout correction: IS weights, rejection sampling, and metrics
                  # Only runs in decoupled mode (computes once per batch using stable π_old)
                  # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout
                  if (
                      rollout_corr_config is not None
                      and "rollout_log_probs" in batch.batch
                      and not bypass_recomputing_logprobs  # Only in decoupled mode
                  ):
                      from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch

                      # Compute IS weights, apply rejection sampling, compute metrics
                      batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
                      # IS and off-policy metrics already have rollout_corr/ prefix
                      metrics.update(is_metrics)

                  # compute advantages, executed on the driver process
                  norm_adv_by_std_in_grpo = self.config.algorithm.get(
                      "norm_adv_by_std_in_grpo", True
                  )  # GRPO adv normalization factor
                  # 计算advantage
                  batch = compute_advantage( 
                      batch, 
                      adv_estimator=self.config.algorithm.adv_estimator, 
                      gamma=self.config.algorithm.gamma, 
                      lam=self.config.algorithm.lam, 
                      num_repeat=self.config.actor_rollout_ref.rollout.n, 
                      norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, 
                      config=self.config.algorithm, 
                  ) 

              # update critic
              if self.use_critic: 
                  with marked_timer("update_critic", timing_raw, color="pink"): 
                      critic_output = self.critic_wg.update_critic(batch) 
                  critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) 
                  metrics.update(critic_output_metrics) 

              # implement critic warmup, warm up:critic先学习一会儿,避免actor不稳定
              if self.config.trainer.critic_warmup <= self.global_steps:
                  # update actor
                  with marked_timer("update_actor", timing_raw, color="red"):
                      rollout_config = self.config.actor_rollout_ref.rollout
                      batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable
                      # TODO: Make "temperature" single source of truth from generation.
                      batch.meta_info["temperature"] = rollout_config.temperature
                      # update actor
                      actor_output = self.actor_rollout_wg.update_actor(batch) 
                  actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
                  metrics.update(actor_output_metrics)

              # Log rollout generations if enabled
              rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
              if rollout_data_dir:
                  self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)

          # validate
          if (
              self.val_reward_fn is not None
              and self.config.trainer.test_freq > 0
              and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
          ):
              with marked_timer("testing", timing_raw, color="green"): 
                  val_metrics: dict = self._validate() 
                  if is_last_step: 
                      last_val_metrics = val_metrics 
              metrics.update(val_metrics) 

          # Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
          esi_close_to_expiration = should_save_ckpt_esi(
              max_steps_duration=self.max_steps_duration,
              redundant_time=self.config.trainer.esi_redundant_time,
          )
          # Check if the conditions for saving a checkpoint are met.
          # The conditions include a mandatory condition (1) and
          # one of the following optional conditions (2/3/4):
          # 1. The save frequency is set to a positive value.
          # 2. It's the last training step.
          # 3. The current step number is a multiple of the save frequency.
          # 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
          if self.config.trainer.save_freq > 0 and (
              is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration
          ):
              if esi_close_to_expiration:
                  print("Force saving checkpoint: ESI instance expiration approaching.")
              with marked_timer("save_checkpoint", timing_raw, color="green"):
                  self._save_checkpoint()

          with marked_timer("stop_profile", timing_raw):
              next_step_profile = (
                  self.global_steps + 1 in self.config.global_profiler.steps
                  if self.config.global_profiler.steps is not None
                  else False
              )
              self._stop_profiling(
                  curr_step_profile and not next_step_profile
                  if self.config.global_profiler.profile_continuous_steps
                  else curr_step_profile
              )
              prev_step_profile = curr_step_profile
              curr_step_profile = next_step_profile

          steps_duration = timing_raw["step"]
          self.max_steps_duration = max(self.max_steps_duration, steps_duration)

          # training metrics
          metrics.update(
              {
                  "training/global_step": self.global_steps,
                  "training/epoch": epoch,
              }
          )
          # collect metrics
          metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
          metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
          # TODO: implement actual tflpo and theoretical tflpo
          n_gpus = self.resource_pool_manager.get_n_gpus()
          metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
          # Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation

          # this is experimental and may be changed/removed in the future in favor of a general-purpose one
          if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
              self.train_dataloader.sampler.update(batch=batch)

          # TODO: make a canonical logger that supports various backend
          logger.log(data=metrics, step=self.global_steps)

          progress_bar.update(1)
          self.global_steps += 1

          if (
              hasattr(self.config.actor_rollout_ref.actor, "profiler")
              and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory"
          ):
              self.actor_rollout_wg.dump_memory_snapshot(
                  tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}"
              )

          if is_last_step:
              pprint(f"Final validation metrics: {last_val_metrics}")
              progress_bar.close()
              return

          # this is experimental and may be changed/removed in the future
          # in favor of a general-purpose data buffer pool
          if hasattr(self.train_dataset, "on_batch_end"):
              # The dataset may be changed after each training batch
              self.train_dataset.on_batch_end(batch=batch)

训练过程关键前置计算 (RayPPOTrainer)

Rollout

python
with marked_timer("gen", timing_raw, color="red"):
  if not self.async_rollout_mode:
      gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
  else:
      gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)  
  timing_raw.update(gen_batch_output.meta_info["timing"])
  gen_batch_output.meta_info.pop("timing", None)

Reward 计算

python
with marked_timer("reward", timing_raw, color="yellow"):
  # compute reward model score
  if self.use_rm and "rm_scores" not in batch.batch.keys(): 
      reward_tensor = self.rm_wg.compute_rm_score(batch)  
      batch = batch.union(reward_tensor)
  if self.config.reward_model.launch_reward_fn_async:
      future_reward = compute_reward_async.remote(
          data=batch, config=self.config, tokenizer=self.tokenizer
      )
  else:
      reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)  

Old LogProb 计算

python
with marked_timer("old_log_prob", timing_raw, color="blue"):
  old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)  
  entropys = old_log_prob.batch["entropys"]
  response_masks = batch.batch["response_mask"]
  loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
  entropy_agg = agg_loss( 
      loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode 
  ) 
  old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
  metrics.update(old_log_prob_metrics)
  old_log_prob.batch.pop("entropys")
  batch = batch.union(old_log_prob)
  if "rollout_log_probs" in batch.batch.keys():
      # TODO: we may want to add diff of probs too.
      from verl.utils.debug.metrics import calculate_debug_metrics

      metrics.update(calculate_debug_metrics(batch))

Ref LogProb 计算

python
if self.use_reference_policy:
  # compute reference log_prob
  with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
      if not self.ref_in_actor:
          ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
      else:
          ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
      batch = batch.union(ref_log_prob)

Critic Values 计算 (PPO使用)

python
# compute values
if self.use_critic:
    with marked_timer("values", timing_raw, color="cyan"):
        values = self.critic_wg.compute_values(batch)
        batch = batch.union(values)

优势计算

python
with marked_timer("adv", timing_raw, color="brown"):
  # we combine with rule-based rm
  reward_extra_infos_dict: dict[str, list]
  if self.config.reward_model.launch_reward_fn_async:
      reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
  batch.batch["token_level_scores"] = reward_tensor

  if reward_extra_infos_dict:
      batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})

  # compute rewards. apply_kl_penalty if available
  if self.config.algorithm.use_kl_in_reward:
      batch, kl_metrics = apply_kl_penalty(
          batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
      )
      metrics.update(kl_metrics)
  else:
      batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

  # Compute rollout correction: IS weights, rejection sampling, and metrics
  # Only runs in decoupled mode (computes once per batch using stable π_old)
  # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout
  if (
      rollout_corr_config is not None
      and "rollout_log_probs" in batch.batch
      and not bypass_recomputing_logprobs  # Only in decoupled mode
  ):
      from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch

      # Compute IS weights, apply rejection sampling, compute metrics
      batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
      # IS and off-policy metrics already have rollout_corr/ prefix
      metrics.update(is_metrics)

  # compute advantages, executed on the driver process
  norm_adv_by_std_in_grpo = self.config.algorithm.get(
      "norm_adv_by_std_in_grpo", True
  )  # GRPO adv normalization factor
	
  # 计算优势
  batch = compute_advantage(  
      batch,
      adv_estimator=self.config.algorithm.adv_estimator,
      gamma=self.config.algorithm.gamma,
      lam=self.config.algorithm.lam,
      num_repeat=self.config.actor_rollout_ref.rollout.n,
      norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
      config=self.config.algorithm,
  )

训练配置参数验证

Verl 基础参数验证

相关并行参数计算
  • n_gpus = n_gpus_per_node * nnodes
  • model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size
  • megatron_dp = n_gpus // (model_parallel_size * context_parallel_size )
  • minimal_bsz = megatron_dp * actor.ppo_micro_batch_size_per_gpu
  • real_train_batch_size = data.train_batch_size * rollout.n
约束条件

条件1:实际train_bs被mini_bs整除

  • real_train_batch_size % minimal_bsz == 0

条件2: gpu数量被parrallel_size整除

  • n_gpus % model_parallel_size * context_parallel_size == 0 ,非常重要!!

条件3:mbs和mbs_per_gpu不能同时设置

  • _log_prob_micro_batch_size_log_prob_micro_batch_size_per_gpu 不能同时设置
  • xx主要是这3种:actor_rollout_ref.rolloutactor_rollout_ref.ref、reward_model

条件4:验证actor_config

  • train_batch_size >= ppo_mini_batch_size
  • ppo_mini_batch_size % ppo_micro_batch_size == 0
  • ppo_micro_batch_size * sp_size >= n_gpus

条件5:验证critic_config

python
def validate_config(
    config: DictConfig,
    use_reference_policy: bool,
    use_critic: bool,
) -> None:
    """Validate an OmegaConf DictConfig.

    Args:
        config (DictConfig): The OmegaConf DictConfig to validate.
        use_reference_policy (bool): is ref policy needed
        use_critic (bool): is critic needed
    """
    # number of GPUs total
    n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes

    if not config.actor_rollout_ref.actor.use_dynamic_bsz:
        if config.actor_rollout_ref.actor.strategy == "megatron":
            model_parallel_size = (
                config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size 
                * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size 
            )
            assert (
                n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0
            ), (
                f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times "
                f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})"
            )
            megatron_dp = n_gpus // ( 
                model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size 
            )
            minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu 
        else:
            minimal_bsz = n_gpus

        # 1. Check total batch size for data correctness
        real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
        assert real_train_batch_size % minimal_bsz == 0, ( 
            f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size "
            f"({minimal_bsz})"
        )

    # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
    # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
    def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
        """Validate mutually exclusive micro batch size configuration options.

        Ensures that users don't set both deprecated micro_batch_size and
        the new micro_batch_size_per_gpu parameters simultaneously.

        Args:
            mbs: Deprecated micro batch size parameter value.
            mbs_per_gpu: New micro batch size per GPU parameter value.
            name (str): Configuration section name for error messages.

        Raises:
            ValueError: If both parameters are set or neither is set.
        """
        settings = {
            "reward_model": "micro_batch_size",
            "actor_rollout_ref.ref": "log_prob_micro_batch_size",
            "actor_rollout_ref.rollout": "log_prob_micro_batch_size",
        }

        if name in settings:
            param = settings[name]
            param_per_gpu = f"{param}_per_gpu"

            if mbs is None and mbs_per_gpu is None:
                raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")

            if mbs is not None and mbs_per_gpu is not None:
                raise ValueError(
                    f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove "
                    f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
                )

    # Actor validation done in ActorConfig.__post_init__ and validate()
    actor_config = omega_conf_to_dataclass(config.actor_rollout_ref.actor)
    actor_config.validate(n_gpus, config.data.train_batch_size, config.actor_rollout_ref.model) 

    if not config.actor_rollout_ref.actor.use_dynamic_bsz:
        if use_reference_policy:
            # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
            check_mutually_exclusive(
                config.actor_rollout_ref.ref.log_prob_micro_batch_size,
                config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
                "actor_rollout_ref.ref",
            )

        #  The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
        check_mutually_exclusive(
            config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
            config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
            "actor_rollout_ref.rollout",
        )

    # Check for reward model micro-batch size conflicts
    if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
        check_mutually_exclusive(
            config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
        )

    if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
        print("NOTICE: You have both enabled in-reward kl and kl loss.")

    # critic
    if use_critic:
        critic_config = omega_conf_to_dataclass(config.critic)
        critic_config.validate(n_gpus, config.data.train_batch_size) 

    if config.data.get("val_batch_size", None) is not None:
        print(
            "WARNING: val_batch_size is deprecated."
            + " Validation datasets are sent to inference engines as a whole batch,"
            + " which will schedule the memory themselves."
        )

    # check eval config
    if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
        assert config.actor_rollout_ref.rollout.temperature > 0, (
            "validation gen temperature should be greater than 0 when enabling do_sample"
        )

    # check LoRA rank in vLLM
    if config.actor_rollout_ref.model.get("lora_rank", 0) > 0 and config.actor_rollout_ref.rollout.name == "vllm":
        assert config.actor_rollout_ref.model.lora_rank <= 512, "LoRA rank in vLLM must be less than or equal to 512"

    print("[validate_config] All configuration checks passed successfully!")

ActorConfig.validate:

python
def validate(self, n_gpus: int, train_batch_size: int, model_config: dict = None):
  """Validate actor configuration with runtime parameters."""
  if not self.use_dynamic_bsz:
      if train_batch_size < self.ppo_mini_batch_size:  
          raise ValueError(
              f"train_batch_size ({train_batch_size}) must be >= "
              f"actor.ppo_mini_batch_size ({self.ppo_mini_batch_size})"
          )

      sp_size = getattr(self, "ulysses_sequence_parallel_size", 1)
      if self.ppo_micro_batch_size is not None:
          if self.ppo_mini_batch_size % self.ppo_micro_batch_size != 0: 
              raise ValueError(
                  f"ppo_mini_batch_size ({self.ppo_mini_batch_size}) must be divisible by "
                  f"ppo_micro_batch_size ({self.ppo_micro_batch_size})"
              )
          if self.ppo_micro_batch_size * sp_size < n_gpus:  
              raise ValueError(
                  f"ppo_micro_batch_size ({self.ppo_micro_batch_size}) * "
                  f"ulysses_sequence_parallel_size ({sp_size}) must be >= n_gpus ({n_gpus})"

Megatron 并行参数验证

详细内容具体见文末TransformerConfig 并行参数

并行参数验证

TP参数

  • num_attention_heads % tensor_model_parallel_size == 0

PP参数

  • num_layers % pipeline_model_parallel_size == 0

EP参数

  • num_moe_experts % moe_router_num_groups == 0

CP参数

  • n_gpus % model_parallel_size * context_parallel_size == 0 ,非常重要!!

关键项

数据加载

主入口main_ppo 调用create_rl_dataset

main_ppo.py 主入口:

python
def run(self, config):
  # ....
  # Create training and validation datasets.
  train_dataset = create_rl_dataset( 
      config.data.train_files, 
      config.data,
      tokenizer,
      processor,
      is_train=True,
      max_samples=config.data.get("train_max_samples", -1),
  )
  val_dataset = create_rl_dataset( 
      config.data.val_files,
      config.data,
      tokenizer,
      processor,
      is_train=False,
      max_samples=config.data.get("val_max_samples", -1),
  )

create_rl_dataset:读取数据集、过滤超长数据

rl_dataset.py 读取完整数据集、过滤超长数据 (配置参数data.max_prompt_length

python
def _read_files_and_tokenize(self):
  dataframes = []
  for parquet_file in self.data_files:
      dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] 
      dataframes.append(dataframe) 
  self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) 

  total = len(self.dataframe)
  print(f"dataset len: {len(self.dataframe)}")
  self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe) 
  
def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None):
  # filter out too long prompts
  if self.filter_overlong_prompts:
      tokenizer = self.tokenizer
      processor = self.processor
      prompt_key = self.prompt_key

      if processor is not None:
          pass
      else:
          def doc2len(doc) -> int:
              try:
                  apply_kwargs = dict(**self.apply_chat_template_kwargs)
                  if self.tool_schemas is not None:
                      apply_kwargs["tools"] = self.tool_schemas

                  return len(
                      tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True, **apply_kwargs) 
                  )
              except Exception:
                  print("Error processing one of the samples, skipping...")
                  traceback.print_exc()
                  return self.max_prompt_length + 1

      dataframe = dataframe.filter( 
          lambda doc: doc2len(doc) <= self.max_prompt_length, 
          num_proc=self.num_workers,
          desc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
      )
      print(f"filter dataset len: {len(dataframe)}")
  return dataframe

训练时读取单条数据:构建messages、tokenize、pad、position_ids等

训练时读取单条数据
  • 根据prompt_key读取messages
  • 添加 generation prompt 模板
  • tokenize,获得input_idsattention_mask,计算position_ids
  • 根据data.max_prompt_length左padding截断
  • 处理extra_infotools_kwargsinteraction_kwargs等内容。
python
def _build_messages(self, example: dict):
  messages: list = example.pop(self.prompt_key)

  if self.image_key in example or self.video_key in example:
      for message in messages:
          content = message["content"]
          content_list = []
          segments = re.split("(<image>|<video>)", content)
          segments = [item for item in segments if item != ""]
          for segment in segments:
              if segment == "<image>":
                  content_list.append({"type": "image"})
              elif segment == "<video>":
                  content_list.append({"type": "video"})
              else:
                  content_list.append({"type": "text", "text": segment})

          message["content"] = content_list

  return messages

def __getitem__(self, item):
  """
  Note that we also return the raw_input_ids so that it can be combined with other chat template
  """
  row_dict: dict = self.dataframe[item] 
  messages = self._build_messages(row_dict) 
  model_inputs = {}

  if self.processor is not None:
      pass
  else:
      if self.apply_chat_template_kwargs.get("chat_template") is None:
          assert hasattr(self.tokenizer, "chat_template"), (
              "chat_template should be provided in apply_chat_template_kwargs or tokenizer config, "
              "models like GLM can copy chat_template.jinja from instruct models"
          )
      # add generation prompt 模板 和 tokenize
      raw_prompt = self.tokenizer.apply_chat_template( 
          messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs 
      )
      model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False) 
      # input_ids和attention_mask 
      input_ids = model_inputs.pop("input_ids") 
      attention_mask = model_inputs.pop("attention_mask") 

  # 根据data.max_prompt_length做左padding或截断
  input_ids, attention_mask = verl_F.postprocess_data( 
      input_ids=input_ids,
      attention_mask=attention_mask, 
      max_length=self.max_prompt_length, 
      pad_token_id=self.tokenizer.pad_token_id, 
      left_pad=True, 
      truncation=self.truncation, 
  ) 

  if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
      # qwen-vl mrope
      from verl.models.transformers.qwen3_vl import get_rope_index
      pass
  elif self.processor is not None and "Glm4vImageProcessor" in self.processor.image_processor.__class__.__name__:
     from verl.models.transformers.glm4v import get_rope_index
	  pass
  else:
      # 根据attention_mask计算position_ids 
      position_ids = compute_position_id_with_mask(attention_mask) 

  row_dict["input_ids"] = input_ids[0] 
  row_dict["attention_mask"] = attention_mask[0] 
  row_dict["position_ids"] = position_ids[0] 

  raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False) 
  if len(raw_prompt_ids) > self.max_prompt_length:
      if self.truncation == "left":
          raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
      elif self.truncation == "right":
          raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
      elif self.truncation == "middle":
          left_half = self.max_prompt_length // 2
          right_half = self.max_prompt_length - left_half
          raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]
      elif self.truncation == "error":
          raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")

  row_dict["raw_prompt_ids"] = raw_prompt_ids
  # encode prompts without chat template
  if self.return_raw_chat: 
      row_dict["raw_prompt"] = messages 

  # get prompts with chat template
  if self.return_full_prompt:
      row_dict["full_prompts"] = raw_prompt  # array of strings

  # add index for each prompt
  if "extra_info" not in row_dict or row_dict["extra_info"] is None:
      row_dict["extra_info"] = dict()
  index = row_dict.get("extra_info", {}).get("index", 0)
  tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {})
  interaction_kwargs = row_dict.get("extra_info", {}).get("interaction_kwargs", {})
  need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs)
  if need_tools_kwargs and not tools_kwargs:
      logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"])
  row_dict["index"] = index
  row_dict["tools_kwargs"] = tools_kwargs
  row_dict["interaction_kwargs"] = interaction_kwargs
  return row_dict

RayPPOTrainer:create data loader

DataLoader
  • 根据train_batch_size, val_batch_size来划分构建DataLoader
  • 总训练步数=原始数据数量 / train_batch_size * 训练epochs
python
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
  """
  Creates the train and validation dataloaders.
  """
  # TODO: we have to make sure the batch size is divisible by the dp size
  from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
  self.train_dataset, self.val_dataset = train_dataset, val_dataset
  if train_sampler is None:
      train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
  if collate_fn is None:
      from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
      collate_fn = default_collate_fn

  num_workers = self.config.data["dataloader_num_workers"]

  self.train_dataloader = StatefulDataLoader( 
      dataset=self.train_dataset, 
      batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), 
      num_workers=num_workers,
      drop_last=True,
      collate_fn=collate_fn,
      sampler=train_sampler,
  )

  val_batch_size = self.config.data.val_batch_size   
  self.val_dataloader = StatefulDataLoader(  
      dataset=self.val_dataset, 
      batch_size=val_batch_size, 
      num_workers=num_workers, 
      shuffle=self.config.data.get("validation_shuffle", True),
      drop_last=False,
      collate_fn=collate_fn,
  )
  # 训练步数=原始数据数量/train_batch_size*训练epochs
  total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs 
  if self.config.trainer.total_training_steps is not None:
      total_training_steps = self.config.trainer.total_training_steps
  self.total_training_steps = total_training_steps
	pass

Validate

主流程 (RayPPOTrainer.Validate)

核心流程

1. 数据准备 (Data Preparation):

  • 遍历函数 self.val_dataloader
  • 为batch中,每个prompt,生成唯一uid
  • prompt 复制n份rollout n次,生成n个不同的答案。

2. 模型生成/环境交互 (Generation):

  • 使用当前 Actor对batch rompts 进行推理,生成 n 个答案。
    • 同步: self.actor_rollout_wg.generate_sequences
    • 异步:self.async_rollout_manager.generate_sequences ,调用agentloop。

3. 评估打分 (Evaluation & Scoring):

  • 把原始prompt和模型response拼接在一起
  • 调用 self.val_reward_fn 来为每一个生成的答案打分, reward_tensor
    • 请注意:如果agentloop里已经包含reward_score,则不会再计算,而是直接返回
  • 除了主分数(reward),评估函数可能还会返回一些额外的信息(reward_extra_info),
    • 比如准确率(acc)、长度惩罚项等,在这里可以添加其他的指标
    • 这些信息都被收集到 reward_extra_infos_dict 字典中。

4. 数据收集 (Data Collection):

  • 在整个循环过程中,代码会把所有需要的信息都存储在列表中:
    • sample_inputs: 原始输入 (Prompts)
    • sample_outputs: 模型生成的答案
    • sample_gts: 人工标注的参考答案 (Ground Truths)
    • sample_scores: 每个答案的总分
    • sample_uids: 每个样本的唯一ID
    • reward_extra_infos_dict: 包含所有评估指标(如 reward, acc)的字典。
    • data_source_lst: 每个样本来自哪个数据集(如 "wiki", "squad" 等)。

5. 指标处理与格式化 (Metric Processing & Formatting):

  • 收集完所有验证数据后,处理成结构化、有意义的统计指标
    • 调用 process_validation_metrics 来做统计。
    • 再把 process_validation_metrics 返回结果整理成扁平字典 metric_dict,Key为:
      • pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
      • val-core/wiki/reward/mean@8

6. 返回结果

  • 最后,函数返回 metric_dict,框架接收这些指标并做展示
python
def _validate(self):
  data_source_lst = []
  reward_extra_infos_dict: dict[str, list] = defaultdict(list)

  # Lists to collect samples for the table
  sample_inputs = [] 
  sample_outputs = [] 
  sample_gts = [] 
  sample_scores = [] 
  sample_turns = [] 
  sample_uids = [] 

  for test_data in self.val_dataloader:
      test_batch = DataProto.from_single_dict(test_data)

      # repeat test batch
      test_batch = test_batch.repeat(
          repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
      )

      # we only do validation on rule-based rm
      if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
          return {}

      # prompt input_ids
      input_ids = test_batch.batch["input_ids"] 
      input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] 
      sample_inputs.extend(input_texts) 
      sample_uids.extend(test_batch.non_tensor_batch["uid"]) 

      ground_truths = [ item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch] 
      sample_gts.extend(ground_truths) 
      # pop input_ids, attention_mask, position_ids
      test_gen_batch = self._get_gen_batch(test_batch)  
      test_gen_batch.meta_info = { 
          "eos_token_id": self.tokenizer.eos_token_id, 
          "pad_token_id": self.tokenizer.pad_token_id, 
          "recompute_log_prob": False, 
          "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, 
          "validate": True, 
          "global_steps": self.global_steps, 
      } 
      print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")

      # pad to be divisible by dp_size
      size_divisor = (
          self.actor_rollout_wg.world_size
          if not self.async_rollout_mode
          else self.config.actor_rollout_ref.rollout.agent.num_workers
      )
      # pad&rollout
      test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) 
      if not self.async_rollout_mode:
          test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
      else:
          test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) 

      # unpad
      test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)

      print("validation generation end")

      # Store generated outputs
      output_ids = test_output_gen_batch.batch["responses"] 
      output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] 
      sample_outputs.extend(output_texts) 
      
      # 拼接结果
      test_batch = test_batch.union(test_output_gen_batch) 
      test_batch.meta_info["validate"] = True

      # evaluate using reward_function
      if self.val_reward_fn is None:
          raise ValueError("val_reward_fn must be provided for validation.")
      # 计算reward,如果rollout的test_batch里已经有reward,则直接返回
      result = self.val_reward_fn(test_batch, return_dict=True) 
      reward_tensor = result["reward_tensor"]
      # 一般只在最后一个位置有得分,其余位置都是0,所以各位置加起来作为最终得分
      scores = reward_tensor.sum(-1).cpu().tolist()  
      sample_scores.extend(scores) 

      reward_extra_infos_dict["reward"].extend(scores)
      if "reward_extra_info" in result:
          for key, lst in result["reward_extra_info"].items():
              reward_extra_infos_dict[key].extend(lst)

      # collect num_turns of each prompt
      if "__num_turns__" in test_batch.non_tensor_batch:
          sample_turns.append(test_batch.non_tensor_batch["__num_turns__"])

      data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))

  self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)

  # dump generations
  val_data_dir = self.config.trainer.get("validation_data_dir", None)
  if val_data_dir:
      self._dump_generations(
          inputs=sample_inputs,
          outputs=sample_outputs,
          gts=sample_gts,
          scores=sample_scores,
          reward_extra_infos_dict=reward_extra_infos_dict,
          dump_path=val_data_dir,
      )

  data_sources = np.concatenate(data_source_lst, axis=0)

  data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict) 
  metric_dict = {}
  for data_source, var2metric2val in data_src2var2metric2val.items():
      core_var = "acc" if "acc" in var2metric2val else "reward" 
      # var_name: reward, acc...
      # metric_name: mean@4, std@4, best@4/mean...
      for var_name, metric2val in var2metric2val.items(): 
          n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
          for metric_name, metric_val in metric2val.items(): 
              if (
                  (var_name == core_var)
                  and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"])
                  and (f"@{n_max}" in metric_name)
              ):
                	# 核心指标
                  metric_sec = "val-core"
              else:
                	# 辅助指标
                  metric_sec = "val-aux"
              # 最终展示名称
              pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
              metric_dict[pfx] = metric_val  

  if len(sample_turns) > 0:
      sample_turns = np.concatenate(sample_turns)
      metric_dict["val-aux/num_turns/min"] = sample_turns.min()
      metric_dict["val-aux/num_turns/max"] = sample_turns.max()
      metric_dict["val-aux/num_turns/mean"] = sample_turns.mean()

  return metric_dict

统计mean/std各种数据指标(metric_utils.process_validation_metrics)

trainer.ppo.metric_utils.py :process_validation_metrics

python
def process_validation_metrics(
    data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42
) -> dict[str, dict[str, dict[str, float]]]:
    """
    Process validation metrics into a structured format with statistical analysis.

    This function organizes validation metrics by data source and prompt, then computes
    various statistical measures including means, standard deviations, best/worst values,
    and majority voting results. It also performs bootstrap sampling to estimate statistics
    for different sample sizes.

    Args:
        data_sources: List of data source identifiers for each sample.
        sample_uids: List of sample uids corresponding to each sample.
        infos_dict: Dictionary mapping variable names to lists of values for each sample.
        seed: Random seed for bootstrap sampling. Defaults to 42.

    Returns:
        A nested dictionary with the structure:
        {
            data_source: {
                variable_name: {
                    metric_name: value
                }
            }
        }

        Where metric_name includes:
        - "mean@N": Mean value across N samples
        - "std@N": Standard deviation across N samples
        - "best@N/mean": Mean of the best values in bootstrap samples of size N
        - "best@N/std": Standard deviation of the best values in bootstrap samples
        - "worst@N/mean": Mean of the worst values in bootstrap samples
        - "worst@N/std": Standard deviation of the worst values in bootstrap samples
        - "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists)
        - "maj@N/std": Standard deviation of majority voting results (if "pred" exists)

    Example:
        >>> data_sources = ["source1", "source1", "source2"]
        >>> sample_uids = ["uid1", "uid1", "uid2"]
        >>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]}
        >>> result = process_validation_metrics(data_sources, sample_uids, infos_dict)
        >>> # result will contain statistics for each data source and variable
    """
    # Group metrics by data source, prompt and variable
    data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    for sample_idx, data_source in enumerate(data_sources):
        uid = sample_uids[sample_idx]
        var2vals = data_src2uid2var2vals[data_source][uid]
        for var_name, var_vals in infos_dict.items():
            var2vals[var_name].append(var_vals[sample_idx])

    # Calculate metrics for each group
    data_src2uid2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    for data_source, uid2var2vals in data_src2uid2var2vals.items():
        for uid, var2vals in uid2var2vals.items():
            for var_name, var_vals in var2vals.items():
                if isinstance(var_vals[0], str):
                    continue

                metric = {}
                n_resps = len(var_vals)
                metric[f"mean@{n_resps}"] = np.mean(var_vals)

                if n_resps > 1:
                    metric[f"std@{n_resps}"] = np.std(var_vals)

                    ns = []
                    n = 2
                    while n < n_resps:
                        ns.append(n)
                        n *= 2
                    ns.append(n_resps)

                    for n in ns:
                        [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(
                            data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed
                        )
                        metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std
                        metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std
                        if var2vals.get("pred", None) is not None:
                            vote_data = [
                                {"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True)
                            ]
                            [(maj_n_mean, maj_n_std)] = bootstrap_metric(
                                data=vote_data,
                                subset_size=n,
                                reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")],
                                seed=seed,
                            )
                            metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std

                data_src2uid2var2metric[data_source][uid][var_name] = metric

    # Aggregate metrics across uids
    data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    for data_source, uid2var2metric in data_src2uid2var2metric.items():
        for uid, var2metric in uid2var2metric.items():
            for var_name, metric in var2metric.items():
                for metric_name, metric_val in metric.items():
                    data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val)

    data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
    for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items():
        for var_name, metric2uid_vals in var2metric2uid_vals.items():
            for metric_name, uid_vals in metric2uid_vals.items():
                data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals)

    return data_src2var2metric2val

Rollout 过程

内容繁多且重要,具体见笔记 AgentLoop Rollout 笔记

奖励计算

奖励计算不同类型
  • 环境给出奖励
  • 奖励函数给出奖励
  • reward model 计算奖励

环境给出奖励

agent_loop结束时调用环境eval,把reward_score放到AgentLoopOutput里。

python
reward_score = await call_env_evaluate(xxx)
output = AgentLoopOutput(
    prompt_ids=prompt_ids,
    response_ids=response_ids[: self.response_length],
    response_mask=agent_data.response_mask[: self.response_length],
    reward_score=reward_score, 
    multi_modal_data={},
    response_logprobs=agent_data.response_logprobs[: self.response_length]
    if agent_data.response_logprobs
    else None,
    num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,
    metrics=agent_data.metrics,
    extra_fields={},
)

在reward function计算的时候,直接返回该值

python
if "rm_scores" in data.batch.keys():  
    if return_dict:
        reward_extra_keys = data.meta_info.get("reward_extra_keys", [])
        reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}
        return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info}
    else:
        return data.batch["rm_scores"] 

reward funtion 计算

见下文naive reward manager

python
# 调用reward_fn计算score
score = self.compute_score( 
    data_source=data_source, 
    solution_str=response_str, 
    ground_truth=ground_truth, 
    extra_info=extra_info, 
) 

reward model 计算

pass

Gen RM 计算

pass

优势计算

pass

core_algos优势计算主入口

python
def compute_advantage(
    data: DataProto,
    adv_estimator: AdvantageEstimator,
    gamma: float = 1.0,
    lam: float = 1.0,
    num_repeat: int = 1,
    norm_adv_by_std_in_grpo: bool = True,
    config: Optional[AlgoConfig] = None,
) -> DataProto:
    """Compute advantage estimates for policy optimization.

    This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc.
    The advantage estimates are used to guide policy optimization in RL algorithms.

    Args:
        data (DataProto): The data containing batched model outputs and inputs.
        adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++).
        gamma (float, optional): Discount factor for future rewards. Defaults to 1.0.
        lam (float, optional): Lambda parameter for GAE. Defaults to 1.0.
        num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1.
        norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in
            GRPO. Defaults to True.
        config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None.

    Returns:
        DataProto: The updated data with computed advantages and returns.
    """
    # Back-compatible with trainers that do not compute response mask in fit
    if "response_mask" not in data.batch.keys():
        data.batch["response_mask"] = compute_response_mask(data)
    # prepare response group
    if adv_estimator == AdvantageEstimator.GAE:
        # Compute advantages and returns using Generalized Advantage Estimation (GAE)
        advantages, returns = core_algos.compute_gae_advantage_return( 
            token_level_rewards=data.batch["token_level_rewards"], 
            values=data.batch["values"], 
            response_mask=data.batch["response_mask"], 
            gamma=gamma, 
            lam=lam, 
        ) 
        data.batch["advantages"] = advantages 
        data.batch["returns"] = returns 
        if config.get("use_pf_ppo", False):
            data = core_algos.compute_pf_ppo_reweight_data(
                data,
                config.pf_ppo.get("reweight_method"),
                config.pf_ppo.get("weight_pow"),
            )
    elif adv_estimator == AdvantageEstimator.GRPO:
        # Initialize the mask for GRPO calculation
        grpo_calculation_mask = data.batch["response_mask"]

        # Call compute_grpo_outcome_advantage with parameters matching its definition
        advantages, returns = core_algos.compute_grpo_outcome_advantage( 
            token_level_rewards=data.batch["token_level_rewards"], 
            response_mask=grpo_calculation_mask, 
            index=data.non_tensor_batch["uid"], 
            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, 
        ) 
        data.batch["advantages"] = advantages 
        data.batch["returns"] = returns 
    else:
        # handle all other adv estimator type other than GAE and GRPO
        adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator)
        adv_kwargs = {
            "token_level_rewards": data.batch["token_level_rewards"],
            "response_mask": data.batch["response_mask"],
            "config": config,
        }
        if "uid" in data.non_tensor_batch:  # optional
            adv_kwargs["index"] = data.non_tensor_batch["uid"]
        if "reward_baselines" in data.batch:  # optional
            adv_kwargs["reward_baselines"] = data.batch["reward_baselines"]

        # calculate advantage estimator
        advantages, returns = adv_estimator_fn(**adv_kwargs)
        data.batch["advantages"] = advantages
        data.batch["returns"] = returns
    return data

GAE 优势估计

GAE 优势估计

理论笔记

核心数学公式

  • GAE 反向迭代计算公式

    AT(st,at)=δTAt(st,at)=δt+γλAt+1(st+1,at+1)
  • TD error

    δt+l=rt+l+γV(st+l+1)V(st+l)

核心思想

  • 反向迭代计算每步的优势

    • 计算TD_errorδt、计算优势At
    • 对于pad环境返回等位置的值,利用response_mask不做计算,即当前值不做更新
  • returns = advantages + values

    • advantages = returns - values
  • 对优势做标准化处理

python
@register_adv_est(AdvantageEstimator.GAE)  # or simply: @register_adv_est("gae")
def compute_gae_advantage_return(
    token_level_rewards: torch.Tensor,
    values: torch.Tensor,
    response_mask: torch.Tensor,
    gamma: torch.Tensor,
    lam: torch.Tensor,
):
    """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape is (bs, response_length)
        values: `(torch.Tensor)`
            shape is (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
        gamma is `(float)`
            discounted factor used in RL
        lam: `(float)`
            lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)

    """
    with torch.no_grad():
      	# V(s_{t+1})
        nextvalues = 0
        # A_{t+1}
        lastgaelam = 0
        advantages_reversed = []
        gen_len = token_level_rewards.shape[-1]

        for t in reversed(range(gen_len)):
          	# TD-Error 计算公式
            delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]  
            # GAE 反向迭代计算公式
            lastgaelam_ = delta + gamma * lam * lastgaelam 

            # skip values and TD-error on observation tokens,仅llm回复的才做计算
            nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues 
            lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam 

            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
				 # 回报
        returns = advantages + values  
        # 白化,标准化,减均值、除标准差
        advantages = verl_F.masked_whiten(advantages, response_mask)
    return advantages, returns

GRPO 组优势计算

核心流程

理论笔记

核心思想

  • 每个uid(prompt),维护一个score list
  • 针对每个prompt,计算meanstd
  • 做normalization
python
@register_adv_est(AdvantageEstimator.GRPO)
def compute_grpo_outcome_advantage(
    token_level_rewards: torch.Tensor,
    response_mask: torch.Tensor,
    index: np.ndarray,
    epsilon: float = 1e-6,
    norm_adv_by_std_in_grpo: bool = True,
    config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute advantage for GRPO, operating only on Outcome reward
    (with only one scalar reward for each response).

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape is (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape is (bs, response_length)
        index: `(np.ndarray)`
            index array for grouping
        epsilon: `(float)`
            small value to avoid division by zero
        norm_adv_by_std_in_grpo: `(bool)`
            whether to scale the GRPO advantage
        config: `(Optional[AlgoConfig])`
            algorithm configuration object

    Note:
        If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO.
        If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).

    Returns:
        advantages: `(torch.Tensor)`
            shape is (bs, response_length)
        Returns: `(torch.Tensor)`
            shape is (bs, response_length)
    """
    scores = token_level_rewards.sum(dim=-1) 

    id2score = defaultdict(list)
    id2mean = {}
    id2std = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        # id: [1, 1, 1, 2, 2, 2, ...]
        # scores: [0, 0.5, 0, 0, 1, 0, ...]
        for i in range(bsz): 
            id2score[index[i]].append(scores[i]) 
        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
                id2std[idx] = torch.tensor(1.0)
            elif len(id2score[idx]) > 1:
                scores_tensor = torch.stack(id2score[idx]) 
                id2mean[idx] = torch.mean(scores_tensor) 
                id2std[idx] = torch.std(scores_tensor) 
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            if norm_adv_by_std_in_grpo: 
              	# local std
                scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) 
            else: 
              	# 不做std
                scores[i] = scores[i] - id2mean[index[i]] 
        # 仅在response位置,才有adv
        scores = scores.unsqueeze(-1) * response_mask 
		# advantages, returns
    return scores, scores

Compute LogProb (最核心,所有内容都在这)

主入口 (megatron_workers)

主入口 megatron_workers.compute_log_prob
  • 为每个data设置meta_info信息。
    • micro_batch_size = rollout.log_prob_micro_batch_size_per_gpu
    • max_token_len = rollout.log_prob_max_token_len_per_gpu
  • 调用actor计算log_prob
  • 返回output(old_log_probs)、entropys

ActorRolloutRefWorker.compute_log_prob

python
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@GPUMemoryLogger(role="compute_log_prob", logger=logger)
@DistProfiler.annotate(color="blue")
def compute_log_prob(self, data: DataProto):
    assert self._is_actor
    if self._is_offload_param:
        load_megatron_model_to_gpu(self.actor_module, load_grad=False)
        log_gpu_memory_usage("After load actor params and grad during compute_log_prob", logger=logger)
    # 关键配置
    # we should always recompute old_log_probs when it is HybridEngine
    data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu 
    data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu 
    data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz 
    data.meta_info["temperature"] = self.config.rollout.temperature 
    # 核心部分,调用megatron_actor计算
    output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) 
    # 返回格式
    output = DataProto.from_dict( 
        tensors={"old_log_probs": output, "entropys": entropys}, 
        meta_info={"temperature": self.config.rollout.temperature}, 
    ) 
    output = output.to("cpu")
    # clear kv cache
    if self._is_offload_param:
        offload_megatron_model_to_cpu(self.actor_module)
        log_gpu_memory_usage("After offload actor params and grad during compute_log_prob", logger=logger)
    aggressive_empty_cache(force_sync=True)
    return output

核心compute_log_prob (MegatronPPOActor)

核心 MegatronPPOActor.compute_log_prob
  • 从data读取meta_info: micro_batch_sizemax_token_len、use_dynamic_bsz。

  • 读取data中的responsesinput_idsattention_maskposition_ids

  • 调用forward_backward_batch,去做真实计算,具体见下文

    • 传入后处理函数compute_logprobs_fn,根据真实response_length 选择log_probs
  • pipeline逻辑:流水线计算,完成后再`广播

MegatronPPOActor.compute_log_prob

python
@GPUMemoryLogger(role="megatron actor", logger=logger)
def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:
    """Compute the log probability of the responses given input_ids, attention_mask and position_ids

    Args:
        data (DataProto): a DataProto containing keys
            ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
            concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
            ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
            ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
            ``responses``:  tensor of shape [batch_size, response_length]. torch.int64.

    Returns:
        DataProto: torch.Tensor: the log_prob tensor
    """
    use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False)
    micro_batch_size = data.meta_info.get("micro_batch_size", None) 
    max_token_len = data.meta_info.get("max_token_len", None) 
    if use_dynamic_bsz:
        assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True"
        max_token_len = max_token_len * self.config.megatron.context_parallel_size
    else:
        assert micro_batch_size is not None, (
            "micro batch size is needed for forward compute when use_dynamic_bsz is False"
        )

    def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): 
        response = data["responses"] 
        response_length = response.size(1) 
        log_probs = output["log_probs"][:, -response_length - 1 : -1].contiguous() 
        return {"log_probs": log_probs} 

    # We make recompute_old_log_prob by default here.
    # TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be
    # handled by user outside
    recompute_old_log_prob = self.config.get("recompute_old_log_prob", True)

    entropys = torch.Tensor()
    if recompute_old_log_prob:
        select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] 
        batch = data.select(batch_keys=select_keys).batch 
        input_ids = batch["input_ids"]
        batch_size = input_ids.size(0)
        response = batch["responses"]
        response_length = response.size(1)
        with torch.no_grad():
            output = self.forward_backward_batch( 
                data, 
                forward_only=True, 
                post_process_fn=compute_logprobs_fn, 
                calculate_entropy=calculate_entropy, 
                use_dynamic_bsz=use_dynamic_bsz, 
                micro_batch_size=micro_batch_size, 
                max_token_len=max_token_len, 
            ) 
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # only on last rank. It should be on every tp rank
                if calculate_entropy: 
                    log_probs = [o[0]["log_probs"] for o in output["output"]]  # (bs, seq_size) #
                else: 
                    log_probs = [o["log_probs"] for o in output["output"]]  # (bs, seq_size) #
                log_probs = torch.cat(log_probs, dim=0).to(torch.float32)
                if use_dynamic_bsz:
                    indices = output["indices"]
                    indices = list(itertools.chain.from_iterable(indices))
                    assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
                    revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
                    log_probs = log_probs[revert_indices]
            else:
                log_probs = torch.empty(
                    size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device
                )
            # broadcast across pp ranks
            log_probs = log_probs.to(get_device_id()) 
            torch.distributed.broadcast( 
                tensor=log_probs, 
                src=mpu.get_pipeline_model_parallel_last_rank(), 
                group=mpu.get_pipeline_model_parallel_group(), 
                async_op=False, 
            ) 
            log_probs = log_probs.to("cpu")
            if calculate_entropy:
                # Note that o[0] is metrics, o[1] is entropy
                if mpu.is_pipeline_last_stage(ignore_virtual=True):
                    entropys = torch.cat([o[1] for o in output["output"]], dim=0)
                    entropys = entropys.to(torch.float32)
                    if use_dynamic_bsz:
                        indices = output["indices"]
                        indices = list(itertools.chain.from_iterable(indices))
                        assert len(indices) == entropys.size(0), f"{len(indices)} vs. {entropys.size()}"
                        revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
                        entropys = entropys[revert_indices]
                else:
                    entropys = torch.empty(
                        size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device
                    )
                # broadcast across pp ranks
                entropys = entropys.to(get_device_id()) 
                torch.distributed.broadcast( 
                    tensor=entropys, 
                    src=mpu.get_pipeline_model_parallel_last_rank(), 
                    group=mpu.get_pipeline_model_parallel_group(), 
                    async_op=False, 
                ) 
                entropys = entropys.to("cpu")

    # add empty cache after each compute
    get_torch_device().empty_cache()

    return log_probs, entropys 

核心forward_backward_batch主流程 (MegatronPPOActor)

MegatronPPOActor.forward_backward_batch
  • 传入data即为mini_batch,根据micro_batch_size 划分成多个micro_batches
    • micro_batch_size=log_prob_micro_batch_size_per_gpu,见上文主入口
  • 计算n_micro_batch,构建batch_generator,后续运算基于此
  • 计算total_seqlen=micro_batch_size * seq_len(input_ids长度)
  • 调用megatron.core.pipeline_parallel的get_forward_backward_func
    • 传参自定义forward_step_func,即forward_step,里包含loss_func,具体见下文。

MegatronPPOActor.forward_backward_batch

python
def forward_backward_batch(
    self,
    data: DataProto,
    forward_only=False,
    post_process_fn=None,
    calculate_entropy=False,
    use_dynamic_bsz=False,
    micro_batch_size=None,
    max_token_len=None,
    mini_batch_size=None,
):
    """
    We assume:
    - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input
    - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled
    """
    # broadcast from last pp rank to all other pp ranks
    # TODO: actually, we just need to control the sampling order.
    data.to(get_device_id()) 
    data.batch = data.batch.contiguous() 
    mini_batch = data 
    broadcast_dict_tensor(
        mini_batch.batch,
        src=mpu.get_pipeline_model_parallel_last_rank(),
        group=mpu.get_pipeline_model_parallel_group(),
    )
    mini_batch.to("cpu")
    # split into micro-batches
    mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) 
    self.has_multi_modal_inputs = "multi_modal_inputs" in mini_batch.non_tensor_batch.keys()
    if self.has_multi_modal_inputs:
        mini_batch.batch["multi_modal_inputs"] = mini_batch.non_tensor_batch["multi_modal_inputs"]
        mini_batch.batch["multi_modal_inputs_idx"] = torch.Tensor(
            list(range(len(mini_batch.non_tensor_batch["multi_modal_inputs"])))
        ).to(torch.int64)

    if mini_batch.batch["position_ids"].dim() == 3:  # qwen2vl mrope [bs, 3, seq_len]
        mini_batch.batch["position_ids"] = mini_batch.batch["position_ids"][
            :, 0
        ]  # mcore patch recompute qwen2vl's pos ids during forward

    indices = None
    temperature = data.meta_info["temperature"]
    if use_dynamic_bsz:
        assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True"
        vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
        if vpp_size is not None and vpp_size > 1:
            microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage
            micro_batches, indices = rearrange_micro_batches(
                batch=mini_batch.batch,
                num_batches_divided_by=microbatch_group_size_per_vp_stage,
                max_token_len=max_token_len,
            )
            assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, (
                f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage "
                f"{microbatch_group_size_per_vp_stage} for megatron backend"
            )
        else:
            micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len)
        total_seqlen = max_token_len
    else:
        assert micro_batch_size is not None, (
            "micro_batch_size is needed to be passed in when not using dynamic batch size"
        )
        micro_batches = mini_batch.batch.split(micro_batch_size) 
        seq_len = micro_batches[0]["input_ids"].shape[1] 
        total_seqlen = micro_batch_size * seq_len 
    # compute input shapes for pp stages
    n_micro_batch = len(micro_batches)
    # from megatron.core.pipeline_parallel import get_forward_backward_func
    forward_backward_func = get_forward_backward_func()
    # batch should be a list of batches inside micro-batches
    batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.actor_module)) 

    # TODO: we may use the new schedule instead
    # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        losses_reduced = forward_backward_func(
            forward_step_func=forward_step, # 自定义的forward_step  #
            data_iterator=batch_generator,
            model=self.actor_module,
            num_microbatches=n_micro_batch,
            seq_length=total_seqlen,  # no use when input_shapes was set
            micro_batch_size=1,  # no use when input_shapes was set
            forward_only=forward_only,
        )
    else:
        losses_reduced = forward_backward_func( 
            forward_step_func=forward_step, # 自定义的forward_step #
            data_iterator=batch_generator, 
            model=self.actor_module, 
            num_microbatches=n_micro_batch, 
            seq_length=total_seqlen,  # in use for pp = 1 #
            micro_batch_size=1,  # in use for pp = 1 #
            forward_only=forward_only,  
        ) 
    # loss_reduces contains the stats returned from loss_func

    if self.has_multi_modal_inputs:
        data.batch.pop("multi_modal_inputs")
        data.batch.pop("multi_modal_inputs_idx")
        data.non_tensor_batch.pop("multi_modal_inputs")

    losses_reduced = {"output": losses_reduced} 
    if use_dynamic_bsz:
        losses_reduced["indices"] = indices
    return losses_reduced 

核心forward_step (forward_backward_batch)

MegatronPPOActor.forward_backward_batch.forward_step
  • 单步batch data
    • 为前面的batch_generator每一步取得数据
      • micro_batch_size=log_prob_micro_batch_size_per_gpu
    • 获取input_idsattention_maskposition_idsresponses
    • 构建labellabel_mask,label为responses
  • 根据hf_config调用get_mcore_forward_fn获取forward_fn
  • 自定义logits_processor函数,传入logitslabellabel_mask,计算log_probs
    • 根据给定label(llm reponse) 计算 logits,label_mask处置为0,
    • 具体计算见下文,本质就是筛选 index=row_labelslog_probs
  • 调用forward_fn,获得output
    • 入参:model、input_ids、attention_mask、position_ids和logits_processor
  • 最后,传入loss_func、batch data、meta_info去计算loss

MegatronPPOActor.forward_backward_batch.forward_step

python
def forward_step(batch_iter, model, return_schedule_plan: bool = False):
    """
    Args:
        batch_iter: the batch iterator
        model: the model
        return_schedule_plan: whether to return the schedule plan, for 1f1b overlap
    """
    batch = next(batch_iter)
    batch = batch.to(get_device_id())
    batch = batch.contiguous()

    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"].to(bool)
    position_ids = batch["position_ids"]

    multi_modal_inputs = {}
    if "multi_modal_inputs" in batch:
        from verl.utils.model import extract_multi_modal_inputs

        indices = batch.get("multi_modal_inputs_idx", None)
        multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices)
    responses = batch["responses"]
    response_length = responses.size(1) 
    label = position_ids.clone() 
    label[:, -response_length - 1 : -1] = responses 
    label_mask = attention_mask.clone() 
    label_mask[:, : -response_length - 1] = False
    label_mask[:, -1] = False

    from verl.models.mcore import get_mcore_forward_fn, get_mcore_forward_fused_fn

    if self.use_fused_kernels:
        forward_fn = get_mcore_forward_fused_fn(self.hf_config)
        if return_schedule_plan:
            forward_fn = gptmodel_forward_1f1b_overlap
        # return dict of [logits, entropy]
        output = forward_fn(
            model=model,
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            labels=label,
            labels_mask=label_mask,
            temperature=temperature,
            multi_modal_inputs=multi_modal_inputs,
        )
    else:
        forward_fn = get_mcore_forward_fn(self.hf_config) 

        def logits_processor(logits, label, label_mask): 
            assert logits.shape[:2] == label.shape[:2]
            assert label.shape == label_mask.shape
            logits.div_(temperature)
            ret = {}
            if calculate_entropy: 
                logits_bak = logits.clone()
                logger.warning_once(
                    "For memory-efficient computation, enable fused kernels via "
                    "`actor_rollout_ref.model.use_fused_kernels=True`. "
                    "The current `clone()` operation ensures correctness but increases memory usage."
                )
                entropy = vocab_parallel_entropy(logits) 
                ret["entropy"] = entropy
            else:
                logits_bak = logits
            # 根据给定label(llm reponse) 计算 logits,label_mask处置为0
            log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label) 
            log_probs = log_probs.masked_fill(~label_mask, 0.0)
            ret["log_probs"] = log_probs
            return ret

        logits_processor_args = {"label": label, "label_mask": label_mask} 
        output = forward_fn(
            model=model,
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            position_ids=position_ids, 
            multi_modal_inputs=multi_modal_inputs,
            logits_processor=logits_processor, 
            logits_processor_args=logits_processor_args,
        )

    if forward_only:
        meta_info = None
    else:
        clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
        meta_info = {
            "clip_ratio": self.config.clip_ratio,
            "entropy_coeff": self.config.entropy_coeff,
            "clip_ratio_c": clip_ratio_c,
        }
    return output, partial(loss_func, data=batch, meta_info=meta_info) 

熵和logprobs计算

计算熵
python
class _VocabParallelEntropy(torch.autograd.Function):
  def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor:
      def mul_reduce(a, b):
          return (a * b).sum(dim=-1, keepdim=True)
      # 稳定性操作,避免溢出,减去最大值
      logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values
      normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max
      # exp_logits
      normalized_exp_logits = normalized_vocab_parallel_logits.exp_() 
      # sum_exp_logits
      normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True) 
      # 计算每个logit概率
      softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits) 
      # p_i * logits_i
      sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits)  
      # 最终的熵,log_sum_exp_logits - sum_softmax_times_logits + logits_max
      entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits 
      ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits)
      return entropy.squeeze(dim=-1)

  def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
      vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors
      # reuse softmax_logits as grad
      vocab_parallel_logits.sub_(sum_softmax_times_logits)
      softmax_logits.mul_(vocab_parallel_logits)
      softmax_logits.mul_(grad_output.unsqueeze(dim=-1))
      # recover vocab_parallel_logits
      vocab_parallel_logits.add_(sum_softmax_times_logits)
      softmax_logits.mul_(-1)
      return softmax_logits
计算label的logprobs

FSDP 实现-torch_function.py

  • flash_attention实现:再点进去看代码。
  • 有简单版实现:logprobs_from_logits_v2
    • 其实就是筛选 index=row_labelslog_probs

Megatron实现

  • 待看
python
def logprobs_from_logits(logits, labels, inplace_backward=True):
    """
    Compute per-token log-probabilities for the given labels.

    Uses a Flash-Attention–based cross-entropy (if available) for efficient backward,
    otherwise falls back to a standard log-softmax+gather approach.

    See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591

    Args:
        logits (Tensor): Model outputs of shape (..., vocab_size).
        labels (LongTensor): True class indices of shape matching logits[..., :-1].
        inplace_backward (bool): If True and Flash-Attn is available, perform backward in-place.

    Returns:
        Tensor: Log-probabilities of the target labels, shape logits.shape[:-1].
    """
    if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE:
        batch_dim = logits.shape[:-1]
        last_dim = logits.shape[-1]
        logits = logits.reshape(-1, last_dim)
        labels = labels.reshape(-1)
        output = logprobs_from_logits_flash_attn(logits, labels, inplace_backward=inplace_backward)  
        output = output.view(*batch_dim)
    elif NPU_CROSS_ENTROPY_LOSS_AVAILABLE:
        output = logprobs_from_logits_torch_npu(logits, labels)
    else:
        output = logprobs_from_logits_v2(logits, labels)  
    return output

核心方法

python
def logprobs_from_logits_v2(logits: torch.FloatTensor, labels):
    """
    A memory efficient implementation of logprobs_from_logits
    """
    if logits.dtype in [torch.float32, torch.float64]:
        logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
        # loop to reduce peak mem consumption
        logsumexp_values = torch.stack([torch.logsumexp(logit, dim=-1) for logit in logits])
        logprobs_labels = logits_labels - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)
    else:
        # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
        logprobs_labels = [] 
        for row_logits, row_labels in zip(logits, labels, strict=True):  # loop to reduce peak mem consumption  #
            row_logprobs = F.log_softmax(row_logits, dim=-1) 
            # 取index=row_labels的logprobs
            row_logprobs_labels = row_logprobs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) 
            logprobs_labels.append(row_logprobs_labels)
        logprobs_labels = torch.stack(logprobs_labels)
    return logprobs_labels

Megatron实现:封装的很,看不进去了。

python
def vocab_parallel_log_probs_from_logits(logits, labels):
    """TODO(zhangchi.usc1992): We may change the implementation later"""
    from megatron.core import tensor_parallel

    return -tensor_parallel.vocab_parallel_cross_entropy(vocab_parallel_logits=logits, target=labels)

核心Policy Loss计算 (forward_backward_batch)

调用逻辑
  • 在上文forward_step里,计算出output的entropylog_probs之后,再计算loss
    • 如果forward_only,则无需计算。
    • 否则,才需要计算loss。
python
def forward_step(batch_iter, model, return_schedule_plan: bool = False):
  # ...
  # ...
  return output, partial(loss_func, data=batch, meta_info=meta_info)
Policy Loss 核心计算逻辑

理论笔记

  • 具体理论见:PPO-Actor 笔记
  • Policy Loss = PPO loss - 熵奖励 + KL 惩罚
  • Policy Loss = pg_loss - entropy_coeff * entropy_loss + kl_loss_coef * kl_loss

PG Loss

  • 优势*重要性权重 + PPO-Clip
  • 具体见下文

熵奖励

  • Policy Loss = Policy Loss - entropy_coeff * entropy_loss

KL惩罚

  • Policy Loss = Policy Loss + kl_loss_coef * kl_loss

Loss_func 返回值

  • policy_loss, [metrics, ret_entropy]
Policy Loss 具体计算过程
  • 数据读取

    • 从output读取当前log_probsentropy
    • 从data读取old_log_probsadvantages
    • 从config里读取相关参数,entropy_coeff, kl_loss_coef, loss_agg_mode等。
  • 计算 PGLoss

    • 根据配置获得policy_loss_fn,默认是compute_policy_loss_vanilla
    • 输入old_log_problog_probadvantagesresponse_mask、loss_agg_mode等参数
  • 计算 熵奖励

    • 根据entropy_coeff,总loss减去熵奖励
  • 计算 KL惩罚

    • 利用ref_log_prob和log_prob计算klloss,根据kl_loss_coef,总loss加上klloss

MegatronPPOActor.forward_backward_batch.loss_func

python
def loss_func(output, data, meta_info):
    # For memory efficiency
    # We move calculation of entropy to compute_log_probs, forward_only == True
    log_probs = None
    entropy = None
    if isinstance(output, dict):
        log_probs = output["log_probs"] 
        if "entropy" in output:
            entropy = output["entropy"] 
    else:
        assert isinstance(output, torch.Tensor)
        log_probs = output

    device = log_probs.device
    metrics = {}  
    if forward_only:
        if post_process_fn is None:
            pass
            # metrics["logits"] = output
        else:
            stats = post_process_fn(output, data)
            metrics.update(stats)
        if not calculate_entropy:
            return torch.tensor(1.0, device=device), metrics

    responses = data["responses"]
    response_length = responses.size(1)
    response_mask = data["response_mask"].to(bool)
    loss_agg_mode = self.config.loss_agg_mode
    # compute policy loss
    log_prob = log_probs[:, -response_length - 1 : -1].contiguous() 
    ret_entropy = None
    stats = {}
    if not forward_only:
        old_log_prob = data["old_log_probs"] 
        advantages = data["advantages"] 

        entropy_coeff = self.config.entropy_coeff
        loss_agg_mode = self.config.loss_agg_mode
        # 调用core_algos.py里的policy_loss_fn,具体见下文
        loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") 
        policy_loss_fn = get_policy_loss_fn(loss_mode) 

        # Extract pre-computed rollout correction weights if present
        # Weights are computed centrally in trainer and added when algorithm.rollout_is=True
        rollout_is_weights = data.get("rollout_is_weights", None)
        pg_loss, pg_metrics = policy_loss_fn( 
            old_log_prob=old_log_prob, 
            log_prob=log_prob, 
            advantages=advantages, 
            response_mask=response_mask, 
            loss_agg_mode=loss_agg_mode, 
            config=self.config, 
            rollout_is_weights=rollout_is_weights, 
        ) 
        stats.update(pg_metrics)

        # Skip if using pure rollout correction mode (metrics already in pg_metrics)
        rollout_log_prob = data.get("rollout_log_probs", None)
        if loss_mode != "rollout_correction" and rollout_log_prob is not None:
            # Compute metrics using CURRENT policy π_θ vs π_rollout
            # Tracks evolving off-policy gap as π_θ updates during mini-batch training
            from verl.trainer.ppo.rollout_corr_helper import compute_rollout_corr_metrics_from_logprobs

            rollout_corr_metrics = compute_rollout_corr_metrics_from_logprobs(
                log_prob=log_prob,
                rollout_log_prob=rollout_log_prob,
                response_mask=response_mask,
            )
            stats.update(rollout_corr_metrics)

        stats["actor/pg_loss"] = pg_loss.detach().item()
        policy_loss = pg_loss

    if calculate_entropy:
      	# 熵奖励
        entropy = output["entropy"][:, -response_length - 1 : -1].contiguous() 
        if not forward_only: 
            entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) 
            entropy_coeff = meta_info["entropy_coeff"] 
            policy_loss = pg_loss - entropy_coeff * entropy_loss 
        else:
            ret_entropy = entropy

    if forward_only:
        policy_loss = torch.tensor(1.0, device=device)
    else:
        if self.config.use_kl_loss: 
          	# KL 惩罚
            ref_log_prob = data["ref_log_prob"] 
            # compute kl loss
            kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type) 
            kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode) 

            policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef 
            metrics["actor/kl_loss"] = kl_loss.detach().item() 
            metrics["actor/kl_coef"] = self.config.kl_loss_coef 

        # return loss and stats

    append_to_dict(metrics, stats)
    return policy_loss, [metrics, ret_entropy]  

策略梯度 PG Loss 计算 (core_algos.py)

PPO PG Loss

PPO Loss

  • 优势 * 重要性权重PPO-ClipPPO-Dual-Clip
  • Loss 聚合类型token-meanseq-mean-token-mean 等。
  • 入参
    • old_log_problog_probadvantagesresponse_mask、loss_agg_mode等

具体计算过程

  • 计算重要性采样权重ppo_kl监控指标

    • ratio=torch.exp(log_prob - old_log_prob)

    • 就是用-(log_prob - old_log_prob),再根据response_mask做mean近似kl

  • 计算朴素无clip的pg_losses1

    • pg_losses1 = -advantages * ratio
  • 计算clip_pg_losses1

    • 根据clip_range(low和high),对IS权重做clip,再乘以优势,计算pg_losses2
    • pg_losses1和pg_losses2 选大的,loss是选大的
  • 计算clip_pg_losses2 (dual clip loss),具体原理见 dual-clip-loss

    • 使用clip_ratio_c=3.0乘以优势,和clip_pg_losses1 选小的。
  • 计算最终clip_pg_losses

    • 优势小于0的token,选择clip_pg_losses2其余仍选择clip_pg_losses1
    • 根据loss_agg_modeloss聚合

core_alogs.compute_policy_loss_vanilla

python
@register_policy_loss("vanilla")  # type: ignore[arg-type]
def compute_policy_loss_vanilla(
    old_log_prob: torch.Tensor,
    log_prob: torch.Tensor,
    advantages: torch.Tensor,
    response_mask: torch.Tensor,
    loss_agg_mode: str = "token-mean",
    config: Optional[DictConfig | AlgoConfig] = None,
    rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
    """
    Compute the clipped policy objective and related metrics for PPO.

    Adapted from
    https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122

    Args:
        old_log_prob (torch.Tensor):
            Log-probabilities of actions under the old policy, shape (batch_size, response_length).
        log_prob (torch.Tensor):
            Log-probabilities of actions under the current policy, shape (batch_size, response_length).
        advantages (torch.Tensor):
            Advantage estimates for each action, shape (batch_size, response_length).
        response_mask (torch.Tensor):
            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
        loss_agg_mode (str, optional):
            Aggregation mode for `agg_loss`. Defaults to "token-mean".
        config: `(verl.trainer.config.ActorConfig)`:
            config for the actor.
        rollout_log_probs: `(torch.Tensor)`:
            log probabilities of actions under the rollout policy, shape (batch_size, response_length).
    """

    assert config is not None
    assert not isinstance(config, AlgoConfig)
    clip_ratio = config.clip_ratio  # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
    clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio 
    clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio 
    clip_ratio_c = config.get(  # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
        "clip_ratio_c", 3.0
    )

    cliprange = clip_ratio
    cliprange_low = clip_ratio_low
    cliprange_high = clip_ratio_high

    assert clip_ratio_c > 1.0, (
        "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
        + f" but get the value: {clip_ratio_c}."
    )
    negative_approx_kl = log_prob - old_log_prob
    # KL Clip for stability
    negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
    # IS 重要性权重
    ratio = torch.exp(negative_approx_kl) 
    # 近似KL散度,作为监控指标,计算均值
    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
    # 朴素无clip的,优势*重要性权重
    pg_losses1 = -advantages * ratio  
    if cliprange_low is None:
        cliprange_low = cliprange
    if cliprange_high is None:
        cliprange_high = cliprange
    # ppo-clip,第二项,对ratio做clip,乘以advantages
    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)  # - clip(ratio, 1-cliprange, 1+cliprange) * A  #
    # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
    clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)    
    # clip 比例,有多少loss2 > loss1的
    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
    
    # dual ppo clip,在优势小于0时,额外增加clip
    pg_losses3 = -advantages * clip_ratio_c 
    clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) 
    # dual ppo clip 比例
    pg_clipfrac_lower = verl_F.masked_mean(
        torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask
    )
    # dual ppo clip,在优势小于0时,额外增加一个裁剪
    pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

    # Apply rollout correction weights if provided
    if rollout_is_weights is not None:
        pg_losses = pg_losses * rollout_is_weights

    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)  

    pg_metrics = {
        "actor/pg_clipfrac": pg_clipfrac.detach().item(),
        "actor/ppo_kl": ppo_kl.detach().item(),
        "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
    }
    return pg_loss, pg_metrics

loss聚合类型

python
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):
    """
    Aggregate the loss matrix into a scalar.

    Args:
        loss_mat: `(torch.Tensor)`:
            shape: (bs, response_length)
        loss_mask: `(torch.Tensor)`:
            shape: (bs, response_length)
        loss_agg_mode: (str) choices:
            method to aggregate the loss matrix into a scalar.
    Returns:
        loss: `a scalar torch.Tensor`
            aggregated loss
    """
    if loss_agg_mode == "token-mean": 
        loss = verl_F.masked_mean(loss_mat, loss_mask)  
    elif loss_agg_mode == "seq-mean-token-sum":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)  # token-sum
        seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float()  # exclude fully masked sequences
        loss = verl_F.masked_mean(seq_losses, seq_mask)  # seq-mean
    elif loss_agg_mode == "seq-mean-token-mean":  
        seq_mask = torch.sum(loss_mask, dim=-1)  # per-sequence token count
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8)  # token-mean
        seq_mask = (seq_mask > 0).float()  # exclude fully masked sequences
        loss = verl_F.masked_mean(seq_losses, seq_mask)  # seq-mean
    elif loss_agg_mode == "seq-mean-token-sum-norm":  
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
        loss = torch.sum(seq_losses) / loss_mask.shape[-1]  # The divisor
        # (loss_mask.shape[-1]) should ideally be constant
        # throughout training to well-replicate the DrGRPO paper.
        # TODO: Perhaps add user-defined normalizer argument to
        # agg_loss to ensure divisor stays constant throughout.
    else:
        raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")

    return loss

Update Critic

python
def fit(self):
    # ....
    # ....
    # update critic
    if self.use_critic:
        with marked_timer("update_critic", timing_raw, color="pink"):
            critic_output = self.critic_wg.update_critic(batch)  
        critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
        metrics.update(critic_output_metrics)

update_critic 主流程 (Megatron CriticWorker)

CriticWorker.update_critic 主入口 (megatron)
  • 为data设置meta_info
    • micro_batch_size=ppo_micro_batch_size_per_gpu
  • 构建dataloader
    • 根据ppo_mini_batch_sizeppo_epochs建立,非常重要!!
    • 代码
  • 传入dataloader,调用criticupdate_critic方法,更新critic,具体见下文
  • 返回一些metrics

CriticWorker.update_critic

python
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic"))
@DistProfiler.annotate(color="pink")
def update_critic(self, data: DataProto):
    data = data.to(get_device_id())

    if self._is_offload_param:
        load_megatron_model_to_gpu(self.critic_module)
    if self._is_offload_optimizer:
        load_megatron_optimizer(self.critic_optimizer)

    dataloader = self.critic.make_minibatch_iterator(data)  
    with Timer(name="update_critic", logger=None) as timer:
        metrics = self.critic.update_critic(dataloader=dataloader) 
    delta_time = timer.last
    global_num_tokens = data.meta_info["global_token_num"]
    estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
    metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
    from verl.utils.megatron.optimizer import get_megatron_last_lr

    metrics["critic/lr"] = get_megatron_last_lr(self.critic_optimizer)
    self.critic_optimizer_scheduler.step(1)

    output = DataProto(batch=None, meta_info={"metrics": metrics})

    if self._is_offload_param:
        offload_megatron_model_to_cpu(self.critic_module)
    if self._is_offload_optimizer:
        offload_megatron_optimizer(self.critic_optimizer)
    output = output.to("cpu")
    return output

make_minibatch_iterator代码

MegatronPPOCritic.make_minibatch_iterator
  • 传入data,调用data.make_iterator,关键参数
    • mini_batch_size=self.config.ppo_mini_batch_size
    • epochs=self.config.ppo_epochs

MegatronPPOCritic.make_minibatch_iterator

python
def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
    select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"]
    data = data.select(batch_keys=select_keys)
    return data.make_iterator(
        mini_batch_size=self.config.ppo_mini_batch_size,
        epochs=self.config.ppo_epochs,
        seed=self.config.data_loader_seed,
        dataloader_kwargs={"shuffle": self.config.shuffle},
    )

update_critic (MegatronPPOCriitc)

MegatronPPOCritic.update_critic
  • 遍历dataloader获取当前data多次计算和更新
    • micro_batch_size=ppo_micro_batch_size_per_gpu
    • mini_batch_size = ppo_mini_batch_size
  • 单步更新过程
    • critic_optimizer.zero_grad() ,梯度归零
    • 调用forward_backward_batch计算各种metrics(含loss),具体见上文
    • critic_optimizer.step()更新actor

MegatronPPOCritic.update_critic

python
@GPUMemoryLogger("megatron critic", logger=logger)
def update_critic(self, dataloader: Iterable[DataProto]):
    metrics = {}
    for data in dataloader:  
        self.critic_optimizer.zero_grad()
        # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
        for chunk in self.critic_module:
            chunk.zero_grad_buffer()

        micro_batch_size = self.config.ppo_micro_batch_size_per_gpu 
        max_token_len = None
        if self.config.use_dynamic_bsz:
            max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size
        # 最核心部分
        metric_micro_batch = self.forward_backward_batch( 
            data,
            forward_only=False,
            use_dynamic_bsz=self.config.use_dynamic_bsz,
            micro_batch_size=micro_batch_size, 
            max_token_len=max_token_len,
            mini_batch_size=self.config.ppo_mini_batch_size, 
        )
        metric_micro_batch = metric_micro_batch["output"]
        update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step() 
        learning_rate = self.critic_optimizer.param_groups[-1]["lr"]
        data = {"critic/grad_norm": grad_norm, "critic/lr": learning_rate}
        append_to_dict(metrics, data)

        if update_successful:
            # allgather already execute in optimizer.step in new megatron
            pass
        else:
            raise NotImplementedError

        for metric in metric_micro_batch:
            append_to_dict(metrics, metric)  # append the metric from this micro-batch to global metrics.

    # add empty cache after each compute
    get_torch_device().empty_cache()
    return metrics

forward_backward_batch (MegatronPPOCritic)

MegatronPPOCritic.forward_backward_batch
  • 传入data即为mini_batch,根据micro_batch_size 划分成多个micro_batches

    • micro_batch_size=ppo_micro_batch_size_per_gpu,见上文主入口
  • 计算n_micro_batch,构建batch_generator,后续运算基于此

    • 计算total_seqlen=micro_batch_size * seq_len(input_ids长度)
  • 调用megatron.core.pipeline_parallel的get_forward_backward_func

    • 传参自定义forward_step_func,即forward_step,里包含loss_func,具体见下文。

MegatronPPOCritic.forward_backward_batch

python
def forward_backward_batch(
        self,
        data: DataProto,
        forward_only=False,
        use_dynamic_bsz=False,
        micro_batch_size=None,
        max_token_len=None,
        mini_batch_size=None,
    ):
    # broadcast from last pp rank to all other pp ranks
    data.to(get_device_id())
    mini_batch = data
    mini_batch.batch = mini_batch.batch.contiguous()
    broadcast_dict_tensor(
        mini_batch.batch,
        src=mpu.get_pipeline_model_parallel_last_rank(),
        group=mpu.get_pipeline_model_parallel_group(),
    )
    mini_batch.to("cpu")
    # split into micro-batches
    mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool)

    indices = None
    if use_dynamic_bsz:
        assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True"
        vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
        if vpp_size is not None and vpp_size > 1:
            microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage
            micro_batches, indices = rearrange_micro_batches(
                batch=mini_batch.batch,
                num_batches_divided_by=microbatch_group_size_per_vp_stage,
                max_token_len=max_token_len,
            )
            assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, (
                f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage "
                f"{microbatch_group_size_per_vp_stage} for megatron backend"
            )
        else:
            micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len)
        total_seqlen = max_token_len
    else:
        assert micro_batch_size is not None, (
            "micro_batch_size is needed to be passed in when not using dynamic batch size"
        )
        micro_batches = mini_batch.batch.split(micro_batch_size) 
        seq_len = micro_batches[0]["input_ids"].shape[1] 
        total_seqlen = micro_batch_size * seq_len 
    
    n_micro_batch = len(micro_batches) 
    forward_backward_func = get_forward_backward_func()
    # batch should be a list of batches inside micro-batches
    batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.critic_module)) 

    # TODO: we may use the new schedule instead
    # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        losses_reduced = forward_backward_func(
            forward_step_func=forward_step,
            data_iterator=batch_generator,
            model=self.critic_module,
            num_microbatches=n_micro_batch,
            seq_length=total_seqlen,  # no use when input_shapes was set
            micro_batch_size=1,  # no use when input_shapes was set
            forward_only=forward_only,
        )
    else:
        losses_reduced = forward_backward_func( 
            forward_step_func=forward_step, 
            data_iterator=batch_generator, 
            model=self.critic_module, 
            num_microbatches=n_micro_batch, 
            seq_length=total_seqlen,  # in use for pp = 1 #
            micro_batch_size=1,  # in use for pp = 1 #
            forward_only=forward_only, 
        ) 
    # loss_reduces contains the stats returned from loss_func
    losses_reduced = {"output": losses_reduced} 
    if use_dynamic_bsz:
        losses_reduced["indices"] = indices
    return losses_reduced

forward_step (forward_backward_batch)

forward_backward_batch.forward_step
  • 当前dataloader信息,对batch根据下面2个参数构建dataloader,具体见上文
    • mini_batch_size = ppo_mini_batch_size
    • epochs= actor_rollout_ref.actor.ppo_epochs
  • 单步batch data
    • 为前面的batch_generator每一步取得数据
      • micro_batch_size=ppo_micro_batch_size_per_gpu
    • 获取input_idsattention_maskposition_ids
  • 调用forward_fn,获得output
    • 入参:model、input_ids、attention_mask、position_ids
  • 最后,传入loss_func、batch data、meta_info去计算loss

MegatronPPOCritic.forward_backward_batch.forward_step

python
def forward_step(batch_iter, model):
    batch = next(batch_iter)
    batch = batch.to(get_device_id())
    batch = batch.contiguous()

    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    position_ids = batch["position_ids"]
    from verl.models.mcore import get_mcore_forward_fn

    forward_fn = get_mcore_forward_fn(self.hf_config)

    output = forward_fn(
        model, 
        input_ids, 
        attention_mask, 
        position_ids, 
        {},  # multi_modal_inputs
        value_model=True,
    )
    # 计算loss
    return output, partial(loss_func, data=batch, meta_info={}) 

loss_func (forward_backward_batch)

forward_backward_batch.loss_func

笔记

实际计算过程

  • 基于preds和returns计算critic loss,但保留旧的values来做ValueClip,防止更新过大

MegatronPPOCritic.forward_backward_batch.loss_func

python
def loss_func(output, data, meta_info):
    nonlocal use_dynamic_bsz

    if forward_only:
        return torch.tensor(1.0, device=output.device), {"vpreds": output}

    responses = data["responses"] 
    attention_mask = data["attention_mask"] 
    # 本来应该用returns,但也用values做clip,防止更新过大
    values = data["values"] 
    returns = data["returns"] 
    response_length = responses.size(1)
    
    # response mask
    response_mask = attention_mask[:, -response_length:]
    
    # value clip
    cliprange_value = self.config.cliprange_value 

    vpreds = output  # (bs, sequence_length)
    vpreds = vpreds[:, -response_length - 1 : -1]  

    vf_loss, vf_clipfrac = core_algos.compute_value_loss(  
        vpreds=vpreds, 
        values=values, 
        returns=returns, 
        response_mask=response_mask, 
        cliprange_value=cliprange_value, 
        loss_agg_mode=self.config.loss_agg_mode, 
    ) 

    stats = {
        "critic/vf_loss": vf_loss.detach().item(),
        "critic/vf_clipfrac": vf_clipfrac.detach().item(),
        "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(),
    }

    return vf_loss, stats

compute_value_loss (core_aglos.py)

compute_value_loss

笔记

Loss

  • 平方误差loss、Value Clip等

core_aglos.compute_value_loss

python
def compute_value_loss(
    vpreds: torch.Tensor,
    returns: torch.Tensor,
    values: torch.Tensor,
    response_mask: torch.Tensor,
    cliprange_value: float,
    loss_agg_mode: str = "token-mean",
):
    """
    Compute the clipped value-function loss for PPO.

    Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151

    Args:
        vpreds (torch.FloatTensor):
            Predicted values from the value head, shape (batch_size, response_length).
        values (torch.FloatTensor):
            Old (baseline) values from the value head, shape (batch_size, response_length).
        returns (torch.FloatTensor):
            Ground-truth returns, shape (batch_size, response_length).
        response_mask (torch.Tensor):
            Mask indicating which tokens to include in the value loss calculation.
        cliprange_value (float):
            Clip range for value prediction updates.
        loss_agg_mode (str, optional):
            Aggregation mode for `agg_loss`. Defaults to "token-mean".

    Returns:
        vf_loss (torch.FloatTensor):
            A scalar tensor containing the aggregated value-function loss.
        vf_clipfrac (float):
            Fraction of elements where the clipped loss was used.
    """
    # 对vpred做clip,防止更新过大,对新价值预测做裁剪
    vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) 
    # clip
    vf_losses1 = (vpreds - returns) ** 2
    vf_losses2 = (vpredclipped - returns) ** 2
    clipped_vf_losses = torch.max(vf_losses1, vf_losses2) 
    # agg loss,乘以0.5是为了消除平方梯度里的2
    vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
    vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
    return vf_loss, vf_clipfrac

Update Actor

RayPPOTrainer.fit()里的actor更新代码

python
 def fit(self):
    # ....
    # ....
    # implement critic warmup
    if self.config.trainer.critic_warmup <= self.global_steps:
        # update actor
        with marked_timer("update_actor", timing_raw, color="red"):
            rollout_config = self.config.actor_rollout_ref.rollout
            batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable  
            # TODO: Make "temperature" single source of truth from generation.
            batch.meta_info["temperature"] = rollout_config.temperature  
            actor_output = self.actor_rollout_wg.update_actor(batch) 
        actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
        metrics.update(actor_output_metrics)

update_actor 主入口 (megatron_worker)

update_actor 主入口 (megatron_worker)
  • 为data设置meta_info
    • micro_batch_size=ppo_micro_batch_size_per_gpu
  • 构建dataloader
    • 根据ppo_mini_batch_sizeppo_epochs建立dataloader非常重要!!
  • 传入dataloader,调用actorupdate_policy方法,更新策略,具体见下文
  • 返回一些metrics

ActorRolloutRefWorker.update_actor

python
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@GPUMemoryLogger(role="update_actor", logger=logger)
@DistProfiler.annotate(color="red")
def update_actor(self, data: DataProto):
    assert self._is_actor
    if self._is_offload_param:
        load_megatron_model_to_gpu(self.actor_module)
        log_gpu_memory_usage("After load actor params and grad during update_actor", logger=logger)
    if self._is_offload_optimizer:
        load_megatron_optimizer(self.actor_optimizer)
        log_gpu_memory_usage("After load actor optimizer during update_actor", logger=logger)

    micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu
    data.meta_info["micro_batch_size"] = micro_batch_size
    dataloader = self.actor.make_minibatch_iterator(data=data)  
    with Timer(name="update_policy", logger=None) as timer:
        metrics = self.actor.update_policy(dataloader=dataloader)  
    delta_time = timer.last
    global_num_tokens = data.meta_info["global_token_num"]
    estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
    metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
    metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3)
    metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3)
    metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
    from verl.utils.megatron.optimizer import get_megatron_last_lr

    metrics["actor/lr"] = get_megatron_last_lr(self.actor_optimizer)
    self.actor_optimizer_scheduler.step(1)

    # TODO: here, we should return all metrics
    output = DataProto(meta_info={"metrics": metrics})
    output = output.to("cpu")

    if self._is_offload_param:
        offload_megatron_model_to_cpu(self.actor_module)
        log_gpu_memory_usage("After offload actor params and grad during update_actor", logger=logger)
    if self._is_offload_optimizer:
        offload_megatron_optimizer(self.actor_optimizer)
        log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger)

    aggressive_empty_cache(force_sync=True)
    return output

make_minibatch_iterator 代码

MegatronPPOActor.make_minibatch_iterator
  • 传入data,调用data.make_iterator,关键参数
    • mini_batch_size=self.config.ppo_mini_batch_size
    • epochs=self.config.ppo_epochs

MegatronPPOActor.make_minibatch_iterator

python
def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
    """Make minibatch iterator for updating the actor
    Args:
      data (DataProto): a DataProto containing keys
        ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where
        ``sequence_length = prompt_length + response_length``
        ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64
        ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64
        ``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that
        responses = input_ids[:, -response_length:]

        ``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability
        of responses.

        ``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of
        responses.
        See PPO paper for details. https://arxiv.org/abs/1707.06347

    Returns:

    """
    select_keys = [
        "responses",
        "input_ids",
        "attention_mask",
        "response_mask",
        "position_ids",
        "old_log_probs",
        "advantages",
    ]
    if self.config.use_kl_loss:
        select_keys.append("ref_log_prob")
    # Include pre-computed IS weights if present in batch
    # Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True
    if "rollout_is_weights" in data.batch.keys():
        select_keys.append("rollout_is_weights")
    self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
    if self.has_multi_modal_inputs:
        data = data.select(select_keys, ["multi_modal_inputs"])
    else:
        data = data.select(batch_keys=select_keys)
    return data.make_iterator(
        mini_batch_size=self.config.ppo_mini_batch_size, 
        epochs=self.config.ppo_epochs, 
        seed=self.config.data_loader_seed,
        dataloader_kwargs={"shuffle": self.config.shuffle},
    )

update_policy (MegatronPPOActor)

MegatronPPOActor.update_policy

MegatronPPOActor.update_policy
  • 当前dataloader信息,对batch根据下面2个参数构建dataloader,具体见上文
    • mini_batch_size = ppo_mini_batch_size
    • epochs= actor_rollout_ref.actor.ppo_epochs
  • 遍历dataloader获取当前data多次计算和更新
    • micro_batch_size=ppo_micro_batch_size_per_gpu
  • 单步更新过程
    • actor_optimizer.zero_grad() ,梯度归零
    • 调用forward_backward_batch计算各种metrics(含loss),具体见上文
    • actor_optimizer.step()更新actor
python
@GPUMemoryLogger(role="megatron actor", logger=logger)
def update_policy(self, dataloader: Iterable[DataProto]) -> dict:
    """Update the policy with an iterator of DataProto

    Args:
        dataloader (Iterable[DataProto]): an iterator over the DataProto that returns by ``make_minibatch_iterator``
            The keys of each data batch is described in the make_minibatch_iterator.

    Returns:
        Dict: a dictionary containing the statistics. Note that the statistics are only valid in the last pp stage
        and users have to combine the output in each dp rank manually.

    """
    metrics = {}
    if self.use_torch_profiler and self.prof and self.prof.enable:
        self.prof.start()
    for data in dataloader: 
        self.actor_optimizer.zero_grad()
        # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
        for chunk in self.actor_module:
            # if use distributed optimizer, zero grad buffer will be handled by optimizer
            chunk.zero_grad_buffer()

        calculate_entropy = self.config.entropy_coeff != 0
        if data.meta_info.get("micro_batch_size", None) is not None:
            micro_batch_size = data.meta_info["micro_batch_size"]
        else:
            micro_batch_size = self.config.ppo_micro_batch_size_per_gpu
        max_token_len = None
        if self.config.use_dynamic_bsz:
            max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size 
        # 调用forward_backward_batch,具体内容见上文
        metric_micro_batch = self.forward_backward_batch(  
            data, 
            calculate_entropy=calculate_entropy, 
            use_dynamic_bsz=self.config.use_dynamic_bsz, 
            micro_batch_size=micro_batch_size, 
            max_token_len=max_token_len, 
            mini_batch_size=self.config.ppo_mini_batch_size, 
        ) 
        # 统计loss等其他指标
        metric_micro_batch = metric_micro_batch["output"]  
        for metric in metric_micro_batch:
            # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask 
            append_to_dict(metrics, metric[0])  # append the metric from this micro-batch to global metrics.
        # 更新模型
        update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step() 
        data = {"actor/grad_norm": grad_norm}
        append_to_dict(metrics, data)

        if update_successful:
            # allgather already execute in optimizer.step in new megatron
            pass
        else:
            raise NotImplementedError
        if self.use_torch_profiler and self.prof and self.prof.enable:
            self.prof.step()
    # add empty cache after each compute
    if self.use_torch_profiler and self.prof and self.prof.enable:
        self.prof.stop_and_save()
        self.prof.stop_trace()
    get_torch_device().empty_cache()
    return metrics

Load 相关

加载 Reward_manager

load_reward_manager

python
def load_reward_manager(
    config: DictConfig, tokenizer: Any, num_examine: int, **reward_kwargs: Any
) -> AbstractRewardManager:
    """
    Load and initialize a reward manager based on the configuration.

    Args:
        config: PPO trainer configuration object containing reward_model fields.
        tokenizer: Tokenizer object used for processing text.
        num_examine: Number of samples to examine.
        **reward_kwargs: Additional keyword arguments for the reward manager.

    Returns:
        An instance of the specified reward manager class.
    """

    # Try to get a custom reward function based on the configuration
    # user defined reward manager can be registered in custom_reward_fn
    compute_score = get_custom_reward_fn(config)
    final_compute_score = compute_score

    # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
    # naive: NaiveRewardManager
    # prime: PrimeRewardManager
    # batch: BatchRewardManager
    # dapo: DAPORewardManager
    # Note(haibin.lin): For custom reward managers, please make sure they are imported and
    # registered via `verl.workers.reward_manager.register`
    # By default reward_manager is set to naive (NaiveRewardManager)
    reward_manager_name = config.reward_model.get("reward_manager", "naive")
    reward_manager_cls = get_reward_manager_cls(reward_manager_name)

    if compute_score is None: 
        sandbox_config = config.reward_model.get("sandbox_fusion") 
        sandbox_url = sandbox_config.get("url") if sandbox_config else None
        memory_limit_mb = sandbox_config.get("memory_limit_mb", 1024) if sandbox_config else 1024
        if sandbox_url:
            sandbox_manager = multiprocessing.Manager()
            # Create a semaphore to control concurrent access to the sandbox
            _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64))
            final_compute_score = partial(
                default_compute_score,
                sandbox_fusion_url=sandbox_url,
                concurrent_semaphore=_concurrent_semaphore,
                memory_limit_mb=memory_limit_mb,
            )
        else:
            final_compute_score = default_compute_score

    # Instantiate and return the reward manager with the specified parameters
    # RewardLoopManagerBase subclasses (like RateLimitedRewardLoopManager) don't accept num_examine
    # while AbstractRewardManager subclasses (like NaiveRewardManager) do
    if RewardLoopManagerBase is not None and issubclass(reward_manager_cls, RewardLoopManagerBase):
        # RewardLoopManagerBase-based manager`s use a different signature
        return reward_manager_cls(
            config=config,
            tokenizer=tokenizer,
            compute_score=final_compute_score,
            **reward_kwargs,
        )
    else:
        # Traditional AbstractRewardManager-based managers
        return reward_manager_cls(
            tokenizer=tokenizer,
            num_examine=num_examine,
            compute_score=final_compute_score,
            reward_fn_key=config.data.reward_fn_key,
            **reward_kwargs,
        )

NaiveRewardManager

python
from collections import defaultdict
from typing import Any

import torch

from verl import DataProto
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager import register
from verl.workers.reward_manager.abstract import AbstractRewardManager


@register("naive")
class NaiveRewardManager(AbstractRewardManager):
    """The reward manager."""

    def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None:
        """
        Initialize the NaiveRewardManager instance.

        Args:
            tokenizer: The tokenizer used to decode token IDs into text.
            num_examine: The number of batches of decoded responses to print to the console for debugging purpose.
            compute_score: A function to compute the reward score. If None, `default_compute_score` will be used.
            reward_fn_key: The key used to access the data source in the non-tensor batch data. Defaults to
                "data_source".
        """
        self.tokenizer = tokenizer  # Store the tokenizer for decoding token IDs
        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console
        self.compute_score = compute_score or default_compute_score
        self.reward_fn_key = reward_fn_key  # Store the key for accessing the data source

    def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
        """We will expand this function gradually based on the available datasets"""

        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
        if "rm_scores" in data.batch.keys():  
            if return_dict:
                reward_extra_keys = data.meta_info.get("reward_extra_keys", [])
                reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}
                return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info}
            else:
                return data.batch["rm_scores"] 

        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
        reward_extra_info = defaultdict(list)

        already_print_data_sources = {}

        for i in range(len(data)):
            data_item = data[i]  # DataProtoItem

            prompt_ids = data_item.batch["prompts"]

            prompt_length = prompt_ids.shape[-1]

            valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
            valid_prompt_ids = prompt_ids[-valid_prompt_length:]

            response_ids = data_item.batch["responses"]
            valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
            valid_response_ids = response_ids[:valid_response_length]

            # decode
            prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)

            ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"]
            data_source = data_item.non_tensor_batch[self.reward_fn_key]
            extra_info = data_item.non_tensor_batch.get("extra_info", {})
            num_turns = data_item.non_tensor_batch.get("__num_turns__", None)
            rollout_reward_scores = data_item.non_tensor_batch.get("reward_scores", {})
            extra_info["num_turns"] = num_turns
            extra_info["rollout_reward_scores"] = rollout_reward_scores
            # 调用reward_fn计算score
            score = self.compute_score( 
                data_source=data_source, 
                solution_str=response_str, 
                ground_truth=ground_truth, 
                extra_info=extra_info, 
            ) 

            if isinstance(score, dict):  
              	# 如果是dict,可以有多个指标
                reward = score["score"]  
                # Store the information including original reward
                for key, value in score.items(): 
                    reward_extra_info[key].append(value) 
            else:
              	# 直接返回单个score
                reward = score 

            reward_tensor[i, valid_response_length - 1] = reward

            if data_source not in already_print_data_sources:
                already_print_data_sources[data_source] = 0

            if already_print_data_sources[data_source] < self.num_examine:
                already_print_data_sources[data_source] += 1
                print("[prompt]", prompt_str)
                print("[response]", response_str)
                print("[ground_truth]", ground_truth)
                if isinstance(score, dict):
                    for key, value in score.items():
                        print(f"[{key}]", value)
                else:
                    print("[score]", score)

        if return_dict:
            return {
                "reward_tensor": reward_tensor,
                "reward_extra_info": reward_extra_info,
            }
        else:
            return reward_tensor

get_custom_reward_fn

自定义的reward_fn.py

python
def get_custom_reward_fn(config: DictConfig) -> Optional[RawRewardFn]:
    """Load and return a custom reward function from external file.

    Dynamically imports a reward function from a specified file path and wraps
    it with additional keyword arguments from the configuration.

    Args:
        config (dict): Configuration dictionary containing custom_reward_function
                      settings with 'path', 'name', and 'reward_kwargs' fields.

    Returns:
        callable or None: Wrapped reward function with merged kwargs, or None
                         if no custom reward function is configured.

    Raises:
        FileNotFoundError: If the specified reward function file doesn't exist.
        RuntimeError: If there's an error loading the module from file.
        AttributeError: If the specified function name isn't found in the module.
    """

    reward_fn_config = config.get("custom_reward_function") or {} 
    file_path = reward_fn_config.get("path") 
    if not file_path:
        return None

    function_name = reward_fn_config.get("name")
    assert function_name is not None

    module = sys.modules.get("custom_module", None)
    if module is None:
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Reward function file '{file_path}' not found.")

        spec = importlib.util.spec_from_file_location("custom_module", file_path)
        assert spec is not None
        module = importlib.util.module_from_spec(spec)
        try:
            sys.modules["custom_module"] = module
            assert spec.loader is not None
            spec.loader.exec_module(module)
        except Exception as e:
            raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e

    if not hasattr(module, function_name):
        raise AttributeError(f"Reward function '{function_name}' not found in '{module.__file__}'.")

    print(f"using customized reward function '{function_name}' from '{module.__file__}'")
    raw_fn = getattr(module, function_name)

    reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {}))

    if not inspect.iscoroutinefunction(raw_fn):
        return partial(_call_with_kwargs, raw_fn, reward_kwargs)
    else:
        return partial(_call_with_kwargs_async, raw_fn, reward_kwargs)

default_compute_score

内置的一些function

python
def default_compute_score(
    data_source,
    solution_str,
    ground_truth,
    extra_info=None,
    sandbox_fusion_url=None,
    concurrent_semaphore=None,
    memory_limit_mb=None,
    **kwargs,
):
    """Compute the score for a given solution based on the data source.

    Args:
        data_source (str): The source dataset identifier which determines the scoring method.
        solution_str (str): The solution string to be evaluated.
        ground_truth (str): The ground truth answer for comparison.
        extra_info (dict, optional): Additional information that might be needed for scoring. Defaults to None.

    Returns:
        float: The computed score as a floating point number. If the result is a dictionary,
               it returns the dictionary instead.

    Raises:
        NotImplementedError: If the reward function is not implemented for the given data source.
    """
    if data_source == "openai/gsm8k":
        from . import gsm8k

        res = gsm8k.compute_score(solution_str, ground_truth)
    elif data_source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval", "HuggingFaceH4/MATH-500"]:
        from . import math_reward

        res = math_reward.compute_score(solution_str, ground_truth)
        # [Optional] Math-Verify Integration
        # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify).
        # Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`.
        # To use it, override the `compute_score` function with the following implementation:

        # from . import math_verify
        # res = math_verify.compute_score(solution_str, ground_truth)
    elif data_source in ["math_dapo", "math", "math_dapo_reasoning"] or data_source.startswith("aime"):
        from . import math_dapo

        res = math_dapo.compute_score(solution_str, ground_truth)
    elif data_source in [
        "numina_aops_forum",
        "numina_synthetic_math",
        "numina_amc_aime",
        "numina_synthetic_amc",
        "numina_cn_k12",
        "numina_olympiads",
    ]:
        from . import prime_math

        res = prime_math.compute_score(solution_str, ground_truth)
    elif data_source in ["codecontests", "apps", "codeforces", "taco"]: 
        # Use the passed sandbox_fusion_url if available
        if sandbox_fusion_url: 
            from . import sandbox_fusion 

            # Pass the URL directly, ground_truth likely contains test cases here
            res = sandbox_fusion.compute_score( 
                sandbox_fusion_url, concurrent_semaphore, memory_limit_mb, solution_str, ground_truth, continuous=True
            ) 
        else:
            # If no sandbox URL is provided, fall back to prime_code or raise error
            from . import prime_code 

            # Assuming prime_code doesn't need the URL
            res = prime_code.compute_score(solution_str, ground_truth, continuous=True) 
    elif data_source in ["hiyouga/geometry3k"]:
        from . import geo3k

        res = geo3k.compute_score(solution_str, ground_truth)
    elif data_source in [
        "searchR1_nq",
        "searchR1_triviaqa",
        "searchR1_popqa",
        "searchR1_hotpotqa",
        "searchR1_2wikimultihopqa",
        "searchR1_musique",
        "searchR1_bamboogle",
    ]:
        from . import search_r1_like_qa_em

        res = search_r1_like_qa_em.compute_score(solution_str, ground_truth)

    else:
        raise NotImplementedError(f"Reward function is not implemented for {data_source=}")

    if isinstance(res, dict):
        return res
    elif isinstance(res, int | float | bool):
        return float(res)
    else:
        return float(res[0])

初始化Workers-主流程 (RayPPOTrainer)

python
def init_workers(self):
    """Initialize distributed training workers using Ray backend.

    Creates:
    1. Ray resource pools from configuration
    2. Worker groups for each role (actor, critic, etc.)
    """
    self.resource_pool_manager.create_resource_pool()

    self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}

    # create actor and rollout
    if self.hybrid_engine:
        resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
        actor_rollout_cls = RayClassWithInitArgs( 
            cls=self.role_worker_mapping[Role.ActorRollout], 
            config=self.config.actor_rollout_ref, 
            role=str(Role.ActorRollout), 
        ) 
        self.resource_pool_to_cls[resource_pool][str(Role.ActorRollout)] = actor_rollout_cls 
    else:
        raise NotImplementedError

    # create critic
    if self.use_critic:
        resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
        critic_cfg = omega_conf_to_dataclass(self.config.critic) 
        critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) 
        self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls

    # create reference policy if needed
    if self.use_reference_policy:
        resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
        ref_policy_cls = RayClassWithInitArgs( 
            self.role_worker_mapping[Role.RefPolicy], 
            config=self.config.actor_rollout_ref, 
            role=str(Role.RefPolicy), 
        ) 
        self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls 

    # create a reward model if reward_fn is None
    if self.use_rm:
        # we create a RM here
        resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
        rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) 
        self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls 

    # initialize WorkerGroup
    # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
    # you should not use `create_colocated_worker_cls`.
    # Instead, directly pass different resource pool to different worker groups.
    # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
    all_wg = {}
    wg_kwargs = {}  # Setting up kwargs for RayWorkerGroup
    if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
        wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
    if OmegaConf.select(self.config.global_profiler, "steps") is not None:
        wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps")
        # Only require nsight worker options when tool is nsys
        if OmegaConf.select(self.config.global_profiler, "tool") == "nsys":
            assert (
                OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
                is not None
            ), "worker_nsight_options must be set when using nsys with profile_steps"
            wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
                OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
            )
    wg_kwargs["device_name"] = self.device_name

    for resource_pool, class_dict in self.resource_pool_to_cls.items():
        worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
        wg_dict = self.ray_worker_group_cls(
            resource_pool=resource_pool,
            ray_cls_with_init=worker_dict_cls,
            **wg_kwargs,
        )
        spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
        all_wg.update(spawn_wg)

    if self.use_critic: 
        self.critic_wg = all_wg[str(Role.Critic)] 
        self.critic_wg.init_model() 

    if self.use_reference_policy and not self.ref_in_actor: 
        self.ref_policy_wg = all_wg[str(Role.RefPolicy)] 
        self.ref_policy_wg.init_model() 

    self.rm_wg = None
    # initalization of rm_wg will be deprecated in the future
    if self.use_rm: 
        self.rm_wg = all_wg[str(Role.RewardModel)] 
        self.rm_wg.init_model() 

    # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
    self.actor_rollout_wg = all_wg[str(Role.ActorRollout)] 
    self.actor_rollout_wg.init_model() 

    # create async rollout manager and request scheduler
    self.async_rollout_mode = False
    if self.config.actor_rollout_ref.rollout.mode == "async":
        from verl.experimental.agent_loop import AgentLoopManager

        self.async_rollout_mode = True
        self.async_rollout_manager = AgentLoopManager( 
            config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg 
        ) 

初始化 Megatron Works

MegatronActorRolloutRefWorker. 构造函数

python
class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
    """
    This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
    or a hybrid engine based on the config.rollout
    """

    def __init__(self, config: DictConfig, role: str, **kwargs):
        Worker.__init__(self)
        self.config = config
        if repatch is not None:
            # NPU MindSpeed patch, will be refactored with MindSpeedEngine.
            repatch(self.config.actor.megatron.get("override_transformer_config", {}))

        # NOTE(sgm): We utilize colocate WorkerGroup by default.
        # As a result, Workers for different model share the same process.
        # Therefore, we only require one distribute initialization.
        # To utilize different parallel strategy in different models:
        # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,
        # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385
        if not torch.distributed.is_initialized():
            set_numa_affinity()
            rank = int(os.environ["LOCAL_RANK"])
            torch.distributed.init_process_group(
                backend=get_nccl_backend(),
                timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)),
                init_method=os.environ.get("DIST_INIT_METHOD", None),
            )
            get_torch_device().set_device(rank)

            mpu.initialize_model_parallel( 
                tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size, 
                pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size, 
                virtual_pipeline_model_parallel_size=self.config.actor.megatron.virtual_pipeline_model_parallel_size, 
                use_sharp=False, 
                context_parallel_size=self.config.actor.megatron.context_parallel_size, 
                expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size, 
                expert_tensor_parallel_size=self.config.actor.megatron.expert_tensor_parallel_size, 
                nccl_communicator_config_path=None, 
            ) 

        is_collect = (
            mpu.get_tensor_model_parallel_rank() == 0
            and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1
            and mpu.get_context_parallel_rank() == 0
        )
        self._register_dispatch_collect_info(
            mesh_name="actor", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect
        )

        set_random_seed(seed=self.config.actor.megatron.seed)

        self.role = role
        assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"]

        self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] 
        self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] 
        self._is_ref = self.role in ["ref", "actor_rollout_ref"] 

        if self._is_actor:
            omega_profiler_config = config.actor.get("profiler", {})
        elif self._is_rollout:
            # NOTE: In colocation mode, rollout config may not take effect (follow the actor config)
            # This is for extendability in AsyncRL cases
            omega_profiler_config = config.rollout.get("profiler", {})
        elif self._is_ref:
            omega_profiler_config = config.ref.get("profiler", {})
        else:
            raise ValueError(
                f"Invalid role {self.role}, should be one of "
                "['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']"
            )
        # omega_profiler_config is DictConfig
        # profiler_config is a ProfilerConfig dataclass
        profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
        if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
            tool_config = omega_conf_to_dataclass(
                omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
            )
        else:
            tool_config = None
        DistProfilerExtension.__init__(
            self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config)
        )

        # TODO(sgm): Currently, we only support reference model param offload
        # will support other offload later
        self._is_offload_param = False
        self._is_offload_grad = False
        self._is_offload_optimizer = False

        # normalize config
        if self._is_actor and self._is_rollout:
            self.config.actor.ppo_mini_batch_size *= self.config.rollout.n 
            self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size() 
            if self.config.actor.get("ppo_micro_batch_size", None):
                self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()
                self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()
                self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size 
                self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size 

            self._is_offload_param = self.config.actor.megatron.get("param_offload", False)
            self._is_offload_grad = self.config.actor.megatron.get("grad_offload", False)
            self._is_offload_optimizer = self.config.actor.megatron.get("optimizer_offload", False)
        elif self._is_ref:
            if self.config.ref.get("log_prob_micro_batch_size", None):
                self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()
                self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size
            else:
                assert self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) is not None, (
                    "Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and "
                    "`log_prob_micro_batch_size` should not be None at the same time."
                )
            self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False)

MegatronActorRolloutRefWorker.初始化模型

python
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
    if self.config.model.get("external_lib", None) is not None:
        # This is used to import external_lib into the huggingface systems
        import importlib

        importlib.import_module(self.config.model.external_lib)

    from verl.utils.torch_dtypes import PrecisionType

    override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
    if self._is_actor:
        override_transformer_config = OmegaConf.to_container(
            OmegaConf.create(self.config.actor.megatron.get("override_transformer_config", {}))
        )
        override_ddp_config = OmegaConf.to_container(
            OmegaConf.create(self.config.actor.megatron.get("override_ddp_config", {}))
        )
    elif self._is_ref:
        override_transformer_config = OmegaConf.to_container(
            OmegaConf.create(self.config.ref.megatron.get("override_transformer_config", {}))
        )
    else:
        override_transformer_config = {}
    self.param_dtype = PrecisionType.to_dtype(self.config.actor.megatron.dtype)
    log_gpu_memory_usage("Before init actor model and optimizer", logger=logger)
    self.dtype = PrecisionType.to_dtype(self.param_dtype)
    if self._is_actor:
        # we need the model for actor and rollout
        optim_config = self.config.actor.optim if self._is_actor else None
        (
            self.actor_module,
            self.actor_optimizer,
            self.actor_optimizer_scheduler,
            self.actor_model_config,
            self.actor_optim_config,
        ) = self._build_model_optimizer( 
            model_path=self.config.model.path, 
            optim_config=optim_config, 
            override_model_config=override_model_config, 
            override_transformer_config=override_transformer_config, 
            override_ddp_config=override_ddp_config, 
        ) 
        if self._is_offload_param:
            offload_megatron_model_to_cpu(self.actor_module)
            log_gpu_memory_usage("After offload actor params and grad during init", logger=logger)
        if self._is_offload_optimizer:
            offload_megatron_optimizer(self.actor_optimizer)
            log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)

    if self._is_actor: 
        actor_cfg = omega_conf_to_dataclass(self.config.actor) 
        self.actor = MegatronPPOActor( 
            config=actor_cfg, 
            model_config=self.actor_model_config, 
            hf_config=self.hf_config, 
            tf_config=self.tf_config, 
            actor_module=self.actor_module, 
            actor_optimizer=self.actor_optimizer, 
        ) 
        log_gpu_memory_usage("After MegatronPPOActor init", logger=logger)

    if self._is_rollout: 
        self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) 
        log_gpu_memory_usage("After rollout init", logger=logger)

    if self._is_ref: 
        self.ref_module, self.ref_model_config = self._build_model_optimizer(
            model_path=self.config.model.path,
            optim_config=None,
            override_model_config=override_model_config,
            override_transformer_config=override_transformer_config,
        )
        log_gpu_memory_usage("After ref model init", logger=logger)
        self.ref_policy = MegatronPPOActor( 
            config=self.config.ref, 
            model_config=self.ref_model_config, 
            hf_config=self.hf_config, 
            tf_config=self.tf_config, 
            actor_module=self.ref_module, 
            actor_optimizer=None, 
        ) 
        if self._ref_is_offload_param:
            offload_megatron_model_to_cpu(self.ref_module)
            log_gpu_memory_usage("After offload ref params during init", logger=logger)

    if self._is_actor:
        self.flops_counter = FlopsCounter(self.actor_model_config)
        self.checkpoint_mananager = MegatronCheckpointManager(
            config=self.config,
            checkpoint_config=self.config.actor.checkpoint,
            model_config=self.actor_model_config,
            transformer_config=self.tf_config,
            role="actor",
            model=self.actor_module,
            arch=self.architectures[0],
            hf_config=self.hf_config,
            param_dtype=self.param_dtype,
            share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
            processing_class=self.processor if self.processor is not None else self.tokenizer,
            optimizer=self.actor_optimizer,
            optimizer_scheduler=self.actor_optimizer_scheduler,
            use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,
            use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler,
            bridge=self.bridge,
            use_dist_checkpointing=self.config.actor.megatron.use_dist_checkpointing,
        )

        self.layer_name_mapping = {
            "qkv_layer_name": "self_attention.linear_qkv.",
            "gate_proj_layer_name": "linear_fc1.",
        }
        self.weight_converter = None
        if not self.config.actor.megatron.use_mbridge:
            self.weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)

    get_torch_device().empty_cache()
    log_gpu_memory_usage("After init_model finish", logger=logger)

MegatronPPOActor 初始化

python
class MegatronPPOActor(BasePPOActor):
    def __init__(
        self,
        config,
        model_config,
        hf_config,
        tf_config,
        actor_module: nn.ModuleList,
        actor_optimizer: DistributedOptimizer,
    ):
        """MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron.

        Args:
            config (OmegaConf): the basic config that contains the hyper-parameters of PPO Actor. It must contain

                ``ppo_micro_batch_size_per_gpu``: micro batch size when updating ppo.

                ``ppo_mini_batch_size``: minibatch size when updating ppo using the batch data.

                ``ppo_epochs``: number of epochs to update the actor using the batch data.

                ``shuffle``: whether to shuffle the data after each ppo epoch.

                ``clip_ratio``: clip ratio of the ppo algorithm. See https://arxiv.org/abs/1707.06347.

                ``entropy_coeff``: entropy coefficient of the PPO loss. See https://arxiv.org/abs/1707.06347.
            model_config (OmegaConf): model configuration. It must contains ``model_config.vocab_size`` and
                ``model_config.hidden_size``
            hf_config (PretrainedConfig): huggingface config
            tf_config (TransformerConfig): mcore transformer config
            actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this
                pp stage.
                each nn.Module in this rank holds a vpp module chunk. See https://arxiv.org/pdf/2104.04473.pdf for
                more details.
                The actor module has some constraints to follow in order to use the updating logics implemented here

                1. It must implement unpad_input before any computation and pad_input after all the computation.
                Remove padding is an
                optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn
                (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py).

                2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size],
                where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size
                of the hidden state is [total_nnz // tp, 1, hidden_size].
            actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron.
                It implements
                zero1 optimizer that shards the optimizer state across dp ranks.

        >>> from megatron.training import get_model
        >>> from megatron.optimizer import get_megatron_optimizer
        >>> actor_module = get_model(megatron_actor_model_provider, wrap_with_ddp=True)
        >>> actor_module = nn.ModuleList(actor_module)
        >>> actor_optimizer = get_megatron_optimizer(actor_module)
        >>> actor = MegatronPPOActor(config=config,
        >>>                          model_config=actor_model_config,
        >>>                          hf_config=hf_config,
        >>>                          tf_config=tf_config,
        >>>                          actor_module=actor_module,
        >>>                          actor_optimizer=actor_optimizer)
        """
        super().__init__(config)
        self._validate_config(config)
        self.model_config = model_config
        self.hf_config = hf_config
        self.tf_config = tf_config
        self.actor_module = actor_module
        self.actor_optimizer: DistributedOptimizer = actor_optimizer
        self.use_torch_profiler = self.config.profiler.get("tool") == "torch"
        if self.use_torch_profiler:
            self.prof = Profiler(
                self.config.profiler, tool_config=self.config.profiler.get("tool_config", {}).get("torch", {})
            )
        else:
            self.prof = None
        self.use_fused_kernels = self.config.get("use_fused_kernels", False)
        if self.use_fused_kernels and not getattr(self.config, "overlap_moe_expert_parallel_comm", False): 
            # do not patch if overlap_moe_expert_parallel_comm is enabled
            from verl.models.mcore.model_forward_fused import patch_fused_forward 
            for model in self.actor_module: 
                patch_fused_forward(model) 

        self.optimizer_step_args = OmegaConf.create(
            {
                "skip_grad": None,
                "overlap_dp_param_comm": False,
                "overlap_dp_grad_comm": False,
                "gradient_accumulation_steps": 1,
                "sequence_parallel": self.tf_config.sequence_parallel,
                "DDP_impl": "local",
                "layernorm_allreduce_bucket_threshold": 0,
                "reduce_grads_use_alltoall": False,
            }
        )

        config = get_model_config(self.actor_module[0])
        print(config)
        config.finalize_model_grads_func = finalize_model_grads

其中,BasePPOActor接口如下,核心是compute_log_probupdate_policy

python
class BasePPOActor(ABC):
    def __init__(self, config):
        """The base class for PPO actor

        Args:
            config (DictConfig): a config passed to the PPOActor. We expect the type to be
                DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general.
        """
        super().__init__()
        self.config = config

    @abstractmethod
    def compute_log_prob(self, data: DataProto) -> torch.Tensor:
        """Compute logits given a batch of data.

        Args:
            data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```,
                ```attention_mask``` and ```position_ids```.

        Returns:
            DataProto: a DataProto containing the key ```log_probs```


        """
        pass

    @abstractmethod
    def update_policy(self, data: DataProto) -> dict:
        """Update the policy with an iterator of DataProto

        Args:
            data (DataProto): an iterator over the DataProto that returns by
                ```make_minibatch_iterator```

        Returns:
            Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model
            such as ```loss```, ```grad_norm```, etc,.

        """
        pass

加载Megatron TransformerConfig

初始化hf和tfconfig (MegatronWorker)

python
class MegatronWorker(Worker):
    def _init_hf_config_and_tf_config(
        self,
        model_path,
        tokenizer_or_path,
        dtype,
        override_model_config,
        override_transformer_config,
        trust_remote_code=False,
        use_mbridge=False,
    ):
        from transformers import AutoConfig

        from verl.models.mcore import hf_to_mcore_config
        from verl.utils import hf_processor, hf_tokenizer
        from verl.utils.fs import copy_to_local
        from verl.utils.model import update_model_config

        # Step 1: initialize the tokenizer
        self.local_path = copy_to_local(model_path)
        if tokenizer_or_path is None:
            self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code)
            self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code)
        elif isinstance(tokenizer_or_path, str):
            self.tokenizer = hf_tokenizer(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code)
            self.processor = hf_processor(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code)
        else:
            self.tokenizer = tokenizer_or_path
            self.processor = tokenizer_or_path

        if self.config.model.get("custom_chat_template", None) is not None:
            if self.processor is not None:
                self.processor.chat_template = self.config.model.custom_chat_template
            else:
                self.tokenizer.chat_template = self.config.model.custom_chat_template

        # Step 2: get the hf
        hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code)  

        # Step 3: override the hf config
        override_config_kwargs = {
            "bos_token_id": self.tokenizer.bos_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
            "pad_token_id": self.tokenizer.pad_token_id,
        }
        override_config_kwargs.update(override_model_config.get("model_config", {}))
        self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False)
        update_model_config(hf_config, override_config_kwargs=override_config_kwargs)
        self.architectures = getattr(hf_config, "architectures", None)
        if self.rank == 0:
            print(f"Model config after override: {hf_config}")

        from verl.models.mcore.config_converter import mapping_string_to_attn_backend

        # todo: remove this line after mcore adopt mbridge 0.15, now for compatibility
        override_transformer_config = mapping_string_to_attn_backend(override_transformer_config)

        if use_mbridge:
            from verl.models.mcore.mbridge import AutoBridge

            bridge = AutoBridge.from_config(hf_config)
            bridge.set_extra_args(**override_transformer_config)
            tf_config = bridge.config
            self.bridge = bridge
        else:
            tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config) 
            self.bridge = None

        print(f"TF config: {tf_config}")
        self.hf_config = hf_config
        self.tf_config = tf_config

转换成TransformerConfig

小心
  • 本质是把HFConfig转换成Megatron的TransformerConfig

支持的模型

python
# Registry for model configuration converters
MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
    SupportedModel.LLAMA: hf_to_mcore_config_dense,
    SupportedModel.QWEN2: hf_to_mcore_config_dense,
    SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe,
    SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3,
    SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral,
    SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,
    SupportedModel.LLAMA4: hf_to_mcore_config_llama4,
    SupportedModel.QWEN3: hf_to_mcore_config_dense,
    SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,
    SupportedModel.QWEN3_TOKEN_CLASSIFICATION: hf_to_mcore_config_dense,
}

Dense示例

python
def hf_to_mcore_config_dense(
    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs
) -> TransformerConfig:
    # for LlamaForCausalLM or Qwen2ForCausalLM
    qkv_bias = True if "Qwen2" in hf_config.architectures[0] else getattr(hf_config, "attention_bias", False)
    qk_layernorm = True if "Qwen3" in hf_config.architectures[0] else False

    args: dict = _get_base_transformer_config(
        hf_config=hf_config,
        dtype=dtype,
        use_cpu_initialization=False,
        add_bias_linear=False,
        add_qkv_bias=qkv_bias,
        qk_layernorm=qk_layernorm,
    )
    # override_transformer_config_kwargs as kwargs shall never be none
    args.update(override_transformer_config_kwargs)
    return check_and_construct_configs(args, TransformerConfig)

get_base_transformer_config

python
def _get_base_transformer_config(
    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs
) -> dict:
    """
    Create a base TransformerConfig with common parameters across different model architectures.
    TODO: (ycl) use dataclass or converter config?

    Args:
        hf_config: HuggingFace model configuration
        dtype: Data type for the model
        override_transformer_config_kwargs: Additional parameters to override defaults

    Returns:
        TransformerConfig with common parameters
    """

    # Common parallel state parameters
    overlap_p2p_comm = (
        mpu.get_virtual_pipeline_model_parallel_world_size() is not None
        and mpu.get_virtual_pipeline_model_parallel_world_size() > 1
    )
    batch_p2p_comm = False

    # Base configuration with common parameters
    base_config = {
        # Model architecture parameters
        "num_layers": hf_config.num_hidden_layers,
        "hidden_size": hf_config.hidden_size,
        "num_attention_heads": hf_config.num_attention_heads,
        "num_query_groups": hf_config.num_key_value_heads,
        "ffn_hidden_size": hf_config.intermediate_size,
        "attention_dropout": hf_config.attention_dropout,
        "hidden_dropout": getattr(hf_config, "hidden_dropout", 0.0),
        "kv_channels": getattr(hf_config, "head_dim", None),
        "layernorm_epsilon": hf_config.rms_norm_eps,
        "add_bias_linear": True,
        # Activation and normalization
        "activation_func": F.silu,
        "normalization": "RMSNorm",
        "gated_linear_unit": True,
        # Data types
        "pipeline_dtype": dtype,
        "params_dtype": dtype,
        "bf16": dtype is torch.bfloat16,
        # Parallel configuration
        "tensor_model_parallel_size": mpu.get_tensor_model_parallel_world_size(),
        "pipeline_model_parallel_size": mpu.get_pipeline_model_parallel_world_size(),
        "expert_model_parallel_size": mpu.get_expert_model_parallel_world_size(),
        "expert_tensor_parallel_size": mpu.get_expert_tensor_parallel_world_size(),
        "virtual_pipeline_model_parallel_size": mpu.get_virtual_pipeline_model_parallel_world_size(),
        "context_parallel_size": mpu.get_context_parallel_world_size(),
        "overlap_p2p_comm": overlap_p2p_comm,
        "batch_p2p_comm": batch_p2p_comm,
        "sequence_parallel": mpu.get_tensor_model_parallel_world_size() > 1,
        # Common settings
        "variable_seq_lengths": True,
        "masked_softmax_fusion": True,
        "moe_token_dispatcher_type": "alltoall",
    }

    # Update with any provided overrides
    # override_transformer_config_kwargs as kwargs shall never be none
    base_config.update(override_transformer_config_kwargs)

    return base_config

TransformerConfig 源码

这个有点长,本打算截一下,但有的参数,可能未来用得到,就懒得截取了,方便查看。

并行/特性核心互斥/限制原因简述
TP (张量)头数必须整除 TP物理上无法将半个 Head 分给 GPU
PP (流水线)不支持 CPU Offloading复杂的调度逻辑难以兼容 CPU 换入换出
PP (流水线)Layout 与 num_layers_first 互斥两种定义切分的方式冲突
VPP (虚拟流水线)层数必须整除 VPP每个虚拟 Chunk 必须包含完整的层
MoE (专家)EP>1 时必须定义专家数逻辑完备性检查
CPU Offload不支持 Recompute, PP, CudaGraph内存管理机制冲突
CudaGraph不支持 Recompute静态图捕获难以处理重计算的动态性

并行参数

TP 参数 (Tensor Parrallel)

关键参数

  • tensor_model_parallel_size, num_attention_heads,
  • num_query_groups (GQA组数),sequence_parallel

约束条件

  • num_attention_heads % tensor_model_parallel_size == 0
  • num_query_groups % tensor_model_parallel_size == 0
  • 如果使用sequence_parallel,distribute_saved_activations必须设为False
python
if self.num_attention_heads % self.tensor_model_parallel_size != 0: 
    raise ValueError(
        f"num_attention_heads ({self.num_attention_heads}) must be a multiple of "
        f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
    )

if self.num_query_groups is None:
    self.num_query_groups = self.num_attention_heads

if self.num_query_groups % self.tensor_model_parallel_size != 0:
    raise ValueError(
        f"num_query_groups ({self.num_query_groups}) must be a multiple of "
        f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
    )
PP 参数 (Pipeline Parrallel)

相关参数

  • pipeline_model_parallel_size:传统PP并行度
  • virtual_pipeline_model_parallel_size:虚拟PP并行度,用于Interleaved调度,减少气泡

约束条件

  • num_layers % pipeline_parallel_size == 0
  • 首尾Stage、中间Stage层数,需要被VPP大小整除
  • cpu_offloading 不支持流水线并行,pipeline_model_parallel_size必须为1.
python
elif (self.num_layers_in_first_pipeline_stage is not None
    or self.num_layers_in_last_pipeline_stage is not None
):
    pipeline_parallel_size = self.pipeline_model_parallel_size
    num_layers = self.num_layers

    if self.num_layers_in_first_pipeline_stage is not None:
        if self.num_layers_in_first_pipeline_stage <= 0:
            raise ValueError("num_layers_in_first_pipeline_stage must be larger than 0")

        if self.virtual_pipeline_model_parallel_size is not None:
            if (
                self.num_layers_in_first_pipeline_stage
                % self.virtual_pipeline_model_parallel_size
                != 0
            ):
                raise ValueError(
                    f"number of layers at first stage: "
                    f"{self.num_layers_in_first_pipeline_stage}"
                    f"must be divisible by virtual pipeline"
                    f"parallel degree {self.virtual_pipeline_model_parallel_size}"
                )
        num_layers -= self.num_layers_in_first_pipeline_stage
        pipeline_parallel_size -= 1

    if self.num_layers_in_last_pipeline_stage is not None:
        if self.num_layers_in_last_pipeline_stage <= 0:
            raise ValueError("num_layers_in_last_pipeline_stage must be larger than 0")

        if self.virtual_pipeline_model_parallel_size is not None:
            if (
                self.num_layers_in_last_pipeline_stage
                % self.virtual_pipeline_model_parallel_size
                != 0
            ):
                raise ValueError(
                    f"number of layers at last stage: "
                    f"{self.num_layers_in_last_pipeline_stage}"
                    f"must be divisible by virtual pipeline"
                    f"parallel degree {self.virtual_pipeline_model_parallel_size}"
                )
        num_layers -= self.num_layers_in_last_pipeline_stage
        pipeline_parallel_size -= 1

    # Here pipeline_parallel_size is the number of middle PP stages. If there are middle
    # PP stages, check number of layers at middle stage is divisible by middle PP size.
    if pipeline_parallel_size and not num_layers % pipeline_parallel_size == 0:
        raise ValueError(
            f"number of layers at middle stage: {num_layers} must be divisible by"
            f"the middle pipeline model parallel size {pipeline_parallel_size}"
        )

    # If there are middle PP stages, check number of layers
    # on each middle PP rank is divisible by VPP size.
    if pipeline_parallel_size and self.virtual_pipeline_model_parallel_size is not None:
        num_layers_per_middle_pipeline_rank = num_layers // pipeline_parallel_size
        if (
            not num_layers_per_middle_pipeline_rank
            % self.virtual_pipeline_model_parallel_size
            == 0
        ):
            raise ValueError(
                f"number of layers on each middle pipeline rank:"
                f"{num_layers_per_middle_pipeline_rank} must be divisible by virtual"
                f"pipeline parallel degree {self.virtual_pipeline_model_parallel_size}"
            )
Eexpert Parallelism EP 专家并行

相关参数

  • expert_model_parallel_size:EP并行度
  • num_moe_experts:专家总数
  • moe_router_num_groups、moe_router_group_topk:组限制路由配置
  • moe_token_dispatcher_type:token分发方式,allgather、alltoall、flex

约束条件

  • expert_model_parallel_size > 1时,必须设置num_moe_experts
  • num_moe_experts % moe_router_num_groups == 0
Context Parallelism CP 上下文并行

相关参数

  • context_parallel_size:CP并行度
  • cp_comm_type:通信类型,取值p2p, all_gather, a2a

约束条件

  • 如果cp_comm_type是列表,即为每层指定不同通信方式,列表长度必须等于num_layers.

显存关键参数

显存关键参数

重计算

  • 为了节省显存,在反向时重计算模型的部分,用计算时间换显存空间
  • selective:选择显存占用大、但计算小的部分。
    • 一般是core_attn。Attention产生的激活值显存占用高,但重计算相对代价小
    • 可选mlp,moe,layernorm等。

配置示例

  • recompute_granularity = selective
  • recompute_modules = '["core_attn", "mlp"]'
  • actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity="${recompute_granularity}"
  • actor_rollout_ref.actor.megatron.override_transformer_config.recompute_modules="${recompute_modules}"

TransformerConfig post_init 源码

python
@dataclass
class TransformerConfig(ModelParallelConfig):
  """Configuration object for megatron-core transformers.

  The initialization function has an argument for each parameter,
  including those in ModelParallelConfig.
  """

  ####################
  # model architecture
  ####################

  num_layers: int = 0
  """Number of transformer layers in a transformer block."""

  mtp_num_layers: Optional[int] = None
  """Number of Multi-Token Prediction (MTP) Layers."""

  mtp_loss_scaling_factor: Optional[float] = None
  """Weighting factor of Multi-Token Prediction (MTP) loss."""

  num_layers_in_first_pipeline_stage: Optional[int] = None
  """Number of transformer layers on first pipeline stage.
  None implies equal layer division across PP ranks."""

  num_layers_in_last_pipeline_stage: Optional[int] = None
  """Number of transformer layers on last pipeline stage.
  None implies equal layer division across PP ranks."""

  pipeline_model_parallel_layout: Optional[Union[str, list, PipelineParallelLayerLayout]] = None
  """Custom definition of the pipeline parallel partitioning.
  Support type:
  - str: e.g., 'Et*3|(tt|)*29,m|L'. Stages are split by '|', replicated stages or layers
  can be described with multiplication. Commas can be used cosmetically.
  - list: e.g., [['embedding', 'decoder'], ['decoder', 'decoder', 'decoder', 'loss']].
  - PipelineParallelLayerLayout: a PipelineParallelLayerLayout object.
  If given either a string or a list, it will be transferred into a PipelineParallelLayerLayout
  in post init. Let i = a * pp_size + b, then layout[i] gives a list of the layers 
  in the a-th vpp stage and the b-th pp stage, i.e., vpp(0)pp(0), vpp(0)pp(1), ..., 
  vpp(i)pp(j), vpp(i)pp(j+1), ..., vpp(-1)pp(-2), vpp(-1)pp(-1).
  In the inner lists of layers, 'embedding' or 'E' denotes the embedding layer, 'loss' or 'L'
  denotes the loss function, and 'decoder' or 't' denotes the transformer decoder layer.
  Examples:
      [['embedding', 'decoder'], ['decoder', 'decoder', 'decoder', 'loss']]:
      pp = 2, vpp = None
      pp rank 0 holds: embedding, decoder
      pp rank 1 holds: decoder*3, loss
      'E|(tt|)*2,(t|)*4,mL':
      pp = 2, vpp = 4
      vpp rank 0 pp rank 0 holds: embedding
      vpp rank 0 pp rank 1~2 holds: decoder*2
      vpp rank 0 pp rank 3 holds: decoder
      vpp rank 1 pp rank 0~2 holds: decoder
      vpp rank 1 pp rank 3 holds: mtp, loss"""

  account_for_embedding_in_pipeline_split: bool = False
  """If set, the embedding layer will be treated as a standard transformer
  layer in the context of partition and placement for pipeline parallelism."""

  account_for_loss_in_pipeline_split: bool = False
  """If set, the loss layer will be treated as a standard transformer
  layer in the context of partition and placement for pipeline parallelism."""

  hidden_size: int = 0
  """Transformer hidden size."""

  num_attention_heads: int = 0
  """Number of transformer attention heads."""

  attention_backend: AttnBackend = AttnBackend.auto
  """Attention backend to run. By default we let transformer engine
  decide the best backend to run (except in the case of local).
  If attention backend is local we use the local pytorch implementation in mcore.
  Users can specify exact backend by changing this config. """

  softmax_scale: Optional[float] = None
  """Softmax scale for attention scaling."""

  num_query_groups: Optional[int] = None
  """Number of query groups for group query attention. If None, normal attention is used."""

  ffn_hidden_size: Optional[int] = None
  """Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size
  if not provided."""

  kv_channels: Optional[int] = None
  """Projection weights dimension in multi-head attention. This is set to hidden_size //
  num_attention_heads if not provided."""

  hidden_dropout: float = 0.1
  """Dropout probability for transformer hidden state."""

  attention_dropout: float = 0.1
  """Post attention dropout probability."""

  fp32_residual_connection: bool = False
  """If true, move residual connections to fp32."""

  # @jcasper should we keep this option?
  apply_residual_connection_post_layernorm: bool = False
  """If True, uses the original BERT residule connection ordering."""

  layernorm_epsilon: float = 1e-5
  """Epsilon value for any LayerNorm operations."""

  layernorm_zero_centered_gamma: bool = False
  """If set to True, the LayerNorm is adjusted to center the gamma values around 0. This improves
  numerical stability."""

  add_bias_linear: bool = True
  """Include a bias term in all linear layers (QKV projections, after core attention, and two in
  MLP layer)."""

  add_qkv_bias: bool = False
  """Add a bias term only for QKV projections."""

  gated_linear_unit: bool = False
  """Use a gated linear unit for the first linear layer in the MLP."""

  activation_func: Callable = F.gelu
  """Activation function to use for the non-linearity in the MLP."""

  activation_func_fp8_input_store: bool = False
  """Store the input of MLP activation function in FP8 for backprop to save memory.
  The stored input is casted back to the original precision before backprop compuatation."""

  num_moe_experts: Optional[int] = None
  """Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None
  for no MoE."""

  rotary_interleaved: bool = False
  """True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of
  first half and second half (LLaMa style). Default to False."""

  window_size: Optional[Tuple[int, int]] = None
  """If not None, then will use sliding window attention. The size of the window is specified by
  the numbers inside the tuple; -1 is special value meaning "infinite window size"."""

  normalization: str = "LayerNorm"
  """Which norm to use for normalization layers, valid options are `LayerNorm` and `RMSNorm`."""

  qk_layernorm: bool = False
  """Whether to apply `normalization` type of normalization to the query and key embeddings."""

  test_mode: bool = False
  """Whether to run real-time tests."""

  calculate_per_token_loss: bool = False
  """Whether cross entropy loss is calculated over the actual number of non-padded tokens in the
  global batch, versus the default behavior of assuming all tokens are non-padded."""

  multi_latent_attention: bool = False
  """Whether to use multi-latent attention."""

  no_rope_freq: Optional[Union[int, List[int]]] = None
  """Controls which layers perform Rotary Position Embedding (RoPE). Accepts either:
  An integer N: Creates a pattern where RoPE is skipped every N-1 layers. For example,
  no_rope=4 means RoPE is applied for 3 layers, then skipped for 1 layer, repeating this pattern.
  A list of integers: Defines a custom pattern where 1 means skip RoPE and 0 means apply RoPE.
  For example, [0,1,1,0] means: apply RoPE, skip RoPE, skip RoPE, apply RoPE."""

  moe_deepep_num_sms: int = 20
  """Number of SMs to use for DeepEP."""

  ####################
  # initialization
  ####################
  init_method: Optional[Callable] = None
  """Method to initialize weights. Note that bias is always set to zero. Should be a function that
  takes a single Tensor and initializes it. If None, will be set to
  megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with
  mean=0.0 and std=init_method_std."""

  output_layer_init_method: Optional[Callable] = None
  """Method to initialize weights of the output layer of both attention and MLP blocks. If None,
  will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn
  init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers)."""

  init_method_std: float = 0.02
  """Standard deviation of the zero mean normal for the default initialization method, not used if
  init_method and output_layer_init_method are provided."""

  init_model_with_meta_device: bool = False
  """
  If True, initializes the model with the meta device. This is helpful for
  training of very large models. This feature is only works when custom fsdp is turned on.
  """

  ####################
  # mixed-precision
  ####################
  apply_query_key_layer_scaling: bool = False
  """If true, scale Q * K^T by 1 / layer-number. This improve numeric stability when training with
  fp16."""

  attention_softmax_in_fp32: bool = True
  """If True, run attention masking and softmax in fp32. This should be True if
  apply_query_key_layer_scaling is True."""

  disable_bf16_reduced_precision_matmul: bool = False
  """If True, sets torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction=False to
  prevent matmul from using reduced precision accumulation when using BF16."""

  ####################
  # fusion
  ####################
  bias_activation_fusion: bool = False
  """If True, fuses bias addition and the activation function when possible."""

  masked_softmax_fusion: bool = False
  """If True, uses softmax fusion."""

  persist_layer_norm: bool = False
  """If True, uses the persistent fused layer norm kernel. This kernel only supports a fixed set
  of hidden sizes."""

  memory_efficient_layer_norm: bool = False
  """If True, and using local layers (not from TransformerEngine), tells Apex to use the memory
  efficient fused LayerNorm kernel. Ignored if not using LayerNorm."""

  bias_dropout_fusion: bool = False  # TODO: this should be bias_dropout_add_fusion?
  """If True, uses bias dropout fusion."""

  apply_rope_fusion: bool = False
  """If True, use fused RoPE kernel."""

  ####################
  # activation recomputation,节省显存占用
  ####################
  recompute_granularity: Optional[str] = None
  """Determines which type of activation recompute to use.  Megatron-core supports 'selective'
  activation checkpointing where the submodules set in --recompute-modules is checkpointed.
  The default is "core_attn" which is the memory intensive part of attention.
  These memory intensive activations are also less compute intensive which makes activation
  checkpointing more efficient for LLMs (20B+).  See Reducing Activation Recomputation in Large
  Transformer Models (https://arxiv.org/abs/2205.05198) for more details.  'full' will checkpoint
  the entire transformer layer.  If None, no recompute is performed and all activations are saved.
  If set, must be 'selective' or 'full'. 'selective' always uses all layers.
  """

  recompute_method: Optional[str] = None
  """Determines which transformer layers will be recomputed. uniform will uniformly divide the
  total number of transformer layers in a transformer block and recompute the input activation of
  each divided chunk at the specified granularity.  block will recompute the input activations for
  only a set number of transformer layers per pipeline stage.  The rest of the layers in the
  pipeline stage will not have any activations recomputed.  If None, and recompute is enabled, all
  layers will do recomputation. If set, must be 'uniform' or 'block'."""

  recompute_num_layers: Optional[int] = None
  """When recompute_method is uniform, recompute_num_layers is the number of transformer layers in
  each uniformly divided recompute unit.  When recompute_method is block, recompute_num_layers is
  the number of transformer layers to recompute within each pipeline stage.  Must be None for
  'selective' activation checkpointing."""

  distribute_saved_activations: Optional[bool] = None
  """If True, distribute recomputed activations across the model parallel group."""

  recompute_modules: Optional[List[str]] = None
  """The submodules to recompute.
  choices: "core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe".
  default: ["core_attn"].
  "core_attn": recompute the core attention part of the transformer layer.
  "moe_act": recompute the MoE MLP activation function.
  "layernorm": recompute the input_layernorm and pre_mlp_layernorm.
  "mla_up_proj": recompute the MLA up projection and RoPE applying parts.
  "mlp": recompute the dense MLP submodule.
  "moe": recompute the MoE layer.
  "moe_act", "layernorm", and "mla_up_proj" use output-discarding checkpointing,
  "core_attn", "mlp", and "moe" uses normal checkpointing.
  """

  ####################
  # fp8 related
  ####################
  fp8: Optional[str] = None
  """If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined
  choices (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8
  activation and weight tensors and e5m2 for all FP8 output activation gradient tensors."""

  fp8_recipe: Optional[str] = "delayed"
  """If set, enables the use of FP8 precision through Transformer Engine. There are 3 predefined
  choices (1) 'tensorwise' uses per tensor current scaling recipe, (2) 'delayed'
  uses delayed scaling recipe, 3) 'mxfp8' for Blackwell architecture only,
  4) 'blockwise' for blockwise scaling recipe."""

  fp8_param: bool = False
  """If set, keep the parameters in fp8 precision to save memory. This option must be used
  together with fp8 mode (i.e., TransformerConfig.fp8 is not None). Note that not all parameters
  will be converted to fp8; for example, biases will remain unchanged. The parameters affected are
  primarily the weights of GEMMs. The specific parameters that will be converted to fp8 are
  determined by TE."""

  fp8_margin: int = 0
  """Margin for the scaling factor computation."""

  fp8_interval: int = 1
  """DEPRECATED from TransformerEngine v1.8.0. This flag is ignored.
  Controls how often the scaling factor is recomputed.
  """

  fp8_amax_history_len: int = 1
  """The length of the amax history window used for scaling factor computation."""

  fp8_amax_compute_algo: str = "most_recent"
  """Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2
  predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent`
  always chooses the most recently seen value.

  """

  fp8_wgrad: bool = True
  """When set to False, override FP8 config options and do the wgrad computation
  in higher precision."""

  fp8_dot_product_attention: bool = False
  """When set to True, use the FP8 implementation of Dot Product Attention."""

  fp8_multi_head_attention: bool = False
  """When set to True, use the FP8 implementation of Multi Head Attention."""

  tp_only_amax_red: bool = False
  """When set to True, reduce the FP8 AMAX only in the TP or TP-CP domain"""

  first_last_layers_bf16: bool = False
  """If True, retains first and last N TransformerBlocks in BF16 as opposed to FP8."""

  num_layers_at_start_in_bf16: int = 1
  """Number of layers at the start of the model to keep in BF16 precision when
  first_last_layers_bf16 is True."""

  num_layers_at_end_in_bf16: int = 1
  """Number of layers at the end of the model to keep in BF16 precision when
  first_last_layers_bf16 is True."""

  use_kitchen: bool = False
  """Use the kitchen extension for transformer quantization."""

  ####################
  # MoE related
  ####################
  moe_shared_expert_intermediate_size: Optional[int] = None
  """Shared expert total ffn hidden size.
  It should be equal to 'num_shared_experts * ffn_size_of_each_shared_expert' if
  there are multiple shared experts.
  None means no shared expert."""

  moe_shared_expert_overlap: bool = False
  """Enable overlapping between shared expert computations and dispatcher communications.
  Without this, the shared epxerts execute after the routed experts."""

  moe_layer_freq: Union[int, List[int]] = 1
  """Frequency between MoE layers and Dense layers. Accepts either:
  - An integer N: Represents a 1:N ratio, meaning one expert layer for every N-1 dense layers.
  - A list that defines a custom pattern, e.g.: [1,1,1,0,1,1,1,0,1,1,1,0]"""

  moe_ffn_hidden_size: Optional[int] = None
  """MoE Feed-Forward Network hidden size"""

  moe_router_load_balancing_type: str = "aux_loss"
  """The load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss
  used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the load balancing loss used
  in DeepSeekV2 and DeepSeekV3, which computes the loss for each individual sample; "sinkhorn"
  corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing.
  The default is "aux_loss"."""

  moe_router_topk: int = 2
  """Number of experts to route to for each token."""

  moe_router_topk_limited_devices: Optional[int] = None
  """Number of EP ranks to consider for each token in group-limited routing,
  DEPRECATED and replaced by moe_router_num_groups and moe_router_group_topk.
  """

  moe_router_padding_for_fp8: Optional[bool] = False
  """Whether to pad the routing_map to make sure the number of tokens each expert received
  is a multiple of 16/32 for FP8 precision. This can remove the explicit padding in the
  GroupedMLP layer."""

  moe_router_num_groups: Optional[int] = None
  """Number of groups to divide experts into for group-limited routing.
  When using group-limited routing:
  1. Experts are divided into 'moe_router_num_groups' equal-sized groups
  2. For each token, 'moe_router_group_topk' groups are selected based on sum of
  top-('moe_router_topk'/'moe_router_group_topk') routing scores within each group
  3. From these selected groups, 'moe_router_topk' individual experts are chosen
  Two common use cases:
  - Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP)
  to limit each token to experts on a subset of devices
  (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)
  - Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group
  to limit each token to experts on a subset of nodes
  (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)
  """

  moe_router_group_topk: Optional[int] = None
  """Number of selected groups for group-limited routing."""

  moe_router_pre_softmax: bool = False
  """Enable pre-softmax(pre-sigmoid) routing for MoE, which means softmax is before the 
  top-k selection.
  By default, softmax is done after top-k."""

  moe_router_topk_scaling_factor: Optional[float] = None
  """Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax
  enabled. Defaults to None, which means no scaling."""

  moe_router_score_function: str = "softmax"
  """Score function for MoE routing. Can be "softmax" or "sigmoid"."""

  moe_router_dtype: Optional[str] = None
  """Data type for routing and expert output weighted averaging. Using fp32 or fp64 can
  improve stability especially when the number of experts is large (e.g. finegrained-moe).
  None means no changes for dtype."""

  moe_router_enable_expert_bias: bool = False
  """TopK routing with dynamic per-expert bias in the aux-loss-free load balancing strategy.
  The routing decision is based on the sum of the routing scores and the expert bias.
  See https://arxiv.org/abs/2408.15664 for details."""

  moe_router_bias_update_rate: float = 1e-3
  """The expert bias is updated based on the number of assigned tokens to each expert
  in a global batch, where the bias is increased for the experts with less assigned tokens
  and decreased for the experts with more assigned tokens.
  The default value 1e-3 is same as that used in DeepSeekV3."""

  moe_router_force_load_balancing: bool = False
  """[Experimental] Force load balancing with random logits for MoE router, supports naive topk 
  and group-limited topk. This is an experimental feature and only for benchmark."""

  moe_grouped_gemm: bool = False
  """When there are multiple experts per rank, compress multiple local (potentially small) gemms
  in a single kernel launch to improve the utilization and performance by leveraging the Grouped
  GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).
  """

  moe_use_legacy_grouped_gemm: bool = False
  """Use legacy GroupedMLP rather than TEGroupedMLP.
  Note: The legacy one will be deprecated soon."""

  moe_aux_loss_coeff: float = 0  # 1e-2 would be a good start value for load balance loss.
  """Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended."""

  moe_z_loss_coeff: Optional[float] = None  # 1e-3 would be a good start value for z-loss
  """Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended."""

  moe_input_jitter_eps: Optional[float] = None
  """Add noise to the input tensor by applying jitter with a specified epsilon value."""

  moe_token_dropping: bool = False
  """This feature involves selectively dropping and padding tokens for each expert to achieve a
  specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is
  currently unsupported so should remain False."""

  moe_token_dispatcher_type: str = "allgather"
  """The type of token dispatcher to use. The default is 'allgather'.
  Options are 'allgather','alltoall' and 'flex'."""

  moe_enable_deepep: bool = False
  """[Experimental] Enable DeepEP for efficient token dispatching and combine in MoE models."""

  moe_per_layer_logging: bool = False
  """Enable per-layer logging for MoE, currently supports auxiliary loss and z loss."""

  moe_expert_capacity_factor: Optional[float] = None
  """moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token
  will be dropped. The default is None."""

  moe_pad_expert_input_to_capacity: bool = False
  """moe_pad_expert_input_to_capacity (bool): If True, pads the input for each expert to match
  the expert capacity length, effective only after the moe_expert_capacity_factor is set. The
  default setting is False."""

  moe_token_drop_policy: str = "probs"
  """The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with
  the lowest probabilities will be dropped. If "position", tokens at the end of each batch will
  be dropped.
  """

  moe_layer_recompute: bool = False
  """Memory optimization: checkpointing moe_layer to save actiavtion memory."""

  moe_permute_fusion: bool = False
  """Fuse token rearrangement ops during token dispatching."""

  moe_apply_probs_on_input: bool = False
  """Apply probs on input of experts instead of applying after activation and glu."""

  ##################
  # Context Parallel
  ##################
  cp_comm_type: Optional[Union[str, List[str]]] = None
  """Inter-gpu communication type for context parallelism.
  str: all layers share same communication type.
  List[str]: each layer has its separate communication type.
  cp_comm_type of each layer can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
  "p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be
  overlapped with attention compute.
  "all_gather": All-gather to get full sequence of KV before attention. The all-gather is not
  async, and cannot be overlapped.
  "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get
  full sequence of QKV.
  "a2a+p2p": A hierarchical implementation of context parallelism to attention.
  It uses A2A communications in low-level CP groups (e.g., via NVLink),
  and P2P communications in high-level CP groups (e.g., via IBLink).
  """

  ##################
  # Cuda Graphs
  ##################
  enable_cuda_graph: bool = False
  """When set to true, TransformerLayer layers are swapped with a CUDA graphed version."""

  cuda_graph_use_single_mempool: bool = False
  """When set to true, cudagraphs will be captured inside a single mempool, in which all
  cudagraphs may only be used once per step. If false, cudagraphs may be reused across
  microbatches. Enabling may reduce cudagraph memory overheads due to memory fragmentation,
  however may greatly increase the number of cudagraphs created when the number of microbatches
  is high."""

  cuda_graph_retain_backward_graph: bool = False
  """When set to true, cudagraph backward passes will be graph captured with 'retain_grad=True'
  This may enable cudagraphs for certain modules that are not completely cudagraph safe. For
  more details, see: https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html."""

  cuda_graph_warmup_steps: int = 3
  """Number of warmup steps for CUDA graphs"""

  external_cuda_graph: bool = False
  """When set to true, TransformerLayer layers are swapped with user provided CUDA graphs."""

  cuda_graph_scope: str = "full"
  """When external_cuda_graph is set to true, cuda_graph_scope determines the CUDA graphs
  capturing scope. Valid values are "full" and "attn". "Full" scope captures a whole Transformer
  layer. "Attn" scope only captures operations in TransformerLayer._forward_attention()."""

  ####################
  # miscellaneous
  ####################
  clone_scatter_output_in_embedding: bool = True
  """When set to True, clone the output of scatter_to_sequence_parallel_region in embedding layer
  to facilitate garbage collection of input."""

  disable_parameter_transpose_cache: bool = False
  """When set to true, the parameter transposes are not cached for subsequent iterations."""

  config_logger_dir: str = ""
  """When non-empty, dumps entry-point configs to config_logger_dir"""

  flash_decode: bool = False
  """ Use the optimized flash decoding kernel during inference. """

  use_te_rng_tracker: bool = False
  """ Whether to use the TE or MCore version of the RNG tracker. """

  inference_rng_tracker: bool = False
  """ Whether we should instantiate a separate RNG tracker for inference. """

  symmetric_ar_type: Optional[str] = None
  """Type of symmetric all reduce to use"""

  mrope_section: Optional[List[int]] = None
  """ Multimodal rope section is for channel dimension of temporal, height and width
  in rope calculation. """

  is_hybrid_model: bool = False
  """ Indicates whether this is a hybrid model. """

  mamba_state_dim: int = 128
  """The dimensionality of the state representation in Mamba layers."""

  mamba_head_dim: int = 64
  """The dimensionality of the heads in the Mamba layers."""

  mamba_num_groups: int = 8
  """The number of groups used in Mamba layers."""

  mamba_num_heads: Optional[int] = None
  """The number of heads used in Mamba layers. 
  If None, the number of heads will be hidden_size * expand // mamba_head_dim."""

  use_mamba_mem_eff_path: bool = True
  """If True, use the memory efficient path for Mamba layers."""

  mlp_chunks_for_prefill: int = 1
  """The number of chunks along the sequence dimension to use for MLP computation
  during prefill."""

  heterogeneous_block_specs: bool = False
  """Whether to use heterogeneous block specs (nemotron-nas architecture)."""

  hetereogenous_dist_checkpoint: bool = False
  """Whether to use heterogenous layers in distributed checkpoint."""

  ####################
  # Quantization
  ####################
  quant_recipe: Optional[RecipeConfig] = None
  """Configuration of any quantization to be applied to the model"""

  def __post_init__(self):
      """Python dataclass method that is used to modify attributes after initialization.
      See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
      details.
      """
      super().__post_init__()
      if self.fp16 and self.bf16:
          raise ValueError(
              f"Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True."
          )

      # Apply BF16 matmul precision setting if needed
      if self.bf16 and self.disable_bf16_reduced_precision_matmul:
          torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False

      if self.num_attention_heads % self.tensor_model_parallel_size != 0:
          raise ValueError(
              f"num_attention_heads ({self.num_attention_heads}) must be a multiple of "
              f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
          )

      if self.ffn_hidden_size is None:
          self.ffn_hidden_size = 4 * self.hidden_size

      if self.kv_channels is None:
          self.kv_channels = self.hidden_size // self.num_attention_heads

      if self.num_query_groups is None:
          self.num_query_groups = self.num_attention_heads

      if self.num_query_groups % self.tensor_model_parallel_size != 0:
          raise ValueError(
              f"num_query_groups ({self.num_query_groups}) must be a multiple of "
              f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
          )

      if self.fp8:
          # cannot support first last layer bf16 with delayed scaling
          if self.first_last_layers_bf16 and self.fp8_recipe == Fp8Recipe.delayed:
              raise ValueError("Delayed scaling does not support first / last layer in BF16.")

          # max bf16 layers per pipeline stage
          max_bf16_layers_per_pipeline_stage = (
              self.num_layers // self.pipeline_model_parallel_size
          )

          # check start/end bf16 layer counts are valid
          if self.first_last_layers_bf16:
              if (
                  self.num_layers_at_start_in_bf16 < 0
                  or self.num_layers_at_start_in_bf16 > max_bf16_layers_per_pipeline_stage
              ):
                  raise ValueError(
                      f"num_layers_at_start_in_bf16 ({self.num_layers_at_start_in_bf16}) must be "
                      f"between 0 and number of layers per pipeline stage "
                      f"({max_bf16_layers_per_pipeline_stage})."
                  )
              if (
                  self.num_layers_at_end_in_bf16 < 0
                  or self.num_layers_at_end_in_bf16 > max_bf16_layers_per_pipeline_stage
              ):
                  raise ValueError(
                      f"num_layers_at_end_in_bf16 ({self.num_layers_at_end_in_bf16}) must be "
                      f"between 0 and number of layers per pipeline stage "
                      f"({max_bf16_layers_per_pipeline_stage})."
                  )

      if self.fp8_param and not self.fp8:
          raise ValueError("fp8_param must be used together with fp8 mode.")

      if self.apply_query_key_layer_scaling:
          self.attention_softmax_in_fp32 = True

      if self.expert_model_parallel_size > 1 and self.num_moe_experts is None:
          raise ValueError("num_moe_experts must be non None to use expert-parallel.")

      if self.num_moe_experts is not None and self.num_moe_experts <= 0:
          raise ValueError("num_moe_experts must be non-negative.")

      if self.num_moe_experts is not None and self.moe_ffn_hidden_size is None:
          self.moe_ffn_hidden_size = self.ffn_hidden_size
          warnings.warn("moe_ffn_hidden_size is not set, using ffn_hidden_size instead.")

      if self.num_moe_experts is None:
          assert (
              self.moe_ffn_hidden_size is None
          ), "moe_ffn_hidden_size must be None when num_experts is not set."

      if self.moe_enable_deepep:
          if self.moe_token_dispatcher_type != "flex":
              raise ValueError("DeepEP backend is only supported with flex token dispatcher.")

      if self.moe_token_dispatcher_type == "flex":
          if self.moe_pad_expert_input_to_capacity:
              raise ValueError(
                  "Flex token dispatcher does not support moe_pad_expert_input_to_capacity"
              )

      if self.moe_shared_expert_intermediate_size is not None:
          if self.moe_shared_expert_intermediate_size <= 0:
              raise ValueError(
                  f"moe_shared_expert_intermediate_size must be "
                  f"num_shared_experts * ffn_size_of_each_shared_expert, "
                  f"but got {self.moe_shared_expert_intermediate_size}"
              )
          if self.moe_shared_expert_overlap and self.moe_token_dispatcher_type not in [
              "alltoall"
          ]:
              raise ValueError(
                  f"moe_shared_expert_overlap only works with alltoall token dispatcher."
              )

      if self.moe_expert_capacity_factor is not None:
          if self.moe_expert_capacity_factor < 0:
              self.moe_expert_capacity_factor = None
          if self.moe_router_load_balancing_type not in ["aux_loss", "seq_aux_loss", "none"]:
              raise ValueError(
                  "moe_expert_capacity_factor only works with aux_loss or none load balancing"
              )

      if self.moe_pad_expert_input_to_capacity:
          if self.moe_expert_capacity_factor is None:
              raise ValueError(
                  "moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity"
              )

      if self.cpu_offloading and (
          self.cpu_offloading_num_layers < 0 or self.cpu_offloading_num_layers >= self.num_layers
      ):
          raise ValueError(
              f"CPU offloading can be done only for layers less than {self.num_layers}"
          )

      if self.cpu_offloading and self.pipeline_model_parallel_size > 1:
          raise ValueError(
              "Currently there is no support for Pipeline parallelism with CPU offloading"
          )

      if self.cpu_offloading and self.recompute_granularity is not None:
          raise ValueError(
              "CPU offloading does not work when activation recomputation is enabled"
          )

      if self.recompute_granularity is not None:
          if self.recompute_granularity not in ["full", "selective"]:
              raise ValueError(
                  f'When using recompute_granuarlity: {self.recompute_granularity} must be "full"'
                  'or "selective".'
              )

          if self.recompute_method is not None:
              if self.recompute_method not in ["block", "uniform"]:
                  raise ValueError(
                      f'recompute_method: {self.recompute_method} must be "block" or "uniform".'
                  )
          elif self.recompute_granularity != "selective":
              raise ValueError(
                  f"Using recompute_granularity: {self.recompute_granularity} so "
                  'recompute_method must be "block" or "uniform"'
              )

          if self.recompute_granularity != "selective" and self.recompute_num_layers is None:
              raise ValueError(
                  f"When using recompute_granularity: {self.recompute_granularity} "
                  "recompute_num_layers must be between "
                  "1 and num_layers_per_pipeline_rank: "
                  f"{self.num_layers // self.pipeline_model_parallel_size}"
              )
          elif (
              self.recompute_granularity == "selective" and self.recompute_num_layers is not None
          ):
              raise ValueError(
                  f"When using recompute_granularity: {self.recompute_granularity} "
                  "recompute_num_layers must be None."
              )

          if self.distribute_saved_activations and self.sequence_parallel:
              raise ValueError(
                  f"distribute_saved_activations: {self.distribute_saved_activations} must be "
                  f"false when sequence parallel is enabled: {self.sequence_parallel}"
              )

      if self.recompute_modules is None:
          self.recompute_modules = ["core_attn"]

      if self.recompute_granularity == "selective":
          if len(self.recompute_modules) > 0:
              allowed_modules = {"core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe"}
              invalid_modules = set(self.recompute_modules) - allowed_modules
              assert not invalid_modules, (
                  f"Invalid choices for recompute_modules: {invalid_modules}. "
                  f"Allowed modules are: {allowed_modules}"
              )

          if "moe_act" in self.recompute_modules and not self.moe_grouped_gemm:
              raise ValueError(
                  "moe_act in recompute_modules is only supported with moe_grouped_gemm."
              )

          if "mla_up_proj" in self.recompute_modules and not self.multi_latent_attention:
              raise ValueError(
                  "mla_up_proj in recompute_modules is only supported with "
                  "multi_latent_attention."
              )

          if "core_attn" in self.recompute_modules:
              warnings.warn(
                  "If you are using transformer_engine as the transformer implementation, "
                  "the core_attn is from transformer_engine and may be the fused version. "
                  "For fused attention, you have no need to set 'core_attn' to recompute. "
                  "Please check that the core_attn recompute is really needed."
              )

          if self.fp8:
              if "moe_act" in self.recompute_modules or "layernorm" in self.recompute_modules:
                  raise ValueError("moe_act and layernorm recompute cannot work with fp8.")

      if self.moe_layer_recompute:
          warnings.warn(
              "--moe-layer-recompute is deprecated. "
              "Use --recompute-granularity selective --recompute-modules moe_layer instead."
          )
          if self.recompute_granularity == "full":
              raise ValueError(
                  "Do not set --moe-layer-recompute with full recompute granularity. "
              )
          self.recompute_granularity = "selective"
          if "moe" not in self.recompute_modules:
              self.recompute_modules.append("moe")

      if (
          self.num_layers_in_first_pipeline_stage is not None
          or self.num_layers_in_last_pipeline_stage is not None
      ) and (
          self.account_for_embedding_in_pipeline_split or self.account_for_loss_in_pipeline_split
      ):
          raise ValueError(
              "num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage cannot be"
              "set at the same time with account_for_embedding_in_pipeline_split"
              "and account_for_loss_in_pipeline_split"
          )

      # PP layout
      if self.pipeline_model_parallel_layout is not None:
          # If pipeline layout is set, we will check the conflicts
          # with other pipeline layout arguments.
          any_conflict = (
              self.num_layers_in_first_pipeline_stage is not None
              or self.num_layers_in_last_pipeline_stage is not None
              or self.account_for_embedding_in_pipeline_split
              or self.account_for_loss_in_pipeline_split
          )
          if any_conflict:
              raise ValueError(
                  "pipeline_model_parallel_layout cannot be set"
                  " with other pipeline layout arguments."
                  f" {self.num_layers_in_first_pipeline_stage=},"
                  f" {self.num_layers_in_last_pipeline_stage=},"
                  f" {self.account_for_embedding_in_pipeline_split=},"
                  f" {self.account_for_loss_in_pipeline_split=}."
              )

          # Transfer pipeline_model_parallel_layout from str or list to
          # PipelineParallelLayerLayout
          if isinstance(self.pipeline_model_parallel_layout, str):
              self.pipeline_model_parallel_layout = PipelineParallelLayerLayout.from_str(
                  layout=self.pipeline_model_parallel_layout,
                  pipeline_model_parallel_size=self.pipeline_model_parallel_size,
              )
          elif isinstance(self.pipeline_model_parallel_layout, list):
              # Since list is not hashable, the initialization will not be cached.
              self.pipeline_model_parallel_layout = PipelineParallelLayerLayout(
                  layout=self.pipeline_model_parallel_layout,
                  pipeline_model_parallel_size=self.pipeline_model_parallel_size,
              )

          # Check whether the input VPP size conflicts with the PP layout
          detected_vpp_size = (
              self.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size
          )
          if self.virtual_pipeline_model_parallel_size is not None:
              assert self.virtual_pipeline_model_parallel_size == detected_vpp_size, (
                  f"virtual_pipeline_model_parallel_size conflicts with"
                  f" pipeline_model_parallel_layout,"
                  f" ({self.virtual_pipeline_model_parallel_size=}, "
                  f" {detected_vpp_size=})"
              )
          elif detected_vpp_size > 1:
              self.virtual_pipeline_model_parallel_size = detected_vpp_size

          # Check whether the layout is valid.
          self.pipeline_model_parallel_layout.validate_layer_layout(num_layers=self.num_layers)

      # Uneven PP
      elif (
          self.num_layers_in_first_pipeline_stage is not None
          or self.num_layers_in_last_pipeline_stage is not None
      ):
          pipeline_parallel_size = self.pipeline_model_parallel_size
          num_layers = self.num_layers

          if self.num_layers_in_first_pipeline_stage is not None:
              if self.num_layers_in_first_pipeline_stage <= 0:
                  raise ValueError("num_layers_in_first_pipeline_stage must be larger than 0")

              if self.virtual_pipeline_model_parallel_size is not None:
                  if (
                      self.num_layers_in_first_pipeline_stage
                      % self.virtual_pipeline_model_parallel_size
                      != 0
                  ):
                      raise ValueError(
                          f"number of layers at first stage: "
                          f"{self.num_layers_in_first_pipeline_stage}"
                          f"must be divisible by virtual pipeline"
                          f"parallel degree {self.virtual_pipeline_model_parallel_size}"
                      )
              num_layers -= self.num_layers_in_first_pipeline_stage
              pipeline_parallel_size -= 1

          if self.num_layers_in_last_pipeline_stage is not None:
              if self.num_layers_in_last_pipeline_stage <= 0:
                  raise ValueError("num_layers_in_last_pipeline_stage must be larger than 0")

              if self.virtual_pipeline_model_parallel_size is not None:
                  if (
                      self.num_layers_in_last_pipeline_stage
                      % self.virtual_pipeline_model_parallel_size
                      != 0
                  ):
                      raise ValueError(
                          f"number of layers at last stage: "
                          f"{self.num_layers_in_last_pipeline_stage}"
                          f"must be divisible by virtual pipeline"
                          f"parallel degree {self.virtual_pipeline_model_parallel_size}"
                      )
              num_layers -= self.num_layers_in_last_pipeline_stage
              pipeline_parallel_size -= 1

          # Here pipeline_parallel_size is the number of middle PP stages. If there are middle
          # PP stages, check number of layers at middle stage is divisible by middle PP size.
          if pipeline_parallel_size and not num_layers % pipeline_parallel_size == 0: 
              raise ValueError(
                  f"number of layers at middle stage: {num_layers} must be divisible by"
                  f"the middle pipeline model parallel size {pipeline_parallel_size}"
              )

          # If there are middle PP stages, check number of layers
          # on each middle PP rank is divisible by VPP size.
          if pipeline_parallel_size and self.virtual_pipeline_model_parallel_size is not None:
              num_layers_per_middle_pipeline_rank = num_layers // pipeline_parallel_size
              if (
                  not num_layers_per_middle_pipeline_rank
                  % self.virtual_pipeline_model_parallel_size
                  == 0
              ):
                  raise ValueError(
                      f"number of layers on each middle pipeline rank:"
                      f"{num_layers_per_middle_pipeline_rank} must be divisible by virtual"
                      f"pipeline parallel degree {self.virtual_pipeline_model_parallel_size}"
                  )

      elif (
          self.account_for_embedding_in_pipeline_split or self.account_for_loss_in_pipeline_split
      ):
          if self.virtual_pipeline_model_parallel_size is None:
              num_layers = self.num_layers

              if self.account_for_embedding_in_pipeline_split:
                  num_layers += 1

              if self.account_for_loss_in_pipeline_split:
                  num_layers += 1

              if not num_layers % self.pipeline_model_parallel_size == 0:
                  raise ValueError(
                      f"number of middle layers: {num_layers} must be divisible by "
                      f"middle pipeline_model_parallel_size {self.pipeline_model_parallel_size}"
                  )
          else:
              num_layers = self.num_layers
              if self.account_for_embedding_in_pipeline_split:
                  num_layers += 1

              if self.account_for_loss_in_pipeline_split:
                  num_layers += 1

              if not num_layers % self.pipeline_model_parallel_size == 0:
                  raise ValueError(
                      f"num_layers: {num_layers} after enable"
                      f"account_for_embedding_in_pipeline_split or "
                      f"account_for_loss_in_pipeline_split must be divisible"
                      f"by pipeline_model_parallel_size "
                      f"{self.pipeline_model_parallel_size}"
                  )

              num_layers_per_pipeline_rank = num_layers // self.pipeline_model_parallel_size
              if (
                  not num_layers_per_pipeline_rank % self.virtual_pipeline_model_parallel_size
                  == 0
              ):
                  raise ValueError(
                      f"number of layers on each pipeline rank: {num_layers_per_pipeline_rank}"
                      f"(after enable account_for_embedding_in_pipeline_split or "
                      f"account_for_loss_in_pipeline_split) must be divisible by"
                      f"virtual_pipeline_model_parallel_size"
                      f"{self.virtual_pipeline_model_parallel_size}"
                  )

      if self.apply_query_key_layer_scaling:
          self.attention_softmax_in_fp32 = True

      if self.bias_activation_fusion:
          if self.activation_func not in [F.gelu, F.silu]:
              raise ValueError(
                  "When bias_activation_fusion is True, activation function should be either "
                  "gelu or swiglu"
              )
          if (
              self.activation_func == F.gelu
              and not self.gated_linear_unit
              and not self.add_bias_linear
          ):
              raise ValueError(
                  "When bias_activation_fusion is True, gated_linear_unit is False, "
                  "and activation function is gelu, add_bias_linear must also be True."
              )

      if self.activation_func_fp8_input_store:
          if self.activation_func != F.silu or not self.gated_linear_unit:
              raise ValueError("Storing activation input in FP8 is supported only for SwiGLU.")

      if self.apply_rope_fusion:
          if self.multi_latent_attention:
              warnings.warn(
                  "apply_rope_fusion for multi-latent attention only supports training. "
                  "It is experimental and may change in future versions."
              )
          else:
              if self.rotary_interleaved:
                  if not is_te_min_version("2.3.0.dev0"):
                      raise ValueError(
                          "rotary_interleaved does not work with apply_rope_fusion for "
                          "TE < 2.3.0.dev0. Please install TE >= 2.3.0.dev0"
                      )

              from megatron.core.models.common.embeddings.rope_utils import (
                  fused_apply_rotary_pos_emb,
                  fused_apply_rotary_pos_emb_thd,
              )

              if fused_apply_rotary_pos_emb is None and fused_apply_rotary_pos_emb_thd is None:
                  raise ValueError(
                      "apply_rope_fusion is not available. Please install TE >= 1.4 or Apex."
                  )

      if self.multi_latent_attention and self.rotary_interleaved:
          raise ValueError("rotary_interleaved does not work with multi_latent_attention.")

      if self.init_method is None:
          self.init_method = init_method_normal(self.init_method_std)

      if self.output_layer_init_method is None:
          self.output_layer_init_method = scaled_init_method_normal(
              self.init_method_std,
              self.num_layers,
              multiplier=2.0 if not self.is_hybrid_model else 1.0,
          )

      if self.num_moe_experts is not None:
          assert not self.add_bias_linear, "Bias is not supported for MoE"

      if self.moe_router_enable_expert_bias and self.moe_router_score_function != "sigmoid":
          raise ValueError(
              "Expert bias for aux-loss-free routing only supports sigmoid score function."
              "Please set --moe-router-score-function sigmoid for sigmoid score function."
          )

      if self.num_moe_experts and self.fp8:
          # TE version below 1.7.0 will raise Error when handle zeros tokens for expert
          if not is_te_min_version("1.7.0.dev0"):
              raise ValueError(
                  "Only transformer-engine>=1.7.0 supports MoE FP8 training, "
                  f"but your version is {get_te_version()}."
              )

          if self.moe_grouped_gemm and not is_te_min_version("1.11.0"):
              raise ValueError(
                  "Only transformer-engine>=1.11.0 supports FP8 grouped gemm, "
                  f"but your version is {get_te_version()}."
              )

      if self.moe_router_padding_for_fp8:
          if self.fp8 is None:
              raise ValueError("fp8 must be specified when moe_router_padding_for_fp8 is True.")

          if self.moe_token_dispatcher_type in ["allgather", "alltoall_seq"]:
              raise ValueError(
                  "allgather and alltoall_seq dispatcher does not support "
                  "moe_router_padding_for_fp8."
              )

      if (
          self.moe_router_topk == 1
          and self.moe_router_score_function == "softmax"
          and not self.moe_router_pre_softmax
          and self.moe_router_load_balancing_type != "sinkhorn"
      ):
          # Requires applying softmax before selecting the top-k when k is 1,
          # since softmax on a [num_tokens, 1] would yield a zero gradient.
          raise ValueError("Please use --moe-router-pre-softmax when topk is 1.")

      if self.moe_router_group_topk:
          if self.moe_router_topk_limited_devices:
              raise ValueError(
                  "moe_router_topk_limited_devices is deprecated and replaced by "
                  "moe_router_group_topk and moe_router_num_groups."
              )
          if not self.moe_router_num_groups:
              raise ValueError(
                  "When using group limited routing, moe_router_num_groups must be specified."
              )
          else:
              assert self.num_moe_experts % self.moe_router_num_groups == 0, (
                  f"num_moe_experts ({self.num_moe_experts}) should be divisible by "
                  f"moe_router_num_groups ({self.moe_router_num_groups})."
              )
              assert self.moe_router_group_topk <= self.moe_router_num_groups, (
                  f"moe_router_group_topk ({self.moe_router_group_topk}) should be smaller than "
                  f"moe_router_num_groups ({self.moe_router_num_groups})."
              )
      elif self.moe_router_topk_limited_devices:
          warnings.warn(
              "moe_router_topk_limited_devices is deprecated. Use moe_router_group_topk and "
              "moe_router_num_groups instead."
          )
          self.moe_router_group_topk = self.moe_router_topk_limited_devices
          self.moe_router_num_groups = self.expert_model_parallel_size

      if self.enable_cuda_graph:
          if self.cpu_offloading:
              raise ValueError("CUDA graphs not supported with CPU offloading.")
          if self.recompute_granularity:
              raise ValueError("CUDA graphs not supported with activation recomputation.")

      if self.moe_token_dispatcher_type in ["allgather"]:
          if self.variable_seq_lengths is True:
              raise ValueError(
                  f"Token dispatcher type: {self.moe_token_dispatcher_type} does not support "
                  f"variable sequence length, please use alltoall dispatcher instead."
              )

      if self.moe_permute_fusion:
          from megatron.core.transformer.moe.moe_utils import (
              fused_permute,
              fused_permute_with_probs,
              fused_sort_chunks_by_index,
              fused_sort_chunks_by_index_with_probs,
              fused_unpermute,
          )

          if (
              fused_permute is None
              or fused_permute_with_probs is None
              or fused_sort_chunks_by_index is None
              or fused_sort_chunks_by_index_with_probs is None
              or fused_unpermute is None
          ):
              raise ValueError("fused permutation is not available. Please install TE >= 2.1.0.")

      if self.context_parallel_size > 1 and self.cp_comm_type is not None:
          if isinstance(self.cp_comm_type, list):
              assert len(self.cp_comm_type) == self.num_layers, (
                  f"Length of cp_comm_type ({len(self.cp_comm_type)}) should equal to "
                  f"the total number of transformer layers ({self.num_layers})!"
              )
          else:
              assert isinstance(
                  self.cp_comm_type, str
              ), "Unsupported communication type for context parallelism!"

      assert (
          self.pipeline_model_parallel_size > 0
      ), f"Pipeline model parallel size must be larger than 0 \
          when enable --standalone-embedding-stage and --standalone-loss-stage"

      if (
          self.num_moe_experts is not None
          and self.num_moe_experts >= 32
          and not self.moe_router_dtype
      ):
          warnings.warn(
              "Using a large number of experts (e.g. >=32) without fp32 routing. "
              "Consider enabling moe_router_dtype for better numerical stability."
          )
      if self.symmetric_ar_type is not None:
          if not HAVE_PACKAGING:
              raise ImportError(
                  "packaging is not installed. Please install it with `pip install packaging`."
              )
          assert is_torch_min_version("2.7.0a0"), "Must have at least torch version 2.7 or higher"
          assert is_te_min_version("2.3.0") or get_te_version() == PkgVersion(
              "2.3.0.dev0+39c0e70"
          ), "Must have at least TE version 2.3 or higher to use symmetric memory all reduce"

      if self.no_rope_freq:
          assert not self.flash_decode, "flash_decode cannot be used with no_rope."
          if isinstance(self.no_rope_freq, int):
              assert self.num_layers % self.no_rope_freq == 0, (
                  f"no_rope_freq={self.no_rope_freq} should be "
                  f"divisible by num_layers={self.num_layers}."
              )
              # Convert integer pattern to list pattern
              # e.g. no_rope=4 with num_layers=8 becomes [0,0,0,1,0,0,0,1]
              pattern = [0] * (self.no_rope_freq - 1) + [1]
              self.no_rope_freq = pattern * (self.num_layers // self.no_rope_freq)
          else:
              assert len(self.no_rope_freq) == self.num_layers, (
                  f"Length of no_rope list ({len(self.no_rope_freq)}) must match "
总访客数:   ·   总访问量:
PLM's Blog @ 2016 - 2025