Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/advance/agent_loop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ could do whatever user wants, such as

class AgentLoopBase(ABC):
@abstractmethod
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
"""Run agent loop to interact with LLM server and environment.

Args:
messages (List[Dict[str, Any]]): Input messages.
sampling_params (Dict[str, Any]): LLM sampling params.
**kwargs: dataset fields from `verl.utils.dataset.RLHFDataset`.

Returns:
AgentLoopOutput: Agent loop output.
Expand Down
4 changes: 3 additions & 1 deletion recipe/langgraph_agent/react_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def build_graph(cls) -> StateGraph:
graph = workflow.compile()
return graph

async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
messages = list(kwargs["raw_prompt"])

model_path = self.config.actor_rollout_ref.model.path
model_name = "/".join(model_path.split("/")[-2:])

Expand Down
1 change: 1 addition & 0 deletions recipe/langgraph_agent/test_react_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def init_config() -> DictConfig:
config.actor_rollout_ref.rollout.n = 4
config.actor_rollout_ref.rollout.agent.num_workers = 2

config.actor_rollout_ref.actor.use_dynamic_bsz = True
# test sleep/wake_up with fsdp offload
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
Expand Down
27 changes: 11 additions & 16 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,12 @@ def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer, **kwargs):
cls._class_initialized = True

@abstractmethod
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
"""Run agent loop to interact with LLM server and environment.

Args:
messages (List[Dict[str, Any]]): Input messages.
sampling_params (Dict[str, Any]): LLM sampling params.
**kwargs: dataset fields from `verl.utils.dataset.RLHFDataset`.

Returns:
AgentLoopOutput: Agent loop output.
Expand Down Expand Up @@ -285,36 +285,31 @@ async def generate_sequences(self, batch: DataProto) -> DataProto:
if "agent_name" not in batch.non_tensor_batch:
batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(batch), dtype=object)

tasks = []
agent_names = batch.non_tensor_batch["agent_name"]
raw_prompts = batch.non_tensor_batch["raw_prompt"]
if "index" in batch.non_tensor_batch:
index = batch.non_tensor_batch["index"]
else:
index = np.arange(len(raw_prompts))
index = np.arange(len(batch))

trajectory_info = await get_trajectory_info(
batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False)
)

for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True):
if not isinstance(messages, list | np.ndarray):
raise TypeError(f"messages must be a list or numpy array, got {type(messages)}")

tasks.append(
asyncio.create_task(self._run_agent_loop(agent_name, list(messages), sampling_params, trajectory))
)
tasks = []
for i in range(len(batch)):
kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}
tasks.append(asyncio.create_task(self._run_agent_loop(sampling_params, trajectory_info[i], **kwargs)))
outputs = await asyncio.gather(*tasks)

output = self._postprocess(outputs)
return output

async def _run_agent_loop(
self,
agent_name: str,
messages: list[dict[str, Any]],
sampling_params: dict[str, Any],
trajectory: dict[str, Any],
*,
agent_name: str,
**kwargs,
) -> _InternalAgentLoopOutput:
with rollout_trace_attr(
step=trajectory["step"],
Expand All @@ -334,7 +329,7 @@ async def _run_agent_loop(
server_manager=self.server_manager,
tokenizer=self.tokenizer,
)
output = await agent_loop.run(messages, sampling_params)
output = await agent_loop.run(sampling_params, **kwargs)

# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
# prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4])
Expand Down
4 changes: 3 additions & 1 deletion verl/experimental/agent_loop/single_turn_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def __init__(self, *args, **kwargs):
self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length
self.response_length = self.config.actor_rollout_ref.rollout.response_length

async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
messages = list(kwargs["raw_prompt"])

metrics = {}
request_id = uuid4().hex
prompt_ids = await self.loop.run_in_executor(
Expand Down
3 changes: 2 additions & 1 deletion verl/experimental/agent_loop/tool_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def init_class(cls, config, tokenizer, **kwargs):
cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)

@rollout_trace_op
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
messages = list(kwargs["raw_prompt"])
metrics = {}
request_id = uuid4().hex
prompt_ids = await self.loop.run_in_executor(
Expand Down
Loading