-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[misc] refactor: deprecate sharding manager (part 1) #2912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fa528ab
0c2c1de
a4004c1
ef2920f
5665405
6bfda69
29e13b9
d1cf8ad
fa9389b
4667b6f
87dcba2
34c9752
9f74633
f89d8d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -410,6 +410,11 @@ def update_policy(self, data: DataProto): | |
| entropy_coeff = self.config.entropy_coeff | ||
| loss_agg_mode = self.config.loss_agg_mode | ||
|
|
||
| if self.config.use_dynamic_bsz: | ||
| loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size | ||
| else: | ||
| loss_scale_factor = 1 / self.gradient_accumulation | ||
|
|
||
| # all return: (bsz, response_length) | ||
| calculate_entropy = False | ||
| if entropy_coeff != 0: | ||
|
|
@@ -449,19 +454,19 @@ def update_policy(self, data: DataProto): | |
| kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) | ||
|
|
||
| policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef | ||
| micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() | ||
| micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() * loss_scale_factor | ||
| micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef | ||
|
|
||
| if self.config.use_dynamic_bsz: | ||
| # relative to the dynamic bsz | ||
| loss = policy_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size) | ||
| loss = policy_loss * loss_scale_factor | ||
| else: | ||
| loss = policy_loss / self.gradient_accumulation | ||
| loss = policy_loss * loss_scale_factor | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a question,
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It needs to divide gradient accumulation |
||
| loss.backward() | ||
|
|
||
| micro_batch_metrics.update( | ||
| { | ||
| "actor/pg_loss": pg_loss.detach().item(), | ||
| "actor/pg_loss": pg_loss.detach().item() * loss_scale_factor, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You've correctly scaled the However, there's an inconsistency. If |
||
| "actor/pg_clipfrac": pg_clipfrac.detach().item(), | ||
| "actor/ppo_kl": ppo_kl.detach().item(), | ||
| "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is a good addition for safety. However, it highlights a more fundamental issue with the current implementation.
__dispatch_dp_rankand__collect_dp_rankare defined as class attributes on theWorkerclass (lines 73-74), which means they are shared across allWorkerinstances within the same process.This will lead to issues when multiple workers are instantiated in the same process (e.g., an actor worker and a critic worker), as they will attempt to write to the same shared dictionaries. This will either raise a
ValueErrordue to this new check or, worse, lead to silent bugs from overwritten dispatch information.These attributes should be instance-specific. The correct fix is to initialize them as instance attributes in
Worker.__init__:And remove the class-level definitions. Since
__init__and the class attribute definitions are not part of this diff, I cannot suggest the change directly, but this is a critical issue that needs to be addressed to ensure correctness.