Skip to content

Commit 9530472

Browse files
Tavish9techkang
authored andcommitted
[rollout, tool] feat: export rollout rewards to total rewards (volcengine#3563)
### What does this PR do? This PR exports rollout rewards including tool calling rewards and interaction rewards to `compute_score` fn. Currently, rollout reward_scores is calculated but not used in the final `compute_score`. https://github.com/volcengine/verl/blob/96e7071de1bc6a7c0e12f4999f97556da9310cc3/verl/workers/rollout/sglang_rollout/sglang_rollout.py#L1320-L1324 Fix volcengine#3525 ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 2b69c1a commit 9530472

File tree

5 files changed

+32
-10
lines changed

5 files changed

+32
-10
lines changed

verl/experimental/agent_loop/agent_loop.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,9 @@ def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:
701701
extra_fields = {}
702702
all_keys = set(key for input_item in inputs for key in input_item.extra_fields)
703703
for key in all_keys:
704-
extra_fields[key] = np.array([input.extra_fields.get(key) for input in inputs], dtype=object)
704+
temp_arr = np.empty(len(inputs), dtype=object)
705+
temp_arr[:] = [input.extra_fields.get(key) for input in inputs]
706+
extra_fields[key] = temp_arr
705707

706708
non_tensor_batch.update(extra_fields)
707709
return DataProto(

verl/experimental/agent_loop/tool_agent_loop.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
self.response_mask: list[int] = []
6969
self.response_logprobs: list[float] = []
7070
self.turn_scores: list[float] = []
71+
self.tool_rewards: list[float] = []
7172
self.user_turns = 0
7273
self.assistant_turns = 0
7374

@@ -175,7 +176,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
175176
metrics=agent_data.metrics,
176177
extra_fields={},
177178
)
178-
output.extra_fields.update({"turn_scores": agent_data.turn_scores})
179+
output.extra_fields.update({"turn_scores": agent_data.turn_scores, "tool_rewards": agent_data.tool_rewards})
179180
return output
180181

181182
async def _handle_pending_state(self, agent_data: AgentData, sampling_params: dict[str, Any]) -> AgentState:
@@ -268,7 +269,7 @@ async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentSt
268269

269270
# Process tool responses and update multi_modal_data
270271
# Removed: agent_data.new_images_this_turn = []
271-
for tool_response in responses:
272+
for tool_response, tool_reward, _ in responses:
272273
# Create message from tool response
273274
if tool_response.image or tool_response.video:
274275
# Multi-modal content with structured format
@@ -321,6 +322,9 @@ async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentSt
321322
"Multimedia type 'video' is not currently supported. Only 'image' is supported."
322323
)
323324

325+
if tool_reward is not None:
326+
agent_data.tool_rewards.append(tool_reward)
327+
324328
# Update prompt with tool responses
325329
if self.processor is not None:
326330
raw_tool_response = await self.loop.run_in_executor(
@@ -403,7 +407,9 @@ async def _handle_interacting_state(self, agent_data: AgentData) -> AgentState:
403407
else:
404408
return AgentState.GENERATING
405409

406-
async def _call_tool(self, tool_call: FunctionCall, tools_kwargs: dict[str, Any]) -> ToolResponse:
410+
async def _call_tool(
411+
self, tool_call: FunctionCall, tools_kwargs: dict[str, Any]
412+
) -> tuple[ToolResponse, float, dict]:
407413
"""Call tool and return tool response."""
408414
tool, instance_id = None, None
409415
try:
@@ -413,11 +419,15 @@ async def _call_tool(self, tool_call: FunctionCall, tools_kwargs: dict[str, Any]
413419
tool = self.tools[tool_name]
414420
kwargs = tools_kwargs.get(tool_name, {})
415421
instance_id, _ = await tool.create(create_kwargs=kwargs.get("create_kwargs", {}))
416-
tool_execution_response, _, _ = await tool.execute(instance_id, tool_args)
422+
tool_execution_response, tool_reward, res = await tool.execute(instance_id, tool_args)
417423
except Exception as e:
418424
logger.warning(f"Error when executing tool: {e}")
419-
return ToolResponse(
420-
text=f"Error when executing tool: {e}",
425+
return (
426+
ToolResponse(
427+
text=f"Error when executing tool: {e}",
428+
),
429+
0.0,
430+
{},
421431
)
422432
finally:
423433
if tool and instance_id:
@@ -443,7 +453,7 @@ async def _call_tool(self, tool_call: FunctionCall, tools_kwargs: dict[str, Any]
443453
if attr_value is not None:
444454
tool_response_kwargs[attr_name] = attr_value
445455

446-
return ToolResponse(**tool_response_kwargs)
456+
return ToolResponse(**tool_response_kwargs), tool_reward, res
447457

448458
@classmethod
449459
def _initialize_interactions(cls, interaction_config_file):

verl/workers/reward_manager/batch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@ def verify(self, data):
6161

6262
ground_truths = [item.non_tensor_batch["reward_model"].get("ground_truth", None) for item in data]
6363
data_sources = data.non_tensor_batch[self.reward_fn_key]
64-
extras = data.non_tensor_batch.get("extra_info", [None] * len(data))
64+
rollout_reward_scores = data.non_tensor_batch.get("reward_scores", [{} for _ in range(len(data))])
65+
extras = data.non_tensor_batch.get("extra_info", [{} for _ in range(len(data))])
66+
67+
for i in range(len(data)):
68+
extras[i]["rollout_reward_scores"] = rollout_reward_scores[i]
6569

6670
scores = self.compute_score(
6771
data_sources=data_sources,

verl/workers/reward_manager/dapo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ def __call__(self, data: DataProto, return_dict: bool = False):
9292

9393
data_source = data_item.non_tensor_batch[self.reward_fn_key]
9494

95-
extra_info = data_item.non_tensor_batch.get("extra_info", None)
95+
extra_info = data_item.non_tensor_batch.get("extra_info", {})
96+
97+
rollout_reward_scores = data_item.non_tensor_batch.get("reward_scores", {})
98+
99+
extra_info["rollout_reward_scores"] = rollout_reward_scores
96100

97101
result = self.compute_score(
98102
data_source=data_source,

verl/workers/reward_manager/naive.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor |
8282
data_source = data_item.non_tensor_batch[self.reward_fn_key]
8383
extra_info = data_item.non_tensor_batch.get("extra_info", {})
8484
num_turns = data_item.non_tensor_batch.get("__num_turns__", None)
85+
rollout_reward_scores = data_item.non_tensor_batch.get("reward_scores", {})
8586
extra_info["num_turns"] = num_turns
87+
extra_info["rollout_reward_scores"] = rollout_reward_scores
8688

8789
score = self.compute_score(
8890
data_source=data_source,

0 commit comments

Comments
 (0)