[Unified Checkpoint] Checkpoint compression#9183
Conversation
| self._shared_save_optimizer_flag = multiprocessing.Array("i", 1) | ||
|
|
||
| def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_type="model_weight"): | ||
| def quant_unified_optimizer(self, state_dict, state_dict_type, ckpt_quant_stage): |
There was a problem hiding this comment.
这块建议单独拎出来放到一个文件里,目前我正在重构unified_checkpoint.py,会把比较多逻辑分离出来。
| self.state.epoch = epoch + (step + 1) / steps_in_epoch | ||
| self.control = self.callback_handler.on_step_end(args, self.state, self.control) | ||
| self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs) | ||
| if self.state.global_step != 0 and (self.state.global_step) % self.args.save_steps == 0: |
| shard_file, | ||
| tp_actions if pre_tensor_parallel_split else None, | ||
| expected_keys, | ||
| ckpt_quant_stage=model.config.ckpt_quant_stage, |
There was a problem hiding this comment.
这块为啥需要传这个ckpt_quant_stage进来,默认O0的话就不用传吧
| shard_file, | ||
| tp_actions if pre_tensor_parallel_split else None, | ||
| expected_keys, | ||
| ckpt_quant_stage=model.config.ckpt_quant_stage, |
| self._lock, | ||
| state_dict_type, | ||
| self.global_rank, | ||
| ckpt_quant_stage, |
There was a problem hiding this comment.
如果只需要对optimizer_weight做压缩,其他例如model_weight、master_weight不用的话,这个变量可以不传入。
| if "skip_save_model_weight" in self.args.unified_checkpoint_config | ||
| else state_dict_type, | ||
| self.global_rank, | ||
| ckpt_quant_stage, |
| lock, | ||
| state_dict_type, | ||
| global_rank, | ||
| ckpt_quant_stage, |
There was a problem hiding this comment.
搞成一个可选参数就行,例如ckpt_quant_stage="O0"
| path=os.path.join(save_directory, shard_file), | ||
| is_sync=is_sync_save, | ||
| state_dict_type="model_weight", | ||
| ckpt_quant_stage=model_to_save.config.ckpt_quant_stage, |
| path=os.path.join(output_dir, master_weights_name), | ||
| is_sync=is_sync_save, | ||
| state_dict_type="master_weight", | ||
| ckpt_quant_stage=model.config.ckpt_quant_stage, |
| path=os.path.join(save_directory, shard_master_weight_file), | ||
| is_sync=is_sync_save, | ||
| state_dict_type="master_weight", | ||
| ckpt_quant_stage=model.config.ckpt_quant_stage, |
…nto ckpt-compress Conflicts: paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
| return returned_state_dict | ||
|
|
||
| state_dict_optim = load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected_keys) | ||
| index = {} |
| new_name = static2struct_name_mappings[static_name] + "/" + type_name | ||
| optim_state_dict[new_name] = optim_state_dict.pop(key) | ||
|
|
||
| if UnifiedCheckpointOption.REMOVE_MASTER_WEIGHT.value in self.args.unified_checkpoint_config: |
There was a problem hiding this comment.
REMOVE_MASTER_WEIGHT 这个判断不应该写在这个函数里,应该控制传进来save_non_merge_optimizer的master_weights就是none。
| return last_dtype | ||
|
|
||
|
|
||
| def dequant_unified_optimizer(self, state_dict, ckpt_quant_stage, scale_dict): |
…nto ckpt-compress Conflicts: paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
| """ | ||
| quant = False | ||
| if ckpt_quant_stage != "O0": | ||
| quant = "optimizer" in checkpoint_file |
| return last_dtype | ||
|
|
||
|
|
||
| def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict): |
| for key in keys: | ||
| if fliter_dict_keys is not None and key not in fliter_dict_keys: | ||
| # non merge ckpt loading dont have filter key. | ||
| if key.endswith(SYMMETRY_QUANT_SCALE) or (fliter_dict_keys is not None and key not in fliter_dict_keys): |
There was a problem hiding this comment.
| if key.endswith(SYMMETRY_QUANT_SCALE) or (fliter_dict_keys is not None and key not in fliter_dict_keys): | |
| if key.endswith(SYMMETRY_QUANT_SCALE): | |
| continue | |
| if (fliter_dict_keys is not None and key not in fliter_dict_keys): | |
| continue |
| MOMENT2_KEYNAME = "moment2_0" | ||
| BETA1_KEYNAME = "beta1_pow_acc_0" | ||
| BETA2_KEYNAME = "beta2_pow_acc_0" | ||
| SYMMETRY_QUANT_SCALE = "_codebook" |
| ) | ||
| }, | ||
| ) | ||
| ckpt_quant_stage: str = field( |
There was a problem hiding this comment.
看看要不要放到 unifie_checkpoint_config 中配置,因为是搭配UC使用。
| @@ -0,0 +1,303 @@ | |||
| # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | |||
| abs_max_values = np.where( | ||
| abs_max_values == np.array(0, dtype=inputs.dtype), np.array(1e-8, dtype=inputs.dtype), abs_max_values | ||
| ) | ||
| return abs_max_values |
There was a problem hiding this comment.
这里直接用1e-8来表示是不是没有考虑训练的dtype,bf16、float16、float32 表示空间不太一样
There was a problem hiding this comment.
group-wise 中一个 group 有可能全是 0,会导致量化时除 0,这里的 1e-8 是防除 0 的一个小偏置
| import numpy as np | ||
| import paddle | ||
|
|
||
|
|
There was a problem hiding this comment.
重要的函数都要加上注释,同时参数的args也需要加上
对于引用的量化算法加上arvix链接
|
|
||
|
|
||
| # channel-wise abs max calculation | ||
| def cal_abs_max_channel(inputs, quant_axis=1): |
| qdq_x = ( | ||
| quant_x | ||
| / bnt | ||
| * scales[rank * scales.shape[0] // world_size : (rank + 1) * scales.shape[0] // world_size] |
There was a problem hiding this comment.
这个变量名比较奇怪,world_size一般情况下都是指带训练总卡数,但是在这里的表示tensor parallel 通信组的size;注意变量名
There was a problem hiding this comment.
这里同时有个疑问,我看是对所有的参数都是做了quant,但是Norm参数没有做参数切分,这个时候还能这么quant吗
| if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]: | ||
| qdq_x = (quant_x / bnt * scales) + mins | ||
| else: | ||
| qdq_x = ( |
| int4_high = np.where(int4_high > 8, int4_high - 16, int4_high) | ||
|
|
||
| high_tensor = paddle.Tensor(int4_high, zero_copy=True) | ||
| low_tensor = paddle.Tensor(int4_low, zero_copy=True) |
There was a problem hiding this comment.
cpu->gpu,已去除 zero_copy
| m1_quant, codebook = qdq_weight(state_dict[m1_key], quant_bit=8) | ||
| quant_weight, mins, maxs = asymmetry_qdq_weight(ratio, quant_bit=8) | ||
| state_dict[m1_key] = m1_quant | ||
| codebook_dict[m1_key + SYMMETRY_QUANT_SCALE] = codebook |
| dist.all_reduce(quant_bits) | ||
|
|
||
| model_numel = all_bits / 4 | ||
| all_bits = model_numel * 7.0 |
* checkpoint compression init * add ckpt quant argument * add ckpt quant ci * fix ci * fix lint * remove stage O2, change O3 --> O2 * support async save * file adjustment * magic string remove * ci fix * ci fix, code refinement * function extraction * fix ci * code refinement * fix ci * fix ci * support non merge tp ckpt quantization * fix ci * update * fix bug * code refactor * fix lint * fix ci * del old uc.py * fix lint * add mgpu ci * fix ci * multi thread loading * fix lint * fix bug * refactor code * add comment * fix lint * add comment * add comment * fix bug * fix bugs when ckpt no quant and no master weight * remove uni-test Conflicts: paddlenlp/transformers/model_utils.py
* checkpoint compression init * add ckpt quant argument * add ckpt quant ci * fix ci * fix lint * remove stage O2, change O3 --> O2 * support async save * file adjustment * magic string remove * ci fix * ci fix, code refinement * function extraction * fix ci * code refinement * fix ci * fix ci * support non merge tp ckpt quantization * fix ci * update * fix bug * code refactor * fix lint * fix ci * del old uc.py * fix lint * add mgpu ci * fix ci * multi thread loading * fix lint * fix bug * refactor code * add comment * fix lint * add comment * add comment * fix bug * fix bugs when ckpt no quant and no master weight * remove uni-test Conflicts: paddlenlp/transformers/model_utils.py
PR types
PR changes
Description
checkpoint 压缩功能实现
新增参数