Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 80 additions & 48 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,19 @@ class AgentLoopOutput(BaseModel):
"""Agent loop output."""

prompt_ids: list[int]
"""Prompt token ids."""
"""Prompt token ids. Is overwritten with padded ids in AgentLoopWorker."""
response_ids: list[int]
"""Response token ids including LLM generated token, tool response token."""
"""Response token ids. Is overwritten with padded ids in AgentLoopWorker."""
response_mask: list[int]
"""Response mask, 1 for LLM generated token, 0 for tool response token."""
"""Response mask. Is overwritten with padded mask in AgentLoopWorker."""
num_turns: int = 0
"""Number of chat turns, including user, assistant, tool."""
metrics: AgentLoopMetrics
"""Auxiliary performance metrics"""

# Populated in AgentLoopWorker. Made optional to prevent validation errors.
attention_mask: list[int] = None


# make hydra.utils.instantiate happy
class _DummyConfig:
Expand Down Expand Up @@ -318,54 +321,83 @@ async def _run_agent_loop(
tokenizer=self.tokenizer,
)
output = await agent_loop.run(messages, sampling_params)

# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
# prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4])
# response_ids: right padded with zeros (e.g., [5,6,7,8,0,0,0,0])
# input_ids: concatenation of prompt + response
# Mask:
# For example, if the prompt is [1,2,3,4] and the response is [5,6,7,(tool start)8,9(tool end),10,11,12]
# - prompt_attention_mask: 0s for padding, 1s for tokens
# e.g., [0,0,0,0,1,1,1,1]
# - response_attention_mask: 0s for padding, 1s for tokens
# e.g., [1,1,1,1,1,1,1,1,1,1,1,0,0,0,0]
# attention_mask: concatenation of prompt_attention_mask and response_attention_mask
# e.g., [0,0,0,0,1,1,1,1(prompt),1,1,1,1,1,1,1,1,1,1,1,0,0,0,0(response)]
# - response_mask: 1s for LLM generated tokens, 0 for tool response/padding tokens
# e.g., [1,1,1,1,1,1,1,(tool start),0,0(tool end),1,1,0,0,0,0]
# - position_ids: sequential positions for tokens, starting at 0
# e.g., [0,0,0,0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,0,0,0,0]

self.tokenizer.padding_side = "left"
prompt_output = self.tokenizer.pad(
{"input_ids": output.prompt_ids},
padding="max_length",
max_length=self.config.actor_rollout_ref.rollout.prompt_length,
return_tensors="pt",
return_attention_mask=True,
)

# Ensure we have a batch dimension
if prompt_output["input_ids"].dim() == 1:
prompt_output["input_ids"] = prompt_output["input_ids"].unsqueeze(0)
prompt_output["attention_mask"] = prompt_output["attention_mask"].unsqueeze(0)

self.tokenizer.padding_side = "right"
response_output = self.tokenizer.pad(
{"input_ids": output.response_ids},
padding="max_length",
max_length=self.config.actor_rollout_ref.rollout.response_length,
return_tensors="pt",
return_attention_mask=True,
)

if response_output["input_ids"].dim() == 1:
response_output["input_ids"] = response_output["input_ids"].unsqueeze(0)
response_output["attention_mask"] = response_output["attention_mask"].unsqueeze(0)

response_mask_output = self.tokenizer.pad(
{"input_ids": output.response_mask},
padding="max_length",
max_length=self.config.actor_rollout_ref.rollout.response_length,
return_tensors="pt",
return_attention_mask=False,
)

if response_mask_output["input_ids"].dim() == 1:
response_mask_output["input_ids"] = response_mask_output["input_ids"].unsqueeze(0)

response_mask = response_mask_output["input_ids"] * response_output["attention_mask"]
attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1)

# Overwrite with padded data, converted to lists for safe serialization.
output.prompt_ids = prompt_output["input_ids"].squeeze(0).tolist()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep tensor as it is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_run_agent_loop can return other class instead of AgentLoopOutput

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed! Thx

output.response_ids = response_output["input_ids"].squeeze(0).tolist()
output.response_mask = response_mask.squeeze(0).tolist()
output.attention_mask = attention_mask.squeeze(0).tolist()

return output

def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto:
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
# prompts: left pad
# responses: right pad
# input_ids: prompt + response
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]

# prompts
self.tokenizer.padding_side = "left"
outputs = self.tokenizer.pad(
[{"input_ids": input.prompt_ids} for input in inputs],
padding="max_length",
max_length=self.config.actor_rollout_ref.rollout.prompt_length,
return_tensors="pt",
return_attention_mask=True,
)
prompt_ids, prompt_attention_mask = outputs["input_ids"], outputs["attention_mask"]

# responses
self.tokenizer.padding_side = "right"
outputs = self.tokenizer.pad(
[{"input_ids": input.response_ids} for input in inputs],
padding="max_length",
max_length=self.config.actor_rollout_ref.rollout.response_length,
return_tensors="pt",
return_attention_mask=True,
)
response_ids, response_attention_mask = outputs["input_ids"], outputs["attention_mask"]

# response_mask
outputs = self.tokenizer.pad(
[{"input_ids": input.response_mask} for input in inputs],
padding="max_length",
max_length=self.config.actor_rollout_ref.rollout.response_length,
return_tensors="pt",
return_attention_mask=False,
)
response_mask = outputs["input_ids"]
assert response_ids.shape == response_mask.shape, (
f"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}"
)
response_mask = response_mask * response_attention_mask

"""Process the padded outputs from _run_agent_loop and combine them into a batch."""
# Convert lists back to tensors and stack them to create a batch.
prompt_ids = torch.stack([torch.tensor(i.prompt_ids, dtype=torch.long) for i in inputs], dim=0)
response_ids = torch.stack([torch.tensor(i.response_ids, dtype=torch.long) for i in inputs], dim=0)
response_mask = torch.stack([torch.tensor(i.response_mask, dtype=torch.long) for i in inputs], dim=0)
attention_mask = torch.stack([torch.tensor(i.attention_mask, dtype=torch.long) for i in inputs], dim=0)

# Derive input_ids and position_ids on the fly.
input_ids = torch.cat([prompt_ids, response_ids], dim=1)
attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1)
position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask

batch = TensorDict(
Expand All @@ -377,7 +409,7 @@ def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto:
"attention_mask": attention_mask, # [bsz, prompt_length + response_length]
"position_ids": position_ids, # [bsz, prompt_length + response_length]
},
batch_size=len(input_ids),
batch_size=len(inputs),
)

num_turns = np.array([input.num_turns for input in inputs], dtype=np.int32)
Expand Down