Commit 23aa105
authored
[training_utils] fix: enforce 1D object array shape for non-tensor data in collate_fn (#2741)
### What does this PR do?
This PR updates the `collate_fn` logic inside
`verl.utils.dataset.rl_dataset` to consistently handle non-tensor fields
as 1D object arrays, preventing runtime errors during concatenation in
downstream code such as `recipe/dapo/dapo_ray_trainer.py`.
### Checklist Before Starting
- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
- `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
- Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`
### Test
* Tested at: https://github.com/kibitzing/verl/tree/test_tool_n1
* Note: This branch is for testing purposes only and is not intended for
merge.
* The data used for testing comes from the `train.parquet` and
`test.parquet` files released by the [Tool N1
repository](https://github.com/NVlabs/Tool-N1).
* part of training script
```python
python3 -m recipe.dapo.main_dapo \
data.train_files=$HOME/Tool-N1/verl/verl/data/train.parquet \
data.val_files=$HOME/Tool-N1/verl/verl/data/test.parquet \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=2048 \
data.max_response_length=4096 \
data.gen_batch_size=32 \
data.train_batch_size=24 \
actor_rollout_ref.rollout.n=5 \
algorithm.adv_estimator=grpo \
algorithm.filter_groups.enable=True \
algorithm.filter_groups.max_num_gen_batches=10 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
...
```
### Before vs After Behavior (Real Output Logs)
* Before: Inconsistent Shape
```
(TaskRunner pid=114826) Training from scratch
(TaskRunner pid=114826) new_batch.non_tensor_batch["conversations"].shape=(32, 1)
(TaskRunner pid=114826) num_prompt_in_batch=3 < prompt_bsz=24
(TaskRunner pid=114826) num_gen_batches=1. Keep generating...
(TaskRunner pid=114826) new_batch.non_tensor_batch["conversations"].shape=(32, 1)
(TaskRunner pid=114826) num_prompt_in_batch=8 < prompt_bsz=24
(TaskRunner pid=114826) num_gen_batches=2. Keep generating...
(TaskRunner pid=114826) new_batch.non_tensor_batch["conversations"].shape=(32, 1)
(TaskRunner pid=114826) num_prompt_in_batch=13 < prompt_bsz=24
(TaskRunner pid=114826) num_gen_batches=3. Keep generating...
(TaskRunner pid=114826) new_batch.non_tensor_batch["conversations"].shape=(32,)
ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 1 has 1 dimension(s)
```
This caused shape inconsistency across steps, leading to downstream
errors during concatenation.
* After: Consistent (32,) Shape
```
(TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,)
(TaskRunner pid=133725) num_prompt_in_batch=4 < prompt_bsz=24
(TaskRunner pid=133725) num_gen_batches=1. Keep generating...
(TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,)
(TaskRunner pid=133725) num_prompt_in_batch=10 < prompt_bsz=24
(TaskRunner pid=133725) num_gen_batches=2. Keep generating...
(TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,)
(TaskRunner pid=133725) num_prompt_in_batch=12 < prompt_bsz=24
(TaskRunner pid=133725) num_gen_batches=3. Keep generating...
(TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,)
(TaskRunner pid=133725) num_prompt_in_batch=15 < prompt_bsz=24
(TaskRunner pid=133725) num_gen_batches=4. Keep generating...
(TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,)
(TaskRunner pid=133725) num_prompt_in_batch=19 < prompt_bsz=24
(TaskRunner pid=133725) num_gen_batches=5. Keep generating...
(TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,)
(TaskRunner pid=133725) num_prompt_in_batch=23 < prompt_bsz=24
(TaskRunner pid=133725) num_gen_batches=6. Keep generating...
(TaskRunner pid=133725) new_batch.non_tensor_batch["conversations"].shape=(32,)
```
With the updated logic, the shape is consistently (32,).
* The issue was traced back to the `"conversations"` field in the Tool
N1 dataset. This key contains a list of human–gpt messages. In most
examples, it's a single-turn conversation (list with length 1), but in
some cases, it's a multi-turn conversation (list with length > 1).
### Design & Code Changes
The current `collate_fn` processes non-tensor values with:
https://github.com/volcengine/verl/blob/1df03f3abf96f59cb90c684f93a71ee0bbb57f49/verl/utils/dataset/rl_dataset.py#L62-L63
While this generally works, it leads to a subtle issue:
If `val` is a list of lists and all inner lists happen to be of the same
length, NumPy will interpret it as a 2D array with shape (N, L).
However, in many RL scenarios, the structure of non-tensor data (e.g.
variable-length lists across batches) is not guaranteed to be uniform,
which means:
- One batch may produce shape `(N, L)`
- Another may produce `(N,)` where each element is a list of different
lengths
- Another may have shape `(N, L')`
This causes downstream errors like:
`ValueError: all the input arrays must have same number of dimensions,
but the array at index 0 has 2 dimension(s) and the array at index 1 has
1 dimension(s)`
Specifically, this occurs when multiple step-wise batches are
concatenated with:
https://github.com/volcengine/verl/blob/1df03f3abf96f59cb90c684f93a71ee0bbb57f49/recipe/dapo/dapo_ray_trainer.py#L240
To enforce consistent 1D object arrays regardless of content, this PR
replaces the original line with:
```python
for key, val in non_tensors.items():
non_tensors[key] = np.empty(len(val), dtype=object)
non_tensors[key][:] = val
```
This ensures that`non_tensors[key]` always has shape (N,) which makes
concatenation in downstream logic safer.
### Checklist Before Submitting
> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)1 parent 2cccd7f commit 23aa105
File tree
7 files changed
+103
-20
lines changed- recipe/spin
- tests/utils/dataset
- verl
- experimental/agent_loop
- utils/dataset
- workers
- rollout
- sglang_rollout
- vllm_rollout
7 files changed
+103
-20
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| 21 | + | |
21 | 22 | | |
22 | 23 | | |
23 | 24 | | |
| |||
483 | 484 | | |
484 | 485 | | |
485 | 486 | | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
486 | 492 | | |
487 | | - | |
488 | | - | |
489 | | - | |
490 | | - | |
| 493 | + | |
491 | 494 | | |
492 | 495 | | |
493 | 496 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
298 | 298 | | |
299 | 299 | | |
300 | 300 | | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
301 | 304 | | |
302 | | - | |
| 305 | + | |
303 | 306 | | |
304 | 307 | | |
305 | 308 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
60 | 60 | | |
61 | 61 | | |
62 | 62 | | |
63 | | - | |
| 63 | + | |
64 | 64 | | |
65 | 65 | | |
66 | 66 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
25 | 26 | | |
26 | 27 | | |
27 | 28 | | |
| |||
1526 | 1527 | | |
1527 | 1528 | | |
1528 | 1529 | | |
| 1530 | + | |
| 1531 | + | |
| 1532 | + | |
| 1533 | + | |
| 1534 | + | |
1529 | 1535 | | |
1530 | | - | |
1531 | | - | |
1532 | | - | |
1533 | | - | |
| 1536 | + | |
1534 | 1537 | | |
1535 | 1538 | | |
1536 | 1539 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
660 | 660 | | |
661 | 661 | | |
662 | 662 | | |
663 | | - | |
664 | 663 | | |
665 | | - | |
666 | | - | |
667 | | - | |
| 664 | + | |
| 665 | + | |
668 | 666 | | |
669 | 667 | | |
670 | 668 | | |
671 | 669 | | |
| 670 | + | |
| 671 | + | |
672 | 672 | | |
673 | 673 | | |
674 | 674 | | |
| |||
1266 | 1266 | | |
1267 | 1267 | | |
1268 | 1268 | | |
| 1269 | + | |
| 1270 | + | |
| 1271 | + | |
1269 | 1272 | | |
1270 | 1273 | | |
1271 | 1274 | | |
1272 | 1275 | | |
1273 | 1276 | | |
1274 | | - | |
| 1277 | + | |
1275 | 1278 | | |
1276 | 1279 | | |
1277 | 1280 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
276 | 276 | | |
277 | 277 | | |
278 | 278 | | |
279 | | - | |
280 | | - | |
281 | 279 | | |
282 | | - | |
283 | | - | |
284 | | - | |
| 280 | + | |
| 281 | + | |
285 | 282 | | |
286 | 283 | | |
287 | 284 | | |
288 | 285 | | |
| 286 | + | |
| 287 | + | |
289 | 288 | | |
290 | 289 | | |
291 | 290 | | |
| |||
0 commit comments