Skip to content

Commit c0042fb

Browse files
committed
add filter_sync_parameters
1 parent 26e51d7 commit c0042fb

3 files changed

Lines changed: 35 additions & 41 deletions

File tree

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def load_resolved_archive_file(
308308
if model_state_dict[key_name[0]].dtype != paddle.float32:
309309
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
310310
else:
311-
# for moe gate with float32 dtype.
311+
# for parameters with float32 dtype, no need to have fp32 master weights.
312312
key_name = "_".join([static_name, key_name[1]])
313313
else:
314314
key_name = "_".join([static_name, key_name[1]])

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
FP32_MASTER,
6868
UnifiedCheckpointOption,
6969
filter_params,
70+
filter_sync_parameters,
7071
gather_sharded_object,
7172
generate_base_static_name,
7273
get_expected_state_dict,
@@ -218,25 +219,9 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
218219
for key in list(master_weights.keys()):
219220
master_weights[static2struct_name_mappings[key]] = master_weights.pop(key)
220221

221-
no_sync_kname = []
222-
model_state_dict = get_expected_state_dict(model)
223-
for k, v in model_state_dict.items():
224-
if getattr(v, "no_sync", False):
225-
no_sync_kname.append(k)
226-
227-
hcg = fleet.get_hybrid_communicate_group()
228-
dp_group = hcg.get_data_parallel_group()
229-
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
230222
if self.args.use_expert_parallel:
231-
for k in list(optim_state_dict.keys()):
232-
model_k = k.split("/")[0]
233-
if dp_rank > 0 and model_k not in no_sync_kname:
234-
optim_state_dict.pop(k)
235-
if master_weights is not None:
236-
for k in list(master_weights.keys()):
237-
model_k = k.split("/")[0]
238-
if dp_rank > 0 and model_k not in no_sync_kname:
239-
master_weights.pop(k)
223+
model_state_dict = get_expected_state_dict(model)
224+
filter_sync_parameters(model_state_dict, optim_state_dict, master_weights, is_model_weight=False)
240225

241226
optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
242227
master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix)
@@ -518,12 +503,7 @@ def unified_checkpoint_into_shards(
518503

519504
if args.use_expert_parallel:
520505
# ignore saving `no_sync=False` tensors when using expert_parallel under dp_rank > 0.
521-
hcg = fleet.get_hybrid_communicate_group()
522-
dp_group = hcg.get_data_parallel_group()
523-
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
524-
for key in list(state_dict.keys()):
525-
if dp_rank > 0 and not getattr(state_dict[key], "no_sync", False):
526-
state_dict.pop(key)
506+
filter_sync_parameters(state_dict, is_model_weight=True)
527507

528508
if config_to_save.tensor_parallel_degree > 1:
529509
if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM):
@@ -631,25 +611,11 @@ def unified_optimizer_into_shards(
631611
filter_master_keys = filter_params(model, master_weights, args, is_optimizer=True)
632612
filter_optim_keys = filter_params(model, optim_state_dict, args, is_optimizer=True)
633613

634-
hcg = fleet.get_hybrid_communicate_group()
635-
tp_group = hcg.get_model_parallel_group()
636-
dp_group = hcg.get_data_parallel_group()
614+
tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()
637615
tp_size = tp_group.nranks
638-
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
639616

640617
if args.use_expert_parallel:
641-
no_sync_kname = []
642-
for k, v in state_dict.items():
643-
if getattr(state_dict[k], "no_sync", False):
644-
no_sync_kname.append(k)
645-
for key in list(optim_state_dict.keys()):
646-
model_key = key.split("/")[0]
647-
if dp_rank > 0 and model_key not in no_sync_kname:
648-
optim_state_dict.pop(key)
649-
if master_weights is not None:
650-
for key in list(master_weights.keys()):
651-
if dp_rank > 0 and key not in no_sync_kname:
652-
master_weights.pop(key)
618+
filter_sync_parameters(state_dict, optim_state_dict, master_weights, is_model_weight=False)
653619

654620
if tp_size > 1:
655621
# get tp_actions

paddlenlp/trainer/unified_checkpoint/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,3 +758,31 @@ def save_config(model_to_save):
758758
# save generation config
759759
if model_to_save.can_generate():
760760
model_to_save.generation_config.save_pretrained(save_directory)
761+
762+
763+
def filter_sync_parameters(model_state_dict, optim_state_dict=None, master_weights=None, is_model_weight=True):
764+
"""Filter sync parameters under expert parallel mode."""
765+
766+
hcg = fleet.get_hybrid_communicate_group()
767+
dp_group = hcg.get_data_parallel_group()
768+
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
769+
770+
if is_model_weight:
771+
for key in list(model_state_dict.keys()):
772+
if dp_rank > 0 and not getattr(model_state_dict[key], "no_sync", False):
773+
model_state_dict.pop(key)
774+
else:
775+
no_sync_kname = []
776+
for k, v in model_state_dict.items():
777+
if getattr(v, "no_sync", False):
778+
no_sync_kname.append(k)
779+
780+
for key in list(optim_state_dict.keys()):
781+
model_key = key.split("/")[0]
782+
if dp_rank > 0 and model_key not in no_sync_kname:
783+
optim_state_dict.pop(key)
784+
785+
if master_weights is not None:
786+
for key in list(master_weights.keys()):
787+
if dp_rank > 0 and key not in no_sync_kname:
788+
master_weights.pop(key)

0 commit comments

Comments
 (0)