-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[megatron] fix: resolve backward propagation error in megatron_actor due to shared logits tensor in-place modification #2484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[megatron] fix: resolve backward propagation error in megatron_actor due to shared logits tensor in-place modification #2484
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR fixes a critical bug in the backward propagation step when entropy regularization is enabled, by cloning the logits tensor. I've added a high-severity comment suggesting an optimization to make the tensor clone conditional, preventing unnecessary memory allocation and computation when entropy is not calculated.
verl/workers/actor/megatron_actor.py
Outdated
| logits_bak = logits.clone() | ||
| if calculate_entropy: | ||
| entropy = vocab_parallel_entropy(logits) | ||
| ret["entropy"] = entropy | ||
| log_probs = vocab_parallel_log_probs_from_logits(logits, label) | ||
| log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label) | ||
| log_probs = log_probs.masked_fill(~label_mask, 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logits.clone() operation is performed unconditionally. This clone is only necessary when calculate_entropy is True to prevent in-place modification from corrupting the computation graph. When calculate_entropy is False, this clone is redundant and introduces unnecessary overhead. To optimize, perform the clone operation conditionally.
| logits_bak = logits.clone() | |
| if calculate_entropy: | |
| entropy = vocab_parallel_entropy(logits) | |
| ret["entropy"] = entropy | |
| log_probs = vocab_parallel_log_probs_from_logits(logits, label) | |
| log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label) | |
| log_probs = log_probs.masked_fill(~label_mask, 0.0) | |
| if calculate_entropy: | |
| logits_bak = logits.clone() | |
| entropy = vocab_parallel_entropy(logits_bak) | |
| ret["entropy"] = entropy | |
| log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label) | |
| else: | |
| log_probs = vocab_parallel_log_probs_from_logits(logits, label) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contribution, @HelloWorld686 Could you consider this advice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, thanks
|
Notice that here verl/verl/utils/megatron/tensor_parallel.py Line 133 in d9a6a31
And verl/verl/utils/megatron/tensor_parallel.py Lines 136 to 137 in d9a6a31
|
During the forward pass of |
|
In megatron, entropy is only computed for observation, not for regularization. logits.clone consumes too much memory and is not affordable. We need a kernel to make entropy regularization happen |
|
#2210 has been merged as mcore side fused kernel function. Shall we fix the correctness here? |
Thank you for your feedback. While memory efficiency is important, ensuring functional correctness should be our top priority—especially since the use_fused_kernels path doesn’t cover all use cases. The clone operation prevents bugs in the non-use_fused_kernels path. (We can’t just force-enable fused kernels to avoid bugs, can we?) What do you think about prioritizing correctness across all scenarios first, then optimizing memory afterward? |
dddbf6c to
9756987
Compare
9756987 to
503ea75
Compare
- Issue: vocab_parallel_log_probs_from_logits modified original logits tensor during log_probs computation,
causing gradient errors when entropy calculation (entropy_coeff ≠ 0) used the same tensor in backward pass.
- Fix: Compute entropy using cloned logits tensor to prevent in-place modification side effects.
…necessary memory allocation and computation when entropy is not calculated
…nel to allow memory efficient calculation.
108a375 to
7fb2eca
Compare
|
@ETOgaosion , I've rebased my PR with the latest code and now some of the tests have passed, including the one that previously failed. However, the PR run was cancelled a few days ago and there hasn't been any further activity. Could you please re-review the PR when you have a chance? Thank you! |
…due to shared logits tensor in-place modification (volcengine#2484) ### What does this PR do? Fixes gradient computation conflict in `verl/workers/actor/megatron_actor.py` when entropy regularization is enabled: - **Root Cause**: The entropy calculation `entropy = vocab_parallel_entropy(logits)` fails during backward propagation because `log_probs = vocab_parallel_log_probs_from_logits(logits, label)` performs in-place modifications on the logits tensor earlier in the code. This corrupts the original computation graph needed for gradient calculation. - **Fix**: Decouples tensor dependencies by cloning logits before entropy calculation to preserve the original computation graph while maintaining existing log_probs computation flow. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test 1. Run modified training script: ```bash examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh \ --actor_rollout_ref.actor.entropy_coeff=0.01 ``` 2. The following error is observed (before repair): <img width="1396" height="605" alt="image" src="https://github.com/user-attachments/assets/0ed0f9f8-f4eb-41d3-9db8-c8f2163de910" /> ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
…due to shared logits tensor in-place modification (volcengine#2484) ### What does this PR do? Fixes gradient computation conflict in `verl/workers/actor/megatron_actor.py` when entropy regularization is enabled: - **Root Cause**: The entropy calculation `entropy = vocab_parallel_entropy(logits)` fails during backward propagation because `log_probs = vocab_parallel_log_probs_from_logits(logits, label)` performs in-place modifications on the logits tensor earlier in the code. This corrupts the original computation graph needed for gradient calculation. - **Fix**: Decouples tensor dependencies by cloning logits before entropy calculation to preserve the original computation graph while maintaining existing log_probs computation flow. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test 1. Run modified training script: ```bash examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh \ --actor_rollout_ref.actor.entropy_coeff=0.01 ``` 2. The following error is observed (before repair): <img width="1396" height="605" alt="image" src="https://github.com/user-attachments/assets/0ed0f9f8-f4eb-41d3-9db8-c8f2163de910" /> ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
* origin/mindspeed: (39 commits) [perf] feat: add optional role selection in discrete mode for NPU Profiler (volcengine#2750) [rollout] feat: remove chat scheduler (volcengine#2725) [trainer] refactor: Make sure to keep the type checking (volcengine#2634) [doc] style: change resize handle from gradient to plain color (volcengine#2746) [CI] feat: add `mypy` to pre-commit (volcengine#2614) [megatron] feat: a bunch of optimzation on vram, sequence packing (volcengine#2678) [docker] feat: upgrade to torch 2.7, sglang 0.4.8 (volcengine#2617) [doc] feat: add resizable sidebar and improve layout (volcengine#2577) [ci] fix: release ascend test time, fix one step off-policy CI (volcengine#2731) [recipe] chore: add retool training script (volcengine#2732) [ci] fix: checkpoint_convertor ci miss a hf model download (volcengine#2730) [doc] feat: Add agent-lightning in the list of "awesome works using verl (volcengine#2726) [tool] fix: geo3k create return str instead of tuple (volcengine#2714) [megatron] fix: resolve backward propagation error in megatron_actor due to shared logits tensor in-place modification (volcengine#2484) [misc] chore: bump main branch version to v0.5.0.dev (volcengine#2718) [sglang] fix: Adding strict naming sanity for sglang (volcengine#2719) [ray] feat: RayWorkerGroup support set worker env (volcengine#2685) [ci] test: add CriticWorker unit test, make some util CPU friendly (volcengine#2717) [cfg] refactor: add ActorConfig, EngineConfig, and ActorWorker unit test, refactor validation code (volcengine#2621) [misc] chore: bump version to v0.5.0 (volcengine#2716) ...
…due to shared logits tensor in-place modification (volcengine#2484) ### What does this PR do? Fixes gradient computation conflict in `verl/workers/actor/megatron_actor.py` when entropy regularization is enabled: - **Root Cause**: The entropy calculation `entropy = vocab_parallel_entropy(logits)` fails during backward propagation because `log_probs = vocab_parallel_log_probs_from_logits(logits, label)` performs in-place modifications on the logits tensor earlier in the code. This corrupts the original computation graph needed for gradient calculation. - **Fix**: Decouples tensor dependencies by cloning logits before entropy calculation to preserve the original computation graph while maintaining existing log_probs computation flow. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test 1. Run modified training script: ```bash examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh \ --actor_rollout_ref.actor.entropy_coeff=0.01 ``` 2. The following error is observed (before repair): <img width="1396" height="605" alt="image" src="https://github.com/user-attachments/assets/0ed0f9f8-f4eb-41d3-9db8-c8f2163de910" /> ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
…due to shared logits tensor in-place modification (volcengine#2484) ### What does this PR do? Fixes gradient computation conflict in `verl/workers/actor/megatron_actor.py` when entropy regularization is enabled: - **Root Cause**: The entropy calculation `entropy = vocab_parallel_entropy(logits)` fails during backward propagation because `log_probs = vocab_parallel_log_probs_from_logits(logits, label)` performs in-place modifications on the logits tensor earlier in the code. This corrupts the original computation graph needed for gradient calculation. - **Fix**: Decouples tensor dependencies by cloning logits before entropy calculation to preserve the original computation graph while maintaining existing log_probs computation flow. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test 1. Run modified training script: ```bash examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh \ --actor_rollout_ref.actor.entropy_coeff=0.01 ``` 2. The following error is observed (before repair): <img width="1396" height="605" alt="image" src="https://github.com/user-attachments/assets/0ed0f9f8-f4eb-41d3-9db8-c8f2163de910" /> ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
…due to shared logits tensor in-place modification (volcengine#2484) ### What does this PR do? Fixes gradient computation conflict in `verl/workers/actor/megatron_actor.py` when entropy regularization is enabled: - **Root Cause**: The entropy calculation `entropy = vocab_parallel_entropy(logits)` fails during backward propagation because `log_probs = vocab_parallel_log_probs_from_logits(logits, label)` performs in-place modifications on the logits tensor earlier in the code. This corrupts the original computation graph needed for gradient calculation. - **Fix**: Decouples tensor dependencies by cloning logits before entropy calculation to preserve the original computation graph while maintaining existing log_probs computation flow. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test 1. Run modified training script: ```bash examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh \ --actor_rollout_ref.actor.entropy_coeff=0.01 ``` 2. The following error is observed (before repair): <img width="1396" height="605" alt="image" src="https://github.com/user-attachments/assets/0ed0f9f8-f4eb-41d3-9db8-c8f2163de910" /> ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
What does this PR do?
Fixes gradient computation conflict in
verl/workers/actor/megatron_actor.pywhen entropy regularization is enabled:entropy = vocab_parallel_entropy(logits)fails during backward propagation becauselog_probs = vocab_parallel_log_probs_from_logits(logits, label)performs in-place modifications on the logits tensor earlier in the code. This corrupts the original computation graph needed for gradient calculation.Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace.