|
67 | 67 | FP32_MASTER, |
68 | 68 | UnifiedCheckpointOption, |
69 | 69 | filter_params, |
| 70 | + filter_sync_parameters, |
70 | 71 | gather_sharded_object, |
71 | 72 | generate_base_static_name, |
72 | 73 | get_expected_state_dict, |
@@ -218,25 +219,9 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp |
218 | 219 | for key in list(master_weights.keys()): |
219 | 220 | master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) |
220 | 221 |
|
221 | | - no_sync_kname = [] |
222 | | - model_state_dict = get_expected_state_dict(model) |
223 | | - for k, v in model_state_dict.items(): |
224 | | - if getattr(v, "no_sync", False): |
225 | | - no_sync_kname.append(k) |
226 | | - |
227 | | - hcg = fleet.get_hybrid_communicate_group() |
228 | | - dp_group = hcg.get_data_parallel_group() |
229 | | - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 |
230 | 222 | if self.args.use_expert_parallel: |
231 | | - for k in list(optim_state_dict.keys()): |
232 | | - model_k = k.split("/")[0] |
233 | | - if dp_rank > 0 and model_k not in no_sync_kname: |
234 | | - optim_state_dict.pop(k) |
235 | | - if master_weights is not None: |
236 | | - for k in list(master_weights.keys()): |
237 | | - model_k = k.split("/")[0] |
238 | | - if dp_rank > 0 and model_k not in no_sync_kname: |
239 | | - master_weights.pop(k) |
| 223 | + model_state_dict = get_expected_state_dict(model) |
| 224 | + filter_sync_parameters(model_state_dict, optim_state_dict, master_weights, is_model_weight=False) |
240 | 225 |
|
241 | 226 | optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) |
242 | 227 | master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) |
@@ -518,12 +503,7 @@ def unified_checkpoint_into_shards( |
518 | 503 |
|
519 | 504 | if args.use_expert_parallel: |
520 | 505 | # ignore saving `no_sync=False` tensors when using expert_parallel under dp_rank > 0. |
521 | | - hcg = fleet.get_hybrid_communicate_group() |
522 | | - dp_group = hcg.get_data_parallel_group() |
523 | | - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 |
524 | | - for key in list(state_dict.keys()): |
525 | | - if dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): |
526 | | - state_dict.pop(key) |
| 506 | + filter_sync_parameters(state_dict, is_model_weight=True) |
527 | 507 |
|
528 | 508 | if config_to_save.tensor_parallel_degree > 1: |
529 | 509 | if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): |
@@ -631,25 +611,11 @@ def unified_optimizer_into_shards( |
631 | 611 | filter_master_keys = filter_params(model, master_weights, args, is_optimizer=True) |
632 | 612 | filter_optim_keys = filter_params(model, optim_state_dict, args, is_optimizer=True) |
633 | 613 |
|
634 | | - hcg = fleet.get_hybrid_communicate_group() |
635 | | - tp_group = hcg.get_model_parallel_group() |
636 | | - dp_group = hcg.get_data_parallel_group() |
| 614 | + tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() |
637 | 615 | tp_size = tp_group.nranks |
638 | | - dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 |
639 | 616 |
|
640 | 617 | if args.use_expert_parallel: |
641 | | - no_sync_kname = [] |
642 | | - for k, v in state_dict.items(): |
643 | | - if getattr(state_dict[k], "no_sync", False): |
644 | | - no_sync_kname.append(k) |
645 | | - for key in list(optim_state_dict.keys()): |
646 | | - model_key = key.split("/")[0] |
647 | | - if dp_rank > 0 and model_key not in no_sync_kname: |
648 | | - optim_state_dict.pop(key) |
649 | | - if master_weights is not None: |
650 | | - for key in list(master_weights.keys()): |
651 | | - if dp_rank > 0 and key not in no_sync_kname: |
652 | | - master_weights.pop(key) |
| 618 | + filter_sync_parameters(state_dict, optim_state_dict, master_weights, is_model_weight=False) |
653 | 619 |
|
654 | 620 | if tp_size > 1: |
655 | 621 | # get tp_actions |
|
0 commit comments