diff --git a/docs/advance/agent_loop.rst b/docs/advance/agent_loop.rst index d5fb5442d18..cb07c62f573 100644 --- a/docs/advance/agent_loop.rst +++ b/docs/advance/agent_loop.rst @@ -44,12 +44,12 @@ could do whatever user wants, such as class AgentLoopBase(ABC): @abstractmethod - async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: """Run agent loop to interact with LLM server and environment. Args: - messages (List[Dict[str, Any]]): Input messages. sampling_params (Dict[str, Any]): LLM sampling params. + **kwargs: dataset fields from `verl.utils.dataset.RLHFDataset`. Returns: AgentLoopOutput: Agent loop output. diff --git a/recipe/langgraph_agent/react_agent_loop.py b/recipe/langgraph_agent/react_agent_loop.py index 578968a92cc..fcb6aa00a8c 100644 --- a/recipe/langgraph_agent/react_agent_loop.py +++ b/recipe/langgraph_agent/react_agent_loop.py @@ -99,7 +99,9 @@ def build_graph(cls) -> StateGraph: graph = workflow.compile() return graph - async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: + messages = list(kwargs["raw_prompt"]) + model_path = self.config.actor_rollout_ref.model.path model_name = "/".join(model_path.split("/")[-2:]) diff --git a/recipe/langgraph_agent/test_react_agent_loop.py b/recipe/langgraph_agent/test_react_agent_loop.py index 0cdc919593a..4dfed31e572 100644 --- a/recipe/langgraph_agent/test_react_agent_loop.py +++ b/recipe/langgraph_agent/test_react_agent_loop.py @@ -41,6 +41,7 @@ def init_config() -> DictConfig: config.actor_rollout_ref.rollout.n = 4 config.actor_rollout_ref.rollout.agent.num_workers = 2 + config.actor_rollout_ref.actor.use_dynamic_bsz = True # test sleep/wake_up with fsdp offload config.actor_rollout_ref.actor.fsdp_config.param_offload = True config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 3239e2963b0..0f5298e86de 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -183,12 +183,12 @@ def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer, **kwargs): cls._class_initialized = True @abstractmethod - async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: """Run agent loop to interact with LLM server and environment. Args: - messages (List[Dict[str, Any]]): Input messages. sampling_params (Dict[str, Any]): LLM sampling params. + **kwargs: dataset fields from `verl.utils.dataset.RLHFDataset`. Returns: AgentLoopOutput: Agent loop output. @@ -285,25 +285,19 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: if "agent_name" not in batch.non_tensor_batch: batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(batch), dtype=object) - tasks = [] - agent_names = batch.non_tensor_batch["agent_name"] - raw_prompts = batch.non_tensor_batch["raw_prompt"] if "index" in batch.non_tensor_batch: index = batch.non_tensor_batch["index"] else: - index = np.arange(len(raw_prompts)) + index = np.arange(len(batch)) trajectory_info = await get_trajectory_info( batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False) ) - for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True): - if not isinstance(messages, list | np.ndarray): - raise TypeError(f"messages must be a list or numpy array, got {type(messages)}") - - tasks.append( - asyncio.create_task(self._run_agent_loop(agent_name, list(messages), sampling_params, trajectory)) - ) + tasks = [] + for i in range(len(batch)): + kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()} + tasks.append(asyncio.create_task(self._run_agent_loop(sampling_params, trajectory_info[i], **kwargs))) outputs = await asyncio.gather(*tasks) output = self._postprocess(outputs) @@ -311,10 +305,11 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: async def _run_agent_loop( self, - agent_name: str, - messages: list[dict[str, Any]], sampling_params: dict[str, Any], trajectory: dict[str, Any], + *, + agent_name: str, + **kwargs, ) -> _InternalAgentLoopOutput: with rollout_trace_attr( step=trajectory["step"], @@ -334,7 +329,7 @@ async def _run_agent_loop( server_manager=self.server_manager, tokenizer=self.tokenizer, ) - output = await agent_loop.run(messages, sampling_params) + output = await agent_loop.run(sampling_params, **kwargs) # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py # prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4]) diff --git a/verl/experimental/agent_loop/single_turn_agent_loop.py b/verl/experimental/agent_loop/single_turn_agent_loop.py index 411388e7321..ea9ad77d7d5 100644 --- a/verl/experimental/agent_loop/single_turn_agent_loop.py +++ b/verl/experimental/agent_loop/single_turn_agent_loop.py @@ -32,7 +32,9 @@ def __init__(self, *args, **kwargs): self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length self.response_length = self.config.actor_rollout_ref.rollout.response_length - async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: + messages = list(kwargs["raw_prompt"]) + metrics = {} request_id = uuid4().hex prompt_ids = await self.loop.run_in_executor( diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py index 3437c0be5ab..1f6208dce90 100644 --- a/verl/experimental/agent_loop/tool_agent_loop.py +++ b/verl/experimental/agent_loop/tool_agent_loop.py @@ -56,7 +56,8 @@ def init_class(cls, config, tokenizer, **kwargs): cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True) @rollout_trace_op - async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: + messages = list(kwargs["raw_prompt"]) metrics = {} request_id = uuid4().hex prompt_ids = await self.loop.run_in_executor(