-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[misc] refactor: deprecate sharding manager (part 1) #2912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[misc] refactor: deprecate sharding manager (part 1) #2912
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the fsdp_workers to deprecate the use of sharding_manager for data preprocessing and postprocessing, as part of a larger effort. The changes introduce a new dispatch mechanism using make_nd_compute_dataproto_dispatch_fn and _register_dispatch_collect_info to handle data distribution, which is a solid approach. The modifications are extensive and consistently applied across actor, critic, and reward model workers. Additionally, there are improvements to loss metric scaling in the actor and critic workers. My main feedback is to address a case of code duplication to improve maintainability.
verl/workers/fsdp_workers.py
Outdated
| is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0 | ||
| self._register_dispatch_collect_info( | ||
| "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code for registering dispatch information is identical to the one for the vllm rollout case (lines 537-540). To improve maintainability and avoid potential bugs from inconsistent updates, consider extracting this logic into a helper method or refactoring the conditional structure to avoid duplication.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant refactoring to deprecate the sharding_manager in favor of a new dispatch mechanism based on device_mesh. This is a positive change that should simplify the sharding logic. The changes are applied consistently across the affected files.
My review has identified two critical issues related to incorrect metric scaling in dp_actor.py and dp_critic.py. These bugs could lead to misleading metrics and affect model development and evaluation. I have provided detailed explanations and code suggestions to address these issues. The rest of the refactoring appears to be solid and well-executed.
verl/workers/actor/dp_actor.py
Outdated
| "actor/pg_loss": pg_loss.detach().item() | ||
| * (response_mask.shape[0] / self.config.ppo_mini_batch_size), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an inconsistency in how the loss is scaled for backpropagation versus how the actor/pg_loss metric is scaled.
When self.config.use_dynamic_bsz is False, the loss for backward() is scaled by 1 / self.gradient_accumulation (line 459). However, the actor/pg_loss metric is always scaled by response_mask.shape[0] / self.config.ppo_mini_batch_size.
These two scaling factors are not equivalent when self.config.ppo_mini_batch_size is not perfectly divisible by self.config.ppo_micro_batch_size_per_gpu, because self.gradient_accumulation uses integer division (//).
For example, if ppo_mini_batch_size=10 and ppo_micro_batch_size_per_gpu=3:
self.gradient_accumulationwould be10 // 3 = 3.- The loss scaling for backpropagation would be
1 / 3. - The metric scaling would be
3 / 10.
This discrepancy will lead to misleading metrics for actor/pg_loss. The metric scaling should match the loss scaling used for backpropagation to ensure correctness.
A similar issue exists in verl/workers/critic/dp_critic.py for critic/vf_loss.
"actor/pg_loss": pg_loss.detach().item()
* (
(response_mask.shape[0] / self.config.ppo_mini_batch_size)
if self.config.use_dynamic_bsz
else (1.0 / self.gradient_accumulation)
)
verl/workers/critic/dp_critic.py
Outdated
| "critic/vf_loss": vf_loss.detach().item() | ||
| * (response_mask.shape[0] / self.config.ppo_mini_batch_size), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the issue in dp_actor.py, there's an inconsistency in how the critic loss is scaled for backpropagation versus how the critic/vf_loss metric is scaled.
When self.config.use_dynamic_bsz is False, the loss for backward() is scaled by 1 / self.gradient_accumulation (line 243). However, the critic/vf_loss metric is always scaled by response_mask.shape[0] / self.config.ppo_mini_batch_size.
This discrepancy, which arises when ppo_mini_batch_size is not perfectly divisible by ppo_micro_batch_size_per_gpu, will lead to misleading metrics for critic/vf_loss. The metric scaling should match the loss scaling used for backpropagation.
"critic/vf_loss": vf_loss.detach().item()
* (
(response_mask.shape[0] / self.config.ppo_mini_batch_size)
if self.config.use_dynamic_bsz
else (1.0 / self.gradient_accumulation)
)|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the sharding manager out of the FSDP workers, which is a significant and positive architectural change. The new dispatch mechanism using make_nd_compute_dataproto_dispatch_fn is consistently applied across the codebase.
My review has identified two main issues:
- A critical issue in
verl/single_controller/base/worker.pywhere dispatch information is stored in class attributes, which will cause conflicts when multiple workers are used in the same process. - A high-severity issue in
verl/workers/actor/dp_actor.pyregarding inconsistent scaling of logged loss metrics, which could lead to incorrect monitoring and analysis.
The rest of the changes appear to correctly implement the intended refactoring. Addressing these points will improve the robustness and correctness of the new implementation.
| if mesh_name in self.__dispatch_dp_rank or mesh_name in self.__collect_dp_rank: | ||
| raise ValueError(f"mesh_name {mesh_name} has been registered") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is a good addition for safety. However, it highlights a more fundamental issue with the current implementation. __dispatch_dp_rank and __collect_dp_rank are defined as class attributes on the Worker class (lines 73-74), which means they are shared across all Worker instances within the same process.
This will lead to issues when multiple workers are instantiated in the same process (e.g., an actor worker and a critic worker), as they will attempt to write to the same shared dictionaries. This will either raise a ValueError due to this new check or, worse, lead to silent bugs from overwritten dispatch information.
These attributes should be instance-specific. The correct fix is to initialize them as instance attributes in Worker.__init__:
# In Worker.__init__
self.__dispatch_dp_rank = {}
self.__collect_dp_rank = {}And remove the class-level definitions. Since __init__ and the class attribute definitions are not part of this diff, I cannot suggest the change directly, but this is a critical issue that needs to be addressed to ensure correctness.
| micro_batch_metrics.update( | ||
| { | ||
| "actor/pg_loss": pg_loss.detach().item(), | ||
| "actor/pg_loss": pg_loss.detach().item() * loss_scale_factor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You've correctly scaled the pg_loss metric by loss_scale_factor to be consistent with the loss value used for backpropagation. This is a good improvement for metric correctness.
However, there's an inconsistency. If self.config.use_kl_loss is true, kl_loss is also a component of the total policy_loss, but it is logged without being scaled by loss_scale_factor (on line 452). For consistent and interpretable metrics, all reported loss components should be scaled in the same way as the total loss. Please consider scaling kl_loss as well when it's added to micro_batch_metrics.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant refactoring to deprecate the sharding_manager in favor of a more explicit, mesh-based dispatch mechanism. The changes are applied consistently across fsdp_workers and megatron_workers.
The most important change is a critical correctness fix in the Worker base class, where shared class attributes for dispatch information were replaced with instance attributes. This prevents state corruption when multiple worker groups are co-located.
Overall, the refactoring improves code clarity, maintainability, and correctness.
| loss = policy_loss * loss_scale_factor | ||
| else: | ||
| loss = policy_loss / self.gradient_accumulation | ||
| loss = policy_loss * loss_scale_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a question, policy_loss is mean loss of all tokens of micro batch samples, why we need loss_scale_factor here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It needs to divide gradient accumulation
### What does this PR do?
- Since we introduce register device_mesh inside the worker, there is no
need to use sharding manager any longer. We will remove the usage of
sharding manager gradually in the main branch.
- This PR removes the sharding manager usage inside fsdp_workers
### Checklist Before Starting
- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
- `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
- Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`
### Test
> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.
### API and Usage Example
> Demonstrate how the API changes if any, and provide usage example(s)
if possible.
```python
# Add code snippet or script demonstrating how to use this
```
### Design & Code Changes
> Demonstrate the high-level design if this PR is complex, and list the
specific changes.
### Checklist Before Submitting
> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.
- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
### What does this PR do?
- Since we introduce register device_mesh inside the worker, there is no
need to use sharding manager any longer. We will remove the usage of
sharding manager gradually in the main branch.
- This PR removes the sharding manager usage inside fsdp_workers
### Checklist Before Starting
- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
- `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
- Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`
### Test
> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.
### API and Usage Example
> Demonstrate how the API changes if any, and provide usage example(s)
if possible.
```python
# Add code snippet or script demonstrating how to use this
```
### Design & Code Changes
> Demonstrate the high-level design if this PR is complex, and list the
specific changes.
### Checklist Before Submitting
> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.
- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
### What does this PR do?
- Since we introduce register device_mesh inside the worker, there is no
need to use sharding manager any longer. We will remove the usage of
sharding manager gradually in the main branch.
- This PR removes the sharding manager usage inside fsdp_workers
### Checklist Before Starting
- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
- `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
- Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`
### Test
> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.
### API and Usage Example
> Demonstrate how the API changes if any, and provide usage example(s)
if possible.
```python
# Add code snippet or script demonstrating how to use this
```
### Design & Code Changes
> Demonstrate the high-level design if this PR is complex, and list the
specific changes.
### Checklist Before Submitting
> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.
- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
What does this PR do?
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)