-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[megatron, worker] fix: use extract_multi_modal_inputs method for handling multi_modal_inputs
#3641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[megatron, worker] fix: use extract_multi_modal_inputs method for handling multi_modal_inputs
#3641
Conversation
…andling `multi_modal_inputs` Follow up for #3206 Without those changes in #3315, the error when we train the mixture modal dataset will remain unresolved, so it would be a good idea to add them back. ```logs File "verl/workers/actor/megatron_actor.py", line 639, in update_policy metric_micro_batch = self.forward_backward_batch( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "verl/workers/actor/megatron_actor.py", line 587, in forward_backward_batch losses_reduced = forward_backward_func( ^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 595, in forward_backward_no_pipelining output_tensor, num_tokens = forward_step( ^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 402, in forward_step output_tensor, loss_func = forward_step_func(data_iterator, model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "verl/workers/actor/megatron_actor.py", line 497, in forward_step multi_modal_inputs[key] = torch.cat( ^^^^^^^^^^ RuntimeError: torch.cat(): expected a non-empty list of Tensors ``` Signed-off-by: Hollow Man <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the handling of multi-modal inputs in megatron_actor.py by introducing and using the extract_multi_modal_inputs utility function. This change effectively fixes a RuntimeError that occurred with mixed-modality datasets. The new approach is cleaner, more robust against edge cases like empty input lists, and centralizes the data processing logic. While the fix is correct, I've suggested moving the local import to the top of the file to adhere to best practices and improve code maintainability, especially in a performance-sensitive area.
…andling `multi_modal_inputs` (volcengine#3641) Follow up for volcengine#3553 ### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Without those changes in volcengine#3315, the error when we train the mixture modal dataset will remain unresolved, so it would be a good idea to add them back. ```logs File "verl/workers/actor/megatron_actor.py", line 639, in update_policy metric_micro_batch = self.forward_backward_batch( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "verl/workers/actor/megatron_actor.py", line 587, in forward_backward_batch losses_reduced = forward_backward_func( ^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 595, in forward_backward_no_pipelining output_tensor, num_tokens = forward_step( ^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 402, in forward_step output_tensor, loss_func = forward_step_func(data_iterator, model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "verl/workers/actor/megatron_actor.py", line 497, in forward_step multi_modal_inputs[key] = torch.cat( ^^^^^^^^^^ RuntimeError: torch.cat(): expected a non-empty list of Tensors ``` ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Signed-off-by: Hollow Man <[email protected]>
…andling `multi_modal_inputs` (volcengine#3641) Follow up for volcengine#3553 ### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Without those changes in volcengine#3315, the error when we train the mixture modal dataset will remain unresolved, so it would be a good idea to add them back. ```logs File "verl/workers/actor/megatron_actor.py", line 639, in update_policy metric_micro_batch = self.forward_backward_batch( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "verl/workers/actor/megatron_actor.py", line 587, in forward_backward_batch losses_reduced = forward_backward_func( ^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 595, in forward_backward_no_pipelining output_tensor, num_tokens = forward_step( ^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 402, in forward_step output_tensor, loss_func = forward_step_func(data_iterator, model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "verl/workers/actor/megatron_actor.py", line 497, in forward_step multi_modal_inputs[key] = torch.cat( ^^^^^^^^^^ RuntimeError: torch.cat(): expected a non-empty list of Tensors ``` ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Signed-off-by: Hollow Man <[email protected]>
…andling `multi_modal_inputs` (volcengine#3641) Follow up for volcengine#3553 ### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Without those changes in volcengine#3315, the error when we train the mixture modal dataset will remain unresolved, so it would be a good idea to add them back. ```logs File "verl/workers/actor/megatron_actor.py", line 639, in update_policy metric_micro_batch = self.forward_backward_batch( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "verl/workers/actor/megatron_actor.py", line 587, in forward_backward_batch losses_reduced = forward_backward_func( ^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 595, in forward_backward_no_pipelining output_tensor, num_tokens = forward_step( ^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 402, in forward_step output_tensor, loss_func = forward_step_func(data_iterator, model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "verl/workers/actor/megatron_actor.py", line 497, in forward_step multi_modal_inputs[key] = torch.cat( ^^^^^^^^^^ RuntimeError: torch.cat(): expected a non-empty list of Tensors ``` ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Signed-off-by: Hollow Man <[email protected]>
…andling `multi_modal_inputs` (volcengine#3641) Follow up for volcengine#3553 ### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Without those changes in volcengine#3315, the error when we train the mixture modal dataset will remain unresolved, so it would be a good idea to add them back. ```logs File "verl/workers/actor/megatron_actor.py", line 639, in update_policy metric_micro_batch = self.forward_backward_batch( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "verl/workers/actor/megatron_actor.py", line 587, in forward_backward_batch losses_reduced = forward_backward_func( ^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 595, in forward_backward_no_pipelining output_tensor, num_tokens = forward_step( ^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 402, in forward_step output_tensor, loss_func = forward_step_func(data_iterator, model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "verl/workers/actor/megatron_actor.py", line 497, in forward_step multi_modal_inputs[key] = torch.cat( ^^^^^^^^^^ RuntimeError: torch.cat(): expected a non-empty list of Tensors ``` ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Signed-off-by: Hollow Man <[email protected]>
Follow up for #3553
What does this PR do?
Without those changes in #3315, the error when we train the mixture modal dataset will remain unresolved, so it would be a good idea to add them back.
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)