Skip to content

Commit 23aa105

Browse files
authored
[training_utils] fix: enforce 1D object array shape for non-tensor data in collate_fn (#2741)
### What does this PR do? This PR updates the `collate_fn` logic inside `verl.utils.dataset.rl_dataset` to consistently handle non-tensor fields as 1D object arrays, preventing runtime errors during concatenation in downstream code such as `recipe/dapo/dapo_ray_trainer.py`. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test * Tested at: https://github.com/kibitzing/verl/tree/test_tool_n1 * Note: This branch is for testing purposes only and is not intended for merge. * The data used for testing comes from the `train.parquet` and `test.parquet` files released by the [Tool N1 repository](https://github.com/NVlabs/Tool-N1). * part of training script ```python python3 -m recipe.dapo.main_dapo \ data.train_files=$HOME/Tool-N1/verl/verl/data/train.parquet \ data.val_files=$HOME/Tool-N1/verl/verl/data/test.parquet \ data.prompt_key=prompt \ data.truncation='left' \ data.max_prompt_length=2048 \ data.max_response_length=4096 \ data.gen_batch_size=32 \ data.train_batch_size=24 \ actor_rollout_ref.rollout.n=5 \ algorithm.adv_estimator=grpo \ algorithm.filter_groups.enable=True \ algorithm.filter_groups.max_num_gen_batches=10 \ actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ ... ``` ### Before vs After Behavior (Real Output Logs) * Before: Inconsistent Shape ``` (TaskRunner pid=114826) Training from scratch (TaskRunner pid=114826) new_batch.non_tensor_batch["conversations"].shape=(32, 1) (TaskRunner pid=114826) num_prompt_in_batch=3 < prompt_bsz=24 (TaskRunner pid=114826) num_gen_batches=1. Keep generating... (TaskRunner pid=114826) new_batch.non_tensor_batch["conversations"].shape=(32, 1) (TaskRunner pid=114826) num_prompt_in_batch=8 < prompt_bsz=24 (TaskRunner pid=114826) num_gen_batches=2. Keep generating... (TaskRunner pid=114826) new_batch.non_tensor_batch["conversations"].shape=(32, 1) (TaskRunner pid=114826) num_prompt_in_batch=13 < prompt_bsz=24 (TaskRunner pid=114826) num_gen_batches=3. Keep generating... (TaskRunner pid=114826) new_batch.non_tensor_batch["conversations"].shape=(32,) ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 1 has 1 dimension(s) ``` This caused shape inconsistency across steps, leading to downstream errors during concatenation. * After: Consistent (32,) Shape ``` (TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,) (TaskRunner pid=133725) num_prompt_in_batch=4 < prompt_bsz=24 (TaskRunner pid=133725) num_gen_batches=1. Keep generating... (TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,) (TaskRunner pid=133725) num_prompt_in_batch=10 < prompt_bsz=24 (TaskRunner pid=133725) num_gen_batches=2. Keep generating... (TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,) (TaskRunner pid=133725) num_prompt_in_batch=12 < prompt_bsz=24 (TaskRunner pid=133725) num_gen_batches=3. Keep generating... (TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,) (TaskRunner pid=133725) num_prompt_in_batch=15 < prompt_bsz=24 (TaskRunner pid=133725) num_gen_batches=4. Keep generating... (TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,) (TaskRunner pid=133725) num_prompt_in_batch=19 < prompt_bsz=24 (TaskRunner pid=133725) num_gen_batches=5. Keep generating... (TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,) (TaskRunner pid=133725) num_prompt_in_batch=23 < prompt_bsz=24 (TaskRunner pid=133725) num_gen_batches=6. Keep generating... (TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,) ``` With the updated logic, the shape is consistently (32,). * The issue was traced back to the `"conversations"` field in the Tool N1 dataset. This key contains a list of human–gpt messages. In most examples, it's a single-turn conversation (list with length 1), but in some cases, it's a multi-turn conversation (list with length > 1). ### Design & Code Changes The current `collate_fn` processes non-tensor values with: https://github.com/volcengine/verl/blob/1df03f3abf96f59cb90c684f93a71ee0bbb57f49/verl/utils/dataset/rl_dataset.py#L62-L63 While this generally works, it leads to a subtle issue: If `val` is a list of lists and all inner lists happen to be of the same length, NumPy will interpret it as a 2D array with shape (N, L). However, in many RL scenarios, the structure of non-tensor data (e.g. variable-length lists across batches) is not guaranteed to be uniform, which means: - One batch may produce shape `(N, L)` - Another may produce `(N,)` where each element is a list of different lengths - Another may have shape `(N, L')` This causes downstream errors like: `ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 1 has 1 dimension(s)` Specifically, this occurs when multiple step-wise batches are concatenated with: https://github.com/volcengine/verl/blob/1df03f3abf96f59cb90c684f93a71ee0bbb57f49/recipe/dapo/dapo_ray_trainer.py#L240 To enforce consistent 1D object arrays regardless of content, this PR replaces the original line with: ```python for key, val in non_tensors.items(): non_tensors[key] = np.empty(len(val), dtype=object) non_tensors[key][:] = val ``` This ensures that`non_tensors[key]` always has shape (N,) which makes concatenation in downstream logic safer. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 2cccd7f commit 23aa105

File tree

7 files changed

+103
-20
lines changed

7 files changed

+103
-20
lines changed

recipe/spin/fsdp_workers.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
import warnings
2020

21+
import numpy as np
2122
import psutil
2223
import torch
2324
import torch.distributed
@@ -483,11 +484,13 @@ def _switch_chat_template(self, data: DataProto):
483484
rm_attention_mask = []
484485

485486
for i in range(data.batch.batch_size[0]):
487+
if not isinstance(data.non_tensor_batch["raw_prompt"][i], list | np.ndarray):
488+
raise TypeError(
489+
f"raw_prompt must be a list or numpy array, got {type(data.non_tensor_batch['raw_prompt'][i])}"
490+
)
491+
486492
# extract raw prompt
487-
if isinstance(data.non_tensor_batch["raw_prompt"][i], list):
488-
chat: list = data.non_tensor_batch["raw_prompt"][i]
489-
else:
490-
chat: list = data.non_tensor_batch["raw_prompt"][i].tolist()
493+
chat: list = list(data.non_tensor_batch["raw_prompt"][i])
491494

492495
# extract response
493496
response_ids = data.batch["responses"][i]
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import torch
15+
16+
17+
def test_rl_collate_fn():
18+
from verl.utils.dataset.rl_dataset import collate_fn
19+
20+
max_prompt_length = 5
21+
22+
test_data = [
23+
{
24+
# test tensor
25+
"input_ids": torch.randint(0, 10, (max_prompt_length,)),
26+
# test fixed length (1) list within a batch
27+
"messages": [{"role": "user", "content": "Hi."}],
28+
# test variable length list within a batch
29+
"raw_prompt_ids": [1, 2, 3, 4],
30+
# test string
31+
"ability": "math",
32+
# test dict
33+
"reward_model": {"ground_truth": 5, "style": "rule"},
34+
# test empty dict
35+
"tools_kwargs": {},
36+
},
37+
{
38+
"input_ids": torch.randint(0, 10, (max_prompt_length,)),
39+
"messages": [{"role": "user", "content": "Hello."}],
40+
"raw_prompt_ids": [1, 2, 3],
41+
"ability": "toolcall",
42+
"reward_model": {
43+
"ground_truth": '[{"name": "rgb_to_cmyk", "arguments": {"r": 0, "g": 0, "b": 255}}]',
44+
"style": "rule",
45+
},
46+
"tools_kwargs": {},
47+
},
48+
]
49+
50+
batch_size = len(test_data)
51+
batch = collate_fn(test_data)
52+
53+
# Tensor part
54+
assert batch["input_ids"].shape == (batch_size, max_prompt_length)
55+
assert isinstance(batch["input_ids"], torch.Tensor)
56+
57+
# Non-tensor parts
58+
expected_types = {
59+
"messages": list,
60+
"raw_prompt_ids": list,
61+
"ability": str,
62+
"reward_model": dict,
63+
"tools_kwargs": dict,
64+
}
65+
66+
for key, dtype in expected_types.items():
67+
assert batch[key].shape == (batch_size,), (
68+
f"Expected shape {(batch_size,)} for '{key}', but got {batch[key].shape}"
69+
)
70+
assert isinstance(batch[key][0], dtype), (
71+
f"'{key}' should contain elements of type {dtype}, but got {type(batch[key][0])}"
72+
)

verl/experimental/agent_loop/agent_loop.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,11 @@ async def generate_sequences(self, batch: DataProto) -> DataProto:
298298
)
299299

300300
for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True):
301+
if not isinstance(messages, list | np.ndarray):
302+
raise TypeError(f"messages must be a list or numpy array, got {type(messages)}")
303+
301304
tasks.append(
302-
asyncio.create_task(self._run_agent_loop(agent_name, messages.tolist(), sampling_params, trajectory))
305+
asyncio.create_task(self._run_agent_loop(agent_name, list(messages), sampling_params, trajectory))
303306
)
304307
outputs = await asyncio.gather(*tasks)
305308

verl/utils/dataset/rl_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def collate_fn(data_list: list[dict]) -> dict:
6060
tensors[key] = torch.stack(val, dim=0)
6161

6262
for key, val in non_tensors.items():
63-
non_tensors[key] = np.array(val, dtype=object)
63+
non_tensors[key] = np.fromiter(val, dtype=object, count=len(val))
6464

6565
return {**tensors, **non_tensors}
6666

verl/workers/fsdp_workers.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from dataclasses import asdict
2323
from typing import Any
2424

25+
import numpy as np
2526
import psutil
2627
import torch
2728
import torch.distributed
@@ -1526,11 +1527,13 @@ def _switch_chat_template(self, data: DataProto):
15261527
rm_attention_mask = []
15271528

15281529
for i in range(data.batch.batch_size[0]):
1530+
if not isinstance(data.non_tensor_batch["raw_prompt"][i], list | np.ndarray):
1531+
raise TypeError(
1532+
f"raw_prompt must be a list or numpy array, got {type(data.non_tensor_batch['raw_prompt'][i])}"
1533+
)
1534+
15291535
# extract raw prompt
1530-
if isinstance(data.non_tensor_batch["raw_prompt"][i], list):
1531-
chat: list = data.non_tensor_batch["raw_prompt"][i]
1532-
else:
1533-
chat: list = data.non_tensor_batch["raw_prompt"][i].tolist()
1536+
chat: list = list(data.non_tensor_batch["raw_prompt"][i])
15341537

15351538
# extract response
15361539
response_ids = data.batch["responses"][i]

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -660,15 +660,15 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP
660660
{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")
661661
]
662662

663-
# Ensure token IDs are lists or numpy arrays
664663
for input_data in sglang_inputs:
665-
if isinstance(input_data["prompt_token_ids"], np.ndarray):
666-
input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist()
667-
elif not isinstance(input_data["prompt_token_ids"], list):
664+
# Ensure token IDs are lists or numpy arrays
665+
if not isinstance(input_data["prompt_token_ids"], list | np.ndarray):
668666
raise TypeError(
669667
f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}"
670668
)
671669

670+
input_data["prompt_token_ids"] = list(input_data["prompt_token_ids"])
671+
672672
# Extract token IDs and image data for SGLang Engine
673673
idx_list = [input_data["prompt_token_ids"] for input_data in sglang_inputs]
674674
image_list = [input_data.get("image_data", None) for input_data in sglang_inputs]
@@ -1266,12 +1266,15 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in
12661266
else:
12671267
_interaction_kwargs = {}
12681268

1269+
if not isinstance(raw_prompt, list | np.ndarray):
1270+
raise TypeError(f"raw_prompt must be a list or numpy array, got {type(raw_prompt)}")
1271+
12691272
req = AsyncRolloutRequest(
12701273
batch_data_id=data_idx,
12711274
rollout_offset=0,
12721275
request_id=str(uuid4()),
12731276
state=AsyncRolloutRequestStateEnum.PENDING,
1274-
messages=raw_prompt.tolist(),
1277+
messages=list(raw_prompt),
12751278
multi_modal_data=multi_modal_data,
12761279
tool_schemas=_tool_schemas,
12771280
tools_kwargs=_tools_kwargs,

verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,16 +276,15 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
276276
{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")
277277
]
278278

279-
# ensure the type of `prompt_token_ids` passed to vllm is list[int]
280-
# https://github.com/volcengine/verl/pull/772
281279
for input_data in vllm_inputs:
282-
if isinstance(input_data["prompt_token_ids"], np.ndarray):
283-
input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist()
284-
elif not isinstance(input_data["prompt_token_ids"], list):
280+
# Ensure token IDs are lists or numpy arrays
281+
if not isinstance(input_data["prompt_token_ids"], list | np.ndarray):
285282
raise TypeError(
286283
f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}"
287284
)
288285

286+
input_data["prompt_token_ids"] = list(input_data["prompt_token_ids"])
287+
289288
do_sample = prompts.meta_info.get("do_sample", True)
290289
is_validate = prompts.meta_info.get("validate", False)
291290
if not do_sample:

0 commit comments

Comments
 (0)