训练主流程
训练主入口 (main_ppo)
main_ppo.py 主入口:
def run(self, config):
from pprint import pprint
from omegaconf import OmegaConf
from verl.utils.fs import copy_to_local
print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
pprint(OmegaConf.to_container(config, resolve=True))
OmegaConf.resolve(config)
actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config)
self.add_critic_worker(config)
# We should adopt a multi-source reward function here:
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# finally, we combine all the rewards together
# The reward type depends on the tag of the data
self.add_reward_model_worker(config)
# Add a reference policy worker if KL loss or KL reward is used.
self.add_ref_policy_worker(config, actor_rollout_cls)
# validate config
validate_config(
config=config,
use_reference_policy=need_reference_policy(self.role_worker_mapping),
use_critic=need_critic(config),
)
# Download the checkpoint from HDFS to the local machine.
# `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
local_path = copy_to_local(
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
)
# Instantiate the tokenizer and processor.
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
# Used for multimodal LLM, could be None
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
# Load the reward manager for training and validation.
# 如果训练和测试reward不一致,需要自己修改load_reward_manager读取不同的fn
reward_fn = load_reward_manager(
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
)
val_reward_fn = load_reward_manager(
config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
)
resource_pool_manager = self.init_resource_pool_mgr(config)
from verl.utils.dataset.rl_dataset import collate_fn
# Create training and validation datasets.
train_dataset = create_rl_dataset(
config.data.train_files,
config.data,
tokenizer,
processor,
is_train=True,
max_samples=config.data.get("train_max_samples", -1),
)
val_dataset = create_rl_dataset(
config.data.val_files,
config.data,
tokenizer,
processor,
is_train=False,
max_samples=config.data.get("val_max_samples", -1),
)
train_sampler = create_rl_sampler(config.data, train_dataset)
# Initialize the PPO trainer.
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=self.role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
train_dataset=train_dataset,
val_dataset=val_dataset,
collate_fn=collate_fn,
train_sampler=train_sampler,
)
# Initialize the workers of the trainer.
trainer.init_workers()
# Start the training process.
trainer.fit()主流程 (RayPPOTrainer.fit)
核心流程
0. 核心思想
- RL思想,采样、评估、学习。
1. Rollout (生成/采样)
- 使用
当前策略Actor,对batch prompts做 采样,生成responses。(环境交互) - gen_batch_output = self.async_rollout_manager.
generate_sequences(gen_batch_output)
2. Evaluate (评估)
3. Learn/(学习/更新)
- Critic-Free
- Actor更新
- Critic
- Actor 更新:Policy Loss 笔记
- Critic 更新:Critic Loss 笔记
代码
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
self.global_steps = 0
self._load_checkpoint()
current_epoch = self.global_steps // len(self.train_dataloader)
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
rollout_skip.wrap_generate_sequences()
# add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
# we start from step 1
self.global_steps += 1
last_val_metrics = None
self.max_steps_duration = 0
prev_step_profile = False
curr_step_profile = (
self.global_steps in self.config.global_profiler.steps
if self.config.global_profiler.steps is not None
else False
)
next_step_profile = False
for epoch in range(current_epoch, self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# add uid to batch
batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
gen_batch = self._get_gen_batch(batch)
# pass global_steps to trace
gen_batch.meta_info["global_steps"] = self.global_steps
# 每个prompt复制n次,对应rollout n次
gen_batch_output = gen_batch.repeat(
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
)
is_last_step = self.global_steps >= self.total_training_steps
with marked_timer("step", timing_raw):
# generate a batch
with marked_timer("gen", timing_raw, color="red"):
# rollout,环境交互/模型生成
if not self.async_rollout_mode:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
else:
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
# 拼接生成结果
batch = batch.union(gen_batch_output)
if "response_mask" not in batch.batch.keys():
batch.batch["response_mask"] = compute_response_mask(batch)
# Balance the number of valid tokens across DP ranks.
# NOTE: This usually changes the order of data in the `batch`,
# which won't affect the advantage calculation (since it's based on uid),
# but might affect the loss calculation (due to the change of mini-batching).
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
with marked_timer("reward", timing_raw, color="yellow"):
# compute reward model score
if self.use_rm and "rm_scores" not in batch.batch.keys():
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
if self.config.reward_model.launch_reward_fn_async:
future_reward = compute_reward_async.remote(
data=batch, config=self.config, tokenizer=self.tokenizer
)
else:
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
# Operating Mode Selection:
# - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ)
# - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ)
# Note: π_old computed once per data batch, serves as stable reference during mini-batch updates
rollout_corr_config = self.config.algorithm.get("rollout_correction", None)
bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False)
if bypass_recomputing_logprobs: # Use `rollout_log_probs`
from verl.trainer.ppo.rollout_corr_helper import apply_rollout_correction
apply_rollout_correction(
batch=batch,
rollout_corr_config=rollout_corr_config,
policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss,
)
else:
# Recompute old_log_probs, 计算\pi_{\theta_{old}}
with marked_timer("old_log_prob", timing_raw, color="blue"):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_agg = agg_loss(
loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode
)
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)
if "rollout_log_probs" in batch.batch.keys():
# TODO: we may want to add diff of probs too.
from verl.utils.debug.metrics import calculate_debug_metrics
metrics.update(calculate_debug_metrics(batch))
assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}'
if self.use_reference_policy:
# compute reference log_prob, \pi_{ref}
with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with marked_timer("values", timing_raw, color="cyan"):
# ppo critic
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with marked_timer("adv", timing_raw, color="brown"):
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
if self.config.reward_model.launch_reward_fn_async:
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
# reward,环境最后一步给的奖励
batch.batch["token_level_scores"] = reward_tensor
if reward_extra_infos_dict:
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
# 在reward里增加kl
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
# 不使用的话,就直接使用环境给的奖励
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# Compute rollout correction: IS weights, rejection sampling, and metrics
# Only runs in decoupled mode (computes once per batch using stable π_old)
# In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout
if (
rollout_corr_config is not None
and "rollout_log_probs" in batch.batch
and not bypass_recomputing_logprobs # Only in decoupled mode
):
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch
# Compute IS weights, apply rejection sampling, compute metrics
batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
# IS and off-policy metrics already have rollout_corr/ prefix
metrics.update(is_metrics)
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(
"norm_adv_by_std_in_grpo", True
) # GRPO adv normalization factor
# 计算advantage
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
config=self.config.algorithm,
)
# update critic
if self.use_critic:
with marked_timer("update_critic", timing_raw, color="pink"):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# implement critic warmup, warm up:critic先学习一会儿,避免actor不稳定
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with marked_timer("update_actor", timing_raw, color="red"):
rollout_config = self.config.actor_rollout_ref.rollout
batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable
# TODO: Make "temperature" single source of truth from generation.
batch.meta_info["temperature"] = rollout_config.temperature
# update actor
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)
# validate
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with marked_timer("testing", timing_raw, color="green"):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
# Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
esi_close_to_expiration = should_save_ckpt_esi(
max_steps_duration=self.max_steps_duration,
redundant_time=self.config.trainer.esi_redundant_time,
)
# Check if the conditions for saving a checkpoint are met.
# The conditions include a mandatory condition (1) and
# one of the following optional conditions (2/3/4):
# 1. The save frequency is set to a positive value.
# 2. It's the last training step.
# 3. The current step number is a multiple of the save frequency.
# 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration
):
if esi_close_to_expiration:
print("Force saving checkpoint: ESI instance expiration approaching.")
with marked_timer("save_checkpoint", timing_raw, color="green"):
self._save_checkpoint()
with marked_timer("stop_profile", timing_raw):
next_step_profile = (
self.global_steps + 1 in self.config.global_profiler.steps
if self.config.global_profiler.steps is not None
else False
)
self._stop_profiling(
curr_step_profile and not next_step_profile
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
)
prev_step_profile = curr_step_profile
curr_step_profile = next_step_profile
steps_duration = timing_raw["step"]
self.max_steps_duration = max(self.max_steps_duration, steps_duration)
# training metrics
metrics.update(
{
"training/global_step": self.global_steps,
"training/epoch": epoch,
}
)
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
# Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation
# this is experimental and may be changed/removed in the future in favor of a general-purpose one
if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
self.train_dataloader.sampler.update(batch=batch)
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
progress_bar.update(1)
self.global_steps += 1
if (
hasattr(self.config.actor_rollout_ref.actor, "profiler")
and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory"
):
self.actor_rollout_wg.dump_memory_snapshot(
tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}"
)
if is_last_step:
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return
# this is experimental and may be changed/removed in the future
# in favor of a general-purpose data buffer pool
if hasattr(self.train_dataset, "on_batch_end"):
# The dataset may be changed after each training batch
self.train_dataset.on_batch_end(batch=batch)训练过程关键前置计算 (RayPPOTrainer)
Rollout
with marked_timer("gen", timing_raw, color="red"):
if not self.async_rollout_mode:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
else:
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)Reward 计算
with marked_timer("reward", timing_raw, color="yellow"):
# compute reward model score
if self.use_rm and "rm_scores" not in batch.batch.keys():
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
if self.config.reward_model.launch_reward_fn_async:
future_reward = compute_reward_async.remote(
data=batch, config=self.config, tokenizer=self.tokenizer
)
else:
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) Old LogProb 计算
- 具体见下文 ComputeLogProb
with marked_timer("old_log_prob", timing_raw, color="blue"):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_agg = agg_loss(
loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode
)
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)
if "rollout_log_probs" in batch.batch.keys():
# TODO: we may want to add diff of probs too.
from verl.utils.debug.metrics import calculate_debug_metrics
metrics.update(calculate_debug_metrics(batch))Ref LogProb 计算
if self.use_reference_policy:
# compute reference log_prob
with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)Critic Values 计算 (PPO使用)
# compute values
if self.use_critic:
with marked_timer("values", timing_raw, color="cyan"):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)优势计算
with marked_timer("adv", timing_raw, color="brown"):
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
if self.config.reward_model.launch_reward_fn_async:
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
batch.batch["token_level_scores"] = reward_tensor
if reward_extra_infos_dict:
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# Compute rollout correction: IS weights, rejection sampling, and metrics
# Only runs in decoupled mode (computes once per batch using stable π_old)
# In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout
if (
rollout_corr_config is not None
and "rollout_log_probs" in batch.batch
and not bypass_recomputing_logprobs # Only in decoupled mode
):
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch
# Compute IS weights, apply rejection sampling, compute metrics
batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
# IS and off-policy metrics already have rollout_corr/ prefix
metrics.update(is_metrics)
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(
"norm_adv_by_std_in_grpo", True
) # GRPO adv normalization factor
# 计算优势
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
config=self.config.algorithm,
)训练配置参数验证
Verl 基础参数验证
n_gpus= n_gpus_per_node * nnodesmodel_parallel_size= tensor_model_parallel_size * pipeline_model_parallel_sizemegatron_dp= n_gpus // (model_parallel_size* context_parallel_size )minimal_bsz= megatron_dp * actor.ppo_micro_batch_size_per_gpureal_train_batch_size= data.train_batch_size * rollout.n
条件1:实际train_bs被mini_bs整除
real_train_batch_size%minimal_bsz== 0
条件2: gpu数量被parrallel_size整除
n_gpus%model_parallel_size*context_parallel_size== 0 ,非常重要!!
条件3:mbs和mbs_per_gpu不能同时设置
- _log_prob_micro_batch_size和 _log_prob_micro_batch_size_per_gpu 不能同时设置
- xx主要是这3种:actor_rollout_ref.rollout、actor_rollout_ref.ref、reward_model
条件4:验证actor_config
- train_batch_size
>=ppo_mini_batch_size - ppo_mini_batch_size % ppo_micro_batch_size
== 0 - ppo_micro_batch_size * sp_size
>= n_gpus
条件5:验证critic_config
def validate_config(
config: DictConfig,
use_reference_policy: bool,
use_critic: bool,
) -> None:
"""Validate an OmegaConf DictConfig.
Args:
config (DictConfig): The OmegaConf DictConfig to validate.
use_reference_policy (bool): is ref policy needed
use_critic (bool): is critic needed
"""
# number of GPUs total
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
if config.actor_rollout_ref.actor.strategy == "megatron":
model_parallel_size = (
config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size
* config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size
)
assert (
n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0
), (
f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times "
f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})"
)
megatron_dp = n_gpus // (
model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size
)
minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
else:
minimal_bsz = n_gpus
# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % minimal_bsz == 0, (
f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size "
f"({minimal_bsz})"
)
# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
"""Validate mutually exclusive micro batch size configuration options.
Ensures that users don't set both deprecated micro_batch_size and
the new micro_batch_size_per_gpu parameters simultaneously.
Args:
mbs: Deprecated micro batch size parameter value.
mbs_per_gpu: New micro batch size per GPU parameter value.
name (str): Configuration section name for error messages.
Raises:
ValueError: If both parameters are set or neither is set.
"""
settings = {
"reward_model": "micro_batch_size",
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
}
if name in settings:
param = settings[name]
param_per_gpu = f"{param}_per_gpu"
if mbs is None and mbs_per_gpu is None:
raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")
if mbs is not None and mbs_per_gpu is not None:
raise ValueError(
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove "
f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
)
# Actor validation done in ActorConfig.__post_init__ and validate()
actor_config = omega_conf_to_dataclass(config.actor_rollout_ref.actor)
actor_config.validate(n_gpus, config.data.train_batch_size, config.actor_rollout_ref.model)
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
if use_reference_policy:
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.ref.log_prob_micro_batch_size,
config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.ref",
)
# The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.rollout",
)
# Check for reward model micro-batch size conflicts
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
check_mutually_exclusive(
config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
)
if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
print("NOTICE: You have both enabled in-reward kl and kl loss.")
# critic
if use_critic:
critic_config = omega_conf_to_dataclass(config.critic)
critic_config.validate(n_gpus, config.data.train_batch_size)
if config.data.get("val_batch_size", None) is not None:
print(
"WARNING: val_batch_size is deprecated."
+ " Validation datasets are sent to inference engines as a whole batch,"
+ " which will schedule the memory themselves."
)
# check eval config
if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
assert config.actor_rollout_ref.rollout.temperature > 0, (
"validation gen temperature should be greater than 0 when enabling do_sample"
)
# check LoRA rank in vLLM
if config.actor_rollout_ref.model.get("lora_rank", 0) > 0 and config.actor_rollout_ref.rollout.name == "vllm":
assert config.actor_rollout_ref.model.lora_rank <= 512, "LoRA rank in vLLM must be less than or equal to 512"
print("[validate_config] All configuration checks passed successfully!")ActorConfig.validate:
def validate(self, n_gpus: int, train_batch_size: int, model_config: dict = None):
"""Validate actor configuration with runtime parameters."""
if not self.use_dynamic_bsz:
if train_batch_size < self.ppo_mini_batch_size:
raise ValueError(
f"train_batch_size ({train_batch_size}) must be >= "
f"actor.ppo_mini_batch_size ({self.ppo_mini_batch_size})"
)
sp_size = getattr(self, "ulysses_sequence_parallel_size", 1)
if self.ppo_micro_batch_size is not None:
if self.ppo_mini_batch_size % self.ppo_micro_batch_size != 0:
raise ValueError(
f"ppo_mini_batch_size ({self.ppo_mini_batch_size}) must be divisible by "
f"ppo_micro_batch_size ({self.ppo_micro_batch_size})"
)
if self.ppo_micro_batch_size * sp_size < n_gpus:
raise ValueError(
f"ppo_micro_batch_size ({self.ppo_micro_batch_size}) * "
f"ulysses_sequence_parallel_size ({sp_size}) must be >= n_gpus ({n_gpus})"Megatron 并行参数验证
详细内容具体见文末TransformerConfig 并行参数。
TP参数
num_attention_heads%tensor_model_parallel_size== 0
PP参数
num_layers%pipeline_model_parallel_size== 0
EP参数
num_moe_experts%moe_router_num_groups== 0
CP参数
n_gpus%model_parallel_size*context_parallel_size== 0 ,非常重要!!
关键项
数据加载
主入口main_ppo 调用create_rl_dataset
main_ppo.py 主入口:
def run(self, config):
# ....
# Create training and validation datasets.
train_dataset = create_rl_dataset(
config.data.train_files,
config.data,
tokenizer,
processor,
is_train=True,
max_samples=config.data.get("train_max_samples", -1),
)
val_dataset = create_rl_dataset(
config.data.val_files,
config.data,
tokenizer,
processor,
is_train=False,
max_samples=config.data.get("val_max_samples", -1),
)create_rl_dataset:读取数据集、过滤超长数据
rl_dataset.py 读取完整数据集、过滤超长数据 (配置参数data.max_prompt_length)
def _read_files_and_tokenize(self):
dataframes = []
for parquet_file in self.data_files:
dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
dataframes.append(dataframe)
self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)
total = len(self.dataframe)
print(f"dataset len: {len(self.dataframe)}")
self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe)
def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None):
# filter out too long prompts
if self.filter_overlong_prompts:
tokenizer = self.tokenizer
processor = self.processor
prompt_key = self.prompt_key
if processor is not None:
pass
else:
def doc2len(doc) -> int:
try:
apply_kwargs = dict(**self.apply_chat_template_kwargs)
if self.tool_schemas is not None:
apply_kwargs["tools"] = self.tool_schemas
return len(
tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True, **apply_kwargs)
)
except Exception:
print("Error processing one of the samples, skipping...")
traceback.print_exc()
return self.max_prompt_length + 1
dataframe = dataframe.filter(
lambda doc: doc2len(doc) <= self.max_prompt_length,
num_proc=self.num_workers,
desc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
)
print(f"filter dataset len: {len(dataframe)}")
return dataframe训练时读取单条数据:构建messages、tokenize、pad、position_ids等
- 根据prompt_key
读取messages - 添加
generation prompt 模板 tokenize,获得input_ids和attention_mask,计算position_ids- 根据
data.max_prompt_length做左padding或截断 - 处理
extra_info、tools_kwargs、interaction_kwargs等内容。
def _build_messages(self, example: dict):
messages: list = example.pop(self.prompt_key)
if self.image_key in example or self.video_key in example:
for message in messages:
content = message["content"]
content_list = []
segments = re.split("(<image>|<video>)", content)
segments = [item for item in segments if item != ""]
for segment in segments:
if segment == "<image>":
content_list.append({"type": "image"})
elif segment == "<video>":
content_list.append({"type": "video"})
else:
content_list.append({"type": "text", "text": segment})
message["content"] = content_list
return messages
def __getitem__(self, item):
"""
Note that we also return the raw_input_ids so that it can be combined with other chat template
"""
row_dict: dict = self.dataframe[item]
messages = self._build_messages(row_dict)
model_inputs = {}
if self.processor is not None:
pass
else:
if self.apply_chat_template_kwargs.get("chat_template") is None:
assert hasattr(self.tokenizer, "chat_template"), (
"chat_template should be provided in apply_chat_template_kwargs or tokenizer config, "
"models like GLM can copy chat_template.jinja from instruct models"
)
# add generation prompt 模板 和 tokenize
raw_prompt = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
)
model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
# input_ids和attention_mask
input_ids = model_inputs.pop("input_ids")
attention_mask = model_inputs.pop("attention_mask")
# 根据data.max_prompt_length做左padding或截断
input_ids, attention_mask = verl_F.postprocess_data(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=self.max_prompt_length,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation,
)
if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
# qwen-vl mrope
from verl.models.transformers.qwen3_vl import get_rope_index
pass
elif self.processor is not None and "Glm4vImageProcessor" in self.processor.image_processor.__class__.__name__:
from verl.models.transformers.glm4v import get_rope_index
pass
else:
# 根据attention_mask计算position_ids
position_ids = compute_position_id_with_mask(attention_mask)
row_dict["input_ids"] = input_ids[0]
row_dict["attention_mask"] = attention_mask[0]
row_dict["position_ids"] = position_ids[0]
raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
if len(raw_prompt_ids) > self.max_prompt_length:
if self.truncation == "left":
raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
elif self.truncation == "right":
raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
elif self.truncation == "middle":
left_half = self.max_prompt_length // 2
right_half = self.max_prompt_length - left_half
raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]
elif self.truncation == "error":
raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")
row_dict["raw_prompt_ids"] = raw_prompt_ids
# encode prompts without chat template
if self.return_raw_chat:
row_dict["raw_prompt"] = messages
# get prompts with chat template
if self.return_full_prompt:
row_dict["full_prompts"] = raw_prompt # array of strings
# add index for each prompt
if "extra_info" not in row_dict or row_dict["extra_info"] is None:
row_dict["extra_info"] = dict()
index = row_dict.get("extra_info", {}).get("index", 0)
tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {})
interaction_kwargs = row_dict.get("extra_info", {}).get("interaction_kwargs", {})
need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs)
if need_tools_kwargs and not tools_kwargs:
logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"])
row_dict["index"] = index
row_dict["tools_kwargs"] = tools_kwargs
row_dict["interaction_kwargs"] = interaction_kwargs
return row_dictRayPPOTrainer:create data loader
- 根据
train_batch_size,val_batch_size来划分构建DataLoader。 - 总训练步数=
原始数据数量/train_batch_size*训练epochs
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
"""
Creates the train and validation dataloaders.
"""
# TODO: we have to make sure the batch size is divisible by the dp size
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
self.train_dataset, self.val_dataset = train_dataset, val_dataset
if train_sampler is None:
train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
if collate_fn is None:
from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
collate_fn = default_collate_fn
num_workers = self.config.data["dataloader_num_workers"]
self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
num_workers=num_workers,
drop_last=True,
collate_fn=collate_fn,
sampler=train_sampler,
)
val_batch_size = self.config.data.val_batch_size
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
batch_size=val_batch_size,
num_workers=num_workers,
shuffle=self.config.data.get("validation_shuffle", True),
drop_last=False,
collate_fn=collate_fn,
)
# 训练步数=原始数据数量/train_batch_size*训练epochs
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps
self.total_training_steps = total_training_steps
passValidate
主流程 (RayPPOTrainer.Validate)
1. 数据准备 (Data Preparation):
- 遍历函数
self.val_dataloader - 为batch中,每个prompt,生成
唯一uid prompt 复制n份,rollout n次,生成n个不同的答案。
2. 模型生成/环境交互 (Generation):
- 使用当前 Actor对batch rompts 进行推理,生成
n个答案。- 同步:
self.actor_rollout_wg.generate_sequences - 异步:
self.async_rollout_manager.generate_sequences,调用agentloop。
- 同步:
3. 评估打分 (Evaluation & Scoring):
- 把原始prompt和模型response拼接在一起
- 调用
self.val_reward_fn来为每一个生成的答案打分,reward_tensor。- 请注意:如果agentloop里
已经包含reward_score,则不会再计算,而是直接返回
- 请注意:如果agentloop里
- 除了主分数(reward),评估函数可能还会返回一些额外的信息(
reward_extra_info),- 比如准确率(
acc)、长度惩罚项等,在这里可以添加其他的指标 - 这些信息都被收集到
reward_extra_infos_dict字典中。
- 比如准确率(
4. 数据收集 (Data Collection):
- 在整个循环过程中,代码会把所有需要的信息都存储在列表中:
sample_inputs: 原始输入 (Prompts)sample_outputs: 模型生成的答案sample_gts: 人工标注的参考答案 (Ground Truths)sample_scores: 每个答案的总分sample_uids: 每个样本的唯一IDreward_extra_infos_dict: 包含所有评估指标(如reward,acc)的字典。data_source_lst: 每个样本来自哪个数据集(如 "wiki", "squad" 等)。
5. 指标处理与格式化 (Metric Processing & Formatting):
- 收集完所有验证数据后,处理成
结构化、有意义的统计指标。- 调用
process_validation_metrics来做统计。 - 再把 process_validation_metrics
返回结果整理成扁平字典 metric_dict,Key为:pfx=f"{metric_sec}/{data_source}/{var_name}/{metric_name}"- 如
val-core/wiki/reward/mean@8
- 调用
6. 返回结果
- 最后,函数返回
metric_dict,框架接收这些指标并做展示
def _validate(self):
data_source_lst = []
reward_extra_infos_dict: dict[str, list] = defaultdict(list)
# Lists to collect samples for the table
sample_inputs = []
sample_outputs = []
sample_gts = []
sample_scores = []
sample_turns = []
sample_uids = []
for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)
# repeat test batch
test_batch = test_batch.repeat(
repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
)
# we only do validation on rule-based rm
if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
return {}
# prompt input_ids
input_ids = test_batch.batch["input_ids"]
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)
sample_uids.extend(test_batch.non_tensor_batch["uid"])
ground_truths = [ item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch]
sample_gts.extend(ground_truths)
# pop input_ids, attention_mask, position_ids
test_gen_batch = self._get_gen_batch(test_batch)
test_gen_batch.meta_info = {
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
"recompute_log_prob": False,
"do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
"validate": True,
"global_steps": self.global_steps,
}
print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")
# pad to be divisible by dp_size
size_divisor = (
self.actor_rollout_wg.world_size
if not self.async_rollout_mode
else self.config.actor_rollout_ref.rollout.agent.num_workers
)
# pad&rollout
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)
if not self.async_rollout_mode:
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
else:
test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
# unpad
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
print("validation generation end")
# Store generated outputs
output_ids = test_output_gen_batch.batch["responses"]
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
sample_outputs.extend(output_texts)
# 拼接结果
test_batch = test_batch.union(test_output_gen_batch)
test_batch.meta_info["validate"] = True
# evaluate using reward_function
if self.val_reward_fn is None:
raise ValueError("val_reward_fn must be provided for validation.")
# 计算reward,如果rollout的test_batch里已经有reward,则直接返回
result = self.val_reward_fn(test_batch, return_dict=True)
reward_tensor = result["reward_tensor"]
# 一般只在最后一个位置有得分,其余位置都是0,所以各位置加起来作为最终得分
scores = reward_tensor.sum(-1).cpu().tolist()
sample_scores.extend(scores)
reward_extra_infos_dict["reward"].extend(scores)
if "reward_extra_info" in result:
for key, lst in result["reward_extra_info"].items():
reward_extra_infos_dict[key].extend(lst)
# collect num_turns of each prompt
if "__num_turns__" in test_batch.non_tensor_batch:
sample_turns.append(test_batch.non_tensor_batch["__num_turns__"])
data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))
self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
# dump generations
val_data_dir = self.config.trainer.get("validation_data_dir", None)
if val_data_dir:
self._dump_generations(
inputs=sample_inputs,
outputs=sample_outputs,
gts=sample_gts,
scores=sample_scores,
reward_extra_infos_dict=reward_extra_infos_dict,
dump_path=val_data_dir,
)
data_sources = np.concatenate(data_source_lst, axis=0)
data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict)
metric_dict = {}
for data_source, var2metric2val in data_src2var2metric2val.items():
core_var = "acc" if "acc" in var2metric2val else "reward"
# var_name: reward, acc...
# metric_name: mean@4, std@4, best@4/mean...
for var_name, metric2val in var2metric2val.items():
n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
for metric_name, metric_val in metric2val.items():
if (
(var_name == core_var)
and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"])
and (f"@{n_max}" in metric_name)
):
# 核心指标
metric_sec = "val-core"
else:
# 辅助指标
metric_sec = "val-aux"
# 最终展示名称
pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
metric_dict[pfx] = metric_val
if len(sample_turns) > 0:
sample_turns = np.concatenate(sample_turns)
metric_dict["val-aux/num_turns/min"] = sample_turns.min()
metric_dict["val-aux/num_turns/max"] = sample_turns.max()
metric_dict["val-aux/num_turns/mean"] = sample_turns.mean()
return metric_dict统计mean/std各种数据指标(metric_utils.process_validation_metrics)
trainer.ppo.metric_utils.py :process_validation_metrics
def process_validation_metrics(
data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42
) -> dict[str, dict[str, dict[str, float]]]:
"""
Process validation metrics into a structured format with statistical analysis.
This function organizes validation metrics by data source and prompt, then computes
various statistical measures including means, standard deviations, best/worst values,
and majority voting results. It also performs bootstrap sampling to estimate statistics
for different sample sizes.
Args:
data_sources: List of data source identifiers for each sample.
sample_uids: List of sample uids corresponding to each sample.
infos_dict: Dictionary mapping variable names to lists of values for each sample.
seed: Random seed for bootstrap sampling. Defaults to 42.
Returns:
A nested dictionary with the structure:
{
data_source: {
variable_name: {
metric_name: value
}
}
}
Where metric_name includes:
- "mean@N": Mean value across N samples
- "std@N": Standard deviation across N samples
- "best@N/mean": Mean of the best values in bootstrap samples of size N
- "best@N/std": Standard deviation of the best values in bootstrap samples
- "worst@N/mean": Mean of the worst values in bootstrap samples
- "worst@N/std": Standard deviation of the worst values in bootstrap samples
- "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists)
- "maj@N/std": Standard deviation of majority voting results (if "pred" exists)
Example:
>>> data_sources = ["source1", "source1", "source2"]
>>> sample_uids = ["uid1", "uid1", "uid2"]
>>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]}
>>> result = process_validation_metrics(data_sources, sample_uids, infos_dict)
>>> # result will contain statistics for each data source and variable
"""
# Group metrics by data source, prompt and variable
data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for sample_idx, data_source in enumerate(data_sources):
uid = sample_uids[sample_idx]
var2vals = data_src2uid2var2vals[data_source][uid]
for var_name, var_vals in infos_dict.items():
var2vals[var_name].append(var_vals[sample_idx])
# Calculate metrics for each group
data_src2uid2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
for data_source, uid2var2vals in data_src2uid2var2vals.items():
for uid, var2vals in uid2var2vals.items():
for var_name, var_vals in var2vals.items():
if isinstance(var_vals[0], str):
continue
metric = {}
n_resps = len(var_vals)
metric[f"mean@{n_resps}"] = np.mean(var_vals)
if n_resps > 1:
metric[f"std@{n_resps}"] = np.std(var_vals)
ns = []
n = 2
while n < n_resps:
ns.append(n)
n *= 2
ns.append(n_resps)
for n in ns:
[(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(
data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed
)
metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std
metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std
if var2vals.get("pred", None) is not None:
vote_data = [
{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True)
]
[(maj_n_mean, maj_n_std)] = bootstrap_metric(
data=vote_data,
subset_size=n,
reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")],
seed=seed,
)
metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std
data_src2uid2var2metric[data_source][uid][var_name] = metric
# Aggregate metrics across uids
data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for data_source, uid2var2metric in data_src2uid2var2metric.items():
for uid, var2metric in uid2var2metric.items():
for var_name, metric in var2metric.items():
for metric_name, metric_val in metric.items():
data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val)
data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items():
for var_name, metric2uid_vals in var2metric2uid_vals.items():
for metric_name, uid_vals in metric2uid_vals.items():
data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals)
return data_src2var2metric2valRollout 过程
内容繁多且重要,具体见笔记 AgentLoop Rollout 笔记
奖励计算
环境给出奖励奖励函数给出奖励reward model计算奖励
环境给出奖励
在agent_loop结束时,调用环境eval,把reward_score放到AgentLoopOutput里。
reward_score = await call_env_evaluate(xxx)
output = AgentLoopOutput(
prompt_ids=prompt_ids,
response_ids=response_ids[: self.response_length],
response_mask=agent_data.response_mask[: self.response_length],
reward_score=reward_score,
multi_modal_data={},
response_logprobs=agent_data.response_logprobs[: self.response_length]
if agent_data.response_logprobs
else None,
num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,
metrics=agent_data.metrics,
extra_fields={},
)在reward function计算的时候,直接返回该值
if "rm_scores" in data.batch.keys():
if return_dict:
reward_extra_keys = data.meta_info.get("reward_extra_keys", [])
reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}
return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info}
else:
return data.batch["rm_scores"] reward funtion 计算
见下文naive reward manager
# 调用reward_fn计算score
score = self.compute_score(
data_source=data_source,
solution_str=response_str,
ground_truth=ground_truth,
extra_info=extra_info,
) reward model 计算
pass
Gen RM 计算
pass
优势计算
pass
core_algos优势计算主入口
def compute_advantage(
data: DataProto,
adv_estimator: AdvantageEstimator,
gamma: float = 1.0,
lam: float = 1.0,
num_repeat: int = 1,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
) -> DataProto:
"""Compute advantage estimates for policy optimization.
This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc.
The advantage estimates are used to guide policy optimization in RL algorithms.
Args:
data (DataProto): The data containing batched model outputs and inputs.
adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++).
gamma (float, optional): Discount factor for future rewards. Defaults to 1.0.
lam (float, optional): Lambda parameter for GAE. Defaults to 1.0.
num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1.
norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in
GRPO. Defaults to True.
config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None.
Returns:
DataProto: The updated data with computed advantages and returns.
"""
# Back-compatible with trainers that do not compute response mask in fit
if "response_mask" not in data.batch.keys():
data.batch["response_mask"] = compute_response_mask(data)
# prepare response group
if adv_estimator == AdvantageEstimator.GAE:
# Compute advantages and returns using Generalized Advantage Estimation (GAE)
advantages, returns = core_algos.compute_gae_advantage_return(
token_level_rewards=data.batch["token_level_rewards"],
values=data.batch["values"],
response_mask=data.batch["response_mask"],
gamma=gamma,
lam=lam,
)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
if config.get("use_pf_ppo", False):
data = core_algos.compute_pf_ppo_reweight_data(
data,
config.pf_ppo.get("reweight_method"),
config.pf_ppo.get("weight_pow"),
)
elif adv_estimator == AdvantageEstimator.GRPO:
# Initialize the mask for GRPO calculation
grpo_calculation_mask = data.batch["response_mask"]
# Call compute_grpo_outcome_advantage with parameters matching its definition
advantages, returns = core_algos.compute_grpo_outcome_advantage(
token_level_rewards=data.batch["token_level_rewards"],
response_mask=grpo_calculation_mask,
index=data.non_tensor_batch["uid"],
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
else:
# handle all other adv estimator type other than GAE and GRPO
adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator)
adv_kwargs = {
"token_level_rewards": data.batch["token_level_rewards"],
"response_mask": data.batch["response_mask"],
"config": config,
}
if "uid" in data.non_tensor_batch: # optional
adv_kwargs["index"] = data.non_tensor_batch["uid"]
if "reward_baselines" in data.batch: # optional
adv_kwargs["reward_baselines"] = data.batch["reward_baselines"]
# calculate advantage estimator
advantages, returns = adv_estimator_fn(**adv_kwargs)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
return dataGAE 优势估计
理论笔记
核心数学公式
GAE 反向迭代计算公式
TD error
核心思想
反向迭代计算每步的优势- 计算
TD_error、计算 优势 - 对于
pad、环境返回等位置的值,利用response_mask,不做计算,即当前值不做更新
- 计算
returns=advantages+valuesadvantages=returns-values
对优势做标准化处理
@register_adv_est(AdvantageEstimator.GAE) # or simply: @register_adv_est("gae")
def compute_gae_advantage_return(
token_level_rewards: torch.Tensor,
values: torch.Tensor,
response_mask: torch.Tensor,
gamma: torch.Tensor,
lam: torch.Tensor,
):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
Args:
token_level_rewards: `(torch.Tensor)`
shape is (bs, response_length)
values: `(torch.Tensor)`
shape is (bs, response_length)
response_mask: `(torch.Tensor)`
shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
gamma is `(float)`
discounted factor used in RL
lam: `(float)`
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
# V(s_{t+1})
nextvalues = 0
# A_{t+1}
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
# TD-Error 计算公式
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
# GAE 反向迭代计算公式
lastgaelam_ = delta + gamma * lam * lastgaelam
# skip values and TD-error on observation tokens,仅llm回复的才做计算
nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues
lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
# 回报
returns = advantages + values
# 白化,标准化,减均值、除标准差
advantages = verl_F.masked_whiten(advantages, response_mask)
return advantages, returnsGRPO 组优势计算
@register_adv_est(AdvantageEstimator.GRPO)
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape is (bs, response_length)
response_mask: `(torch.Tensor)`
shape is (bs, response_length)
index: `(np.ndarray)`
index array for grouping
epsilon: `(float)`
small value to avoid division by zero
norm_adv_by_std_in_grpo: `(bool)`
whether to scale the GRPO advantage
config: `(Optional[AlgoConfig])`
algorithm configuration object
Note:
If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO.
If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).
Returns:
advantages: `(torch.Tensor)`
shape is (bs, response_length)
Returns: `(torch.Tensor)`
shape is (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
# id: [1, 1, 1, 2, 2, 2, ...]
# scores: [0, 0.5, 0, 0, 1, 0, ...]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
scores_tensor = torch.stack(id2score[idx])
id2mean[idx] = torch.mean(scores_tensor)
id2std[idx] = torch.std(scores_tensor)
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
if norm_adv_by_std_in_grpo:
# local std
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
else:
# 不做std
scores[i] = scores[i] - id2mean[index[i]]
# 仅在response位置,才有adv
scores = scores.unsqueeze(-1) * response_mask
# advantages, returns
return scores, scoresCompute LogProb (最核心,所有内容都在这)
主入口 (megatron_workers)
- 为每个data设置meta_info信息。
micro_batch_size= rollout.log_prob_micro_batch_size_per_gpumax_token_len= rollout.log_prob_max_token_len_per_gpu
- 调用actor计算
log_prob和熵。 - 返回output(
old_log_probs)、entropys。
ActorRolloutRefWorker.compute_log_prob
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@GPUMemoryLogger(role="compute_log_prob", logger=logger)
@DistProfiler.annotate(color="blue")
def compute_log_prob(self, data: DataProto):
assert self._is_actor
if self._is_offload_param:
load_megatron_model_to_gpu(self.actor_module, load_grad=False)
log_gpu_memory_usage("After load actor params and grad during compute_log_prob", logger=logger)
# 关键配置
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu
data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz
data.meta_info["temperature"] = self.config.rollout.temperature
# 核心部分,调用megatron_actor计算
output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
# 返回格式
output = DataProto.from_dict(
tensors={"old_log_probs": output, "entropys": entropys},
meta_info={"temperature": self.config.rollout.temperature},
)
output = output.to("cpu")
# clear kv cache
if self._is_offload_param:
offload_megatron_model_to_cpu(self.actor_module)
log_gpu_memory_usage("After offload actor params and grad during compute_log_prob", logger=logger)
aggressive_empty_cache(force_sync=True)
return output核心compute_log_prob (MegatronPPOActor)
从data读取meta_info:
micro_batch_size、max_token_len、use_dynamic_bsz。读取data中的
responses、input_ids、attention_mask、position_ids调用
forward_backward_batch,去做真实计算,具体见下文- 传入后处理函数
compute_logprobs_fn,根据真实response_length选择log_probs
- 传入后处理函数
pipeline逻辑:
流水线计算,完成后再`广播
MegatronPPOActor.compute_log_prob
@GPUMemoryLogger(role="megatron actor", logger=logger)
def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
Args:
data (DataProto): a DataProto containing keys
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
``responses``: tensor of shape [batch_size, response_length]. torch.int64.
Returns:
DataProto: torch.Tensor: the log_prob tensor
"""
use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False)
micro_batch_size = data.meta_info.get("micro_batch_size", None)
max_token_len = data.meta_info.get("max_token_len", None)
if use_dynamic_bsz:
assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True"
max_token_len = max_token_len * self.config.megatron.context_parallel_size
else:
assert micro_batch_size is not None, (
"micro batch size is needed for forward compute when use_dynamic_bsz is False"
)
def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None):
response = data["responses"]
response_length = response.size(1)
log_probs = output["log_probs"][:, -response_length - 1 : -1].contiguous()
return {"log_probs": log_probs}
# We make recompute_old_log_prob by default here.
# TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be
# handled by user outside
recompute_old_log_prob = self.config.get("recompute_old_log_prob", True)
entropys = torch.Tensor()
if recompute_old_log_prob:
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
batch = data.select(batch_keys=select_keys).batch
input_ids = batch["input_ids"]
batch_size = input_ids.size(0)
response = batch["responses"]
response_length = response.size(1)
with torch.no_grad():
output = self.forward_backward_batch(
data,
forward_only=True,
post_process_fn=compute_logprobs_fn,
calculate_entropy=calculate_entropy,
use_dynamic_bsz=use_dynamic_bsz,
micro_batch_size=micro_batch_size,
max_token_len=max_token_len,
)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# only on last rank. It should be on every tp rank
if calculate_entropy:
log_probs = [o[0]["log_probs"] for o in output["output"]] # (bs, seq_size) #
else:
log_probs = [o["log_probs"] for o in output["output"]] # (bs, seq_size) #
log_probs = torch.cat(log_probs, dim=0).to(torch.float32)
if use_dynamic_bsz:
indices = output["indices"]
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
log_probs = log_probs[revert_indices]
else:
log_probs = torch.empty(
size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device
)
# broadcast across pp ranks
log_probs = log_probs.to(get_device_id())
torch.distributed.broadcast(
tensor=log_probs,
src=mpu.get_pipeline_model_parallel_last_rank(),
group=mpu.get_pipeline_model_parallel_group(),
async_op=False,
)
log_probs = log_probs.to("cpu")
if calculate_entropy:
# Note that o[0] is metrics, o[1] is entropy
if mpu.is_pipeline_last_stage(ignore_virtual=True):
entropys = torch.cat([o[1] for o in output["output"]], dim=0)
entropys = entropys.to(torch.float32)
if use_dynamic_bsz:
indices = output["indices"]
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == entropys.size(0), f"{len(indices)} vs. {entropys.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
entropys = entropys[revert_indices]
else:
entropys = torch.empty(
size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device
)
# broadcast across pp ranks
entropys = entropys.to(get_device_id())
torch.distributed.broadcast(
tensor=entropys,
src=mpu.get_pipeline_model_parallel_last_rank(),
group=mpu.get_pipeline_model_parallel_group(),
async_op=False,
)
entropys = entropys.to("cpu")
# add empty cache after each compute
get_torch_device().empty_cache()
return log_probs, entropys 核心forward_backward_batch主流程 (MegatronPPOActor)
- 传入data即为
mini_batch,根据micro_batch_size划分成多个micro_batches- micro_batch_size=
log_prob_micro_batch_size_per_gpu,见上文主入口
- micro_batch_size=
- 计算n_micro_batch,构建
batch_generator,后续运算基于此 - 计算
total_seqlen=micro_batch_size * seq_len(input_ids长度) - 调用megatron.core.pipeline_parallel的
get_forward_backward_func- 传参
自定义forward_step_func,即forward_step,里包含loss_func,具体见下文。
- 传参
MegatronPPOActor.forward_backward_batch
def forward_backward_batch(
self,
data: DataProto,
forward_only=False,
post_process_fn=None,
calculate_entropy=False,
use_dynamic_bsz=False,
micro_batch_size=None,
max_token_len=None,
mini_batch_size=None,
):
"""
We assume:
- The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input
- The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled
"""
# broadcast from last pp rank to all other pp ranks
# TODO: actually, we just need to control the sampling order.
data.to(get_device_id())
data.batch = data.batch.contiguous()
mini_batch = data
broadcast_dict_tensor(
mini_batch.batch,
src=mpu.get_pipeline_model_parallel_last_rank(),
group=mpu.get_pipeline_model_parallel_group(),
)
mini_batch.to("cpu")
# split into micro-batches
mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool)
self.has_multi_modal_inputs = "multi_modal_inputs" in mini_batch.non_tensor_batch.keys()
if self.has_multi_modal_inputs:
mini_batch.batch["multi_modal_inputs"] = mini_batch.non_tensor_batch["multi_modal_inputs"]
mini_batch.batch["multi_modal_inputs_idx"] = torch.Tensor(
list(range(len(mini_batch.non_tensor_batch["multi_modal_inputs"])))
).to(torch.int64)
if mini_batch.batch["position_ids"].dim() == 3: # qwen2vl mrope [bs, 3, seq_len]
mini_batch.batch["position_ids"] = mini_batch.batch["position_ids"][
:, 0
] # mcore patch recompute qwen2vl's pos ids during forward
indices = None
temperature = data.meta_info["temperature"]
if use_dynamic_bsz:
assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True"
vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
if vpp_size is not None and vpp_size > 1:
microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage
micro_batches, indices = rearrange_micro_batches(
batch=mini_batch.batch,
num_batches_divided_by=microbatch_group_size_per_vp_stage,
max_token_len=max_token_len,
)
assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, (
f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage "
f"{microbatch_group_size_per_vp_stage} for megatron backend"
)
else:
micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len)
total_seqlen = max_token_len
else:
assert micro_batch_size is not None, (
"micro_batch_size is needed to be passed in when not using dynamic batch size"
)
micro_batches = mini_batch.batch.split(micro_batch_size)
seq_len = micro_batches[0]["input_ids"].shape[1]
total_seqlen = micro_batch_size * seq_len
# compute input shapes for pp stages
n_micro_batch = len(micro_batches)
# from megatron.core.pipeline_parallel import get_forward_backward_func
forward_backward_func = get_forward_backward_func()
# batch should be a list of batches inside micro-batches
batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.actor_module))
# TODO: we may use the new schedule instead
# for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)
if mpu.get_pipeline_model_parallel_world_size() > 1:
losses_reduced = forward_backward_func(
forward_step_func=forward_step, # 自定义的forward_step #
data_iterator=batch_generator,
model=self.actor_module,
num_microbatches=n_micro_batch,
seq_length=total_seqlen, # no use when input_shapes was set
micro_batch_size=1, # no use when input_shapes was set
forward_only=forward_only,
)
else:
losses_reduced = forward_backward_func(
forward_step_func=forward_step, # 自定义的forward_step #
data_iterator=batch_generator,
model=self.actor_module,
num_microbatches=n_micro_batch,
seq_length=total_seqlen, # in use for pp = 1 #
micro_batch_size=1, # in use for pp = 1 #
forward_only=forward_only,
)
# loss_reduces contains the stats returned from loss_func
if self.has_multi_modal_inputs:
data.batch.pop("multi_modal_inputs")
data.batch.pop("multi_modal_inputs_idx")
data.non_tensor_batch.pop("multi_modal_inputs")
losses_reduced = {"output": losses_reduced}
if use_dynamic_bsz:
losses_reduced["indices"] = indices
return losses_reduced 核心forward_step (forward_backward_batch)
- 单步batch data
- 为前面的
batch_generator每一步取得数据- micro_batch_size=
log_prob_micro_batch_size_per_gpu
- micro_batch_size=
- 获取
input_ids、attention_mask、position_ids,responses - 构建
label和label_mask,label为responses
- 为前面的
- 根据hf_config调用get_mcore_forward_fn获取
forward_fn - 自定义
logits_processor函数,传入logits、label、label_mask,计算log_probs- 根据给定label(llm reponse) 计算 logits,label_mask处置为0,
- 具体计算见下文,本质就是
筛选index=row_labels的log_probs
调用forward_fn,获得output- 入参:
model、input_ids、attention_mask、position_ids和logits_processor
- 入参:
- 最后,
传入loss_func、batch data、meta_info去计算loss
MegatronPPOActor.forward_backward_batch.forward_step
def forward_step(batch_iter, model, return_schedule_plan: bool = False):
"""
Args:
batch_iter: the batch iterator
model: the model
return_schedule_plan: whether to return the schedule plan, for 1f1b overlap
"""
batch = next(batch_iter)
batch = batch.to(get_device_id())
batch = batch.contiguous()
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"].to(bool)
position_ids = batch["position_ids"]
multi_modal_inputs = {}
if "multi_modal_inputs" in batch:
from verl.utils.model import extract_multi_modal_inputs
indices = batch.get("multi_modal_inputs_idx", None)
multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices)
responses = batch["responses"]
response_length = responses.size(1)
label = position_ids.clone()
label[:, -response_length - 1 : -1] = responses
label_mask = attention_mask.clone()
label_mask[:, : -response_length - 1] = False
label_mask[:, -1] = False
from verl.models.mcore import get_mcore_forward_fn, get_mcore_forward_fused_fn
if self.use_fused_kernels:
forward_fn = get_mcore_forward_fused_fn(self.hf_config)
if return_schedule_plan:
forward_fn = gptmodel_forward_1f1b_overlap
# return dict of [logits, entropy]
output = forward_fn(
model=model,
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
labels=label,
labels_mask=label_mask,
temperature=temperature,
multi_modal_inputs=multi_modal_inputs,
)
else:
forward_fn = get_mcore_forward_fn(self.hf_config)
def logits_processor(logits, label, label_mask):
assert logits.shape[:2] == label.shape[:2]
assert label.shape == label_mask.shape
logits.div_(temperature)
ret = {}
if calculate_entropy:
logits_bak = logits.clone()
logger.warning_once(
"For memory-efficient computation, enable fused kernels via "
"`actor_rollout_ref.model.use_fused_kernels=True`. "
"The current `clone()` operation ensures correctness but increases memory usage."
)
entropy = vocab_parallel_entropy(logits)
ret["entropy"] = entropy
else:
logits_bak = logits
# 根据给定label(llm reponse) 计算 logits,label_mask处置为0
log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label)
log_probs = log_probs.masked_fill(~label_mask, 0.0)
ret["log_probs"] = log_probs
return ret
logits_processor_args = {"label": label, "label_mask": label_mask}
output = forward_fn(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
multi_modal_inputs=multi_modal_inputs,
logits_processor=logits_processor,
logits_processor_args=logits_processor_args,
)
if forward_only:
meta_info = None
else:
clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
meta_info = {
"clip_ratio": self.config.clip_ratio,
"entropy_coeff": self.config.entropy_coeff,
"clip_ratio_c": clip_ratio_c,
}
return output, partial(loss_func, data=batch, meta_info=meta_info) 熵和logprobs计算
- VocabParallelEntropy:
根据logits计算熵 - 熵计算公式推导和实现(Verl+FSDP) 笔记
class _VocabParallelEntropy(torch.autograd.Function):
def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor:
def mul_reduce(a, b):
return (a * b).sum(dim=-1, keepdim=True)
# 稳定性操作,避免溢出,减去最大值
logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values
normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max
# exp_logits
normalized_exp_logits = normalized_vocab_parallel_logits.exp_()
# sum_exp_logits
normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True)
# 计算每个logit概率
softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits)
# p_i * logits_i
sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits)
# 最终的熵,log_sum_exp_logits - sum_softmax_times_logits + logits_max
entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits
ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits)
return entropy.squeeze(dim=-1)
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors
# reuse softmax_logits as grad
vocab_parallel_logits.sub_(sum_softmax_times_logits)
softmax_logits.mul_(vocab_parallel_logits)
softmax_logits.mul_(grad_output.unsqueeze(dim=-1))
# recover vocab_parallel_logits
vocab_parallel_logits.add_(sum_softmax_times_logits)
softmax_logits.mul_(-1)
return softmax_logitsFSDP 实现-torch_function.py
- flash_attention实现:再点进去看代码。
- 有简单版实现:
logprobs_from_logits_v2- 其实就是
筛选index=row_labels的log_probs
- 其实就是
Megatron实现
- 待看
def logprobs_from_logits(logits, labels, inplace_backward=True):
"""
Compute per-token log-probabilities for the given labels.
Uses a Flash-Attention–based cross-entropy (if available) for efficient backward,
otherwise falls back to a standard log-softmax+gather approach.
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
Args:
logits (Tensor): Model outputs of shape (..., vocab_size).
labels (LongTensor): True class indices of shape matching logits[..., :-1].
inplace_backward (bool): If True and Flash-Attn is available, perform backward in-place.
Returns:
Tensor: Log-probabilities of the target labels, shape logits.shape[:-1].
"""
if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE:
batch_dim = logits.shape[:-1]
last_dim = logits.shape[-1]
logits = logits.reshape(-1, last_dim)
labels = labels.reshape(-1)
output = logprobs_from_logits_flash_attn(logits, labels, inplace_backward=inplace_backward)
output = output.view(*batch_dim)
elif NPU_CROSS_ENTROPY_LOSS_AVAILABLE:
output = logprobs_from_logits_torch_npu(logits, labels)
else:
output = logprobs_from_logits_v2(logits, labels)
return output核心方法
def logprobs_from_logits_v2(logits: torch.FloatTensor, labels):
"""
A memory efficient implementation of logprobs_from_logits
"""
if logits.dtype in [torch.float32, torch.float64]:
logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(logit, dim=-1) for logit in logits])
logprobs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
logprobs_labels = []
for row_logits, row_labels in zip(logits, labels, strict=True): # loop to reduce peak mem consumption #
row_logprobs = F.log_softmax(row_logits, dim=-1)
# 取index=row_labels的logprobs
row_logprobs_labels = row_logprobs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
logprobs_labels.append(row_logprobs_labels)
logprobs_labels = torch.stack(logprobs_labels)
return logprobs_labelsMegatron实现:封装的很,看不进去了。
def vocab_parallel_log_probs_from_logits(logits, labels):
"""TODO(zhangchi.usc1992): We may change the implementation later"""
from megatron.core import tensor_parallel
return -tensor_parallel.vocab_parallel_cross_entropy(vocab_parallel_logits=logits, target=labels)核心Policy Loss计算 (forward_backward_batch)
- 在上文
forward_step里,计算出output的entropy和log_probs之后,再计算loss- 如果forward_only,则无需计算。
- 否则,才需要计算loss。
def forward_step(batch_iter, model, return_schedule_plan: bool = False):
# ...
# ...
return output, partial(loss_func, data=batch, meta_info=meta_info)理论笔记
- 具体理论见:PPO-Actor 笔记
- Policy Loss =
PPO loss-熵奖励+KL 惩罚 - Policy Loss =
pg_loss-entropy_coeff*entropy_loss+kl_loss_coef*kl_loss
PG Loss
优势*重要性权重+PPO-Clip- 具体见下文
熵奖励
- Policy Loss = Policy Loss -
entropy_coeff*entropy_loss
KL惩罚
- Policy Loss = Policy Loss +
kl_loss_coef*kl_loss
Loss_func 返回值
policy_loss,[metrics, ret_entropy]
数据读取
- 从output读取当前
log_probs和entropy - 从data读取
old_log_probs、advantages - 从config里读取相关参数,
entropy_coeff,kl_loss_coef,loss_agg_mode等。
- 从output读取当前
计算 PGLoss- 根据配置获得policy_loss_fn,默认是
compute_policy_loss_vanilla - 输入
old_log_prob、log_prob、advantages、response_mask、loss_agg_mode等参数
- 根据配置获得policy_loss_fn,默认是
计算 熵奖励- 根据entropy_coeff,总loss
减去熵奖励
- 根据entropy_coeff,总loss
计算 KL惩罚- 利用ref_log_prob和log_prob计算klloss,根据kl_loss_coef,总loss
加上klloss
- 利用ref_log_prob和log_prob计算klloss,根据kl_loss_coef,总loss
MegatronPPOActor.forward_backward_batch.loss_func
def loss_func(output, data, meta_info):
# For memory efficiency
# We move calculation of entropy to compute_log_probs, forward_only == True
log_probs = None
entropy = None
if isinstance(output, dict):
log_probs = output["log_probs"]
if "entropy" in output:
entropy = output["entropy"]
else:
assert isinstance(output, torch.Tensor)
log_probs = output
device = log_probs.device
metrics = {}
if forward_only:
if post_process_fn is None:
pass
# metrics["logits"] = output
else:
stats = post_process_fn(output, data)
metrics.update(stats)
if not calculate_entropy:
return torch.tensor(1.0, device=device), metrics
responses = data["responses"]
response_length = responses.size(1)
response_mask = data["response_mask"].to(bool)
loss_agg_mode = self.config.loss_agg_mode
# compute policy loss
log_prob = log_probs[:, -response_length - 1 : -1].contiguous()
ret_entropy = None
stats = {}
if not forward_only:
old_log_prob = data["old_log_probs"]
advantages = data["advantages"]
entropy_coeff = self.config.entropy_coeff
loss_agg_mode = self.config.loss_agg_mode
# 调用core_algos.py里的policy_loss_fn,具体见下文
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
policy_loss_fn = get_policy_loss_fn(loss_mode)
# Extract pre-computed rollout correction weights if present
# Weights are computed centrally in trainer and added when algorithm.rollout_is=True
rollout_is_weights = data.get("rollout_is_weights", None)
pg_loss, pg_metrics = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
rollout_is_weights=rollout_is_weights,
)
stats.update(pg_metrics)
# Skip if using pure rollout correction mode (metrics already in pg_metrics)
rollout_log_prob = data.get("rollout_log_probs", None)
if loss_mode != "rollout_correction" and rollout_log_prob is not None:
# Compute metrics using CURRENT policy π_θ vs π_rollout
# Tracks evolving off-policy gap as π_θ updates during mini-batch training
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_corr_metrics_from_logprobs
rollout_corr_metrics = compute_rollout_corr_metrics_from_logprobs(
log_prob=log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=response_mask,
)
stats.update(rollout_corr_metrics)
stats["actor/pg_loss"] = pg_loss.detach().item()
policy_loss = pg_loss
if calculate_entropy:
# 熵奖励
entropy = output["entropy"][:, -response_length - 1 : -1].contiguous()
if not forward_only:
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
entropy_coeff = meta_info["entropy_coeff"]
policy_loss = pg_loss - entropy_coeff * entropy_loss
else:
ret_entropy = entropy
if forward_only:
policy_loss = torch.tensor(1.0, device=device)
else:
if self.config.use_kl_loss:
# KL 惩罚
ref_log_prob = data["ref_log_prob"]
# compute kl loss
kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
metrics["actor/kl_loss"] = kl_loss.detach().item()
metrics["actor/kl_coef"] = self.config.kl_loss_coef
# return loss and stats
append_to_dict(metrics, stats)
return policy_loss, [metrics, ret_entropy] 策略梯度 PG Loss 计算 (core_algos.py)
PPO Loss
优势*重要性权重,PPO-Clip,PPO-Dual-Clip- Loss 聚合类型:
token-mean、seq-mean-token-mean等。 - 入参
old_log_prob、log_prob、advantages、response_mask、loss_agg_mode等
具体计算过程
计算
重要性采样权重和ppo_kl监控指标ratio=torch.exp(log_prob - old_log_prob)就是用
-(log_prob - old_log_prob),再根据response_mask做mean,近似kl
计算朴素无clip的pg_losses1
pg_losses1= -advantages*ratio
计算clip_pg_losses1
- 根据
clip_range(low和high),对IS权重做clip,再乘以优势,计算pg_losses2 - pg_losses1和pg_losses2 选大的,
loss是选大的
- 根据
计算clip_pg_losses2 (
dual clip loss),具体原理见 dual-clip-loss- 使用clip_ratio_c=3.0乘以优势,和clip_pg_losses1 选小的。
计算最终clip_pg_losses
- 在
优势小于0的token,选择clip_pg_losses2;其余仍选择clip_pg_losses1 - 根据
loss_agg_mode做loss聚合
- 在
core_alogs.compute_policy_loss_vanilla
@register_policy_loss("vanilla") # type: ignore[arg-type]
def compute_policy_loss_vanilla(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
"""
Compute the clipped policy objective and related metrics for PPO.
Adapted from
https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
config: `(verl.trainer.config.ActorConfig)`:
config for the actor.
rollout_log_probs: `(torch.Tensor)`:
log probabilities of actions under the rollout policy, shape (batch_size, response_length).
"""
assert config is not None
assert not isinstance(config, AlgoConfig)
clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
"clip_ratio_c", 3.0
)
cliprange = clip_ratio
cliprange_low = clip_ratio_low
cliprange_high = clip_ratio_high
assert clip_ratio_c > 1.0, (
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
+ f" but get the value: {clip_ratio_c}."
)
negative_approx_kl = log_prob - old_log_prob
# KL Clip for stability
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
# IS 重要性权重
ratio = torch.exp(negative_approx_kl)
# 近似KL散度,作为监控指标,计算均值
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
# 朴素无clip的,优势*重要性权重
pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
# ppo-clip,第二项,对ratio做clip,乘以advantages
pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) # - clip(ratio, 1-cliprange, 1+cliprange) * A #
# max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
# clip 比例,有多少loss2 > loss1的
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
# dual ppo clip,在优势小于0时,额外增加clip
pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
# dual ppo clip 比例
pg_clipfrac_lower = verl_F.masked_mean(
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask
)
# dual ppo clip,在优势小于0时,额外增加一个裁剪
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
# Apply rollout correction weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
pg_metrics = {
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(),
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
}
return pg_loss, pg_metricsloss聚合类型
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):
"""
Aggregate the loss matrix into a scalar.
Args:
loss_mat: `(torch.Tensor)`:
shape: (bs, response_length)
loss_mask: `(torch.Tensor)`:
shape: (bs, response_length)
loss_agg_mode: (str) choices:
method to aggregate the loss matrix into a scalar.
Returns:
loss: `a scalar torch.Tensor`
aggregated loss
"""
if loss_agg_mode == "token-mean":
loss = verl_F.masked_mean(loss_mat, loss_mask)
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences
loss = verl_F.masked_mean(seq_losses, seq_mask) # seq-mean
elif loss_agg_mode == "seq-mean-token-mean":
seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean
seq_mask = (seq_mask > 0).float() # exclude fully masked sequences
loss = verl_F.masked_mean(seq_losses, seq_mask) # seq-mean
elif loss_agg_mode == "seq-mean-token-sum-norm":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor
# (loss_mask.shape[-1]) should ideally be constant
# throughout training to well-replicate the DrGRPO paper.
# TODO: Perhaps add user-defined normalizer argument to
# agg_loss to ensure divisor stays constant throughout.
else:
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
return lossUpdate Critic
def fit(self):
# ....
# ....
# update critic
if self.use_critic:
with marked_timer("update_critic", timing_raw, color="pink"):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)update_critic 主流程 (Megatron CriticWorker)
- 为data设置meta_info
micro_batch_size=ppo_micro_batch_size_per_gpu
- 构建dataloader
- 根据
ppo_mini_batch_size和ppo_epochs建立,非常重要!! - 代码
- 根据
- 传入dataloader,调用
critic的update_critic方法,更新critic,具体见下文 - 返回一些metrics
CriticWorker.update_critic
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic"))
@DistProfiler.annotate(color="pink")
def update_critic(self, data: DataProto):
data = data.to(get_device_id())
if self._is_offload_param:
load_megatron_model_to_gpu(self.critic_module)
if self._is_offload_optimizer:
load_megatron_optimizer(self.critic_optimizer)
dataloader = self.critic.make_minibatch_iterator(data)
with Timer(name="update_critic", logger=None) as timer:
metrics = self.critic.update_critic(dataloader=dataloader)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
from verl.utils.megatron.optimizer import get_megatron_last_lr
metrics["critic/lr"] = get_megatron_last_lr(self.critic_optimizer)
self.critic_optimizer_scheduler.step(1)
output = DataProto(batch=None, meta_info={"metrics": metrics})
if self._is_offload_param:
offload_megatron_model_to_cpu(self.critic_module)
if self._is_offload_optimizer:
offload_megatron_optimizer(self.critic_optimizer)
output = output.to("cpu")
return outputmake_minibatch_iterator代码
- 传入data,调用data.make_iterator,关键参数
mini_batch_size=self.config.ppo_mini_batch_sizeepochs=self.config.ppo_epochs
MegatronPPOCritic.make_minibatch_iterator
def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"]
data = data.select(batch_keys=select_keys)
return data.make_iterator(
mini_batch_size=self.config.ppo_mini_batch_size,
epochs=self.config.ppo_epochs,
seed=self.config.data_loader_seed,
dataloader_kwargs={"shuffle": self.config.shuffle},
)update_critic (MegatronPPOCriitc)
遍历dataloader,获取当前data,多次计算和更新micro_batch_size=ppo_micro_batch_size_per_gpumini_batch_size=ppo_mini_batch_size
- 单步更新过程
- critic_optimizer.
zero_grad(),梯度归零 - 调用
forward_backward_batch计算各种metrics(含loss),具体见上文 - critic_optimizer.
step(),更新actor
- critic_optimizer.
MegatronPPOCritic.update_critic
@GPUMemoryLogger("megatron critic", logger=logger)
def update_critic(self, dataloader: Iterable[DataProto]):
metrics = {}
for data in dataloader:
self.critic_optimizer.zero_grad()
# use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
for chunk in self.critic_module:
chunk.zero_grad_buffer()
micro_batch_size = self.config.ppo_micro_batch_size_per_gpu
max_token_len = None
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size
# 最核心部分
metric_micro_batch = self.forward_backward_batch(
data,
forward_only=False,
use_dynamic_bsz=self.config.use_dynamic_bsz,
micro_batch_size=micro_batch_size,
max_token_len=max_token_len,
mini_batch_size=self.config.ppo_mini_batch_size,
)
metric_micro_batch = metric_micro_batch["output"]
update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step()
learning_rate = self.critic_optimizer.param_groups[-1]["lr"]
data = {"critic/grad_norm": grad_norm, "critic/lr": learning_rate}
append_to_dict(metrics, data)
if update_successful:
# allgather already execute in optimizer.step in new megatron
pass
else:
raise NotImplementedError
for metric in metric_micro_batch:
append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics.
# add empty cache after each compute
get_torch_device().empty_cache()
return metricsforward_backward_batch (MegatronPPOCritic)
传入data即为
mini_batch,根据micro_batch_size划分成多个micro_batches- micro_batch_size=
ppo_micro_batch_size_per_gpu,见上文主入口
- micro_batch_size=
计算n_micro_batch,构建
batch_generator,后续运算基于此- 计算
total_seqlen=micro_batch_size * seq_len(input_ids长度)
- 计算
调用megatron.core.pipeline_parallel的
get_forward_backward_func- 传参
自定义forward_step_func,即forward_step,里包含loss_func,具体见下文。
- 传参
MegatronPPOCritic.forward_backward_batch
def forward_backward_batch(
self,
data: DataProto,
forward_only=False,
use_dynamic_bsz=False,
micro_batch_size=None,
max_token_len=None,
mini_batch_size=None,
):
# broadcast from last pp rank to all other pp ranks
data.to(get_device_id())
mini_batch = data
mini_batch.batch = mini_batch.batch.contiguous()
broadcast_dict_tensor(
mini_batch.batch,
src=mpu.get_pipeline_model_parallel_last_rank(),
group=mpu.get_pipeline_model_parallel_group(),
)
mini_batch.to("cpu")
# split into micro-batches
mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool)
indices = None
if use_dynamic_bsz:
assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True"
vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
if vpp_size is not None and vpp_size > 1:
microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage
micro_batches, indices = rearrange_micro_batches(
batch=mini_batch.batch,
num_batches_divided_by=microbatch_group_size_per_vp_stage,
max_token_len=max_token_len,
)
assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, (
f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage "
f"{microbatch_group_size_per_vp_stage} for megatron backend"
)
else:
micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len)
total_seqlen = max_token_len
else:
assert micro_batch_size is not None, (
"micro_batch_size is needed to be passed in when not using dynamic batch size"
)
micro_batches = mini_batch.batch.split(micro_batch_size)
seq_len = micro_batches[0]["input_ids"].shape[1]
total_seqlen = micro_batch_size * seq_len
n_micro_batch = len(micro_batches)
forward_backward_func = get_forward_backward_func()
# batch should be a list of batches inside micro-batches
batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.critic_module))
# TODO: we may use the new schedule instead
# for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)
if mpu.get_pipeline_model_parallel_world_size() > 1:
losses_reduced = forward_backward_func(
forward_step_func=forward_step,
data_iterator=batch_generator,
model=self.critic_module,
num_microbatches=n_micro_batch,
seq_length=total_seqlen, # no use when input_shapes was set
micro_batch_size=1, # no use when input_shapes was set
forward_only=forward_only,
)
else:
losses_reduced = forward_backward_func(
forward_step_func=forward_step,
data_iterator=batch_generator,
model=self.critic_module,
num_microbatches=n_micro_batch,
seq_length=total_seqlen, # in use for pp = 1 #
micro_batch_size=1, # in use for pp = 1 #
forward_only=forward_only,
)
# loss_reduces contains the stats returned from loss_func
losses_reduced = {"output": losses_reduced}
if use_dynamic_bsz:
losses_reduced["indices"] = indices
return losses_reducedforward_step (forward_backward_batch)
当前dataloader信息,对batch根据下面2个参数构建dataloader,具体见上文mini_batch_size=ppo_mini_batch_sizeepochs=actor_rollout_ref.actor.ppo_epochs
- 单步batch data
- 为前面的
batch_generator每一步取得数据- micro_batch_size=
ppo_micro_batch_size_per_gpu
- micro_batch_size=
- 获取
input_ids、attention_mask、position_ids
- 为前面的
调用forward_fn,获得output- 入参:
model、input_ids、attention_mask、position_ids
- 入参:
- 最后,
传入loss_func、batch data、meta_info去计算loss
MegatronPPOCritic.forward_backward_batch.forward_step
def forward_step(batch_iter, model):
batch = next(batch_iter)
batch = batch.to(get_device_id())
batch = batch.contiguous()
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
position_ids = batch["position_ids"]
from verl.models.mcore import get_mcore_forward_fn
forward_fn = get_mcore_forward_fn(self.hf_config)
output = forward_fn(
model,
input_ids,
attention_mask,
position_ids,
{}, # multi_modal_inputs
value_model=True,
)
# 计算loss
return output, partial(loss_func, data=batch, meta_info={}) loss_func (forward_backward_batch)
笔记
实际计算过程
- 基于preds和returns计算critic loss,但保留旧的values来做ValueClip,防止更新过大
MegatronPPOCritic.forward_backward_batch.loss_func
def loss_func(output, data, meta_info):
nonlocal use_dynamic_bsz
if forward_only:
return torch.tensor(1.0, device=output.device), {"vpreds": output}
responses = data["responses"]
attention_mask = data["attention_mask"]
# 本来应该用returns,但也用values做clip,防止更新过大
values = data["values"]
returns = data["returns"]
response_length = responses.size(1)
# response mask
response_mask = attention_mask[:, -response_length:]
# value clip
cliprange_value = self.config.cliprange_value
vpreds = output # (bs, sequence_length)
vpreds = vpreds[:, -response_length - 1 : -1]
vf_loss, vf_clipfrac = core_algos.compute_value_loss(
vpreds=vpreds,
values=values,
returns=returns,
response_mask=response_mask,
cliprange_value=cliprange_value,
loss_agg_mode=self.config.loss_agg_mode,
)
stats = {
"critic/vf_loss": vf_loss.detach().item(),
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(),
}
return vf_loss, statscompute_value_loss (core_aglos.py)
core_aglos.compute_value_loss
def compute_value_loss(
vpreds: torch.Tensor,
returns: torch.Tensor,
values: torch.Tensor,
response_mask: torch.Tensor,
cliprange_value: float,
loss_agg_mode: str = "token-mean",
):
"""
Compute the clipped value-function loss for PPO.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
Args:
vpreds (torch.FloatTensor):
Predicted values from the value head, shape (batch_size, response_length).
values (torch.FloatTensor):
Old (baseline) values from the value head, shape (batch_size, response_length).
returns (torch.FloatTensor):
Ground-truth returns, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the value loss calculation.
cliprange_value (float):
Clip range for value prediction updates.
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
Returns:
vf_loss (torch.FloatTensor):
A scalar tensor containing the aggregated value-function loss.
vf_clipfrac (float):
Fraction of elements where the clipped loss was used.
"""
# 对vpred做clip,防止更新过大,对新价值预测做裁剪
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
# clip
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
clipped_vf_losses = torch.max(vf_losses1, vf_losses2)
# agg loss,乘以0.5是为了消除平方梯度里的2
vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
return vf_loss, vf_clipfracUpdate Actor
RayPPOTrainer.fit()里的actor更新代码
def fit(self):
# ....
# ....
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with marked_timer("update_actor", timing_raw, color="red"):
rollout_config = self.config.actor_rollout_ref.rollout
batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable
# TODO: Make "temperature" single source of truth from generation.
batch.meta_info["temperature"] = rollout_config.temperature
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)update_actor 主入口 (megatron_worker)
- 为data设置meta_info
micro_batch_size=ppo_micro_batch_size_per_gpu
- 构建dataloader
- 根据
ppo_mini_batch_size和ppo_epochs建立dataloader,非常重要!!
- 根据
- 传入dataloader,调用
actor的update_policy方法,更新策略,具体见下文 - 返回一些metrics
ActorRolloutRefWorker.update_actor
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@GPUMemoryLogger(role="update_actor", logger=logger)
@DistProfiler.annotate(color="red")
def update_actor(self, data: DataProto):
assert self._is_actor
if self._is_offload_param:
load_megatron_model_to_gpu(self.actor_module)
log_gpu_memory_usage("After load actor params and grad during update_actor", logger=logger)
if self._is_offload_optimizer:
load_megatron_optimizer(self.actor_optimizer)
log_gpu_memory_usage("After load actor optimizer during update_actor", logger=logger)
micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
dataloader = self.actor.make_minibatch_iterator(data=data)
with Timer(name="update_policy", logger=None) as timer:
metrics = self.actor.update_policy(dataloader=dataloader)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3)
metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3)
metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
from verl.utils.megatron.optimizer import get_megatron_last_lr
metrics["actor/lr"] = get_megatron_last_lr(self.actor_optimizer)
self.actor_optimizer_scheduler.step(1)
# TODO: here, we should return all metrics
output = DataProto(meta_info={"metrics": metrics})
output = output.to("cpu")
if self._is_offload_param:
offload_megatron_model_to_cpu(self.actor_module)
log_gpu_memory_usage("After offload actor params and grad during update_actor", logger=logger)
if self._is_offload_optimizer:
offload_megatron_optimizer(self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger)
aggressive_empty_cache(force_sync=True)
return outputmake_minibatch_iterator 代码
- 传入data,调用data.make_iterator,关键参数
mini_batch_size=self.config.ppo_mini_batch_sizeepochs=self.config.ppo_epochs
MegatronPPOActor.make_minibatch_iterator
def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
"""Make minibatch iterator for updating the actor
Args:
data (DataProto): a DataProto containing keys
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where
``sequence_length = prompt_length + response_length``
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64
``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that
responses = input_ids[:, -response_length:]
``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability
of responses.
``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of
responses.
See PPO paper for details. https://arxiv.org/abs/1707.06347
Returns:
"""
select_keys = [
"responses",
"input_ids",
"attention_mask",
"response_mask",
"position_ids",
"old_log_probs",
"advantages",
]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
# Include pre-computed IS weights if present in batch
# Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True
if "rollout_is_weights" in data.batch.keys():
select_keys.append("rollout_is_weights")
self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
if self.has_multi_modal_inputs:
data = data.select(select_keys, ["multi_modal_inputs"])
else:
data = data.select(batch_keys=select_keys)
return data.make_iterator(
mini_batch_size=self.config.ppo_mini_batch_size,
epochs=self.config.ppo_epochs,
seed=self.config.data_loader_seed,
dataloader_kwargs={"shuffle": self.config.shuffle},
)update_policy (MegatronPPOActor)
MegatronPPOActor.update_policy
当前dataloader信息,对batch根据下面2个参数构建dataloader,具体见上文mini_batch_size=ppo_mini_batch_sizeepochs=actor_rollout_ref.actor.ppo_epochs
遍历dataloader,获取当前data,多次计算和更新micro_batch_size=ppo_micro_batch_size_per_gpu
- 单步更新过程
- actor_optimizer
.zero_grad(),梯度归零 - 调用
forward_backward_batch计算各种metrics(含loss),具体见上文 - actor_optimizer.
step(),更新actor
- actor_optimizer
@GPUMemoryLogger(role="megatron actor", logger=logger)
def update_policy(self, dataloader: Iterable[DataProto]) -> dict:
"""Update the policy with an iterator of DataProto
Args:
dataloader (Iterable[DataProto]): an iterator over the DataProto that returns by ``make_minibatch_iterator``
The keys of each data batch is described in the make_minibatch_iterator.
Returns:
Dict: a dictionary containing the statistics. Note that the statistics are only valid in the last pp stage
and users have to combine the output in each dp rank manually.
"""
metrics = {}
if self.use_torch_profiler and self.prof and self.prof.enable:
self.prof.start()
for data in dataloader:
self.actor_optimizer.zero_grad()
# use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
for chunk in self.actor_module:
# if use distributed optimizer, zero grad buffer will be handled by optimizer
chunk.zero_grad_buffer()
calculate_entropy = self.config.entropy_coeff != 0
if data.meta_info.get("micro_batch_size", None) is not None:
micro_batch_size = data.meta_info["micro_batch_size"]
else:
micro_batch_size = self.config.ppo_micro_batch_size_per_gpu
max_token_len = None
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size
# 调用forward_backward_batch,具体内容见上文
metric_micro_batch = self.forward_backward_batch(
data,
calculate_entropy=calculate_entropy,
use_dynamic_bsz=self.config.use_dynamic_bsz,
micro_batch_size=micro_batch_size,
max_token_len=max_token_len,
mini_batch_size=self.config.ppo_mini_batch_size,
)
# 统计loss等其他指标
metric_micro_batch = metric_micro_batch["output"]
for metric in metric_micro_batch:
# Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask
append_to_dict(metrics, metric[0]) # append the metric from this micro-batch to global metrics.
# 更新模型
update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step()
data = {"actor/grad_norm": grad_norm}
append_to_dict(metrics, data)
if update_successful:
# allgather already execute in optimizer.step in new megatron
pass
else:
raise NotImplementedError
if self.use_torch_profiler and self.prof and self.prof.enable:
self.prof.step()
# add empty cache after each compute
if self.use_torch_profiler and self.prof and self.prof.enable:
self.prof.stop_and_save()
self.prof.stop_trace()
get_torch_device().empty_cache()
return metricsLoad 相关
加载 Reward_manager
load_reward_manager
def load_reward_manager(
config: DictConfig, tokenizer: Any, num_examine: int, **reward_kwargs: Any
) -> AbstractRewardManager:
"""
Load and initialize a reward manager based on the configuration.
Args:
config: PPO trainer configuration object containing reward_model fields.
tokenizer: Tokenizer object used for processing text.
num_examine: Number of samples to examine.
**reward_kwargs: Additional keyword arguments for the reward manager.
Returns:
An instance of the specified reward manager class.
"""
# Try to get a custom reward function based on the configuration
# user defined reward manager can be registered in custom_reward_fn
compute_score = get_custom_reward_fn(config)
final_compute_score = compute_score
# The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
# naive: NaiveRewardManager
# prime: PrimeRewardManager
# batch: BatchRewardManager
# dapo: DAPORewardManager
# Note(haibin.lin): For custom reward managers, please make sure they are imported and
# registered via `verl.workers.reward_manager.register`
# By default reward_manager is set to naive (NaiveRewardManager)
reward_manager_name = config.reward_model.get("reward_manager", "naive")
reward_manager_cls = get_reward_manager_cls(reward_manager_name)
if compute_score is None:
sandbox_config = config.reward_model.get("sandbox_fusion")
sandbox_url = sandbox_config.get("url") if sandbox_config else None
memory_limit_mb = sandbox_config.get("memory_limit_mb", 1024) if sandbox_config else 1024
if sandbox_url:
sandbox_manager = multiprocessing.Manager()
# Create a semaphore to control concurrent access to the sandbox
_concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64))
final_compute_score = partial(
default_compute_score,
sandbox_fusion_url=sandbox_url,
concurrent_semaphore=_concurrent_semaphore,
memory_limit_mb=memory_limit_mb,
)
else:
final_compute_score = default_compute_score
# Instantiate and return the reward manager with the specified parameters
# RewardLoopManagerBase subclasses (like RateLimitedRewardLoopManager) don't accept num_examine
# while AbstractRewardManager subclasses (like NaiveRewardManager) do
if RewardLoopManagerBase is not None and issubclass(reward_manager_cls, RewardLoopManagerBase):
# RewardLoopManagerBase-based manager`s use a different signature
return reward_manager_cls(
config=config,
tokenizer=tokenizer,
compute_score=final_compute_score,
**reward_kwargs,
)
else:
# Traditional AbstractRewardManager-based managers
return reward_manager_cls(
tokenizer=tokenizer,
num_examine=num_examine,
compute_score=final_compute_score,
reward_fn_key=config.data.reward_fn_key,
**reward_kwargs,
)NaiveRewardManager
from collections import defaultdict
from typing import Any
import torch
from verl import DataProto
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager import register
from verl.workers.reward_manager.abstract import AbstractRewardManager
@register("naive")
class NaiveRewardManager(AbstractRewardManager):
"""The reward manager."""
def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None:
"""
Initialize the NaiveRewardManager instance.
Args:
tokenizer: The tokenizer used to decode token IDs into text.
num_examine: The number of batches of decoded responses to print to the console for debugging purpose.
compute_score: A function to compute the reward score. If None, `default_compute_score` will be used.
reward_fn_key: The key used to access the data source in the non-tensor batch data. Defaults to
"data_source".
"""
self.tokenizer = tokenizer # Store the tokenizer for decoding token IDs
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
self.compute_score = compute_score or default_compute_score
self.reward_fn_key = reward_fn_key # Store the key for accessing the data source
def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]:
"""We will expand this function gradually based on the available datasets"""
# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
if "rm_scores" in data.batch.keys():
if return_dict:
reward_extra_keys = data.meta_info.get("reward_extra_keys", [])
reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}
return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info}
else:
return data.batch["rm_scores"]
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_extra_info = defaultdict(list)
already_print_data_sources = {}
for i in range(len(data)):
data_item = data[i] # DataProtoItem
prompt_ids = data_item.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
response_ids = data_item.batch["responses"]
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"]
data_source = data_item.non_tensor_batch[self.reward_fn_key]
extra_info = data_item.non_tensor_batch.get("extra_info", {})
num_turns = data_item.non_tensor_batch.get("__num_turns__", None)
rollout_reward_scores = data_item.non_tensor_batch.get("reward_scores", {})
extra_info["num_turns"] = num_turns
extra_info["rollout_reward_scores"] = rollout_reward_scores
# 调用reward_fn计算score
score = self.compute_score(
data_source=data_source,
solution_str=response_str,
ground_truth=ground_truth,
extra_info=extra_info,
)
if isinstance(score, dict):
# 如果是dict,可以有多个指标
reward = score["score"]
# Store the information including original reward
for key, value in score.items():
reward_extra_info[key].append(value)
else:
# 直接返回单个score
reward = score
reward_tensor[i, valid_response_length - 1] = reward
if data_source not in already_print_data_sources:
already_print_data_sources[data_source] = 0
if already_print_data_sources[data_source] < self.num_examine:
already_print_data_sources[data_source] += 1
print("[prompt]", prompt_str)
print("[response]", response_str)
print("[ground_truth]", ground_truth)
if isinstance(score, dict):
for key, value in score.items():
print(f"[{key}]", value)
else:
print("[score]", score)
if return_dict:
return {
"reward_tensor": reward_tensor,
"reward_extra_info": reward_extra_info,
}
else:
return reward_tensorget_custom_reward_fn
自定义的reward_fn.py
def get_custom_reward_fn(config: DictConfig) -> Optional[RawRewardFn]:
"""Load and return a custom reward function from external file.
Dynamically imports a reward function from a specified file path and wraps
it with additional keyword arguments from the configuration.
Args:
config (dict): Configuration dictionary containing custom_reward_function
settings with 'path', 'name', and 'reward_kwargs' fields.
Returns:
callable or None: Wrapped reward function with merged kwargs, or None
if no custom reward function is configured.
Raises:
FileNotFoundError: If the specified reward function file doesn't exist.
RuntimeError: If there's an error loading the module from file.
AttributeError: If the specified function name isn't found in the module.
"""
reward_fn_config = config.get("custom_reward_function") or {}
file_path = reward_fn_config.get("path")
if not file_path:
return None
function_name = reward_fn_config.get("name")
assert function_name is not None
module = sys.modules.get("custom_module", None)
if module is None:
if not os.path.exists(file_path):
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
spec = importlib.util.spec_from_file_location("custom_module", file_path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
try:
sys.modules["custom_module"] = module
assert spec.loader is not None
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e
if not hasattr(module, function_name):
raise AttributeError(f"Reward function '{function_name}' not found in '{module.__file__}'.")
print(f"using customized reward function '{function_name}' from '{module.__file__}'")
raw_fn = getattr(module, function_name)
reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {}))
if not inspect.iscoroutinefunction(raw_fn):
return partial(_call_with_kwargs, raw_fn, reward_kwargs)
else:
return partial(_call_with_kwargs_async, raw_fn, reward_kwargs)default_compute_score
内置的一些function
def default_compute_score(
data_source,
solution_str,
ground_truth,
extra_info=None,
sandbox_fusion_url=None,
concurrent_semaphore=None,
memory_limit_mb=None,
**kwargs,
):
"""Compute the score for a given solution based on the data source.
Args:
data_source (str): The source dataset identifier which determines the scoring method.
solution_str (str): The solution string to be evaluated.
ground_truth (str): The ground truth answer for comparison.
extra_info (dict, optional): Additional information that might be needed for scoring. Defaults to None.
Returns:
float: The computed score as a floating point number. If the result is a dictionary,
it returns the dictionary instead.
Raises:
NotImplementedError: If the reward function is not implemented for the given data source.
"""
if data_source == "openai/gsm8k":
from . import gsm8k
res = gsm8k.compute_score(solution_str, ground_truth)
elif data_source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval", "HuggingFaceH4/MATH-500"]:
from . import math_reward
res = math_reward.compute_score(solution_str, ground_truth)
# [Optional] Math-Verify Integration
# For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify).
# Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`.
# To use it, override the `compute_score` function with the following implementation:
# from . import math_verify
# res = math_verify.compute_score(solution_str, ground_truth)
elif data_source in ["math_dapo", "math", "math_dapo_reasoning"] or data_source.startswith("aime"):
from . import math_dapo
res = math_dapo.compute_score(solution_str, ground_truth)
elif data_source in [
"numina_aops_forum",
"numina_synthetic_math",
"numina_amc_aime",
"numina_synthetic_amc",
"numina_cn_k12",
"numina_olympiads",
]:
from . import prime_math
res = prime_math.compute_score(solution_str, ground_truth)
elif data_source in ["codecontests", "apps", "codeforces", "taco"]:
# Use the passed sandbox_fusion_url if available
if sandbox_fusion_url:
from . import sandbox_fusion
# Pass the URL directly, ground_truth likely contains test cases here
res = sandbox_fusion.compute_score(
sandbox_fusion_url, concurrent_semaphore, memory_limit_mb, solution_str, ground_truth, continuous=True
)
else:
# If no sandbox URL is provided, fall back to prime_code or raise error
from . import prime_code
# Assuming prime_code doesn't need the URL
res = prime_code.compute_score(solution_str, ground_truth, continuous=True)
elif data_source in ["hiyouga/geometry3k"]:
from . import geo3k
res = geo3k.compute_score(solution_str, ground_truth)
elif data_source in [
"searchR1_nq",
"searchR1_triviaqa",
"searchR1_popqa",
"searchR1_hotpotqa",
"searchR1_2wikimultihopqa",
"searchR1_musique",
"searchR1_bamboogle",
]:
from . import search_r1_like_qa_em
res = search_r1_like_qa_em.compute_score(solution_str, ground_truth)
else:
raise NotImplementedError(f"Reward function is not implemented for {data_source=}")
if isinstance(res, dict):
return res
elif isinstance(res, int | float | bool):
return float(res)
else:
return float(res[0])初始化Workers-主流程 (RayPPOTrainer)
def init_workers(self):
"""Initialize distributed training workers using Ray backend.
Creates:
1. Ray resource pools from configuration
2. Worker groups for each role (actor, critic, etc.)
"""
self.resource_pool_manager.create_resource_pool()
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
# create actor and rollout
if self.hybrid_engine:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
actor_rollout_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[Role.ActorRollout],
config=self.config.actor_rollout_ref,
role=str(Role.ActorRollout),
)
self.resource_pool_to_cls[resource_pool][str(Role.ActorRollout)] = actor_rollout_cls
else:
raise NotImplementedError
# create critic
if self.use_critic:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
critic_cfg = omega_conf_to_dataclass(self.config.critic)
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls
# create reference policy if needed
if self.use_reference_policy:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(
self.role_worker_mapping[Role.RefPolicy],
config=self.config.actor_rollout_ref,
role=str(Role.RefPolicy),
)
self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls
# create a reward model if reward_fn is None
if self.use_rm:
# we create a RM here
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls
# initialize WorkerGroup
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
# you should not use `create_colocated_worker_cls`.
# Instead, directly pass different resource pool to different worker groups.
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
all_wg = {}
wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
if OmegaConf.select(self.config.global_profiler, "steps") is not None:
wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps")
# Only require nsight worker options when tool is nsys
if OmegaConf.select(self.config.global_profiler, "tool") == "nsys":
assert (
OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
is not None
), "worker_nsight_options must be set when using nsys with profile_steps"
wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
)
wg_kwargs["device_name"] = self.device_name
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(
resource_pool=resource_pool,
ray_cls_with_init=worker_dict_cls,
**wg_kwargs,
)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
if self.use_critic:
self.critic_wg = all_wg[str(Role.Critic)]
self.critic_wg.init_model()
if self.use_reference_policy and not self.ref_in_actor:
self.ref_policy_wg = all_wg[str(Role.RefPolicy)]
self.ref_policy_wg.init_model()
self.rm_wg = None
# initalization of rm_wg will be deprecated in the future
if self.use_rm:
self.rm_wg = all_wg[str(Role.RewardModel)]
self.rm_wg.init_model()
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
self.actor_rollout_wg = all_wg[str(Role.ActorRollout)]
self.actor_rollout_wg.init_model()
# create async rollout manager and request scheduler
self.async_rollout_mode = False
if self.config.actor_rollout_ref.rollout.mode == "async":
from verl.experimental.agent_loop import AgentLoopManager
self.async_rollout_mode = True
self.async_rollout_manager = AgentLoopManager(
config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg
) 初始化 Megatron Works
MegatronActorRolloutRefWorker. 构造函数
class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
def __init__(self, config: DictConfig, role: str, **kwargs):
Worker.__init__(self)
self.config = config
if repatch is not None:
# NPU MindSpeed patch, will be refactored with MindSpeedEngine.
repatch(self.config.actor.megatron.get("override_transformer_config", {}))
# NOTE(sgm): We utilize colocate WorkerGroup by default.
# As a result, Workers for different model share the same process.
# Therefore, we only require one distribute initialization.
# To utilize different parallel strategy in different models:
# 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,
# 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385
if not torch.distributed.is_initialized():
set_numa_affinity()
rank = int(os.environ["LOCAL_RANK"])
torch.distributed.init_process_group(
backend=get_nccl_backend(),
timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)),
init_method=os.environ.get("DIST_INIT_METHOD", None),
)
get_torch_device().set_device(rank)
mpu.initialize_model_parallel(
tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size,
pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=self.config.actor.megatron.virtual_pipeline_model_parallel_size,
use_sharp=False,
context_parallel_size=self.config.actor.megatron.context_parallel_size,
expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size,
expert_tensor_parallel_size=self.config.actor.megatron.expert_tensor_parallel_size,
nccl_communicator_config_path=None,
)
is_collect = (
mpu.get_tensor_model_parallel_rank() == 0
and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1
and mpu.get_context_parallel_rank() == 0
)
self._register_dispatch_collect_info(
mesh_name="actor", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect
)
set_random_seed(seed=self.config.actor.megatron.seed)
self.role = role
assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"]
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
if self._is_actor:
omega_profiler_config = config.actor.get("profiler", {})
elif self._is_rollout:
# NOTE: In colocation mode, rollout config may not take effect (follow the actor config)
# This is for extendability in AsyncRL cases
omega_profiler_config = config.rollout.get("profiler", {})
elif self._is_ref:
omega_profiler_config = config.ref.get("profiler", {})
else:
raise ValueError(
f"Invalid role {self.role}, should be one of "
"['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']"
)
# omega_profiler_config is DictConfig
# profiler_config is a ProfilerConfig dataclass
profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
else:
tool_config = None
DistProfilerExtension.__init__(
self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config)
)
# TODO(sgm): Currently, we only support reference model param offload
# will support other offload later
self._is_offload_param = False
self._is_offload_grad = False
self._is_offload_optimizer = False
# normalize config
if self._is_actor and self._is_rollout:
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()
if self.config.actor.get("ppo_micro_batch_size", None):
self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()
self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()
self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size
self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size
self._is_offload_param = self.config.actor.megatron.get("param_offload", False)
self._is_offload_grad = self.config.actor.megatron.get("grad_offload", False)
self._is_offload_optimizer = self.config.actor.megatron.get("optimizer_offload", False)
elif self._is_ref:
if self.config.ref.get("log_prob_micro_batch_size", None):
self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()
self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size
else:
assert self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) is not None, (
"Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and "
"`log_prob_micro_batch_size` should not be None at the same time."
)
self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False)MegatronActorRolloutRefWorker.初始化模型
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
if self.config.model.get("external_lib", None) is not None:
# This is used to import external_lib into the huggingface systems
import importlib
importlib.import_module(self.config.model.external_lib)
from verl.utils.torch_dtypes import PrecisionType
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
if self._is_actor:
override_transformer_config = OmegaConf.to_container(
OmegaConf.create(self.config.actor.megatron.get("override_transformer_config", {}))
)
override_ddp_config = OmegaConf.to_container(
OmegaConf.create(self.config.actor.megatron.get("override_ddp_config", {}))
)
elif self._is_ref:
override_transformer_config = OmegaConf.to_container(
OmegaConf.create(self.config.ref.megatron.get("override_transformer_config", {}))
)
else:
override_transformer_config = {}
self.param_dtype = PrecisionType.to_dtype(self.config.actor.megatron.dtype)
log_gpu_memory_usage("Before init actor model and optimizer", logger=logger)
self.dtype = PrecisionType.to_dtype(self.param_dtype)
if self._is_actor:
# we need the model for actor and rollout
optim_config = self.config.actor.optim if self._is_actor else None
(
self.actor_module,
self.actor_optimizer,
self.actor_optimizer_scheduler,
self.actor_model_config,
self.actor_optim_config,
) = self._build_model_optimizer(
model_path=self.config.model.path,
optim_config=optim_config,
override_model_config=override_model_config,
override_transformer_config=override_transformer_config,
override_ddp_config=override_ddp_config,
)
if self._is_offload_param:
offload_megatron_model_to_cpu(self.actor_module)
log_gpu_memory_usage("After offload actor params and grad during init", logger=logger)
if self._is_offload_optimizer:
offload_megatron_optimizer(self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
if self._is_actor:
actor_cfg = omega_conf_to_dataclass(self.config.actor)
self.actor = MegatronPPOActor(
config=actor_cfg,
model_config=self.actor_model_config,
hf_config=self.hf_config,
tf_config=self.tf_config,
actor_module=self.actor_module,
actor_optimizer=self.actor_optimizer,
)
log_gpu_memory_usage("After MegatronPPOActor init", logger=logger)
if self._is_rollout:
self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False))
log_gpu_memory_usage("After rollout init", logger=logger)
if self._is_ref:
self.ref_module, self.ref_model_config = self._build_model_optimizer(
model_path=self.config.model.path,
optim_config=None,
override_model_config=override_model_config,
override_transformer_config=override_transformer_config,
)
log_gpu_memory_usage("After ref model init", logger=logger)
self.ref_policy = MegatronPPOActor(
config=self.config.ref,
model_config=self.ref_model_config,
hf_config=self.hf_config,
tf_config=self.tf_config,
actor_module=self.ref_module,
actor_optimizer=None,
)
if self._ref_is_offload_param:
offload_megatron_model_to_cpu(self.ref_module)
log_gpu_memory_usage("After offload ref params during init", logger=logger)
if self._is_actor:
self.flops_counter = FlopsCounter(self.actor_model_config)
self.checkpoint_mananager = MegatronCheckpointManager(
config=self.config,
checkpoint_config=self.config.actor.checkpoint,
model_config=self.actor_model_config,
transformer_config=self.tf_config,
role="actor",
model=self.actor_module,
arch=self.architectures[0],
hf_config=self.hf_config,
param_dtype=self.param_dtype,
share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
processing_class=self.processor if self.processor is not None else self.tokenizer,
optimizer=self.actor_optimizer,
optimizer_scheduler=self.actor_optimizer_scheduler,
use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,
use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler,
bridge=self.bridge,
use_dist_checkpointing=self.config.actor.megatron.use_dist_checkpointing,
)
self.layer_name_mapping = {
"qkv_layer_name": "self_attention.linear_qkv.",
"gate_proj_layer_name": "linear_fc1.",
}
self.weight_converter = None
if not self.config.actor.megatron.use_mbridge:
self.weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
get_torch_device().empty_cache()
log_gpu_memory_usage("After init_model finish", logger=logger)MegatronPPOActor 初始化
class MegatronPPOActor(BasePPOActor):
def __init__(
self,
config,
model_config,
hf_config,
tf_config,
actor_module: nn.ModuleList,
actor_optimizer: DistributedOptimizer,
):
"""MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron.
Args:
config (OmegaConf): the basic config that contains the hyper-parameters of PPO Actor. It must contain
``ppo_micro_batch_size_per_gpu``: micro batch size when updating ppo.
``ppo_mini_batch_size``: minibatch size when updating ppo using the batch data.
``ppo_epochs``: number of epochs to update the actor using the batch data.
``shuffle``: whether to shuffle the data after each ppo epoch.
``clip_ratio``: clip ratio of the ppo algorithm. See https://arxiv.org/abs/1707.06347.
``entropy_coeff``: entropy coefficient of the PPO loss. See https://arxiv.org/abs/1707.06347.
model_config (OmegaConf): model configuration. It must contains ``model_config.vocab_size`` and
``model_config.hidden_size``
hf_config (PretrainedConfig): huggingface config
tf_config (TransformerConfig): mcore transformer config
actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this
pp stage.
each nn.Module in this rank holds a vpp module chunk. See https://arxiv.org/pdf/2104.04473.pdf for
more details.
The actor module has some constraints to follow in order to use the updating logics implemented here
1. It must implement unpad_input before any computation and pad_input after all the computation.
Remove padding is an
optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn
(https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py).
2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size],
where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size
of the hidden state is [total_nnz // tp, 1, hidden_size].
actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron.
It implements
zero1 optimizer that shards the optimizer state across dp ranks.
>>> from megatron.training import get_model
>>> from megatron.optimizer import get_megatron_optimizer
>>> actor_module = get_model(megatron_actor_model_provider, wrap_with_ddp=True)
>>> actor_module = nn.ModuleList(actor_module)
>>> actor_optimizer = get_megatron_optimizer(actor_module)
>>> actor = MegatronPPOActor(config=config,
>>> model_config=actor_model_config,
>>> hf_config=hf_config,
>>> tf_config=tf_config,
>>> actor_module=actor_module,
>>> actor_optimizer=actor_optimizer)
"""
super().__init__(config)
self._validate_config(config)
self.model_config = model_config
self.hf_config = hf_config
self.tf_config = tf_config
self.actor_module = actor_module
self.actor_optimizer: DistributedOptimizer = actor_optimizer
self.use_torch_profiler = self.config.profiler.get("tool") == "torch"
if self.use_torch_profiler:
self.prof = Profiler(
self.config.profiler, tool_config=self.config.profiler.get("tool_config", {}).get("torch", {})
)
else:
self.prof = None
self.use_fused_kernels = self.config.get("use_fused_kernels", False)
if self.use_fused_kernels and not getattr(self.config, "overlap_moe_expert_parallel_comm", False):
# do not patch if overlap_moe_expert_parallel_comm is enabled
from verl.models.mcore.model_forward_fused import patch_fused_forward
for model in self.actor_module:
patch_fused_forward(model)
self.optimizer_step_args = OmegaConf.create(
{
"skip_grad": None,
"overlap_dp_param_comm": False,
"overlap_dp_grad_comm": False,
"gradient_accumulation_steps": 1,
"sequence_parallel": self.tf_config.sequence_parallel,
"DDP_impl": "local",
"layernorm_allreduce_bucket_threshold": 0,
"reduce_grads_use_alltoall": False,
}
)
config = get_model_config(self.actor_module[0])
print(config)
config.finalize_model_grads_func = finalize_model_grads其中,BasePPOActor接口如下,核心是compute_log_prob和update_policy。
class BasePPOActor(ABC):
def __init__(self, config):
"""The base class for PPO actor
Args:
config (DictConfig): a config passed to the PPOActor. We expect the type to be
DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general.
"""
super().__init__()
self.config = config
@abstractmethod
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
"""Compute logits given a batch of data.
Args:
data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```,
```attention_mask``` and ```position_ids```.
Returns:
DataProto: a DataProto containing the key ```log_probs```
"""
pass
@abstractmethod
def update_policy(self, data: DataProto) -> dict:
"""Update the policy with an iterator of DataProto
Args:
data (DataProto): an iterator over the DataProto that returns by
```make_minibatch_iterator```
Returns:
Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model
such as ```loss```, ```grad_norm```, etc,.
"""
pass加载Megatron TransformerConfig
初始化hf和tfconfig (MegatronWorker)
class MegatronWorker(Worker):
def _init_hf_config_and_tf_config(
self,
model_path,
tokenizer_or_path,
dtype,
override_model_config,
override_transformer_config,
trust_remote_code=False,
use_mbridge=False,
):
from transformers import AutoConfig
from verl.models.mcore import hf_to_mcore_config
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.model import update_model_config
# Step 1: initialize the tokenizer
self.local_path = copy_to_local(model_path)
if tokenizer_or_path is None:
self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code)
self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code)
elif isinstance(tokenizer_or_path, str):
self.tokenizer = hf_tokenizer(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code)
self.processor = hf_processor(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code)
else:
self.tokenizer = tokenizer_or_path
self.processor = tokenizer_or_path
if self.config.model.get("custom_chat_template", None) is not None:
if self.processor is not None:
self.processor.chat_template = self.config.model.custom_chat_template
else:
self.tokenizer.chat_template = self.config.model.custom_chat_template
# Step 2: get the hf
hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code)
# Step 3: override the hf config
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_model_config.get("model_config", {}))
self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False)
update_model_config(hf_config, override_config_kwargs=override_config_kwargs)
self.architectures = getattr(hf_config, "architectures", None)
if self.rank == 0:
print(f"Model config after override: {hf_config}")
from verl.models.mcore.config_converter import mapping_string_to_attn_backend
# todo: remove this line after mcore adopt mbridge 0.15, now for compatibility
override_transformer_config = mapping_string_to_attn_backend(override_transformer_config)
if use_mbridge:
from verl.models.mcore.mbridge import AutoBridge
bridge = AutoBridge.from_config(hf_config)
bridge.set_extra_args(**override_transformer_config)
tf_config = bridge.config
self.bridge = bridge
else:
tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config)
self.bridge = None
print(f"TF config: {tf_config}")
self.hf_config = hf_config
self.tf_config = tf_config转换成TransformerConfig
- 本质是把HFConfig转换成Megatron的TransformerConfig
支持的模型
# Registry for model configuration converters
MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
SupportedModel.LLAMA: hf_to_mcore_config_dense,
SupportedModel.QWEN2: hf_to_mcore_config_dense,
SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe,
SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3,
SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral,
SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,
SupportedModel.LLAMA4: hf_to_mcore_config_llama4,
SupportedModel.QWEN3: hf_to_mcore_config_dense,
SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: hf_to_mcore_config_dense,
}Dense示例
def hf_to_mcore_config_dense(
hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs
) -> TransformerConfig:
# for LlamaForCausalLM or Qwen2ForCausalLM
qkv_bias = True if "Qwen2" in hf_config.architectures[0] else getattr(hf_config, "attention_bias", False)
qk_layernorm = True if "Qwen3" in hf_config.architectures[0] else False
args: dict = _get_base_transformer_config(
hf_config=hf_config,
dtype=dtype,
use_cpu_initialization=False,
add_bias_linear=False,
add_qkv_bias=qkv_bias,
qk_layernorm=qk_layernorm,
)
# override_transformer_config_kwargs as kwargs shall never be none
args.update(override_transformer_config_kwargs)
return check_and_construct_configs(args, TransformerConfig)get_base_transformer_config
def _get_base_transformer_config(
hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs
) -> dict:
"""
Create a base TransformerConfig with common parameters across different model architectures.
TODO: (ycl) use dataclass or converter config?
Args:
hf_config: HuggingFace model configuration
dtype: Data type for the model
override_transformer_config_kwargs: Additional parameters to override defaults
Returns:
TransformerConfig with common parameters
"""
# Common parallel state parameters
overlap_p2p_comm = (
mpu.get_virtual_pipeline_model_parallel_world_size() is not None
and mpu.get_virtual_pipeline_model_parallel_world_size() > 1
)
batch_p2p_comm = False
# Base configuration with common parameters
base_config = {
# Model architecture parameters
"num_layers": hf_config.num_hidden_layers,
"hidden_size": hf_config.hidden_size,
"num_attention_heads": hf_config.num_attention_heads,
"num_query_groups": hf_config.num_key_value_heads,
"ffn_hidden_size": hf_config.intermediate_size,
"attention_dropout": hf_config.attention_dropout,
"hidden_dropout": getattr(hf_config, "hidden_dropout", 0.0),
"kv_channels": getattr(hf_config, "head_dim", None),
"layernorm_epsilon": hf_config.rms_norm_eps,
"add_bias_linear": True,
# Activation and normalization
"activation_func": F.silu,
"normalization": "RMSNorm",
"gated_linear_unit": True,
# Data types
"pipeline_dtype": dtype,
"params_dtype": dtype,
"bf16": dtype is torch.bfloat16,
# Parallel configuration
"tensor_model_parallel_size": mpu.get_tensor_model_parallel_world_size(),
"pipeline_model_parallel_size": mpu.get_pipeline_model_parallel_world_size(),
"expert_model_parallel_size": mpu.get_expert_model_parallel_world_size(),
"expert_tensor_parallel_size": mpu.get_expert_tensor_parallel_world_size(),
"virtual_pipeline_model_parallel_size": mpu.get_virtual_pipeline_model_parallel_world_size(),
"context_parallel_size": mpu.get_context_parallel_world_size(),
"overlap_p2p_comm": overlap_p2p_comm,
"batch_p2p_comm": batch_p2p_comm,
"sequence_parallel": mpu.get_tensor_model_parallel_world_size() > 1,
# Common settings
"variable_seq_lengths": True,
"masked_softmax_fusion": True,
"moe_token_dispatcher_type": "alltoall",
}
# Update with any provided overrides
# override_transformer_config_kwargs as kwargs shall never be none
base_config.update(override_transformer_config_kwargs)
return base_configTransformerConfig 源码
这个有点长,本打算截一下,但有的参数,可能未来用得到,就懒得截取了,方便查看。
| 并行/特性 | 核心互斥/限制 | 原因简述 |
|---|---|---|
| TP (张量) | 头数必须整除 TP | 物理上无法将半个 Head 分给 GPU |
| PP (流水线) | 不支持 CPU Offloading | 复杂的调度逻辑难以兼容 CPU 换入换出 |
| PP (流水线) | Layout 与 num_layers_first 互斥 | 两种定义切分的方式冲突 |
| VPP (虚拟流水线) | 层数必须整除 VPP | 每个虚拟 Chunk 必须包含完整的层 |
| MoE (专家) | EP>1 时必须定义专家数 | 逻辑完备性检查 |
| CPU Offload | 不支持 Recompute, PP, CudaGraph | 内存管理机制冲突 |
| CudaGraph | 不支持 Recompute | 静态图捕获难以处理重计算的动态性 |
并行参数
关键参数
- tensor_model_parallel_size, num_attention_heads,
- num_query_groups (GQA组数),sequence_parallel
约束条件
num_attention_heads%tensor_model_parallel_size== 0- num_query_groups % tensor_model_parallel_size == 0
- 如果使用sequence_parallel,distribute_saved_activations必须设为False
if self.num_attention_heads % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_attention_heads ({self.num_attention_heads}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)
if self.num_query_groups is None:
self.num_query_groups = self.num_attention_heads
if self.num_query_groups % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_query_groups ({self.num_query_groups}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)相关参数
- pipeline_model_parallel_size:传统PP并行度
- virtual_pipeline_model_parallel_size:虚拟PP并行度,用于Interleaved调度,减少气泡
约束条件
num_layers%pipeline_parallel_size== 0- 首尾Stage、
中间Stage层数,需要被VPP大小整除。 - cpu_offloading 不支持流水线并行,pipeline_model_parallel_size必须为1.
elif (self.num_layers_in_first_pipeline_stage is not None
or self.num_layers_in_last_pipeline_stage is not None
):
pipeline_parallel_size = self.pipeline_model_parallel_size
num_layers = self.num_layers
if self.num_layers_in_first_pipeline_stage is not None:
if self.num_layers_in_first_pipeline_stage <= 0:
raise ValueError("num_layers_in_first_pipeline_stage must be larger than 0")
if self.virtual_pipeline_model_parallel_size is not None:
if (
self.num_layers_in_first_pipeline_stage
% self.virtual_pipeline_model_parallel_size
!= 0
):
raise ValueError(
f"number of layers at first stage: "
f"{self.num_layers_in_first_pipeline_stage}"
f"must be divisible by virtual pipeline"
f"parallel degree {self.virtual_pipeline_model_parallel_size}"
)
num_layers -= self.num_layers_in_first_pipeline_stage
pipeline_parallel_size -= 1
if self.num_layers_in_last_pipeline_stage is not None:
if self.num_layers_in_last_pipeline_stage <= 0:
raise ValueError("num_layers_in_last_pipeline_stage must be larger than 0")
if self.virtual_pipeline_model_parallel_size is not None:
if (
self.num_layers_in_last_pipeline_stage
% self.virtual_pipeline_model_parallel_size
!= 0
):
raise ValueError(
f"number of layers at last stage: "
f"{self.num_layers_in_last_pipeline_stage}"
f"must be divisible by virtual pipeline"
f"parallel degree {self.virtual_pipeline_model_parallel_size}"
)
num_layers -= self.num_layers_in_last_pipeline_stage
pipeline_parallel_size -= 1
# Here pipeline_parallel_size is the number of middle PP stages. If there are middle
# PP stages, check number of layers at middle stage is divisible by middle PP size.
if pipeline_parallel_size and not num_layers % pipeline_parallel_size == 0:
raise ValueError(
f"number of layers at middle stage: {num_layers} must be divisible by"
f"the middle pipeline model parallel size {pipeline_parallel_size}"
)
# If there are middle PP stages, check number of layers
# on each middle PP rank is divisible by VPP size.
if pipeline_parallel_size and self.virtual_pipeline_model_parallel_size is not None:
num_layers_per_middle_pipeline_rank = num_layers // pipeline_parallel_size
if (
not num_layers_per_middle_pipeline_rank
% self.virtual_pipeline_model_parallel_size
== 0
):
raise ValueError(
f"number of layers on each middle pipeline rank:"
f"{num_layers_per_middle_pipeline_rank} must be divisible by virtual"
f"pipeline parallel degree {self.virtual_pipeline_model_parallel_size}"
)相关参数
expert_model_parallel_size:EP并行度num_moe_experts:专家总数- moe_router_num_groups、moe_router_group_topk:组限制路由配置
- moe_token_dispatcher_type:token分发方式,allgather、alltoall、flex
约束条件
- expert_model_parallel_size > 1时,
必须设置num_moe_experts num_moe_experts%moe_router_num_groups== 0
相关参数
- context_parallel_size:CP并行度
- cp_comm_type:通信类型,取值
p2p,all_gather,a2a
约束条件
- 如果cp_comm_type是列表,即为每层指定不同通信方式,列表长度必须等于num_layers.
显存关键参数
重计算
- 为了节省显存,在反向时重计算模型的部分,
用计算时间换显存空间。 - selective:选择显存占用大、但计算小的部分。
- 一般是core_attn。Attention产生的
激活值显存占用高,但重计算相对代价小。 - 可选mlp,moe,layernorm等。
- 一般是core_attn。Attention产生的
配置示例
recompute_granularity=selectiverecompute_modules='["core_attn", "mlp"]'- actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity="${recompute_granularity}"
- actor_rollout_ref.actor.megatron.override_transformer_config.recompute_modules="${recompute_modules}"
TransformerConfig post_init 源码
@dataclass
class TransformerConfig(ModelParallelConfig):
"""Configuration object for megatron-core transformers.
The initialization function has an argument for each parameter,
including those in ModelParallelConfig.
"""
####################
# model architecture
####################
num_layers: int = 0
"""Number of transformer layers in a transformer block."""
mtp_num_layers: Optional[int] = None
"""Number of Multi-Token Prediction (MTP) Layers."""
mtp_loss_scaling_factor: Optional[float] = None
"""Weighting factor of Multi-Token Prediction (MTP) loss."""
num_layers_in_first_pipeline_stage: Optional[int] = None
"""Number of transformer layers on first pipeline stage.
None implies equal layer division across PP ranks."""
num_layers_in_last_pipeline_stage: Optional[int] = None
"""Number of transformer layers on last pipeline stage.
None implies equal layer division across PP ranks."""
pipeline_model_parallel_layout: Optional[Union[str, list, PipelineParallelLayerLayout]] = None
"""Custom definition of the pipeline parallel partitioning.
Support type:
- str: e.g., 'Et*3|(tt|)*29,m|L'. Stages are split by '|', replicated stages or layers
can be described with multiplication. Commas can be used cosmetically.
- list: e.g., [['embedding', 'decoder'], ['decoder', 'decoder', 'decoder', 'loss']].
- PipelineParallelLayerLayout: a PipelineParallelLayerLayout object.
If given either a string or a list, it will be transferred into a PipelineParallelLayerLayout
in post init. Let i = a * pp_size + b, then layout[i] gives a list of the layers
in the a-th vpp stage and the b-th pp stage, i.e., vpp(0)pp(0), vpp(0)pp(1), ...,
vpp(i)pp(j), vpp(i)pp(j+1), ..., vpp(-1)pp(-2), vpp(-1)pp(-1).
In the inner lists of layers, 'embedding' or 'E' denotes the embedding layer, 'loss' or 'L'
denotes the loss function, and 'decoder' or 't' denotes the transformer decoder layer.
Examples:
[['embedding', 'decoder'], ['decoder', 'decoder', 'decoder', 'loss']]:
pp = 2, vpp = None
pp rank 0 holds: embedding, decoder
pp rank 1 holds: decoder*3, loss
'E|(tt|)*2,(t|)*4,mL':
pp = 2, vpp = 4
vpp rank 0 pp rank 0 holds: embedding
vpp rank 0 pp rank 1~2 holds: decoder*2
vpp rank 0 pp rank 3 holds: decoder
vpp rank 1 pp rank 0~2 holds: decoder
vpp rank 1 pp rank 3 holds: mtp, loss"""
account_for_embedding_in_pipeline_split: bool = False
"""If set, the embedding layer will be treated as a standard transformer
layer in the context of partition and placement for pipeline parallelism."""
account_for_loss_in_pipeline_split: bool = False
"""If set, the loss layer will be treated as a standard transformer
layer in the context of partition and placement for pipeline parallelism."""
hidden_size: int = 0
"""Transformer hidden size."""
num_attention_heads: int = 0
"""Number of transformer attention heads."""
attention_backend: AttnBackend = AttnBackend.auto
"""Attention backend to run. By default we let transformer engine
decide the best backend to run (except in the case of local).
If attention backend is local we use the local pytorch implementation in mcore.
Users can specify exact backend by changing this config. """
softmax_scale: Optional[float] = None
"""Softmax scale for attention scaling."""
num_query_groups: Optional[int] = None
"""Number of query groups for group query attention. If None, normal attention is used."""
ffn_hidden_size: Optional[int] = None
"""Transformer Feed-Forward Network hidden size. This is set to 4*hidden_size
if not provided."""
kv_channels: Optional[int] = None
"""Projection weights dimension in multi-head attention. This is set to hidden_size //
num_attention_heads if not provided."""
hidden_dropout: float = 0.1
"""Dropout probability for transformer hidden state."""
attention_dropout: float = 0.1
"""Post attention dropout probability."""
fp32_residual_connection: bool = False
"""If true, move residual connections to fp32."""
# @jcasper should we keep this option?
apply_residual_connection_post_layernorm: bool = False
"""If True, uses the original BERT residule connection ordering."""
layernorm_epsilon: float = 1e-5
"""Epsilon value for any LayerNorm operations."""
layernorm_zero_centered_gamma: bool = False
"""If set to True, the LayerNorm is adjusted to center the gamma values around 0. This improves
numerical stability."""
add_bias_linear: bool = True
"""Include a bias term in all linear layers (QKV projections, after core attention, and two in
MLP layer)."""
add_qkv_bias: bool = False
"""Add a bias term only for QKV projections."""
gated_linear_unit: bool = False
"""Use a gated linear unit for the first linear layer in the MLP."""
activation_func: Callable = F.gelu
"""Activation function to use for the non-linearity in the MLP."""
activation_func_fp8_input_store: bool = False
"""Store the input of MLP activation function in FP8 for backprop to save memory.
The stored input is casted back to the original precision before backprop compuatation."""
num_moe_experts: Optional[int] = None
"""Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None
for no MoE."""
rotary_interleaved: bool = False
"""True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of
first half and second half (LLaMa style). Default to False."""
window_size: Optional[Tuple[int, int]] = None
"""If not None, then will use sliding window attention. The size of the window is specified by
the numbers inside the tuple; -1 is special value meaning "infinite window size"."""
normalization: str = "LayerNorm"
"""Which norm to use for normalization layers, valid options are `LayerNorm` and `RMSNorm`."""
qk_layernorm: bool = False
"""Whether to apply `normalization` type of normalization to the query and key embeddings."""
test_mode: bool = False
"""Whether to run real-time tests."""
calculate_per_token_loss: bool = False
"""Whether cross entropy loss is calculated over the actual number of non-padded tokens in the
global batch, versus the default behavior of assuming all tokens are non-padded."""
multi_latent_attention: bool = False
"""Whether to use multi-latent attention."""
no_rope_freq: Optional[Union[int, List[int]]] = None
"""Controls which layers perform Rotary Position Embedding (RoPE). Accepts either:
An integer N: Creates a pattern where RoPE is skipped every N-1 layers. For example,
no_rope=4 means RoPE is applied for 3 layers, then skipped for 1 layer, repeating this pattern.
A list of integers: Defines a custom pattern where 1 means skip RoPE and 0 means apply RoPE.
For example, [0,1,1,0] means: apply RoPE, skip RoPE, skip RoPE, apply RoPE."""
moe_deepep_num_sms: int = 20
"""Number of SMs to use for DeepEP."""
####################
# initialization
####################
init_method: Optional[Callable] = None
"""Method to initialize weights. Note that bias is always set to zero. Should be a function that
takes a single Tensor and initializes it. If None, will be set to
megatron.core.utils.init_method_normal(init_method_std) which is torch nn init normal with
mean=0.0 and std=init_method_std."""
output_layer_init_method: Optional[Callable] = None
"""Method to initialize weights of the output layer of both attention and MLP blocks. If None,
will be set to megatron.core.utils.scaled_init_method_normal(init_method_std) which is torch nn
init normal with mean=0.0 and std=init_method_std / math.sqrt(2.0 * num_layers)."""
init_method_std: float = 0.02
"""Standard deviation of the zero mean normal for the default initialization method, not used if
init_method and output_layer_init_method are provided."""
init_model_with_meta_device: bool = False
"""
If True, initializes the model with the meta device. This is helpful for
training of very large models. This feature is only works when custom fsdp is turned on.
"""
####################
# mixed-precision
####################
apply_query_key_layer_scaling: bool = False
"""If true, scale Q * K^T by 1 / layer-number. This improve numeric stability when training with
fp16."""
attention_softmax_in_fp32: bool = True
"""If True, run attention masking and softmax in fp32. This should be True if
apply_query_key_layer_scaling is True."""
disable_bf16_reduced_precision_matmul: bool = False
"""If True, sets torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction=False to
prevent matmul from using reduced precision accumulation when using BF16."""
####################
# fusion
####################
bias_activation_fusion: bool = False
"""If True, fuses bias addition and the activation function when possible."""
masked_softmax_fusion: bool = False
"""If True, uses softmax fusion."""
persist_layer_norm: bool = False
"""If True, uses the persistent fused layer norm kernel. This kernel only supports a fixed set
of hidden sizes."""
memory_efficient_layer_norm: bool = False
"""If True, and using local layers (not from TransformerEngine), tells Apex to use the memory
efficient fused LayerNorm kernel. Ignored if not using LayerNorm."""
bias_dropout_fusion: bool = False # TODO: this should be bias_dropout_add_fusion?
"""If True, uses bias dropout fusion."""
apply_rope_fusion: bool = False
"""If True, use fused RoPE kernel."""
####################
# activation recomputation,节省显存占用
####################
recompute_granularity: Optional[str] = None
"""Determines which type of activation recompute to use. Megatron-core supports 'selective'
activation checkpointing where the submodules set in --recompute-modules is checkpointed.
The default is "core_attn" which is the memory intensive part of attention.
These memory intensive activations are also less compute intensive which makes activation
checkpointing more efficient for LLMs (20B+). See Reducing Activation Recomputation in Large
Transformer Models (https://arxiv.org/abs/2205.05198) for more details. 'full' will checkpoint
the entire transformer layer. If None, no recompute is performed and all activations are saved.
If set, must be 'selective' or 'full'. 'selective' always uses all layers.
"""
recompute_method: Optional[str] = None
"""Determines which transformer layers will be recomputed. uniform will uniformly divide the
total number of transformer layers in a transformer block and recompute the input activation of
each divided chunk at the specified granularity. block will recompute the input activations for
only a set number of transformer layers per pipeline stage. The rest of the layers in the
pipeline stage will not have any activations recomputed. If None, and recompute is enabled, all
layers will do recomputation. If set, must be 'uniform' or 'block'."""
recompute_num_layers: Optional[int] = None
"""When recompute_method is uniform, recompute_num_layers is the number of transformer layers in
each uniformly divided recompute unit. When recompute_method is block, recompute_num_layers is
the number of transformer layers to recompute within each pipeline stage. Must be None for
'selective' activation checkpointing."""
distribute_saved_activations: Optional[bool] = None
"""If True, distribute recomputed activations across the model parallel group."""
recompute_modules: Optional[List[str]] = None
"""The submodules to recompute.
choices: "core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe".
default: ["core_attn"].
"core_attn": recompute the core attention part of the transformer layer.
"moe_act": recompute the MoE MLP activation function.
"layernorm": recompute the input_layernorm and pre_mlp_layernorm.
"mla_up_proj": recompute the MLA up projection and RoPE applying parts.
"mlp": recompute the dense MLP submodule.
"moe": recompute the MoE layer.
"moe_act", "layernorm", and "mla_up_proj" use output-discarding checkpointing,
"core_attn", "mlp", and "moe" uses normal checkpointing.
"""
####################
# fp8 related
####################
fp8: Optional[str] = None
"""If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined
choices (1) 'e4m3' uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8
activation and weight tensors and e5m2 for all FP8 output activation gradient tensors."""
fp8_recipe: Optional[str] = "delayed"
"""If set, enables the use of FP8 precision through Transformer Engine. There are 3 predefined
choices (1) 'tensorwise' uses per tensor current scaling recipe, (2) 'delayed'
uses delayed scaling recipe, 3) 'mxfp8' for Blackwell architecture only,
4) 'blockwise' for blockwise scaling recipe."""
fp8_param: bool = False
"""If set, keep the parameters in fp8 precision to save memory. This option must be used
together with fp8 mode (i.e., TransformerConfig.fp8 is not None). Note that not all parameters
will be converted to fp8; for example, biases will remain unchanged. The parameters affected are
primarily the weights of GEMMs. The specific parameters that will be converted to fp8 are
determined by TE."""
fp8_margin: int = 0
"""Margin for the scaling factor computation."""
fp8_interval: int = 1
"""DEPRECATED from TransformerEngine v1.8.0. This flag is ignored.
Controls how often the scaling factor is recomputed.
"""
fp8_amax_history_len: int = 1
"""The length of the amax history window used for scaling factor computation."""
fp8_amax_compute_algo: str = "most_recent"
"""Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2
predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent`
always chooses the most recently seen value.
"""
fp8_wgrad: bool = True
"""When set to False, override FP8 config options and do the wgrad computation
in higher precision."""
fp8_dot_product_attention: bool = False
"""When set to True, use the FP8 implementation of Dot Product Attention."""
fp8_multi_head_attention: bool = False
"""When set to True, use the FP8 implementation of Multi Head Attention."""
tp_only_amax_red: bool = False
"""When set to True, reduce the FP8 AMAX only in the TP or TP-CP domain"""
first_last_layers_bf16: bool = False
"""If True, retains first and last N TransformerBlocks in BF16 as opposed to FP8."""
num_layers_at_start_in_bf16: int = 1
"""Number of layers at the start of the model to keep in BF16 precision when
first_last_layers_bf16 is True."""
num_layers_at_end_in_bf16: int = 1
"""Number of layers at the end of the model to keep in BF16 precision when
first_last_layers_bf16 is True."""
use_kitchen: bool = False
"""Use the kitchen extension for transformer quantization."""
####################
# MoE related
####################
moe_shared_expert_intermediate_size: Optional[int] = None
"""Shared expert total ffn hidden size.
It should be equal to 'num_shared_experts * ffn_size_of_each_shared_expert' if
there are multiple shared experts.
None means no shared expert."""
moe_shared_expert_overlap: bool = False
"""Enable overlapping between shared expert computations and dispatcher communications.
Without this, the shared epxerts execute after the routed experts."""
moe_layer_freq: Union[int, List[int]] = 1
"""Frequency between MoE layers and Dense layers. Accepts either:
- An integer N: Represents a 1:N ratio, meaning one expert layer for every N-1 dense layers.
- A list that defines a custom pattern, e.g.: [1,1,1,0,1,1,1,0,1,1,1,0]"""
moe_ffn_hidden_size: Optional[int] = None
"""MoE Feed-Forward Network hidden size"""
moe_router_load_balancing_type: str = "aux_loss"
"""The load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss
used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the load balancing loss used
in DeepSeekV2 and DeepSeekV3, which computes the loss for each individual sample; "sinkhorn"
corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing.
The default is "aux_loss"."""
moe_router_topk: int = 2
"""Number of experts to route to for each token."""
moe_router_topk_limited_devices: Optional[int] = None
"""Number of EP ranks to consider for each token in group-limited routing,
DEPRECATED and replaced by moe_router_num_groups and moe_router_group_topk.
"""
moe_router_padding_for_fp8: Optional[bool] = False
"""Whether to pad the routing_map to make sure the number of tokens each expert received
is a multiple of 16/32 for FP8 precision. This can remove the explicit padding in the
GroupedMLP layer."""
moe_router_num_groups: Optional[int] = None
"""Number of groups to divide experts into for group-limited routing.
When using group-limited routing:
1. Experts are divided into 'moe_router_num_groups' equal-sized groups
2. For each token, 'moe_router_group_topk' groups are selected based on sum of
top-('moe_router_topk'/'moe_router_group_topk') routing scores within each group
3. From these selected groups, 'moe_router_topk' individual experts are chosen
Two common use cases:
- Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP)
to limit each token to experts on a subset of devices
(See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)
- Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group
to limit each token to experts on a subset of nodes
(See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)
"""
moe_router_group_topk: Optional[int] = None
"""Number of selected groups for group-limited routing."""
moe_router_pre_softmax: bool = False
"""Enable pre-softmax(pre-sigmoid) routing for MoE, which means softmax is before the
top-k selection.
By default, softmax is done after top-k."""
moe_router_topk_scaling_factor: Optional[float] = None
"""Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax
enabled. Defaults to None, which means no scaling."""
moe_router_score_function: str = "softmax"
"""Score function for MoE routing. Can be "softmax" or "sigmoid"."""
moe_router_dtype: Optional[str] = None
"""Data type for routing and expert output weighted averaging. Using fp32 or fp64 can
improve stability especially when the number of experts is large (e.g. finegrained-moe).
None means no changes for dtype."""
moe_router_enable_expert_bias: bool = False
"""TopK routing with dynamic per-expert bias in the aux-loss-free load balancing strategy.
The routing decision is based on the sum of the routing scores and the expert bias.
See https://arxiv.org/abs/2408.15664 for details."""
moe_router_bias_update_rate: float = 1e-3
"""The expert bias is updated based on the number of assigned tokens to each expert
in a global batch, where the bias is increased for the experts with less assigned tokens
and decreased for the experts with more assigned tokens.
The default value 1e-3 is same as that used in DeepSeekV3."""
moe_router_force_load_balancing: bool = False
"""[Experimental] Force load balancing with random logits for MoE router, supports naive topk
and group-limited topk. This is an experimental feature and only for benchmark."""
moe_grouped_gemm: bool = False
"""When there are multiple experts per rank, compress multiple local (potentially small) gemms
in a single kernel launch to improve the utilization and performance by leveraging the Grouped
GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).
"""
moe_use_legacy_grouped_gemm: bool = False
"""Use legacy GroupedMLP rather than TEGroupedMLP.
Note: The legacy one will be deprecated soon."""
moe_aux_loss_coeff: float = 0 # 1e-2 would be a good start value for load balance loss.
"""Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended."""
moe_z_loss_coeff: Optional[float] = None # 1e-3 would be a good start value for z-loss
"""Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended."""
moe_input_jitter_eps: Optional[float] = None
"""Add noise to the input tensor by applying jitter with a specified epsilon value."""
moe_token_dropping: bool = False
"""This feature involves selectively dropping and padding tokens for each expert to achieve a
specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note that this is
currently unsupported so should remain False."""
moe_token_dispatcher_type: str = "allgather"
"""The type of token dispatcher to use. The default is 'allgather'.
Options are 'allgather','alltoall' and 'flex'."""
moe_enable_deepep: bool = False
"""[Experimental] Enable DeepEP for efficient token dispatching and combine in MoE models."""
moe_per_layer_logging: bool = False
"""Enable per-layer logging for MoE, currently supports auxiliary loss and z loss."""
moe_expert_capacity_factor: Optional[float] = None
"""moe_expert_capacity_factor (float): The capacity factor for each expert, None means no token
will be dropped. The default is None."""
moe_pad_expert_input_to_capacity: bool = False
"""moe_pad_expert_input_to_capacity (bool): If True, pads the input for each expert to match
the expert capacity length, effective only after the moe_expert_capacity_factor is set. The
default setting is False."""
moe_token_drop_policy: str = "probs"
"""The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with
the lowest probabilities will be dropped. If "position", tokens at the end of each batch will
be dropped.
"""
moe_layer_recompute: bool = False
"""Memory optimization: checkpointing moe_layer to save actiavtion memory."""
moe_permute_fusion: bool = False
"""Fuse token rearrangement ops during token dispatching."""
moe_apply_probs_on_input: bool = False
"""Apply probs on input of experts instead of applying after activation and glu."""
##################
# Context Parallel
##################
cp_comm_type: Optional[Union[str, List[str]]] = None
"""Inter-gpu communication type for context parallelism.
str: all layers share same communication type.
List[str]: each layer has its separate communication type.
cp_comm_type of each layer can be "p2p" or "all_gather" or "a2a" or "a2a+p2p".
"p2p": Exchange KV chunks with P2P communications in ring topology. P2P is async and can be
overlapped with attention compute.
"all_gather": All-gather to get full sequence of KV before attention. The all-gather is not
async, and cannot be overlapped.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP group, and gather to get
full sequence of QKV.
"a2a+p2p": A hierarchical implementation of context parallelism to attention.
It uses A2A communications in low-level CP groups (e.g., via NVLink),
and P2P communications in high-level CP groups (e.g., via IBLink).
"""
##################
# Cuda Graphs
##################
enable_cuda_graph: bool = False
"""When set to true, TransformerLayer layers are swapped with a CUDA graphed version."""
cuda_graph_use_single_mempool: bool = False
"""When set to true, cudagraphs will be captured inside a single mempool, in which all
cudagraphs may only be used once per step. If false, cudagraphs may be reused across
microbatches. Enabling may reduce cudagraph memory overheads due to memory fragmentation,
however may greatly increase the number of cudagraphs created when the number of microbatches
is high."""
cuda_graph_retain_backward_graph: bool = False
"""When set to true, cudagraph backward passes will be graph captured with 'retain_grad=True'
This may enable cudagraphs for certain modules that are not completely cudagraph safe. For
more details, see: https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html."""
cuda_graph_warmup_steps: int = 3
"""Number of warmup steps for CUDA graphs"""
external_cuda_graph: bool = False
"""When set to true, TransformerLayer layers are swapped with user provided CUDA graphs."""
cuda_graph_scope: str = "full"
"""When external_cuda_graph is set to true, cuda_graph_scope determines the CUDA graphs
capturing scope. Valid values are "full" and "attn". "Full" scope captures a whole Transformer
layer. "Attn" scope only captures operations in TransformerLayer._forward_attention()."""
####################
# miscellaneous
####################
clone_scatter_output_in_embedding: bool = True
"""When set to True, clone the output of scatter_to_sequence_parallel_region in embedding layer
to facilitate garbage collection of input."""
disable_parameter_transpose_cache: bool = False
"""When set to true, the parameter transposes are not cached for subsequent iterations."""
config_logger_dir: str = ""
"""When non-empty, dumps entry-point configs to config_logger_dir"""
flash_decode: bool = False
""" Use the optimized flash decoding kernel during inference. """
use_te_rng_tracker: bool = False
""" Whether to use the TE or MCore version of the RNG tracker. """
inference_rng_tracker: bool = False
""" Whether we should instantiate a separate RNG tracker for inference. """
symmetric_ar_type: Optional[str] = None
"""Type of symmetric all reduce to use"""
mrope_section: Optional[List[int]] = None
""" Multimodal rope section is for channel dimension of temporal, height and width
in rope calculation. """
is_hybrid_model: bool = False
""" Indicates whether this is a hybrid model. """
mamba_state_dim: int = 128
"""The dimensionality of the state representation in Mamba layers."""
mamba_head_dim: int = 64
"""The dimensionality of the heads in the Mamba layers."""
mamba_num_groups: int = 8
"""The number of groups used in Mamba layers."""
mamba_num_heads: Optional[int] = None
"""The number of heads used in Mamba layers.
If None, the number of heads will be hidden_size * expand // mamba_head_dim."""
use_mamba_mem_eff_path: bool = True
"""If True, use the memory efficient path for Mamba layers."""
mlp_chunks_for_prefill: int = 1
"""The number of chunks along the sequence dimension to use for MLP computation
during prefill."""
heterogeneous_block_specs: bool = False
"""Whether to use heterogeneous block specs (nemotron-nas architecture)."""
hetereogenous_dist_checkpoint: bool = False
"""Whether to use heterogenous layers in distributed checkpoint."""
####################
# Quantization
####################
quant_recipe: Optional[RecipeConfig] = None
"""Configuration of any quantization to be applied to the model"""
def __post_init__(self):
"""Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
details.
"""
super().__post_init__()
if self.fp16 and self.bf16:
raise ValueError(
f"Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True."
)
# Apply BF16 matmul precision setting if needed
if self.bf16 and self.disable_bf16_reduced_precision_matmul:
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
if self.num_attention_heads % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_attention_heads ({self.num_attention_heads}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)
if self.ffn_hidden_size is None:
self.ffn_hidden_size = 4 * self.hidden_size
if self.kv_channels is None:
self.kv_channels = self.hidden_size // self.num_attention_heads
if self.num_query_groups is None:
self.num_query_groups = self.num_attention_heads
if self.num_query_groups % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_query_groups ({self.num_query_groups}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)
if self.fp8:
# cannot support first last layer bf16 with delayed scaling
if self.first_last_layers_bf16 and self.fp8_recipe == Fp8Recipe.delayed:
raise ValueError("Delayed scaling does not support first / last layer in BF16.")
# max bf16 layers per pipeline stage
max_bf16_layers_per_pipeline_stage = (
self.num_layers // self.pipeline_model_parallel_size
)
# check start/end bf16 layer counts are valid
if self.first_last_layers_bf16:
if (
self.num_layers_at_start_in_bf16 < 0
or self.num_layers_at_start_in_bf16 > max_bf16_layers_per_pipeline_stage
):
raise ValueError(
f"num_layers_at_start_in_bf16 ({self.num_layers_at_start_in_bf16}) must be "
f"between 0 and number of layers per pipeline stage "
f"({max_bf16_layers_per_pipeline_stage})."
)
if (
self.num_layers_at_end_in_bf16 < 0
or self.num_layers_at_end_in_bf16 > max_bf16_layers_per_pipeline_stage
):
raise ValueError(
f"num_layers_at_end_in_bf16 ({self.num_layers_at_end_in_bf16}) must be "
f"between 0 and number of layers per pipeline stage "
f"({max_bf16_layers_per_pipeline_stage})."
)
if self.fp8_param and not self.fp8:
raise ValueError("fp8_param must be used together with fp8 mode.")
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.expert_model_parallel_size > 1 and self.num_moe_experts is None:
raise ValueError("num_moe_experts must be non None to use expert-parallel.")
if self.num_moe_experts is not None and self.num_moe_experts <= 0:
raise ValueError("num_moe_experts must be non-negative.")
if self.num_moe_experts is not None and self.moe_ffn_hidden_size is None:
self.moe_ffn_hidden_size = self.ffn_hidden_size
warnings.warn("moe_ffn_hidden_size is not set, using ffn_hidden_size instead.")
if self.num_moe_experts is None:
assert (
self.moe_ffn_hidden_size is None
), "moe_ffn_hidden_size must be None when num_experts is not set."
if self.moe_enable_deepep:
if self.moe_token_dispatcher_type != "flex":
raise ValueError("DeepEP backend is only supported with flex token dispatcher.")
if self.moe_token_dispatcher_type == "flex":
if self.moe_pad_expert_input_to_capacity:
raise ValueError(
"Flex token dispatcher does not support moe_pad_expert_input_to_capacity"
)
if self.moe_shared_expert_intermediate_size is not None:
if self.moe_shared_expert_intermediate_size <= 0:
raise ValueError(
f"moe_shared_expert_intermediate_size must be "
f"num_shared_experts * ffn_size_of_each_shared_expert, "
f"but got {self.moe_shared_expert_intermediate_size}"
)
if self.moe_shared_expert_overlap and self.moe_token_dispatcher_type not in [
"alltoall"
]:
raise ValueError(
f"moe_shared_expert_overlap only works with alltoall token dispatcher."
)
if self.moe_expert_capacity_factor is not None:
if self.moe_expert_capacity_factor < 0:
self.moe_expert_capacity_factor = None
if self.moe_router_load_balancing_type not in ["aux_loss", "seq_aux_loss", "none"]:
raise ValueError(
"moe_expert_capacity_factor only works with aux_loss or none load balancing"
)
if self.moe_pad_expert_input_to_capacity:
if self.moe_expert_capacity_factor is None:
raise ValueError(
"moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity"
)
if self.cpu_offloading and (
self.cpu_offloading_num_layers < 0 or self.cpu_offloading_num_layers >= self.num_layers
):
raise ValueError(
f"CPU offloading can be done only for layers less than {self.num_layers}"
)
if self.cpu_offloading and self.pipeline_model_parallel_size > 1:
raise ValueError(
"Currently there is no support for Pipeline parallelism with CPU offloading"
)
if self.cpu_offloading and self.recompute_granularity is not None:
raise ValueError(
"CPU offloading does not work when activation recomputation is enabled"
)
if self.recompute_granularity is not None:
if self.recompute_granularity not in ["full", "selective"]:
raise ValueError(
f'When using recompute_granuarlity: {self.recompute_granularity} must be "full"'
'or "selective".'
)
if self.recompute_method is not None:
if self.recompute_method not in ["block", "uniform"]:
raise ValueError(
f'recompute_method: {self.recompute_method} must be "block" or "uniform".'
)
elif self.recompute_granularity != "selective":
raise ValueError(
f"Using recompute_granularity: {self.recompute_granularity} so "
'recompute_method must be "block" or "uniform"'
)
if self.recompute_granularity != "selective" and self.recompute_num_layers is None:
raise ValueError(
f"When using recompute_granularity: {self.recompute_granularity} "
"recompute_num_layers must be between "
"1 and num_layers_per_pipeline_rank: "
f"{self.num_layers // self.pipeline_model_parallel_size}"
)
elif (
self.recompute_granularity == "selective" and self.recompute_num_layers is not None
):
raise ValueError(
f"When using recompute_granularity: {self.recompute_granularity} "
"recompute_num_layers must be None."
)
if self.distribute_saved_activations and self.sequence_parallel:
raise ValueError(
f"distribute_saved_activations: {self.distribute_saved_activations} must be "
f"false when sequence parallel is enabled: {self.sequence_parallel}"
)
if self.recompute_modules is None:
self.recompute_modules = ["core_attn"]
if self.recompute_granularity == "selective":
if len(self.recompute_modules) > 0:
allowed_modules = {"core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe"}
invalid_modules = set(self.recompute_modules) - allowed_modules
assert not invalid_modules, (
f"Invalid choices for recompute_modules: {invalid_modules}. "
f"Allowed modules are: {allowed_modules}"
)
if "moe_act" in self.recompute_modules and not self.moe_grouped_gemm:
raise ValueError(
"moe_act in recompute_modules is only supported with moe_grouped_gemm."
)
if "mla_up_proj" in self.recompute_modules and not self.multi_latent_attention:
raise ValueError(
"mla_up_proj in recompute_modules is only supported with "
"multi_latent_attention."
)
if "core_attn" in self.recompute_modules:
warnings.warn(
"If you are using transformer_engine as the transformer implementation, "
"the core_attn is from transformer_engine and may be the fused version. "
"For fused attention, you have no need to set 'core_attn' to recompute. "
"Please check that the core_attn recompute is really needed."
)
if self.fp8:
if "moe_act" in self.recompute_modules or "layernorm" in self.recompute_modules:
raise ValueError("moe_act and layernorm recompute cannot work with fp8.")
if self.moe_layer_recompute:
warnings.warn(
"--moe-layer-recompute is deprecated. "
"Use --recompute-granularity selective --recompute-modules moe_layer instead."
)
if self.recompute_granularity == "full":
raise ValueError(
"Do not set --moe-layer-recompute with full recompute granularity. "
)
self.recompute_granularity = "selective"
if "moe" not in self.recompute_modules:
self.recompute_modules.append("moe")
if (
self.num_layers_in_first_pipeline_stage is not None
or self.num_layers_in_last_pipeline_stage is not None
) and (
self.account_for_embedding_in_pipeline_split or self.account_for_loss_in_pipeline_split
):
raise ValueError(
"num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage cannot be"
"set at the same time with account_for_embedding_in_pipeline_split"
"and account_for_loss_in_pipeline_split"
)
# PP layout
if self.pipeline_model_parallel_layout is not None:
# If pipeline layout is set, we will check the conflicts
# with other pipeline layout arguments.
any_conflict = (
self.num_layers_in_first_pipeline_stage is not None
or self.num_layers_in_last_pipeline_stage is not None
or self.account_for_embedding_in_pipeline_split
or self.account_for_loss_in_pipeline_split
)
if any_conflict:
raise ValueError(
"pipeline_model_parallel_layout cannot be set"
" with other pipeline layout arguments."
f" {self.num_layers_in_first_pipeline_stage=},"
f" {self.num_layers_in_last_pipeline_stage=},"
f" {self.account_for_embedding_in_pipeline_split=},"
f" {self.account_for_loss_in_pipeline_split=}."
)
# Transfer pipeline_model_parallel_layout from str or list to
# PipelineParallelLayerLayout
if isinstance(self.pipeline_model_parallel_layout, str):
self.pipeline_model_parallel_layout = PipelineParallelLayerLayout.from_str(
layout=self.pipeline_model_parallel_layout,
pipeline_model_parallel_size=self.pipeline_model_parallel_size,
)
elif isinstance(self.pipeline_model_parallel_layout, list):
# Since list is not hashable, the initialization will not be cached.
self.pipeline_model_parallel_layout = PipelineParallelLayerLayout(
layout=self.pipeline_model_parallel_layout,
pipeline_model_parallel_size=self.pipeline_model_parallel_size,
)
# Check whether the input VPP size conflicts with the PP layout
detected_vpp_size = (
self.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size
)
if self.virtual_pipeline_model_parallel_size is not None:
assert self.virtual_pipeline_model_parallel_size == detected_vpp_size, (
f"virtual_pipeline_model_parallel_size conflicts with"
f" pipeline_model_parallel_layout,"
f" ({self.virtual_pipeline_model_parallel_size=}, "
f" {detected_vpp_size=})"
)
elif detected_vpp_size > 1:
self.virtual_pipeline_model_parallel_size = detected_vpp_size
# Check whether the layout is valid.
self.pipeline_model_parallel_layout.validate_layer_layout(num_layers=self.num_layers)
# Uneven PP
elif (
self.num_layers_in_first_pipeline_stage is not None
or self.num_layers_in_last_pipeline_stage is not None
):
pipeline_parallel_size = self.pipeline_model_parallel_size
num_layers = self.num_layers
if self.num_layers_in_first_pipeline_stage is not None:
if self.num_layers_in_first_pipeline_stage <= 0:
raise ValueError("num_layers_in_first_pipeline_stage must be larger than 0")
if self.virtual_pipeline_model_parallel_size is not None:
if (
self.num_layers_in_first_pipeline_stage
% self.virtual_pipeline_model_parallel_size
!= 0
):
raise ValueError(
f"number of layers at first stage: "
f"{self.num_layers_in_first_pipeline_stage}"
f"must be divisible by virtual pipeline"
f"parallel degree {self.virtual_pipeline_model_parallel_size}"
)
num_layers -= self.num_layers_in_first_pipeline_stage
pipeline_parallel_size -= 1
if self.num_layers_in_last_pipeline_stage is not None:
if self.num_layers_in_last_pipeline_stage <= 0:
raise ValueError("num_layers_in_last_pipeline_stage must be larger than 0")
if self.virtual_pipeline_model_parallel_size is not None:
if (
self.num_layers_in_last_pipeline_stage
% self.virtual_pipeline_model_parallel_size
!= 0
):
raise ValueError(
f"number of layers at last stage: "
f"{self.num_layers_in_last_pipeline_stage}"
f"must be divisible by virtual pipeline"
f"parallel degree {self.virtual_pipeline_model_parallel_size}"
)
num_layers -= self.num_layers_in_last_pipeline_stage
pipeline_parallel_size -= 1
# Here pipeline_parallel_size is the number of middle PP stages. If there are middle
# PP stages, check number of layers at middle stage is divisible by middle PP size.
if pipeline_parallel_size and not num_layers % pipeline_parallel_size == 0:
raise ValueError(
f"number of layers at middle stage: {num_layers} must be divisible by"
f"the middle pipeline model parallel size {pipeline_parallel_size}"
)
# If there are middle PP stages, check number of layers
# on each middle PP rank is divisible by VPP size.
if pipeline_parallel_size and self.virtual_pipeline_model_parallel_size is not None:
num_layers_per_middle_pipeline_rank = num_layers // pipeline_parallel_size
if (
not num_layers_per_middle_pipeline_rank
% self.virtual_pipeline_model_parallel_size
== 0
):
raise ValueError(
f"number of layers on each middle pipeline rank:"
f"{num_layers_per_middle_pipeline_rank} must be divisible by virtual"
f"pipeline parallel degree {self.virtual_pipeline_model_parallel_size}"
)
elif (
self.account_for_embedding_in_pipeline_split or self.account_for_loss_in_pipeline_split
):
if self.virtual_pipeline_model_parallel_size is None:
num_layers = self.num_layers
if self.account_for_embedding_in_pipeline_split:
num_layers += 1
if self.account_for_loss_in_pipeline_split:
num_layers += 1
if not num_layers % self.pipeline_model_parallel_size == 0:
raise ValueError(
f"number of middle layers: {num_layers} must be divisible by "
f"middle pipeline_model_parallel_size {self.pipeline_model_parallel_size}"
)
else:
num_layers = self.num_layers
if self.account_for_embedding_in_pipeline_split:
num_layers += 1
if self.account_for_loss_in_pipeline_split:
num_layers += 1
if not num_layers % self.pipeline_model_parallel_size == 0:
raise ValueError(
f"num_layers: {num_layers} after enable"
f"account_for_embedding_in_pipeline_split or "
f"account_for_loss_in_pipeline_split must be divisible"
f"by pipeline_model_parallel_size "
f"{self.pipeline_model_parallel_size}"
)
num_layers_per_pipeline_rank = num_layers // self.pipeline_model_parallel_size
if (
not num_layers_per_pipeline_rank % self.virtual_pipeline_model_parallel_size
== 0
):
raise ValueError(
f"number of layers on each pipeline rank: {num_layers_per_pipeline_rank}"
f"(after enable account_for_embedding_in_pipeline_split or "
f"account_for_loss_in_pipeline_split) must be divisible by"
f"virtual_pipeline_model_parallel_size"
f"{self.virtual_pipeline_model_parallel_size}"
)
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.bias_activation_fusion:
if self.activation_func not in [F.gelu, F.silu]:
raise ValueError(
"When bias_activation_fusion is True, activation function should be either "
"gelu or swiglu"
)
if (
self.activation_func == F.gelu
and not self.gated_linear_unit
and not self.add_bias_linear
):
raise ValueError(
"When bias_activation_fusion is True, gated_linear_unit is False, "
"and activation function is gelu, add_bias_linear must also be True."
)
if self.activation_func_fp8_input_store:
if self.activation_func != F.silu or not self.gated_linear_unit:
raise ValueError("Storing activation input in FP8 is supported only for SwiGLU.")
if self.apply_rope_fusion:
if self.multi_latent_attention:
warnings.warn(
"apply_rope_fusion for multi-latent attention only supports training. "
"It is experimental and may change in future versions."
)
else:
if self.rotary_interleaved:
if not is_te_min_version("2.3.0.dev0"):
raise ValueError(
"rotary_interleaved does not work with apply_rope_fusion for "
"TE < 2.3.0.dev0. Please install TE >= 2.3.0.dev0"
)
from megatron.core.models.common.embeddings.rope_utils import (
fused_apply_rotary_pos_emb,
fused_apply_rotary_pos_emb_thd,
)
if fused_apply_rotary_pos_emb is None and fused_apply_rotary_pos_emb_thd is None:
raise ValueError(
"apply_rope_fusion is not available. Please install TE >= 1.4 or Apex."
)
if self.multi_latent_attention and self.rotary_interleaved:
raise ValueError("rotary_interleaved does not work with multi_latent_attention.")
if self.init_method is None:
self.init_method = init_method_normal(self.init_method_std)
if self.output_layer_init_method is None:
self.output_layer_init_method = scaled_init_method_normal(
self.init_method_std,
self.num_layers,
multiplier=2.0 if not self.is_hybrid_model else 1.0,
)
if self.num_moe_experts is not None:
assert not self.add_bias_linear, "Bias is not supported for MoE"
if self.moe_router_enable_expert_bias and self.moe_router_score_function != "sigmoid":
raise ValueError(
"Expert bias for aux-loss-free routing only supports sigmoid score function."
"Please set --moe-router-score-function sigmoid for sigmoid score function."
)
if self.num_moe_experts and self.fp8:
# TE version below 1.7.0 will raise Error when handle zeros tokens for expert
if not is_te_min_version("1.7.0.dev0"):
raise ValueError(
"Only transformer-engine>=1.7.0 supports MoE FP8 training, "
f"but your version is {get_te_version()}."
)
if self.moe_grouped_gemm and not is_te_min_version("1.11.0"):
raise ValueError(
"Only transformer-engine>=1.11.0 supports FP8 grouped gemm, "
f"but your version is {get_te_version()}."
)
if self.moe_router_padding_for_fp8:
if self.fp8 is None:
raise ValueError("fp8 must be specified when moe_router_padding_for_fp8 is True.")
if self.moe_token_dispatcher_type in ["allgather", "alltoall_seq"]:
raise ValueError(
"allgather and alltoall_seq dispatcher does not support "
"moe_router_padding_for_fp8."
)
if (
self.moe_router_topk == 1
and self.moe_router_score_function == "softmax"
and not self.moe_router_pre_softmax
and self.moe_router_load_balancing_type != "sinkhorn"
):
# Requires applying softmax before selecting the top-k when k is 1,
# since softmax on a [num_tokens, 1] would yield a zero gradient.
raise ValueError("Please use --moe-router-pre-softmax when topk is 1.")
if self.moe_router_group_topk:
if self.moe_router_topk_limited_devices:
raise ValueError(
"moe_router_topk_limited_devices is deprecated and replaced by "
"moe_router_group_topk and moe_router_num_groups."
)
if not self.moe_router_num_groups:
raise ValueError(
"When using group limited routing, moe_router_num_groups must be specified."
)
else:
assert self.num_moe_experts % self.moe_router_num_groups == 0, (
f"num_moe_experts ({self.num_moe_experts}) should be divisible by "
f"moe_router_num_groups ({self.moe_router_num_groups})."
)
assert self.moe_router_group_topk <= self.moe_router_num_groups, (
f"moe_router_group_topk ({self.moe_router_group_topk}) should be smaller than "
f"moe_router_num_groups ({self.moe_router_num_groups})."
)
elif self.moe_router_topk_limited_devices:
warnings.warn(
"moe_router_topk_limited_devices is deprecated. Use moe_router_group_topk and "
"moe_router_num_groups instead."
)
self.moe_router_group_topk = self.moe_router_topk_limited_devices
self.moe_router_num_groups = self.expert_model_parallel_size
if self.enable_cuda_graph:
if self.cpu_offloading:
raise ValueError("CUDA graphs not supported with CPU offloading.")
if self.recompute_granularity:
raise ValueError("CUDA graphs not supported with activation recomputation.")
if self.moe_token_dispatcher_type in ["allgather"]:
if self.variable_seq_lengths is True:
raise ValueError(
f"Token dispatcher type: {self.moe_token_dispatcher_type} does not support "
f"variable sequence length, please use alltoall dispatcher instead."
)
if self.moe_permute_fusion:
from megatron.core.transformer.moe.moe_utils import (
fused_permute,
fused_permute_with_probs,
fused_sort_chunks_by_index,
fused_sort_chunks_by_index_with_probs,
fused_unpermute,
)
if (
fused_permute is None
or fused_permute_with_probs is None
or fused_sort_chunks_by_index is None
or fused_sort_chunks_by_index_with_probs is None
or fused_unpermute is None
):
raise ValueError("fused permutation is not available. Please install TE >= 2.1.0.")
if self.context_parallel_size > 1 and self.cp_comm_type is not None:
if isinstance(self.cp_comm_type, list):
assert len(self.cp_comm_type) == self.num_layers, (
f"Length of cp_comm_type ({len(self.cp_comm_type)}) should equal to "
f"the total number of transformer layers ({self.num_layers})!"
)
else:
assert isinstance(
self.cp_comm_type, str
), "Unsupported communication type for context parallelism!"
assert (
self.pipeline_model_parallel_size > 0
), f"Pipeline model parallel size must be larger than 0 \
when enable --standalone-embedding-stage and --standalone-loss-stage"
if (
self.num_moe_experts is not None
and self.num_moe_experts >= 32
and not self.moe_router_dtype
):
warnings.warn(
"Using a large number of experts (e.g. >=32) without fp32 routing. "
"Consider enabling moe_router_dtype for better numerical stability."
)
if self.symmetric_ar_type is not None:
if not HAVE_PACKAGING:
raise ImportError(
"packaging is not installed. Please install it with `pip install packaging`."
)
assert is_torch_min_version("2.7.0a0"), "Must have at least torch version 2.7 or higher"
assert is_te_min_version("2.3.0") or get_te_version() == PkgVersion(
"2.3.0.dev0+39c0e70"
), "Must have at least TE version 2.3 or higher to use symmetric memory all reduce"
if self.no_rope_freq:
assert not self.flash_decode, "flash_decode cannot be used with no_rope."
if isinstance(self.no_rope_freq, int):
assert self.num_layers % self.no_rope_freq == 0, (
f"no_rope_freq={self.no_rope_freq} should be "
f"divisible by num_layers={self.num_layers}."
)
# Convert integer pattern to list pattern
# e.g. no_rope=4 with num_layers=8 becomes [0,0,0,1,0,0,0,1]
pattern = [0] * (self.no_rope_freq - 1) + [1]
self.no_rope_freq = pattern * (self.num_layers // self.no_rope_freq)
else:
assert len(self.no_rope_freq) == self.num_layers, (
f"Length of no_rope list ({len(self.no_rope_freq)}) must match "