-
Notifications
You must be signed in to change notification settings - Fork 3k
[Unified Checkpoint] Fix expert parallel #9821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
37f3be1
8bed006
4a58f61
26e51d7
c0042fb
7ec8bdc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -516,6 +516,15 @@ def unified_checkpoint_into_shards( | |
|
|
||
| config_to_save = copy.deepcopy(model_to_save.config) | ||
|
|
||
| if args.use_expert_parallel: | ||
| # ignore saving `no_sync=False` tensors when using expert_parallel under dp_rank > 0. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
之前的写法主要问题是?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 之前的写法把删除no_sync=False的逻辑只放在了 merge_tp之类的函数里面,所以只有开启TP训练的时候才会不冗余保存。这次修复了这块,把这个删除逻辑挪到前面,从而即使不开TP,也可以实现不冗余保存。 |
||
| hcg = fleet.get_hybrid_communicate_group() | ||
| dp_group = hcg.get_data_parallel_group() | ||
| dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 | ||
| for key in list(state_dict.keys()): | ||
| if dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): | ||
| state_dict.pop(key) | ||
|
|
||
| if config_to_save.tensor_parallel_degree > 1: | ||
| if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): | ||
| tp_actions = model_to_save._get_tensor_parallel_convert_actions( | ||
|
|
@@ -622,8 +631,25 @@ def unified_optimizer_into_shards( | |
| filter_master_keys = filter_params(model, master_weights, args, is_optimizer=True) | ||
| filter_optim_keys = filter_params(model, optim_state_dict, args, is_optimizer=True) | ||
|
|
||
| tp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() | ||
| hcg = fleet.get_hybrid_communicate_group() | ||
| tp_group = hcg.get_model_parallel_group() | ||
| dp_group = hcg.get_data_parallel_group() | ||
| tp_size = tp_group.nranks | ||
| dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 | ||
|
|
||
| if args.use_expert_parallel: | ||
| no_sync_kname = [] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加注释。为什么好几处出现了 filter 参数操作,可以看看能否集中起来处理。
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 嗯嗯,统一到一个函数里面了 |
||
| for k, v in state_dict.items(): | ||
| if getattr(state_dict[k], "no_sync", False): | ||
| no_sync_kname.append(k) | ||
| for key in list(optim_state_dict.keys()): | ||
| model_key = key.split("/")[0] | ||
| if dp_rank > 0 and model_key not in no_sync_kname: | ||
| optim_state_dict.pop(key) | ||
| if master_weights is not None: | ||
| for key in list(master_weights.keys()): | ||
| if dp_rank > 0 and key not in no_sync_kname: | ||
| master_weights.pop(key) | ||
|
|
||
| if tp_size > 1: | ||
| # get tp_actions | ||
|
|
@@ -643,7 +669,6 @@ def unified_optimizer_into_shards( | |
| optim_state_dict, | ||
| tp_actions, | ||
| filter_optim_keys, | ||
| state_dict if args.use_expert_parallel else None, | ||
| ) | ||
| paddle.device.cuda.empty_cache() | ||
|
|
||
|
|
@@ -653,7 +678,6 @@ def unified_optimizer_into_shards( | |
| master_weights, | ||
| tp_actions, | ||
| filter_master_keys, | ||
| state_dict if args.use_expert_parallel else None, | ||
| ) | ||
| paddle.device.cuda.empty_cache() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from paddlenlp.utils.downloader import get_path_from_url_with_filelock | ||
| from tests.parallel_launch import TestMultipleGpus | ||
| from tests.testing_utils import require_paddle_at_least_8_gpu, skip_for_none_ce_case | ||
| from tests.trainer.test_unified_checkpoint import remove_ckpt, remove_logs | ||
| from tests.trainer.trainer_utils import get_pretrain_arguments | ||
|
|
||
| environment_variables = { | ||
| "NCCL_ALGO": "Tree", | ||
| "NVIDIA_TF32_OVERRIDE": "0", | ||
| "NCCL_IB_TIMEOUT": "22", | ||
| "NCCL_DEBUG": "INFO", | ||
| "FLAGS_embedding_deterministic": "1", | ||
| "FLAGS_cudnn_deterministic": "1", | ||
| "Flags_mp_aysnc_allreduce": "1", | ||
| "Flags_skip_mp_c_identity": "1", | ||
| "FLAGS_shard_norm_align_dp": "0", | ||
| "FLAGS_shard_use_reduce": "1", | ||
| "test_ci_no_save_model": "1", | ||
| } | ||
|
|
||
| moe_arguments = { | ||
| "model_name_or_path": "./tests/trainer/unified-ckpt-qwen2moe", | ||
| "dataset_name_or_path": "./unified_checkpoint/peft_input/data/", | ||
| "output_dir": "./unified_checkpoint/checkpoints/qwen2moe_sft_ckpts", | ||
| "per_device_train_batch_size": 1, | ||
| "gradient_accumulation_steps": 8, | ||
| "per_device_eval_batch_size": 8, | ||
| "eval_accumulation_steps": 16, | ||
| "learning_rate": 3e-04, | ||
| "max_steps": 10, | ||
| "save_steps": 6, | ||
| "warmup_steps": 30, | ||
| "logging_steps": 1, | ||
| "evaluation_strategy": "no", | ||
| "save_strategy": "steps", | ||
| "src_length": 1024, | ||
| "max_length": 2048, | ||
| "bf16": "true", | ||
| "fp16_opt_level": "O2", | ||
| "do_train": "true", | ||
| "do_eval": "false", | ||
| "disable_tqdm": "true", | ||
| "eval_with_do_generation": "false", | ||
| "recompute": "true", | ||
| "recompute_granularity": "full", | ||
| "save_total_limit": 1, | ||
| "tensor_parallel_degree": 1, | ||
| "pipeline_parallel_degree": 1, | ||
| "sharding": "", | ||
| "lora": "false", | ||
| "zero_padding": "false", | ||
| "use_flash_attention": "false", | ||
| "unified_checkpoint": 1, | ||
| "continue_training": 0, | ||
| "sequence_parallel": 0, | ||
| } | ||
|
|
||
|
|
||
| def check_acc(log_dir="log"): | ||
| file_path = os.path.join(log_dir, "workerlog.n0.c0") | ||
| cmd = "grep -a 'global_step: 10' " + file_path + " | awk -F ',' '{print $2}' | awk '{print $6}'" | ||
| import subprocess | ||
|
|
||
| res = subprocess.check_output(cmd, shell=True, text=True) | ||
| res = [float(x) for x in res.split()] | ||
|
|
||
| return res | ||
|
|
||
|
|
||
| seed = 2024 | ||
|
|
||
| rng = np.random.default_rng(seed=seed) | ||
|
|
||
|
|
||
| @pytest.mark.xdist_group(name="UC") | ||
| class TestUnifiedCheckpointBase(TestMultipleGpus): | ||
| @classmethod | ||
| @property | ||
| def __test__(cls): | ||
| return cls != TestUnifiedCheckpointBase | ||
|
|
||
| def setUp(self): | ||
| """ | ||
| 1. update runfirst and rerun to run defined different config | ||
| 2. update need_allclose to True if you want to check the result | ||
| 3. update rtol to the relative value you want to check | ||
| """ | ||
|
|
||
| self.configs = get_pretrain_arguments(moe_arguments) | ||
| os.environ.update(environment_variables) | ||
|
|
||
| file_ = "https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz" | ||
| input_dir = "unified_checkpoint/peft_input/" | ||
| os.makedirs(input_dir, exist_ok=True) | ||
| file_path = os.path.join(input_dir, "AdvertiseGen.tar.gz") | ||
| if not os.path.exists(file_path): | ||
| get_path_from_url_with_filelock(file_, root_dir=input_dir) | ||
|
|
||
| self.need_allclose = True | ||
| self.rtol = 1e-7 | ||
|
|
||
| self.run_file = "llm/run_finetune.py" | ||
|
|
||
| def runfirst(self, train_args): | ||
| self.run_n1c8(self.run_file, **train_args) | ||
|
|
||
| def rerun(self, train_args): | ||
| self.run_n1c8(self.run_file, **train_args) | ||
|
|
||
| @require_paddle_at_least_8_gpu | ||
| def testTP4DP2(self): | ||
| remove_logs() | ||
| remove_ckpt(moe_arguments["output_dir"]) | ||
|
|
||
| train_args = self.configs["TP4DP2"] | ||
| self.runfirst(train_args) | ||
| self.rerun(train_args) | ||
|
|
||
| if self.need_allclose: | ||
| res = check_acc() | ||
| assert len(res) == 2 | ||
| np.testing.assert_allclose(res[0], res[1], self.rtol) | ||
|
|
||
| @skip_for_none_ce_case | ||
| @require_paddle_at_least_8_gpu | ||
| def testTP2Sharding4(self): | ||
| remove_logs() | ||
| remove_ckpt(moe_arguments["output_dir"]) | ||
|
|
||
| train_args = self.configs["TP2Sharding4"] | ||
| self.runfirst(train_args) | ||
| self.rerun(train_args) | ||
|
|
||
| if self.need_allclose: | ||
| res = check_acc() | ||
| assert len(res) == 2 | ||
| np.testing.assert_allclose(res[0], res[1], self.rtol) | ||
|
|
||
|
|
||
| @pytest.mark.xdist_group(name="UC") | ||
| class TestUnifiedCheckpointFull(TestUnifiedCheckpointBase): | ||
| @skip_for_none_ce_case | ||
| @require_paddle_at_least_8_gpu | ||
| def testTP2Sharding4V2(self): | ||
| remove_logs() | ||
| remove_ckpt(moe_arguments["output_dir"]) | ||
|
|
||
| train_args = self.configs["TP2Sharding4"] | ||
| train_args.update({"sharding_parallel_config": "split_param"}) | ||
| train_args.update({"amp_master_grad": True}) | ||
| self.runfirst(train_args) | ||
| self.rerun(train_args) | ||
|
|
||
| if self.need_allclose: | ||
| res = check_acc() | ||
| assert len(res) == 2 | ||
| np.testing.assert_allclose(res[0], res[1], self.rtol) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| { | ||
| "architectures": [ | ||
| "Qwen2MoeForCausalLM" | ||
| ], | ||
| "attention_dropout": 0.0, | ||
| "bos_token_id": 151643, | ||
| "decoder_sparse_step": 1, | ||
| "eos_token_id": 151643, | ||
| "hidden_act": "silu", | ||
| "hidden_size": 3584, | ||
| "initializer_range": 0.02, | ||
| "intermediate_size": 18944, | ||
| "max_position_embeddings": 131072, | ||
| "max_window_layers": 28, | ||
| "model_type": "qwen2_moe", | ||
| "moe_intermediate_size": 2560, | ||
| "norm_topk_prob": false, | ||
| "num_attention_heads": 28, | ||
| "num_experts": 8, | ||
| "num_experts_per_tok": 2, | ||
| "num_hidden_layers": 8, | ||
| "num_key_value_heads": 4, | ||
| "output_router_logits": false, | ||
| "rms_norm_eps": 1e-06, | ||
| "rope_theta": 1000000.0, | ||
| "router_aux_loss_coef": 0.001, | ||
| "shared_expert_intermediate_size": 20480, | ||
| "sliding_window": 131072, | ||
| "tie_word_embeddings": false, | ||
| "dtype": "bfloat16", | ||
| "use_cache": true, | ||
| "use_sliding_window": false, | ||
| "vocab_size": 151936 | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| { | ||
| "bos_token_id": 151643, | ||
| "pad_token_id": 151643, | ||
| "eos_token_id": [ | ||
| 151645, | ||
| 151643 | ||
| ] | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该是 fp32 bf16 混合的情况对吧,注释一下,fp32参数,不需要master weight 即可。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done