-
Notifications
You must be signed in to change notification settings - Fork 3k
[Unified Checkpoint] Checkpoint compression #9183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
cd4e5e0
7684576
afcecad
d8f3351
434bd4c
a98fb8b
2e5c73b
6b1f3bf
c4a80e7
ae305a9
fd6ad57
f766d15
e74b68b
a7b053d
10b1064
ad1dc75
fb2c2e9
a1c35af
f8530c0
4e21fb9
a602fe5
55b8639
3a87734
a3073aa
8a8aca7
bab5235
c3c500d
a45c7f6
a4a3e23
2330839
3fcd471
f57aab5
50ee148
75a1011
ffd0823
4947a8c
3eaebbb
a6b2236
a5d0afa
fdd92a8
b2b20be
432e97c
5eb201c
b2bcf16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -333,7 +333,10 @@ def from_pretrained( | |
| pre_tensor_parallel_split = True | ||
| tp_actions = prefix_model._get_tensor_parallel_convert_actions(is_split=True) | ||
| state_dict = load_state_dict( | ||
| shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys | ||
| shard_file, | ||
| tp_actions if pre_tensor_parallel_split else None, | ||
| expected_keys, | ||
| ckpt_quant_stage=model.config.ckpt_quant_stage, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| ) | ||
| error_msgs += _load_state_dict_into_model(prefix_model.prefix_encoder, state_dict, "") | ||
| del state_dict | ||
|
|
||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1179,6 +1179,17 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): | |
| self.state.epoch = epoch + (step + 1) / steps_in_epoch | ||
| self.control = self.callback_handler.on_step_end(args, self.state, self.control) | ||
| self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs) | ||
| if self.state.global_step != 0 and (self.state.global_step) % self.args.save_steps == 0: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个地方具体是啥?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| # 假设你想要获取的环境变量名为 'MY_ENV_VAR' | ||
| env_var_name = "BREAK" | ||
|
|
||
| # 使用 os.getenv() 方法获取环境变量的值 | ||
| # 如果环境变量不存在,可以设置一个默认值 | ||
| env_var_value = os.getenv(env_var_name, "0") | ||
|
|
||
| print(f"环境变量 {env_var_name} 的值为: {env_var_value}") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 记得去掉 |
||
| if env_var_value == "1": | ||
| exit(0) | ||
| self._print_timer() | ||
| step_control = 0 | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1623,6 +1623,7 @@ def is_segment_parallel_supported(): | |
| if x not in [ | ||
| "skip_save_model_weight", | ||
| "master_weight_compatible", | ||
| "remove_master_weight", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在840行那里加一下这个配置项的说明
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| "async_save", | ||
| "enable_all_options", | ||
| "ignore_merge_optimizer", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -269,6 +269,9 @@ class LlmMetaConfig: | |
| ), | ||
| ("recompute_use_reentrant", bool, False, "recompute_use_reentrant"), | ||
| ] | ||
| checkpoint_compression = [ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 新增修改建议:这个配置放到training_args.py里面,因为是和训练相关的配置。同时checkpoint保存的时候可以在optimizer.safetensors.index.json里面保留此信息。
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| ("ckpt_quant_stage", str, "O0", "checkpoint quantization stage."), | ||
| ] | ||
|
|
||
| @classmethod | ||
| def _get_defaults(cls): | ||
|
|
@@ -277,6 +280,7 @@ def _get_defaults(cls): | |
| cls.op_fusion_attributes, | ||
| cls.hybrid_parallel_attributes, | ||
| cls.recompute_attributes, | ||
| cls.checkpoint_compression, | ||
| ]: | ||
| for attr in attrs: | ||
| # return dict of key and default values | ||
|
|
@@ -290,6 +294,7 @@ def _get_all_meta(cls): | |
| cls.op_fusion_attributes, | ||
| cls.hybrid_parallel_attributes, | ||
| cls.recompute_attributes, | ||
| cls.checkpoint_compression, | ||
| ]: | ||
| for attr in attrs: | ||
| # return dict of key and default values | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,6 +42,7 @@ | |
| ) | ||
| from huggingface_hub.utils import EntryNotFoundError | ||
| from paddle import Tensor | ||
| from paddle.distributed import fleet | ||
| from paddle.distributed.fleet.meta_parallel.parallel_layers import ( | ||
| PipelineLayer, | ||
| SharedLayerDesc, | ||
|
|
@@ -53,8 +54,12 @@ | |
| from tqdm.auto import tqdm | ||
|
|
||
| from paddlenlp.utils.env import ( | ||
| ASYMMETRY_QUANT_SCALE_MAX, | ||
| ASYMMETRY_QUANT_SCALE_MIN, | ||
| CONFIG_NAME, | ||
| LEGACY_CONFIG_NAME, | ||
| MOMENT1_KEYNAME, | ||
| MOMENT2_KEYNAME, | ||
| PADDLE_WEIGHTS_INDEX_NAME, | ||
| PADDLE_WEIGHTS_NAME, | ||
| PYTORCH_WEIGHTS_INDEX_NAME, | ||
|
|
@@ -63,11 +68,18 @@ | |
| SAFE_PEFT_WEIGHTS_INDEX_NAME, | ||
| SAFE_WEIGHTS_INDEX_NAME, | ||
| SAFE_WEIGHTS_NAME, | ||
| SYMMETRY_QUANT_SCALE, | ||
| ) | ||
| from paddlenlp.utils.log import logger | ||
|
|
||
| from ..generation import GenerationConfig, GenerationMixin | ||
| from ..utils import device_guard | ||
| from ..utils.checkpoint_quantization_utils import ( | ||
| asymmetry_qdq_weight, | ||
| group_wise_quant_dequant, | ||
| qdq_weight, | ||
| split_int8, | ||
| ) | ||
| from ..utils.download import resolve_file_path | ||
| from .configuration_utils import PretrainedConfig | ||
| from .conversion_utils import ConversionMixin | ||
|
|
@@ -320,11 +332,19 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype: | |
|
|
||
|
|
||
| def load_state_dict( | ||
| checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None, device="cpu" | ||
| checkpoint_file: Union[str, os.PathLike], | ||
| tensor_parallel_split_mapping=None, | ||
| fliter_dict_keys=None, | ||
| device="cpu", | ||
| ckpt_quant_stage="O0", | ||
| ): | ||
| """ | ||
| Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise. | ||
| """ | ||
| quant = False | ||
| if ckpt_quant_stage != "O0": | ||
| quant = "optimizer" in checkpoint_file | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个有点 hack了。 |
||
|
|
||
| if tensor_parallel_split_mapping is None: | ||
| tensor_parallel_split_mapping = {} | ||
|
|
||
|
|
@@ -344,6 +364,7 @@ def load_state_dict( | |
| raise ValueError("Currently unsupport paddle weights file, use numpy instead.") | ||
| if metadata.get("format", "np") == "np": | ||
| state_dict = {} | ||
| scale_dict = {} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不能直接加代码进来,需要结合下面多线程重新改。 |
||
| with safe_open(checkpoint_file, framework="np") as f: | ||
| for key in f.keys(): | ||
| if fliter_dict_keys is not None and key not in fliter_dict_keys: | ||
|
|
@@ -358,11 +379,108 @@ def load_state_dict( | |
| weight = paddle.Tensor(weight, zero_copy=True) | ||
| weight = weight._copy_to(paddle.framework._current_expected_place(), False) | ||
| state_dict[key] = weight | ||
| for key in f.keys(): | ||
| if key.endswith(SYMMETRY_QUANT_SCALE): | ||
| scale = f.get_tensor(key) | ||
| with device_guard(): | ||
| scale = paddle.Tensor(scale, zero_copy=True) | ||
| scale_dict[key] = scale | ||
|
|
||
| if device == "cpu": | ||
| for k in list(state_dict.keys()): | ||
| with device_guard(): | ||
| state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True) | ||
| if quant: | ||
|
ZHUI marked this conversation as resolved.
Outdated
ZHUI marked this conversation as resolved.
Outdated
|
||
| rank, world_size = -1, 1 | ||
| if paddle.distributed.get_world_size() > 1: | ||
| hcg = fleet.get_hybrid_communicate_group() | ||
| tp_group = hcg.get_model_parallel_group() | ||
| rank, world_size = tp_group.rank, tp_group.nranks | ||
|
|
||
| if ckpt_quant_stage == "O1": | ||
| # set eps | ||
| eps = 1e-8 | ||
| for quant_key in state_dict.keys(): | ||
| is_moment1 = MOMENT1_KEYNAME in quant_key | ||
| is_moment2 = MOMENT2_KEYNAME in quant_key | ||
| if is_moment1: | ||
| # dequant m1 | ||
| scale_key = quant_key + SYMMETRY_QUANT_SCALE | ||
| weight = state_dict[quant_key] | ||
| scales = scale_dict[scale_key] | ||
| weight, _ = qdq_weight( | ||
| weight, | ||
| scales=scales, | ||
| quant_bit=8, | ||
| dequant=True, | ||
| rank=rank, | ||
| world_size=world_size, | ||
| peek=True, | ||
| ) | ||
| state_dict[quant_key] = weight | ||
| elif is_moment2: | ||
| # dequant ratio | ||
| weight = state_dict[quant_key] | ||
| min_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MIN | ||
| max_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MAX | ||
| mins, maxs = scale_dict[min_scale_key], scale_dict[max_scale_key] | ||
| weight, _ = asymmetry_qdq_weight( | ||
| weight, | ||
| mins=mins, | ||
| maxs=maxs, | ||
| quant_bit=8, | ||
| dequant=True, | ||
| rank=rank, | ||
| world_size=world_size, | ||
| peek=True, | ||
| ) | ||
| # cal m2 | ||
| weight = paddle.square(1.0 / weight - eps) | ||
| state_dict[quant_key] = weight | ||
| elif ckpt_quant_stage == "O2": | ||
| # set eps | ||
| eps = 1e-8 | ||
| m1_state_dict = {} | ||
| for quant_key in state_dict.keys(): | ||
| if state_dict[quant_key].dtype != paddle.int8: | ||
| logger.info(f"{quant_key} skip.") | ||
| continue | ||
| # split int8 | ||
| weight = state_dict[quant_key] | ||
| m1_quant, ratio_quant = split_int8(weight.numpy()) | ||
| # dequant ratio | ||
| ratio_min_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MIN | ||
| ratio_max_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MAX | ||
| m1_scale_key = quant_key[: -len(MOMENT2_KEYNAME)] + MOMENT1_KEYNAME + SYMMETRY_QUANT_SCALE | ||
| m1_codebook = scale_dict[m1_scale_key] | ||
| ratio_mins, ratio_maxs = scale_dict[ratio_min_scale_key], scale_dict[ratio_max_scale_key] | ||
| m1_weight = group_wise_quant_dequant( | ||
| m1_quant, | ||
| mins=m1_codebook, | ||
| maxs=None, | ||
| quant_bits=4, | ||
| quant=False, | ||
| rank=rank, | ||
| world_size=world_size, | ||
| use_pd=True, | ||
| symetry=True, | ||
| ) | ||
| ratio_weight = group_wise_quant_dequant( | ||
| ratio_quant, | ||
| mins=ratio_mins, | ||
| maxs=ratio_maxs, | ||
| quant_bits=4, | ||
| quant=False, | ||
| rank=rank, | ||
| world_size=world_size, | ||
| use_pd=True, | ||
| ) | ||
|
|
||
| ratio_weight = paddle.square(1.0 / ratio_weight - eps) | ||
| state_dict[quant_key] = ratio_weight | ||
| m1_state_dict[quant_key[: -len(MOMENT2_KEYNAME)] + MOMENT1_KEYNAME] = m1_weight | ||
|
|
||
| state_dict.update(m1_state_dict) | ||
|
|
||
| return state_dict | ||
|
|
||
|
|
@@ -1965,7 +2083,10 @@ def _fuse_or_split_keys( | |
| filter_dict_keys = None | ||
|
|
||
| state_dict = load_state_dict( | ||
| shard_file, tp_actions if pre_tensor_parallel_split else None, filter_dict_keys | ||
|
ZHUI marked this conversation as resolved.
|
||
| shard_file, | ||
| tp_actions if pre_tensor_parallel_split else None, | ||
| filter_dict_keys, | ||
| ckpt_quant_stage=config.ckpt_quant_stage, | ||
| ) | ||
|
|
||
| # convert for fusing or splitting weights | ||
|
|
@@ -2288,9 +2409,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): | |
| with safe_open(resolved_archive_file, framework="np", device="cpu") as f: | ||
| loaded_keys = f.keys() | ||
| tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys) | ||
| state_dict = load_state_dict(resolved_archive_file, tp_actions) | ||
| state_dict = load_state_dict( | ||
| resolved_archive_file, tp_actions, ckpt_quant_stage=config.ckpt_quant_stage | ||
| ) | ||
| else: | ||
| state_dict = load_state_dict(resolved_archive_file) | ||
| state_dict = load_state_dict(resolved_archive_file, ckpt_quant_stage=config.ckpt_quant_stage) | ||
|
|
||
| logger.info("Loaded weights file from disk, setting weights to model.") | ||
|
|
||
|
|
@@ -2792,7 +2915,7 @@ def load_tp_checkpoint(folder, cls, config, return_numpy=False): | |
| with safe_open(safe_model_path, framework="np", device="cpu") as f: | ||
| loaded_keys = f.keys() | ||
| tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys) | ||
| state_dict = load_state_dict(safe_model_path, tp_actions) | ||
| state_dict = load_state_dict(safe_model_path, tp_actions, ckpt_quant_stage=config.ckpt_quant_stage) | ||
| else: # shard files safetensors | ||
| resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path( | ||
| pretrained_model_name_or_path=folder, | ||
|
|
@@ -2808,6 +2931,7 @@ def load_tp_checkpoint(folder, cls, config, return_numpy=False): | |
| shard_file, | ||
| tp_actions, | ||
| loaded_state_dict_keys, | ||
| ckpt_quant_stage=config.ckpt_quant_stage, | ||
| ) | ||
| state_dict.update(shard_state_dict) | ||
| if return_numpy: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块为啥需要传这个ckpt_quant_stage进来,默认O0的话就不用传吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done