Skip to content

Verl AgentLoop Rollout 相关

📅 发表于 2025/11/26
🔄 更新于 2025/11/26
👁️ -- 次访问
📝 0 字
0 分钟
verl
#agent-loop
#AgentState
#AgentData
#AgentLoopOutput
#InternalAgentLoop
#_handle_generating_state
#LLM生成
#_handle_processing_tools_state
#环境交互
#用户交互
#prompt_ids
#response_ids
#response_mask
#reward_score
#rollout.prompt_length
#rollout.response_length
#attention_mask
#rm_scores
#rollout调用过程
#_get_gen_batch
#AgentLoopManager
#generate_sequences
#AgentLoopWorker
#异步提交多任务
#数据后处理
#_run_agent_loop
#_initialize_llm_servers
#_init_agent_loop_workers
#SGLangReplica
#SGLangHttpServer
#RolloutReplica

基础版AgentLoop

AgentLoop 核心功能

核心功能

  • 处理1条环境交互数据
  • 包括准备初始Messages、LLM-Generate、调用环境
  • 数据处理

数据结构定义

AgentState 定义

可以自定义自己的状态。

python
class AgentState(Enum):
  PENDING = "pending"
  GENERATING = "generating"
  PROCESSING_TOOLS = "processing_tools"
  TERMINATED = "terminated"
  INTERACTING = "interacting"
  SKIIPED = "skipped" # 比如跳过一些case

AgentData 定义

可以自定义自己的AgentData格式,去掉一些不用的。

python
class AgentData:
  """Encapsulates all state variables for the agent loop."""

  def __init__(
      self,
      messages: list[dict[str, Any]],
      image_data: Any,
      metrics: dict[str, Any],
      request_id: str,
      tools_kwargs: dict[str, Any],
      interaction: Optional[BaseInteraction] = None,
      interaction_kwargs: Optional[dict[str, Any]] = None,
  ):
      self.messages = messages
      self.image_data = image_data
      self.metrics = metrics
      self.request_id = request_id
      self.tools_kwargs = tools_kwargs
      self.interaction = interaction
      self.interaction_kwargs = interaction_kwargs or {}

      # State variables
      self.prompt_ids: list[int] = []
      self.response_ids: list[int] = []
      self.response_mask: list[int] = []
      self.response_logprobs: list[float] = []
      self.turn_scores: list[float] = []
      self.tool_rewards: list[float] = []
      self.user_turns = 0
      self.assistant_turns = 0

      # Temporary state for tool calls
      self.tool_calls: list[FunctionCall] = []

      # Extra fields for dynamic addition
      self.extra_fields: dict[str, Any] = {}

AgentLoopOutput (未pad) 定义

python
class AgentLoopOutput(BaseModel):
  """Agent loop output."""

  prompt_ids: list[int] 
  """Prompt token ids."""
  response_ids: list[int] 
  """Response token ids including LLM generated token, tool response token."""
  response_mask: list[int] 
  """Response mask, 1 for LLM generated token, 0 for tool response token."""
  response_logprobs: Optional[list[float]] = None
  """Log probabilities for the response tokens."""
  multi_modal_data: Optional[dict[str, Any]] = None
  """Multi-modal data for multi-modal tools."""
  reward_score: Optional[float] = None
  """Reward score for the trajectory."""
  num_turns: int = 0
  """Number of chat turns, including user, assistant, tool."""
  metrics: AgentLoopMetrics
  """Auxiliary performance metrics"""
  extra_fields: dict[str, Any] = {}
  """Extra fields for dynamic addition."""

InternalAgentLoop (已pad) 定义

python
class _InternalAgentLoopOutput(AgentLoopOutput):
  """Internal agent loop output with padded sequences."""

  model_config = ConfigDict(arbitrary_types_allowed=True)

  prompt_ids: torch.Tensor
  """Padded prompt token ids."""
  response_ids: torch.Tensor
  """Padded response token ids."""
  input_ids: torch.Tensor
  """Padded input ids(prompt_ids + response_ids)."""
  position_ids: torch.Tensor
  """Padded position ids."""
  response_mask: torch.Tensor
  """Padded response mask."""
  attention_mask: torch.Tensor
  """Padded attention mask."""
  response_logprobs: Optional[torch.Tensor] = None
  """Padded log probabilities for the response tokens."""
  multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None
  """Multi-modal inputs for processors (e.g., pixel_values, image_grid_thw)."""
  extra_fields: dict[str, Any] = {}
  """Extra fields for dynamic addition."""

DataProto

DataProto

核心组件

  • batch:TensorDict
    • 存放Tensor数据,比如input_ids,attention_mask等。
    • 可以执行pop、切片等功能。
  • non_tensor_batch:dict
    • 非Tensor数据
  • meta_info: dict
    • meta信息。

常见方法

  • from_dict、select、chunkpop等。

比较长,可先不细看,非重点

python
@dataclass
class DataProto:
  """
  A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.
  It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.
  TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the
  same batch size should be put inside batch.
  """

  batch: TensorDict = None
  non_tensor_batch: dict = field(default_factory=dict)
  meta_info: dict = field(default_factory=dict)

  def __post_init__(self):
      # perform necessary checking
      self.check_consistency()

  def __len__(self):
      if self.batch is not None:
          return self.batch.batch_size[0]
      elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
          random_key = list(self.non_tensor_batch.keys())[0]
          return self.non_tensor_batch[random_key].shape[0]
      else:
          return 0

  def __getitem__(self, item):
      """
      Enhanced indexing for DataProto objects.

      Args:
          item: Can be one of:
              - int: A single index
              - slice: A slice object (start:stop:step)
              - list: A list of indices
              - numpy.ndarray: An array of indices
              - torch.Tensor: A tensor of indices

      Returns:
          DataProto: For all indexing types except single integers
          DataProtoItem: Only for single integer indices
      """
      # Case 1: Slice object - use the slice method
      if isinstance(item, slice):
          return self.slice(item.start, item.stop, item.step)

      # Case 2: List, numpy array, or torch tensor - use sel_idxs
      elif isinstance(item, list | np.ndarray | torch.Tensor):
          return self.select_idxs(item)

      # Case 3: Single integer - return DataProtoItem for backward compatibility
      elif isinstance(item, int | np.integer):
          tensor_data = self.batch[item] if self.batch is not None else None
          non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
          return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)

      # # Case 4: Unsupported type
      else:
          raise TypeError(f"Indexing with {type(item)} is not supported")

  def print_size(self, prefix=""):
      size_of_tensordict = 0
      if self.batch is not None:
          for _, tensor in self.batch.items():
              size_of_tensordict += tensor.element_size() * tensor.numel()
      size_of_numpy_array = 0
      for _, numpy_array in self.non_tensor_batch.items():
          size_of_numpy_array += numpy_array.nbytes

      size_of_numpy_array /= 1024**3
      size_of_tensordict /= 1024**3

      message = f"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB"

      if prefix:
          message = f"{prefix}, " + message
      print(message)

  @classmethod
  def from_dict(
      cls,
      tensors: Optional[dict[str, torch.Tensor]] = None,
      non_tensors=None,
      meta_info=None,
      num_batch_dims=1,
      auto_padding=False,
  ):
      """Create a DataProto from a dict of tensors. This assumes that
      1. All the tensor in tensors have the same dim0
      2. Only dim0 is the batch dim
      """

      assert num_batch_dims > 0, "num_batch_dims must be greater than zero"
      if non_tensors is not None:
          assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None."

      if tensors is None:
          tensors = {}
      if meta_info is None:
          meta_info = {}
      if non_tensors is None:
          non_tensors = {}

      assert isinstance(non_tensors, dict)

      # get and check batch size
      batch_size = None
      pivot_key = None
      for key, tensor in tensors.items():
          if batch_size is None:
              batch_size = tensor.shape[:num_batch_dims]
              pivot_key = key
          else:
              current_batch = tensor.shape[:num_batch_dims]
              assert batch_size == current_batch, (
                  f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. "
                  f"Got {pivot_key} has {batch_size}, {key} has {current_batch}"
              )

      for key, val in non_tensors.items():
          if not isinstance(val, np.ndarray):
              non_tensors[key] = np.array(val, dtype=object)

      tensor_dict = TensorDict(source=tensors, batch_size=batch_size) if tensors else None
      if auto_padding:
          meta_info[DataProtoConfig.auto_padding_key] = True
      return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)

  @classmethod
  def from_tensordict(
      cls,
      tensor_dict: TensorDict = None,
      meta_info=None,
      num_batch_dims=1,
  ):
      """Create a DataProto from a TensorDict. This assumes that
      1. All the tensor in tensor_dict have the same dim0
      2. Only dim0 is the batch dim
      """
      assert version.parse(tensordict.__version__) >= version.parse("0.10.0"), (
          "Build DataProto from TensorDict at least requires tensordict version 0.10.0"
      )
      from tensordict import NonTensorData, NonTensorStack

      assert num_batch_dims > 0, "num_batch_dims must be greater than zero"
      if not all(isinstance(val, torch.Tensor) for val in tensor_dict.values()):
          assert num_batch_dims == 1, "only support num_batch_dims=1 when tensor_dict contains non tensor data."

      if meta_info is None:
          meta_info = {}
      batch = {}
      non_tensor_batch = {}
      batch_size = None
      for key, val in tensor_dict.items():
          if isinstance(val, torch.Tensor):
              batch[key] = val
              if batch_size is None:
                  batch_size = val.shape[:num_batch_dims]
          elif isinstance(val, NonTensorStack):
              non_tensor_batch[key] = np.array([elem.data for elem in val], dtype=object)
          elif isinstance(val, NonTensorData):
              meta_info[key] = val.data

      return cls(
          batch=TensorDict(batch, batch_size=batch_size),
          non_tensor_batch=non_tensor_batch,
          meta_info=meta_info,
      )

  def to(self, device) -> "DataProto":
      """move the batch to device

      Args:
          device (torch.device, str): torch device

      Returns:
          DataProto: the current DataProto

      """
      if self.batch is not None:
          self.batch = self.batch.to(device)
      return self

  def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> "DataProto":
      """Select a subset of the DataProto via batch_keys and meta_info_keys

      Args:
          batch_keys (list, optional): a list of strings indicating the keys in batch to select
          meta_info_keys (list, optional): a list of keys indicating the meta info to select

      Returns:
          DataProto: the DataProto with the selected batch_keys and meta_info_keys
      """
      # TODO (zhangchi.usc1992) whether to copy
      if batch_keys is not None:
          batch_keys = tuple(batch_keys)
          sub_batch = self.batch.select(*batch_keys)
      else:
          sub_batch = self.batch

      if non_tensor_batch_keys is not None:
          non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys}
      else:
          non_tensor_batch = self.non_tensor_batch

      if deepcopy:
          non_tensor_batch = copy.deepcopy(non_tensor_batch)

      if meta_info_keys is not None:
          sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys}
      else:
          sub_meta_info = self.meta_info

      if deepcopy:
          sub_meta_info = copy.deepcopy(sub_meta_info)

      return type(self)(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)

  def select_idxs(self, idxs):
      """
      Select specific indices from the DataProto.

      Args:
          idxs (torch.Tensor or numpy.ndarray or list): Indices to select

      Returns:
          DataProto: A new DataProto containing only the selected indices
      """
      if isinstance(idxs, list):
          idxs = torch.tensor(idxs)
          if idxs.dtype != torch.bool:
              idxs = idxs.type(torch.int32)

      if isinstance(idxs, np.ndarray):
          idxs_np = idxs
          idxs_torch = torch.from_numpy(idxs)
      else:  # torch.Tensor
          idxs_torch = idxs
          idxs_np = idxs.detach().cpu().numpy()

      batch_size = int(idxs_np.sum()) if idxs_np.dtype == bool else idxs_np.shape[0]

      if self.batch is not None:
          # Use TensorDict's built-in indexing capabilities
          selected_batch = TensorDict(
              source={key: tensor[idxs_torch] for key, tensor in self.batch.items()},
              batch_size=(batch_size,),
              device=self.batch.device,
          )
      else:
          selected_batch = None

      selected_non_tensor = {}
      for key, val in self.non_tensor_batch.items():
          selected_non_tensor[key] = val[idxs_np]

      return type(self)(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info)

  def slice(self, start=None, end=None, step=None):
      """
      Slice the DataProto and return a new DataProto object.
      This is an improved version of direct slicing which returns a DataProtoItem.

      Args:
          start (int, optional): Start index. Defaults to None (start from beginning).
          end (int, optional): End index (exclusive). Defaults to None (go to end).
          step (int, optional): Step size. Defaults to None (step=1).

      Returns:
          DataProto: A new DataProto containing the sliced data

      Examples:
          # Using the slice method directly
          sliced_data = data_proto.slice(10, 20)

          # Using enhanced indexing (returns DataProto)
          sliced_data = data_proto[10:20]
          sliced_data = data_proto[::2]  # Every other element

          # Using list indexing (returns DataProto)
          indices = [1, 5, 10]
          selected_data = data_proto[indices]

          # Single index still returns DataProtoItem
          single_item = data_proto[5]
      """
      # Create a slice object
      slice_obj = slice(start, end, step)

      # Handle the batch data
      if self.batch is not None:
          # Use TensorDict's built-in slicing capabilities
          sliced_batch = self.batch[slice_obj]
      else:
          sliced_batch = None

      # Handle the non-tensor batch data
      sliced_non_tensor = {}
      for key, val in self.non_tensor_batch.items():
          sliced_non_tensor[key] = val[slice_obj]

      # Return a new DataProto object
      return type(self)(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info)

  def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto":
      """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`

      Args:
          batch_keys (list, optional): a list of strings indicating the keys in batch to pop
          meta_info_keys (list, optional): a list of keys indicating the meta info to pop

      Returns:
          DataProto: the DataProto with the poped batch_keys and meta_info_keys
      """
      if batch_keys is None:
          batch_keys = []
      if meta_info_keys is None:
          meta_info_keys = []
      if non_tensor_batch_keys is None:
          non_tensor_batch_keys = []

      tensors = {}
      # tensor batch
      for key in batch_keys:
          assert key in self.batch.keys()
          tensors[key] = self.batch.pop(key)
      non_tensors = {}
      # non tensor batch
      for key in non_tensor_batch_keys:
          assert key in self.non_tensor_batch.keys()
          non_tensors[key] = self.non_tensor_batch.pop(key)
      meta_info = {}
      for key in meta_info_keys:
          assert key in self.meta_info.keys()
          meta_info[key] = self.meta_info.pop(key)
      return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)

  def union(self, other: "DataProto") -> "DataProto":
      """Union with another DataProto. Union batch and meta_info separately.
      Throw an error if

      - there are conflict keys in batch and they are not equal
      - the batch size of two data batch is not the same
      - there are conflict keys in meta_info and they are not the same.

      Args:
          other (DataProto): another DataProto to union

      Returns:
          DataProto: the DataProto after union
      """
      self.batch = union_tensor_dict(self.batch, other.batch)
      self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch)
      self.meta_info = union_two_dict(self.meta_info, other.meta_info)
      return self

  def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):
      r"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch
      dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.


      Args:
          mini_batch_size (int): mini-batch size when iterating the dataset. We require that
              ``batch.batch_size[0] % mini_batch_size == 0``.
          epochs (int): number of epochs when iterating the dataset.
          dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The
              dataloader_kwargs is the kwargs passed to the DataLoader.

      Returns:
          Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration
              steps is ``self.batch.batch_size * epochs // mini_batch_size``
      """
      assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0"
      # we can directly create a dataloader from TensorDict
      if dataloader_kwargs is None:
          dataloader_kwargs = {}

      if seed is not None:
          generator = torch.Generator()
          generator.manual_seed(seed)
      else:
          generator = None

      assert isinstance(dataloader_kwargs, dict)
      train_dataloader = DataLoader(
          dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs
      )

      def get_data():
          for _ in range(epochs):
              for d in train_dataloader:
                  d.meta_info = self.meta_info
                  yield d

      return iter(get_data())


  def chunk(self, chunks: int) -> list["DataProto"]:
      """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.

      Args:
          chunks (int): the number of chunks to split on dim=0

      Returns:
          List[DataProto]: a list of DataProto after splitting
      """
      if not self.is_padding_enabled():
          assert len(self) % chunks == 0, (
              f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}."
          )

      bsz_in_batch = None
      if self.batch is not None:
          batch_lst = self.batch.chunk(chunks=chunks, dim=0)
          bsz_in_batch = np.array([batch.batch_size[0] for batch in batch_lst])
          chunk_indices = np.cumsum(bsz_in_batch)[:-1]
      else:
          batch_lst = [None for _ in range(chunks)]

      non_tensor_batch_lst = [{} for _ in range(chunks)]
      for key, val in self.non_tensor_batch.items():
          assert isinstance(val, np.ndarray)
          if bsz_in_batch is not None:
              non_tensor_lst = np.array_split(val, chunk_indices.tolist())
          else:
              non_tensor_lst = np.array_split(val, chunks)
          assert len(non_tensor_lst) == chunks
          for i in range(chunks):
              non_tensor_batch_lst[i][key] = non_tensor_lst[i]

      output = []
      for i in range(chunks):
          output.append(
              type(self)(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info)
          )

      return output

  def split(self, split_size: int) -> list["DataProto"]:
      """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.

      Args:
          split_size (int): the size of each split

      Returns:
          List[DataProto]: a list of DataProto after splitting
      """
      return [self[i : i + split_size] for i in range(0, len(self), split_size)]

交互流程

整体状态切换流程

通过状态切换,来做流程控制。

python
# State machine loop
state = AgentState.PENDING
while state != AgentState.TERMINATED:
    if state == AgentState.PENDING:
        state = await self._handle_pending_state(agent_data, sampling_params)
    elif state == AgentState.GENERATING: 
        state = await self._handle_generating_state(agent_data, sampling_params) 
    elif state == AgentState.PROCESSING_TOOLS: 
        state = await self._handle_processing_tools_state(agent_data) 
    elif state == AgentState.INTERACTING:
        state = await self._handle_interacting_state(agent_data)
    else:
        logger.error(f"Invalid state: {state}")
        state = AgentState.TERMINATED

准备初始messages

从数据或环境里读取messages
  • 作为初始prompt,构造AgentData
  • 后进入上文的AgentLoop流程
python
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
  messages = list(kwargs["raw_prompt"])
	agent_data = AgentData(
            messages=messages,
            image_data=image_data,
            metrics=metrics,
            request_id=request_id,
            tools_kwargs=tools_kwargs,
            interaction=interaction,
            interaction_kwargs=interaction_kwargs,
        )
  # State machine loop
  state = AgentState.PENDING
  ....
准备 agent_data.prompt_ids
  • 初始prompttokenize
python
async def _handle_pending_state(self, agent_data: AgentData, sampling_params: dict[str, Any]) -> AgentState:
    """Handle the pending state: prepare the prompt and start generation."""
    agent_data.prompt_ids = await self.loop.run_in_executor(
        None,
        lambda: self.tokenizer.apply_chat_template(
            agent_data.messages,
            tools=self.tool_schemas,
            add_generation_prompt=True,
            tokenize=True,
            **self.apply_chat_template_kwargs,
        ),
    )
    return AgentState.GENERATING

LLM Generation

LLM Generation
  • 核心是基于现有prompt_ids去做生成,比如调用SGLangHttpServer去做生成
  • LLM生成结果,追加到messagesprompt_ids里,步骤+1
  • response_mask设为1,是需要参与训练的。
  • 从当前模型回复中,解码出工具调用信息状态转为工具调用
python
async def _handle_generating_state(
    self, agent_data: AgentData, sampling_params: dict[str, Any], ignore_termination: bool = False
) -> AgentState:
    """Handle the generating state: generate model response and check for tool calls."""
    add_messages: list[dict[str, Any]] = []

    with simple_timer("generate_sequences", agent_data.metrics):
        output = await self.server_manager.generate( 
            request_id=agent_data.request_id, 
            prompt_ids=agent_data.prompt_ids, 
            sampling_params=sampling_params, 
            image_data=agent_data.image_data, 
        )

    agent_data.assistant_turns += 1 
    agent_data.response_ids = output.token_ids
    # prompt_ids 追加当前轮的LLM Output
    agent_data.prompt_ids += agent_data.response_ids
    # resposne_mask 全部置为1,LLM Output 是需要进行训练的
    agent_data.response_mask += [1] * len(agent_data.response_ids)
    if output.log_probs:
        agent_data.response_logprobs += output.log_probs
		
    # 删掉一些处理超长终止的代码。
    
    # 从回复中提取工具,如果有工具调用,则跳转到工具调用状态。
    _, agent_data.tool_calls = await self.tool_parser.extract_tool_calls(agent_data.response_ids) 

    # 解码当前的模型回复的message,追加到全局message里,轨迹线增加一个step
    assistant_message = await self.loop.run_in_executor(
                None, lambda: self.tokenizer.decode(agent_data.response_ids, skip_special_tokens=True) 
            )
    add_messages.append({"role": "assistant", "content": assistant_message}) 
    agent_data.messages.extend(add_messages) 

    # Determine next state
    if agent_data.tool_calls:
        # 跳转到工具调用状态
        return AgentState.PROCESSING_TOOLS
    else:
        # 没有工具调用,则结束,此处删除了一些Interaction的内容
        return AgentState.TERMINATED

环境交互/工具调用

环境交互
  • 环境返回结果放到messages里、追加prompt_ids
  • response_mask设为0不参与训练
python
async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentState:
    """Handle the processing tools state: execute tool calls and prepare tool responses."""
    add_messages: list[dict[str, Any]] = []
    new_images_this_turn: list[Any] = []  # Local variable instead of agent_data attribute

    tasks = []
    tool_call_names = []
    for tool_call in agent_data.tool_calls[: self.max_parallel_calls]:  
        tasks.append(self._call_tool(tool_call, agent_data.tools_kwargs)) 
        tool_call_names.append(tool_call.name)

    with simple_timer("tool_calls", agent_data.metrics): 
        responses = await asyncio.gather(*tasks) 

    # Process tool responses and update multi_modal_data
    # Removed: agent_data.new_images_this_turn = []
    for tool_response, tool_reward, _ in responses:
        # Create message from tool response
        if tool_response.image or tool_response.video:
            # Multi-modal content with structured format
            content = []
            if tool_response.image:
                content.append({"type": "image"})
            if tool_response.video:
                content.append({"type": "video"})
            if tool_response.text:
                content.append({"type": "text", "text": tool_response.text})
            message = {"role": "tool", "content": content}
        else:
            # Text-only content
            message = {"role": "tool", "content": tool_response.text or ""} 

        add_messages.append(message) 

        # Handle image data
        if tool_response.image:
           pass

        if tool_reward is not None:
            agent_data.tool_rewards.append(tool_reward)

    agent_data.messages.extend(add_messages) 
    # 解析出环境的response_ids
    response_ids = await self.loop.run_in_executor( 
        None, 
        lambda: self.tokenizer.apply_chat_template(add_messages, add_generation_prompt=True, tokenize=True), 
    ) 
    response_ids = response_ids[len(self.system_prompt) :] 
    if len(agent_data.response_mask) + len(response_ids) >= self.response_length:
        return AgentState.TERMINATED
    
    # 更新 prompt_ids and response_mask, response_mask设为0,不参与训练
    agent_data.prompt_ids += response_ids 
    agent_data.response_mask += [0] * len(response_ids) 
    if agent_data.response_logprobs:
        agent_data.response_logprobs += [0.0] * len(response_ids)
    agent_data.user_turns += 1
    return AgentState.GENERATING

用户交互

Interaction System for Multi-turn RL Training

用户交互
  • 获取用户输入
  • 追加messages、prompt_ids
  • response_mask设为0,不参与训练
python
async def _handle_interacting_state(self, agent_data: AgentData) -> AgentState:
    """Handle the interacting state: get user input from interaction."""
    (
        should_terminate_sequence,
        interaction_responses,
        reward,
        metrics,
    ) = await agent_data.interaction.generate_response( 
        agent_data.request_id, agent_data.messages, **agent_data.interaction_kwargs 
    ) 
    agent_data.user_turns += 1

    add_messages: list[dict[str, Any]] = [{"role": "user", "content": interaction_responses}] 
    agent_data.messages.extend(add_messages) 

    if reward is not None:
        agent_data.turn_scores.append(reward)

    # Update prompt with user responses (similar to _handle_processing_tools_state)
    response_ids = await self.loop.run_in_executor( 
        None, 
        lambda: self.tokenizer.apply_chat_template(add_messages,  add_generation_prompt=True, tokenize=True), 
    ) 
    response_ids = response_ids[len(self.system_prompt) :] 

    # Update prompt_ids and response_mask
    agent_data.prompt_ids += response_ids 
    agent_data.response_mask += [0] * len(response_ids) 
    if agent_data.response_logprobs:
        agent_data.response_logprobs += [0.0] * len(response_ids)

    # double check prompt
    # Check termination condition
    if should_terminate_sequence:
        return AgentState.TERMINATED
    else:
        return AgentState.GENERATING

任务评估

评估,可以在AgentLoop结束之后,调用环境接口计算reward_score

异常轨迹线处理

可以直接丢掉该条数据,构建fake agent_data。

  • 注意长度问题。prompt_ids比response_ids大,response_mask和response_ids长度相同。
python
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0
agent_data.prompt_ids = [pad_token_id]*400
agent_data.response_ids = [pad_token_id]*200
agent_data.response_mask = [pad_token_id]*200

数据后处理

数据后处理1-AgentLoop内部 (截断+提取, AgentLoopOutput)

AgentLoop 单条数据内部截断

调用

  • 具体AgentLoop内部,状态循环完成以后
  • 返回:AgentLoopOutput

核心逻辑

  • 从AgentData全部prompt_ids(序列)
    • 根据response_mask长度,抽离出真正的response_ids
    • 根据response_mask长度,截取真正的prompt_ids,也就是初始messages的prompt
  • 根据response_length对response_idsresponse_mask做截断,超长的部分丢掉。
  • 整体封装为AgentLoopOutput类。
python
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
  # init message
  # agent loop
  # Finalize output
  response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :] 
  prompt_ids = agent_data.prompt_ids[: len(agent_data.prompt_ids) - len(agent_data.response_mask)] 
  output = AgentLoopOutput(
      prompt_ids=prompt_ids,
      response_ids=response_ids[: self.response_length], 
      response_mask=agent_data.response_mask[: self.response_length], 
      reward_score=reward_score, # 没有就不要了
      multi_modal_data={},
      response_logprobs=agent_data.response_logprobs[: self.response_length]
      if agent_data.response_logprobs
      else None,
      num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,
      metrics=agent_data.metrics,
      extra_fields={},
  )

数据后处理2-AgentLoopWorker (Pad+拼接,构造InternalAgentLoopOutput)

Pad+拼接,构造 InternalAgentLoopOutput

调用

  • AgentLoopWorker._run_agent_loop,即在完成具体AgentLoop以后,做通用的pad拼接。
  • 返回:InternalAgentLoopOutput

核心逻辑

  • 对prompt, response, response_mask 做 padding
    • 对prompt 左paddingrollout.prompt_length
    • 对response 右paddingrollout.response_length
    • padding 要返回 attention_mask,pad处为0,其余位置为1
  • 基于response的attention_mask来修正response_mask
    • pad处为0env_response为0仅llm_response为1
  • prompt和response做序列拼接
    • input_ids做拼接attention_mask做拼接
  • 可能会有一些异步计算reward,如果环境没有返回的话。
python
async def _run_agent_loop(
  self,
  sampling_params: dict[str, Any],
  trajectory: dict[str, Any],
  *,
  agent_name: str,
  **kwargs,
) -> _InternalAgentLoopOutput:

  agent_loop_config = _agent_loop_registry[agent_name]
  agent_loop = hydra.utils.instantiate(
      config=agent_loop_config,
      trainer_config=_DummyConfig(config=self.config),
      server_manager=self.server_manager,
      tokenizer=self.tokenizer,
      processor=self.processor,
  )
  output: AgentLoopOutput = await agent_loop.run(sampling_params, **kwargs) 
  output.extra_fields["raw_prompt"] = kwargs["raw_prompt"]

  self.tokenizer.padding_side = "left"
  # 对输入prompt padding到最长长度,左padding 
  prompt_output = self.tokenizer.pad( 
      {"input_ids": output.prompt_ids}, 
      padding="max_length",
      max_length=self.config.actor_rollout_ref.rollout.prompt_length,
      return_tensors="pt",
      return_attention_mask=True, 
  )
  if prompt_output["input_ids"].dim() == 1:
      prompt_output["input_ids"] = prompt_output["input_ids"].unsqueeze(0)
      # 处理 prompt attention_mask,避免pad_token参与计算,pad_token位置 attention=0
      prompt_output["attention_mask"] = prompt_output["attention_mask"].unsqueeze(0)

  self.tokenizer.padding_side = "right"
  # 对response_ids padding到最长长度,右padding
  response_output = self.tokenizer.pad( 
      {"input_ids": output.response_ids},
      padding="max_length",
      max_length=self.config.actor_rollout_ref.rollout.response_length,
      return_tensors="pt",
      return_attention_mask=True, 
  )
  if response_output["input_ids"].dim() == 1:
      response_output["input_ids"] = response_output["input_ids"].unsqueeze(0)
      # 处理 response attention_mask,避免pad_token参与计算,pad_token位置 attention=0
      response_output["attention_mask"] = response_output["attention_mask"].unsqueeze(0)
  # response_mask 做padding,右padding,同response_ids
  response_mask_output = self.tokenizer.pad(
      {"input_ids": output.response_mask},
      padding="max_length",
      max_length=self.config.actor_rollout_ref.rollout.response_length,
      return_tensors="pt",
      return_attention_mask=False,
  )
  if response_mask_output["input_ids"].dim() == 1:
      response_mask_output["input_ids"] = response_mask_output["input_ids"].unsqueeze(0)

  # response_mask,pad处全为0、env_response全为0,仅llm_response为1
  response_mask = response_mask_output["input_ids"] * response_output["attention_mask"] 
  # 整个序列的attention_mask:prompt_attention_mask, response_attention_mask
  attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1) 
  input_ids = torch.cat([prompt_output["input_ids"], response_output["input_ids"]], dim=1) 

  position_ids = compute_position_id_with_mask(attention_mask)  # (1, seq_len)
  enable_async_reward = (
      self.reward_router_address is not None and self.config.reward_model.enable_resource_pool
  ) or not self.config.reward_model.enable
  if output.reward_score is None and enable_async_reward:
      batch = TensorDict(
          {
              "prompts": prompt_output["input_ids"],  # [1, prompt_length]
              "responses": response_output["input_ids"],  # [1, response_length]
              "attention_mask": attention_mask,  # [1, prompt_length + response_length]
              "input_ids": input_ids,  # [1, prompt_length + response_length]
              "position_ids": position_ids,
          },
          batch_size=1,
      )
      non_tensor_batch = {
          **{k: np.array([v]) for k, v in kwargs.items()},
          "__num_turns__": np.array([output.num_turns]),
          "tool_extra_fields": np.array([output.extra_fields], dtype=object),
      }

      data = DataProto(
          batch=batch,
          non_tensor_batch=non_tensor_batch,
      )
      # 计算reward
      result = await self.reward_manager_worker.compute_score.remote(data)
      output.reward_score = result["reward_score"]
      output.extra_fields["reward_extra_info"] = result["reward_extra_info"]

  return _InternalAgentLoopOutput(
      prompt_ids=prompt_output["input_ids"], 
      response_ids=response_output["input_ids"], 
      input_ids=input_ids,
      position_ids=position_ids,
      response_mask=response_mask, 
      attention_mask=attention_mask, 
      response_logprobs=response_logprobs,
      multi_modal_inputs=multi_modal_inputs,
      multi_modal_data=output.multi_modal_data,
      reward_score=output.reward_score,
      num_turns=output.num_turns,
      metrics=output.metrics,
      extra_fields=output.extra_fields,
  )

数据后处理3-AgentLoopWorker (Batch 组装)

Batch组装

调用

  • AgentLoopWorker.generate_sequences,返回DataProto
  • 该Worker下的所有Data(agentloo任务)完成后,重新做组装成一个Batch。

核心流程

  • 多条数据组装回一个Batch
    • prompt_ids, response_ids, response_mask, attention_mask, input_ids, position_ids
  • 构造每条数据的rm_score
    • rm_scores 形状同 response_mask
      • 计算每条数据的真实response_length
    • 仅在response_length位置 赋值reward_score其余处均为0
    • 必须保证batch里所有reward_score都有值,不为None,否则会报错
python
def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:
    """Process the padded outputs from _run_agent_loop and combine them into a batch."""
    # Convert lists back to tensors and stack them to create a batch.
    prompt_ids = torch.cat([input.prompt_ids for input in inputs], dim=0) 
    response_ids = torch.cat([input.response_ids for input in inputs], dim=0) 
    response_mask = torch.cat([input.response_mask for input in inputs], dim=0) 
    attention_mask = torch.cat([input.attention_mask for input in inputs], dim=0) 
    input_ids = torch.cat([input.input_ids for input in inputs], dim=0) 
    position_ids = torch.cat([input.position_ids for input in inputs], dim=0) 
    optional_outputs = {}
    if inputs[0].response_logprobs is not None:
        optional_outputs["rollout_log_probs"] = torch.cat([input.response_logprobs for input in inputs], dim=0)
	
    batch = TensorDict(
        {
            "prompts": prompt_ids,  # [bsz, prompt_length]
            "responses": response_ids,  # [bsz, response_length]
            "response_mask": response_mask,  # [bsz, response_length]
            "input_ids": input_ids,  # [bsz, prompt_length + response_length
            "attention_mask": attention_mask,  # [bsz, prompt_length + response_length
            # position_ids: [bsz, 3, prompt_length + response_length] or [bsz, prompt_length + response_length] 
            "position_ids": position_ids
            **optional_outputs,
        },
        batch_size=len(inputs),
    )

    scores = [input.reward_score for input in inputs]
    if all(score is not None for score in scores): 
        prompt_length = prompt_ids.size(1) 
        # 实际的回复长度
        response_length = attention_mask[:, prompt_length:].sum(dim=1) - 1
        # rm_scors 形状和response_mask一致
        rm_scores = torch.zeros_like(response_mask, dtype=torch.float32) 
        # 高级索引,仅在每条数据各自的response_length处,即最后一个token,赋值reward_score
        rm_scores[torch.arange(response_mask.size(0)), response_length] = torch.tensor(scores, dtype=torch.float32) 
        batch["rm_scores"] = rm_scores 
    print(f"batch scores: {scores}")
    non_tensor_batch = {
        "__num_turns__": np.array([input.num_turns for input in inputs], dtype=np.int32),
    }

    # add reward_extra_info to non_tensor_batch
    reward_extra_infos = [input.extra_fields.get("reward_extra_info", {}) for input in inputs]
    reward_extra_keys = list(reward_extra_infos[0].keys())
    for key in reward_extra_keys:
        non_tensor_batch[key] = np.array([info[key] for info in reward_extra_infos])

    # Add multi_modal_inputs to non_tensor_batch if any samples have them
    multi_modal_inputs_list = [input.multi_modal_inputs for input in inputs]
    if any(mmi is not None for mmi in multi_modal_inputs_list):
        non_tensor_batch["multi_modal_inputs"] = np.array(multi_modal_inputs_list, dtype=object)

    metrics = [input.metrics.model_dump() for input in inputs]
    # Collect extra fields from all inputs and convert them to np.ndarray
    extra_fields = {}
    all_keys = set(key for input_item in inputs for key in input_item.extra_fields)
    for key in all_keys:
        temp_arr = np.empty(len(inputs), dtype=object)
        temp_arr[:] = [input.extra_fields.get(key) for input in inputs]
        extra_fields[key] = temp_arr

    non_tensor_batch.update(extra_fields)
    return DataProto(
        batch=batch,
        non_tensor_batch=non_tensor_batch,
        meta_info={"metrics": metrics, "reward_extra_keys": reward_extra_keys},
    )

Rollout 调用过程

数据前置处理

DataLoader之前的数据处理流程请见:数据加载处理笔记

DataLoader 构建

DataLoader
python
self.train_dataloader = StatefulDataLoader( 
    dataset=self.train_dataset, 
    batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),  
    num_workers=num_workers,
    drop_last=True,
    collate_fn=collate_fn,
    sampler=train_sampler,
)

gen batch 数据读取

get_gen_batch
  • 选择input_idsattention_maskposition_ids等内容,用作rollout
python
def _get_gen_batch(self, batch: DataProto) -> DataProto:
    reward_model_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys()

    # pop those keys for generation
    batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]  
    non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_model_keys
    gen_batch = batch.pop(
        batch_keys=batch_keys_to_pop,
        non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop),
    )
   
    # For agent loop, we need reward model keys to compute score.
    if self.async_rollout_mode:
        gen_batch.non_tensor_batch.update(batch.non_tensor_batch)

    return gen_batch

Rollout 调用入口 (RayPPOTrainer.fit)

Rollout 调用入口 (RayPPOTrainer.fit)

相关笔记

核心内容

  • data batch长度为train_batch_size,每条数据一个uid

  • 调用get_gen_batch,获取input_idsattention_maskposition_ids等内容。

  • 依据rollout.n,对batch数据重复n次,batch长度:

    • train_batch_size * rollout.n
    • 训练bs * rollout数量
  • 调用async_rollout_manager去做生成环境交互

    • 即调用AgentLoopManager去做生成
python
for epoch in range(current_epoch, self.config.trainer.total_epochs):  
    for batch_dict in self.train_dataloader:  
      	# ....
        # ....
      	batch: DataProto = DataProto.from_single_dict(batch_dict)
        # add uid to batch
        batch.non_tensor_batch["uid"] = np.array(
            [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
        )
        gen_batch = self._get_gen_batch(batch)
        # pass global_steps to trace
        gen_batch.meta_info["global_steps"] = self.global_steps
        gen_batch_output = gen_batch.repeat( 
            repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
        ) 
        if not self.async_rollout_mode:
            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
        else:
            gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)   
        timing_raw.update(gen_batch_output.meta_info["timing"])
        gen_batch_output.meta_info.pop("timing", None)

AgentLoopManager 调用多个worker去做生成

AgentLoopManager.generate_sequences
  • batch 数量长度

    • train_batch_size * rollout.n
    • 训练bs * rollout数量
  • 根据agent_loop_workers数量把batch数据分为多块

    • 超参: actor_rollout_ref.rollout.agent.num_workers=64
  • 每个worker处理1个chunk

  • 调用worker.generate_sequences

python
def generate_sequences(self, prompts: DataProto) -> DataProto:
    """Split input batch and dispatch to agent loop workers.

    Args:
        prompts (DataProto): Input batch.

    Returns:
        DataProto: Output batch.
    """

    if self.config.actor_rollout_ref.rollout.free_cache_engine:
        self.wake_up()
    if self.reward_model_manager and self.config.reward_model.rollout.free_cache_engine:
        self.reward_model_manager.wake_up()

    chunkes = prompts.chunk(len(self.agent_loop_workers)) 
    outputs = ray.get(
        [
            worker.generate_sequences.remote(chunk) 
            for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True)
        ]
    )
    output = DataProto.concat(outputs)
    if self.config.actor_rollout_ref.rollout.free_cache_engine:
        self.sleep()
    if self.reward_model_manager and self.config.reward_model.rollout.free_cache_engine:
        self.reward_model_manager.sleep()

    # calculate performance metrics
    metrics = [output.meta_info.pop("metrics") for output in outputs]  # List[List[Dict[str, str]]]
    timing = self._performance_metrics(metrics, output)

    output.meta_info = {"timing": timing, **outputs[0].meta_info}
    return output

AgentLoopWorker 处理chunk 生成序列

Worker 异步提交多个 agentloop任务

worker 异步发起多个agentloop任务

处理数据

  • batch 数量长度

    • train_batch_size * rollout.n / len(agent_loop_workers)
  • 单worker处理一个chunk数据,一个batch为1个chunk数据

核心流程

  • 读取设定batchsampling_params参数用于llm生成
  • 依据chunk数据数量串行提交多个异步agentloop任务
  • 多个任务完成后,做后处理,组装回Batch数据
    • 具体见上文数据后处理3
python
async def generate_sequences(self, batch: DataProto) -> DataProto:
    """Generate sequences from agent loop.

    Args:
        batch (DataProto): Input batch.

    Returns:
        DataProto: Output batch.
        - prompts: [bsz, prompt_length], prompt token ids from dataset.
        - responses: [bsz, response_length], output token ids include response tokens
          from LLM generation and observation tokens from tool_calls.
        - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens.
        - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens
          and response tokens.
        - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens.
        - position_ids: [bsz, prompt_length + response_length], incremental position ids.

        For multi-turn conversations:
        responses:     |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->|
        response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0|
    """
    config = self.config.actor_rollout_ref.rollout
    sampling_params = dict(
        temperature=config.temperature, 
        top_p=config.top_p, 
        repetition_penalty=1.0, 
        logprobs=config.calculate_log_probs,
    )

    # override sampling params for validation
    if batch.meta_info.get("validate", False):
        sampling_params["top_p"] = config.val_kwargs.top_p
        sampling_params["temperature"] = config.val_kwargs.temperature

    # by default, we assume it's a single turn agent
    if "agent_name" not in batch.non_tensor_batch:
        default_agent_loop = config.agent.default_agent_loop
        batch.non_tensor_batch["agent_name"] = np.array([default_agent_loop] * len(batch), dtype=object)

    if "index" in batch.non_tensor_batch:
        index = batch.non_tensor_batch["index"]
    else:
        index = np.arange(len(batch))

    trajectory_info = await get_trajectory_info(
        batch.meta_info.get("global_steps", -1), index.tolist(), batch.meta_info.get("validate", False)
    )

    tasks = []
    for i in range(len(batch)): 
        kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()} 
        tasks.append(asyncio.create_task(self._run_agent_loop(sampling_params, trajectory_info[i], **kwargs))) 
    outputs = await asyncio.gather(*tasks)

    output = self._postprocess(outputs)  
    return output

Worker run agentloop 处理单条数据

Worker run agentloop 处理单条数据/单任务

处理数据

  • 1条数据

核心流程

  • 根据数据里的agent_name,加载配置,实例化对应的AgentLoop
  • 异步调用对应的agent_loop,执行真正的环境交互/LLM生成等内容。
    • 具体见本文基础版AgentLoop-交互流程内容
  • AgentLoop执行完成以后,对数据做后处理,具体见本文基础版AgentLoop-数据后处理2
python
async def _run_agent_loop(
    self,
    sampling_params: dict[str, Any],
    trajectory: dict[str, Any],
    *,
    agent_name: str,
    **kwargs,
) -> _InternalAgentLoopOutput:
    with rollout_trace_attr(
        step=trajectory["step"],
        sample_index=trajectory["sample_index"],
        rollout_n=trajectory["rollout_n"],
        validate=trajectory["validate"],
        name="agent_loop",
    ):
        assert agent_name in _agent_loop_registry, (
            f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}"
        )

        agent_loop_config = _agent_loop_registry[agent_name] 
        agent_loop = hydra.utils.instantiate(
            config=agent_loop_config,
            trainer_config=_DummyConfig(config=self.config),
            server_manager=self.server_manager, 
            tokenizer=self.tokenizer,
            processor=self.processor,
        )
        kwargs["validate"] = trajectory["validate"]
        output: AgentLoopOutput = await agent_loop.run(sampling_params, **kwargs) 
        # ....
        # 后续都是数据后处理,主要构建成InternalAgentLoopOutput
        # 具体代码见上文 AgentLoop 数据后处理2
        # ....
        # ....
        return _InternalAgentLoopOutput(
            prompt_ids=prompt_output["input_ids"], 
            response_ids=response_output["input_ids"], 
            input_ids=input_ids,
            position_ids=position_ids,
            response_mask=response_mask, 
            attention_mask=attention_mask, 
            response_logprobs=response_logprobs,
            multi_modal_inputs=multi_modal_inputs,
            multi_modal_data=output.multi_modal_data,
            reward_score=output.reward_score,
            num_turns=output.num_turns,
            metrics=output.metrics,
            extra_fields=output.extra_fields,
        )

AgentLoopManager 初始化 (SGLangHttpServer,AgentLoopWorker)

初始化入口 (调用llm server 和 agentloop worker初始化)

AgentLoopManager 初始化

初始化llm servers

  • world_sizen_gpus_per_node * nnodes
  • 计算rollout_world_sizeTP* DP * PP
    • rollout.tensor_model_parallel_size * rollout.data_parallel_size * rollout.pipeline_model_parallel_size
    • 后2者默认为1,我一般先调整TP参数
  • 计算rollout_replicas数量,world_size / rollout_world_size
    • 即SGLangReplica数量
  • 举个例子
    • Rollout.TP=8,其余不设,默认为1。rollout_world_size=8
    • 6机8卡做训练,则一共有48/8=6个Rollout副本
    • 下文初始化SGLangReplica,根据配置计算,1个rollout副本,只在1台机器上

初始化agentloop workers

  • 根据agent.num_workers 初始化多个worker
python
class AgentLoopManager:
    """Agent loop manager that manages a group of agent loop workers."""

    def __init__(self, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None):
        """Initialize agent loop manager.

        Args:
            config (DictConfig): trainer config.
            worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode.
        """
        self.config = config
        self.worker_group = worker_group
        self.reward_model_manager = None
        self.reward_router_address = None
        if self.config.reward_model.enable and self.config.reward_model.enable_resource_pool:
            from verl.experimental.reward import RewardModelManager

            self.reward_model_manager = RewardModelManager(config.reward_model, rm_wg)
            self.reward_router_address = self.reward_model_manager.get_router_address()

        # for recipe to change
        if not hasattr(self, "rollout_replica_class"):
            self.rollout_replica_class = get_rollout_replica_class(self.config.actor_rollout_ref.rollout.name)
        if not hasattr(self, "agent_loop_workers_class"):
            self.agent_loop_workers_class = AgentLoopWorker

        self._initialize_llm_servers() 
        self._init_agent_loop_workers() 

        # Initially we're in sleep mode.
        if self.config.actor_rollout_ref.rollout.free_cache_engine:
            self.sleep()

    def _initialize_llm_servers(self):
        rollout_world_size = (
            self.config.actor_rollout_ref.rollout.tensor_model_parallel_size 
            * self.config.actor_rollout_ref.rollout.data_parallel_size 
            * self.config.actor_rollout_ref.rollout.pipeline_model_parallel_size 
        )
        world_size = (
            self.worker_group.world_size
            if self.worker_group
            else self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes
        )
        num_replicas = world_size // rollout_world_size

        rollout_config = self.config.actor_rollout_ref.rollout
        model_config = self.config.actor_rollout_ref.model
        self.rollout_replicas = [
            self.rollout_replica_class(
                replica_rank=replica_rank,
                config=rollout_config,
                model_config=model_config,
                gpus_per_node=self.config.trainer.n_gpus_per_node,
            )
            for replica_rank in range(num_replicas)
        ]
        if self.worker_group:
            self._run_all([server.init_hybrid(self.worker_group) for server in self.rollout_replicas])
        else:
            self._run_all([server.init_standalone() for server in self.rollout_replicas])
        self.server_handles = [server._server_handle for server in self.rollout_replicas]
        self.server_addresses = [server._server_address for server in self.rollout_replicas]

    def _init_agent_loop_workers(self):
        self.agent_loop_workers = []
        num_workers = self.config.actor_rollout_ref.rollout.agent.num_workers

        node_ids = [node["NodeID"] for node in ray.nodes() if node["Alive"] and node["Resources"].get("CPU", 0) > 0]
        for i in range(num_workers):
            # Round-robin scheduling over the all nodes
            node_id = node_ids[i % len(node_ids)]  
            self.agent_loop_workers.append(
                self.agent_loop_workers_class.options(
                    name=f"agent_loop_worker_{i}",
                    scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
                        node_id=node_id, soft=True
                    ),
                ).remote(self.config, self.server_handles, self.reward_router_address)
            )

llm server 初始化 - SGLangReplica

llm server 初始化 - SGLangReplica
  • launch_servers
    • 即调用初始化SGLangHttpServer
  • 对于此,配置,1个replica即1个httpserver
python
class SGLangReplica(RolloutReplica):
    def get_ray_class_with_init_args(self) -> RayClassWithInitArgs:
        """Get rollout worker actor class for colocated and standalone mode."""
        worker_dict_cls = RayClassWithInitArgs(
            cls=_rollout_worker_actor_cls,
            config=self.config,
            model_config=self.model_config,
            device_mesh=None,
        )
        return worker_dict_cls

    async def launch_servers(self):
        """Launch http server in each node."""
        assert len(self.workers) == self.world_size, (
            f"worker number {len(self.workers)} not equal to world size {self.world_size}"
        )

        # get (node_id, CUDA_VISIBLE_DEVICES) of all workers
        worker_infos = await asyncio.gather(
            *[
                worker.__ray_call__.remote(
                    lambda self: (ray.get_runtime_context().get_node_id(), os.environ["CUDA_VISIBLE_DEVICES"])
                )
                for worker in self.workers
            ]
        )
        worker_cuda_visible_devices = [worker_info[1] for worker_info in worker_infos]
        worker_node_ids = [worker_info[0] for worker_info in worker_infos]

        # create server actor in each node with node affinity and cuda visible devices
        for node_rank in range(self.nnodes): 
            workers = self.workers[node_rank * self.gpus_per_node : (node_rank + 1) * self.gpus_per_node]
            node_cuda_visible_devices = ",".join(
                worker_cuda_visible_devices[node_rank * self.gpus_per_node : (node_rank + 1) * self.gpus_per_node]
            )
            node_id = worker_node_ids[node_rank * self.gpus_per_node]
            name = (
                f"sglang_server_{self.replica_rank}_{node_rank}"
                if not self.is_reward_model
                else f"sglang_server_reward_{self.replica_rank}_{node_rank}"
            )
            server = SGLangHttpServer.options(
                scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
                    node_id=node_id,
                    soft=False,
                ),
                runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}},
                name=name,
            ).remote(
                config=self.config,
                model_config=self.model_config,
                rollout_mode=self.rollout_mode,
                workers=workers,
                replica_rank=self.replica_rank,
                node_rank=node_rank,
                nnodes=self.nnodes,
                cuda_visible_devices=node_cuda_visible_devices,
            )
            self.servers.append(server)

        # launch http server in each node
        master_address, master_port = await self.servers[0].get_master_address.remote()
        await asyncio.gather(
            *[
                server.launch_server.remote(master_address=master_address, master_port=master_port)
                for server in self.servers
            ]
        )

        # get http server address from first server
        server_address, server_port = await self.servers[0].get_server_address.remote()
        self._server_handle = self.servers[0]
        self._server_address = (
            f"[{server_address}]:{server_port}"
            if is_valid_ipv6_address(server_address)
            else f"{server_address}:{server_port}"
        )

基类RolloutReplica 初始化方法

基类RolloutReplica init

核心流程

  • 计算world_size:rollout.tp * rollout.dp * rollout.pp
  • 重计算 gpus_per_node:gpus_per_node 和 world_size 取更小的
  • 重计算 nnodes:world_size // gpus_per_node

举个例子

  • rollout.tensor_model_parallel_size = 8,其余不设置,默认为1
  • 训练机器:6机8卡
  • 那么,world_size=8, gpus_per_node=8,nnodes= 8 // 8=1
  • 1个副本,只在1台机器上,实际上会有多个这个副本。
python
class RolloutReplica(ABC):
    """Rollout replica is an individual server instance, which may be deployed on single or multiple nodes.
    It is equivalent to launch server in each node with command line:

    SGLang:
    ```
    python -m sglang.launch_server --node-rank 0 --nnode 2 ...
    python -m sglang.launch_server --node-rank 1 --nnode 2 ...
    ```

    vLLM:
    ```
    vllm serve --data-parallel-size 16 --data-parallel-size-local 8 --data-parallel-start-rank 0 ...
    vllm serve --data-parallel-size 16 --data-parallel-size-local 8 --data-parallel-start-rank 8 ...
    ```

    Args:
        replica_rank: int, rank of this rollout replica.
        config: RolloutConfig, full config.
        gpus_per_node: int, number of gpus per node.
    """

    def __init__(
        self,
        replica_rank: int,
        config: RolloutConfig,
        model_config: HFModelConfig,
        gpus_per_node: int = 8,
        is_reward_model: bool = False,
    ) -> None:
        self.replica_rank = replica_rank
        self.config = omega_conf_to_dataclass(config)
        self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)

        self.world_size = (
            self.config.tensor_model_parallel_size
            * self.config.data_parallel_size
            * self.config.pipeline_model_parallel_size
        )
        self.gpus_per_node = min(gpus_per_node, self.world_size)
        assert self.world_size % self.gpus_per_node == 0, (
            f"world_size {self.world_size} must be divisible by gpus_per_node {self.gpus_per_node}"
        )
        self.nnodes = self.world_size // self.gpus_per_node
        self.is_reward_model = is_reward_model

        self.rollout_mode: RolloutMode = None
        self.workers: list[ActorHandle] = []
        self.resource_pool: RayResourcePool = None

        self.servers: list[ActorHandle] = []
        self._server_address: str = None
        self._server_handle: ActorHandle = None

llm server 初始化 - SGLangHttpServer (SGLangReplica)

SGLangHttpServer初始化

Launch Server

  • 传入hfmodel_config、rollout_config加载Server

generate

python
@ray.remote(num_cpus=1)
class SGLangHttpServer:
    """SGLang http server in single node, this is equivalent to launch server with command line:
python -m sglang.launch_server --node-rank 0 --nnode 1 ...
```

Args:
    config (DictConfig): full config.
    rollout_mode (RolloutMode): rollout mode.
    replica_rank (int): replica rank, a replica may contain multiple nodes.
    node_rank (int): node rank.
    nnodes (int): number of nodes.
    cuda_visible_devices (str): cuda visible devices.
"""

def __init__(
    self,
    config: RolloutConfig,
    model_config: HFModelConfig,
    rollout_mode: RolloutMode,
    workers: list[ActorHandle],
    replica_rank: int,
    node_rank: int,
    nnodes: int,
    cuda_visible_devices: str,
):
    print(f"SGLang http server: {rollout_mode=}, {replica_rank=}, {node_rank=}, {nnodes=}, {cuda_visible_devices=}")
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
    assert torch.cuda.is_available(), "SGLang http server should run on GPU node"

    self.config: RolloutConfig = omega_conf_to_dataclass(config)
    self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)
    self.config.max_model_len = self.config.prompt_length + self.config.response_length
    self.rollout_mode = rollout_mode
    self.workers = workers

    self.replica_rank = replica_rank
    self.node_rank = node_rank
    self.nnodes = nnodes

    if self.rollout_mode != RolloutMode.HYBRID and self.config.load_format == "dummy":
        logger.warning(f"rollout mode is {self.rollout_mode}, load_format is dummy, set to auto")
        self.config.load_format = "auto"

    # used for http server
    self._server_address = ray.util.get_node_ip_address().strip("[]")
    self._server_port = None

    # used for NCCL process group
    if self.node_rank == 0:
        self._master_address = self._server_address
        self._master_port, self._master_sock = get_free_port(self._server_address)
        logger.info(
            f"SGLangHttpServer, replica_rank: {self.replica_rank}, "
            f"master address: {self._master_address}, port: {self._master_port}"
        )
    else:
        self._master_address = None
        self._master_port = None

async def launch_server(self, master_address: str = None, master_port: int = None):
    if self.node_rank != 0:
        assert master_address and master_port, "non-master node should provide master address and port"
        self._master_address = master_address
        self._master_port = master_port

    engine_kwargs = self.config.get("engine_kwargs", {}).get("sglang", {}) or {}
    attention_backend = engine_kwargs.pop("attention_backend", None)
    dist_init_addr = (
        f"[{self._master_address}]:{self._master_port}"
        if is_valid_ipv6_address(self._master_address)
        else f"{self._master_address}:{self._master_port}"
    )

    args = {
        "model_path": self.model_config.local_path,
        "dtype": self.config.dtype,
        "mem_fraction_static": self.config.gpu_memory_utilization,
        "disable_cuda_graph": self.config.enforce_eager,
        "enable_memory_saver": True,
        "base_gpu_id": 0,
        "gpu_id_step": 1,
        "tp_size": self.config.tensor_model_parallel_size,
        "dp_size": self.config.data_parallel_size,
        "ep_size": self.config.expert_parallel_size,
        "node_rank": self.node_rank,
        "load_format": self.config.load_format,
        "dist_init_addr": dist_init_addr,
        "nnodes": self.nnodes,
        "trust_remote_code": self.model_config.trust_remote_code,
        "max_running_requests": self.config.get("max_num_seqs", None),
        "log_level": "error",
        "mm_attention_backend": "fa3",
        "attention_backend": attention_backend if attention_backend is not None else "fa3",
        "skip_tokenizer_init": self.config.skip_tokenizer_init,
        **engine_kwargs,
    }
    # enable_weights_cpu_backup is supported in sglang>=0.5.3
    if "enable_weights_cpu_backup" in [f.name for f in dataclasses.fields(ServerArgs)]:
        enable_weights_cpu_backup = True if self.rollout_mode == RolloutMode.COLOCATED else False
        args["enable_weights_cpu_backup"] = enable_weights_cpu_backup

    # NOTE: We can't directly call SGLang's launch_server since it's not an async function.
    # https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py
    sglang.srt.entrypoints.engine._set_envs_and_config = _set_envs_and_config
    os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
    server_args = ServerArgs(**args) 
    self.tokenizer_manager, self.template_manager, self.scheduler_info = _launch_subprocesses( 
        server_args=server_args 
    ) 

    # In multi-node cases, non-zero rank nodes should not launch http server.
    if self.node_rank > 0:
        return

    set_global_state(
        _GlobalState(
            tokenizer_manager=self.tokenizer_manager,
            template_manager=self.template_manager,
            scheduler_info=self.scheduler_info,
        )
    )
    app.is_single_tokenizer_mode = True
    self._server_port, self._server_task = await run_unvicorn(app, server_args, self._server_address)
    self.tokenizer_manager.server_status = ServerStatus.Up

async def generate(
    self,
    prompt_ids: torch.Tensor,
    sampling_params: dict[str, Any],
    request_id: str,
    image_data: Optional[list[Any]] = None,
) -> TokenOutput:
    """Generate sequence with token-in-token-out."""
    # TODO(@wuxibin): switch to `/generate` http endpoint once multi-modal support ready.
    max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(prompt_ids) - 1) 
    sampling_params["max_new_tokens"] = max_new_tokens
    return_logprob = sampling_params.pop("logprobs", False)

    request = GenerateReqInput(
        rid=request_id,
        input_ids=prompt_ids, 
        sampling_params=sampling_params, 
        return_logprob=return_logprob,
        image_data=image_data,
    )
    output = await self.tokenizer_manager.generate_request(request, None).__anext__()
    if return_logprob:
        output_token_logprobs = output["meta_info"]["output_token_logprobs"]
        log_probs, token_ids = zip(
            *[(log_prob, token_ids) for log_prob, token_ids, _ in output_token_logprobs], strict=True
        )
    else:
        token_ids = output["output_ids"]
        log_probs = None
    return TokenOutput(token_ids=token_ids, log_probs=log_probs)
总访客数:   ·   总访问量:
PLM's Blog @ 2016 - 2026