Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 81 additions & 46 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class AgentLoopOutput(BaseModel):
"""Number of chat turns, including user, assistant, tool."""
metrics: AgentLoopMetrics
"""Auxiliary performance metrics"""
processed_tensors: dict = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can save padded ids directly in prompt_ids, response_ids, response_mask.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but what about tensor like attention_mask?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my fault, they all can be computed within the three one, should we move the calculation in the postprocess ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should move pad logic of prompt_ids, response_ids, response_mask in _postprocess to _run_agent_loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx! I will do this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"""Pre-processed tensors for batching."""


# make hydra.utils.instantiate happy
Expand Down Expand Up @@ -318,55 +320,88 @@ async def _run_agent_loop(
tokenizer=self.tokenizer,
)
output = await agent_loop.run(messages, sampling_params)

# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
# prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4])
# response_ids: right padded with zeros (e.g., [5,6,7,8,0,0,0,0])
# input_ids: concatenation of prompt + response
# Mask:
# For example, if the prompt is [1,2,3,4] and the response is [5,6,7,(tool start)8,9(tool end),10,11,12]
# - prompt_attention_mask: 0s for padding, 1s for tokens
# e.g., [0,0,0,0,1,1,1,1]
# - response_attention_mask: 0s for padding, 1s for tokens
# e.g., [1,1,1,1,1,1,1,1,1,1,1,0,0,0,0]
# attention_mask: concatenation of prompt_attention_mask and response_attention_mask
# e.g., [0,0,0,0,1,1,1,1(prompt),1,1,1,1,1,1,1,1,1,1,1,0,0,0,0(response)]
# - response_mask: 1s for LLM generated tokens, 0 for tool response/padding tokens
# e.g., [1,1,1,1,1,1,1,(tool start),0,0(tool end),1,1,0,0,0,0]
# - position_ids: sequential positions for tokens, starting at 0
# e.g., [0,0,0,0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,0,0,0,0]

self.tokenizer.padding_side = "left"
prompt_output = self.tokenizer.pad(
{"input_ids": output.prompt_ids},
padding="max_length",
max_length=self.config.actor_rollout_ref.rollout.prompt_length,
return_tensors="pt",
return_attention_mask=True,
)

# Ensure we have a batch dimension
if prompt_output["input_ids"].dim() == 1:
prompt_output["input_ids"] = prompt_output["input_ids"].unsqueeze(0)
prompt_output["attention_mask"] = prompt_output["attention_mask"].unsqueeze(0)

self.tokenizer.padding_side = "right"
response_output = self.tokenizer.pad(
{"input_ids": output.response_ids},
padding="max_length",
max_length=self.config.actor_rollout_ref.rollout.response_length,
return_tensors="pt",
return_attention_mask=True,
)

if response_output["input_ids"].dim() == 1:
response_output["input_ids"] = response_output["input_ids"].unsqueeze(0)
response_output["attention_mask"] = response_output["attention_mask"].unsqueeze(0)

response_mask_output = self.tokenizer.pad(
{"input_ids": output.response_mask},
padding="max_length",
max_length=self.config.actor_rollout_ref.rollout.response_length,
return_tensors="pt",
return_attention_mask=False,
)

if response_mask_output["input_ids"].dim() == 1:
response_mask_output["input_ids"] = response_mask_output["input_ids"].unsqueeze(0)

response_mask = response_mask_output["input_ids"] * response_output["attention_mask"]

input_ids = torch.cat([prompt_output["input_ids"], response_output["input_ids"]], dim=1)
attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1)
position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask

output.processed_tensors = {
"prompt_ids": prompt_output["input_ids"],
"response_ids": response_output["input_ids"],
"response_mask": response_mask,
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}

return output

def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto:
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
# prompts: left pad
# responses: right pad
# input_ids: prompt + response
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]

# prompts
self.tokenizer.padding_side = "left"
outputs = self.tokenizer.pad(
[{"input_ids": input.prompt_ids} for input in inputs],
padding="max_length",
max_length=self.config.actor_rollout_ref.rollout.prompt_length,
return_tensors="pt",
return_attention_mask=True,
)
prompt_ids, prompt_attention_mask = outputs["input_ids"], outputs["attention_mask"]

# responses
self.tokenizer.padding_side = "right"
outputs = self.tokenizer.pad(
[{"input_ids": input.response_ids} for input in inputs],
padding="max_length",
max_length=self.config.actor_rollout_ref.rollout.response_length,
return_tensors="pt",
return_attention_mask=True,
)
response_ids, response_attention_mask = outputs["input_ids"], outputs["attention_mask"]

# response_mask
outputs = self.tokenizer.pad(
[{"input_ids": input.response_mask} for input in inputs],
padding="max_length",
max_length=self.config.actor_rollout_ref.rollout.response_length,
return_tensors="pt",
return_attention_mask=False,
)
response_mask = outputs["input_ids"]
assert response_ids.shape == response_mask.shape, (
f"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}"
)
response_mask = response_mask * response_attention_mask

input_ids = torch.cat([prompt_ids, response_ids], dim=1)
attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1)
position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask
"""Process the padded outputs from _run_agent_loop and combine them into a batch."""
# Extract pre-processed tensors from each output
prompt_ids = torch.cat([input.processed_tensors["prompt_ids"] for input in inputs], dim=0)
response_ids = torch.cat([input.processed_tensors["response_ids"] for input in inputs], dim=0)
response_mask = torch.cat([input.processed_tensors["response_mask"] for input in inputs], dim=0)
input_ids = torch.cat([input.processed_tensors["input_ids"] for input in inputs], dim=0)
attention_mask = torch.cat([input.processed_tensors["attention_mask"] for input in inputs], dim=0)
position_ids = torch.cat([input.processed_tensors["position_ids"] for input in inputs], dim=0)

batch = TensorDict(
{
Expand Down