Skip to content

Commit e5af9fb

Browse files
HollowMan6techkang
authored andcommitted
[megatron, worker] fix: use extract_multi_modal_inputs method for handling multi_modal_inputs (volcengine#3641)
Follow up for volcengine#3553 ### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Without those changes in volcengine#3315, the error when we train the mixture modal dataset will remain unresolved, so it would be a good idea to add them back. ```logs File "verl/workers/actor/megatron_actor.py", line 639, in update_policy metric_micro_batch = self.forward_backward_batch( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "verl/workers/actor/megatron_actor.py", line 587, in forward_backward_batch losses_reduced = forward_backward_func( ^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 595, in forward_backward_no_pipelining output_tensor, num_tokens = forward_step( ^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 402, in forward_step output_tensor, loss_func = forward_step_func(data_iterator, model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "verl/workers/actor/megatron_actor.py", line 497, in forward_step multi_modal_inputs[key] = torch.cat( ^^^^^^^^^^ RuntimeError: torch.cat(): expected a non-empty list of Tensors ``` ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Signed-off-by: Hollow Man <[email protected]>
1 parent 4e020a8 commit e5af9fb

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

verl/workers/actor/megatron_actor.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -491,12 +491,10 @@ def forward_step(batch_iter, model):
491491

492492
multi_modal_inputs = {}
493493
if "multi_modal_inputs" in batch:
494-
for key in batch["multi_modal_inputs"][0].keys():
495-
idxs = batch["multi_modal_inputs_idx"]
496-
mmi = batch["multi_modal_inputs"]
497-
multi_modal_inputs[key] = torch.cat(
498-
[mmi[idx].get(key) for idx in idxs if mmi[idx].get(key) is not None], dim=0
499-
)
494+
from verl.utils.model import extract_multi_modal_inputs
495+
496+
indices = batch.get("multi_modal_inputs_idx", None)
497+
multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices)
500498
responses = batch["responses"]
501499
response_length = responses.size(1)
502500
label = position_ids.clone()

0 commit comments

Comments
 (0)