diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 2d2073f293ed79..bcc64a50ae2187 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -78,6 +78,7 @@ def set_field_default_config(category, field, default_value): set_field_default_config(AMP, "custom_black_varnames", []) set_field_default_config(AMP, "use_fp16_guard", False) set_field_default_config(AMP, "use_bf16_guard", False) +set_field_default_config(AMP, "use_master_grad", False) ######################################### # sharding configuration diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index 73dd1de8508bf9..7f38ebb9f6bedd 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -252,14 +252,15 @@ def _generate_optimizer( # but optimizer will be called repeatedly in re-launch, so optimizer need to be copied. # 2. lr_scheduler cannot be deepcopy, cause 'deepcopy' will lead to difference of learning_rate between executor and engine. learning_rate = optimizer._learning_rate - optimizer = copy.deepcopy(optimizer) + new_optimizer = copy.deepcopy(optimizer) + new_optimizer._learning_rate = learning_rate + new_optimizer._sorted = False self._dist_context._serial_optimizer = optimizer self._dist_context._serial_optimizer._learning_rate = learning_rate - optimizer._sorted = False with program_guard(main_program, startup_program): with main_program.switch_name_generator_guard("opt_"): - optimizer_ops = optimizer.apply_gradients(params_grads) + optimizer_ops = new_optimizer.apply_gradients(params_grads) self._completer.complete_update_annotation(main_program) return optimizer_ops @@ -380,6 +381,21 @@ def _apply_post_optimization( [main_program], [startup_program], self._pass_context ) + # apply master grad pass + if self._strategy.amp.enable: + amp_config = copy.deepcopy(self._strategy.amp.to_dict()) + config = {} + config["dist_context"] = self._dist_context + config["params_grads"] = params_grads + config["completer"] = self._completer + if amp_config['level'] == "o2" and amp_config["use_master_grad"]: + master_grad_pass = new_pass( + "auto_parallel_master_grad_pass", config + ) + master_grad_pass.apply( + [main_program], [startup_program], self._pass_context + ) + # data parallel optimization if self._strategy.dp_optimization.enable: config = copy.deepcopy(self._strategy.dp_optimization.to_dict()) diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index 107fe74a569d0a..e78cc5bbd0081d 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -17,6 +17,7 @@ from .auto_parallel_gradient_merge import * # noqa: F403 from .auto_parallel_sharding import * # noqa: F403 from .auto_parallel_amp import * # noqa: F403 +from .auto_parallel_master_grad import * # noqa: F403 from .auto_parallel_fp16 import * # noqa: F403 from .auto_parallel_recompute import * # noqa: F403 from .auto_parallel_quantization import * # noqa: F403 diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 92259dee3ae057..cd29cbbacc2cef 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -79,7 +79,7 @@ def set_auto_cast_attr(cast_op, block): ), f"in_var {in_name} or out_var {out_name} is None of cast op" if is_forward_op(cast_op): cast_op._set_attr('in_dtype', in_var.dtype) - cast_op._set_attr('out_dtype', out_var.dtype) + out_var.desc.set_dtype(paddle.dtype(cast_op.attr('out_dtype'))) elif is_backward_op(cast_op): in_var_fw = block._find_var_recursive(in_name[: in_name.find("@")]) out_var_fw = block._find_var_recursive(out_name[: out_name.find("@")]) diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index c793639c5ba013..51a781b6f0f85b 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -178,6 +178,7 @@ def _append_gradient_merge_backward_op( for out_name in op.desc.output_arg_names(): if out_name in grad_to_params_grads: param = grad_to_params_grads[out_name][0] + grad = grad_to_params_grads[out_name][1] assert param is not None ref_dist_attr = dist_context.get_tensor_dist_attr_for_program( param @@ -188,8 +189,8 @@ def _append_gradient_merge_backward_op( # Add persistable gradient variables in main_program gradient_merge_var = main_block.create_var( name=param.name + "@GRAD@MERGE", - shape=param.shape, - dtype=param.dtype, + shape=grad.shape, + dtype=grad.dtype, persistable=True, ) ref_process_mesh = ref_dist_attr.process_mesh @@ -205,8 +206,8 @@ def _append_gradient_merge_backward_op( # Add persistable gradient variables in startup_program startup_gradient_merge_var = startup_block.create_var( name=param.name + "@GRAD@MERGE", - shape=param.shape, - dtype=param.dtype, + shape=grad.shape, + dtype=grad.dtype, persistable=True, ) # Initial persistable gradient variables in startup_program @@ -214,8 +215,8 @@ def _append_gradient_merge_backward_op( type="fill_constant", outputs={"Out": startup_gradient_merge_var}, attrs={ - "shape": param.shape, - "dtype": param.dtype, + "shape": grad.shape, + "dtype": grad.dtype, "value": float(0), }, ) diff --git a/python/paddle/distributed/passes/auto_parallel_master_grad.py b/python/paddle/distributed/passes/auto_parallel_master_grad.py new file mode 100644 index 00000000000000..9d105acade045b --- /dev/null +++ b/python/paddle/distributed/passes/auto_parallel_master_grad.py @@ -0,0 +1,239 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +from collections import OrderedDict +from typing import List, Tuple + +from paddle.base import Variable +from paddle.distributed.auto_parallel.static.utils import ( + is_backward_op, + is_gradient_clip_op, + is_optimize_op, + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, + set_var_dist_attr, +) +from paddle.distributed.fleet.meta_optimizers.common import ( + OP_ROLE_KEY, + OpRole, +) +from paddle.framework import core +from paddle.static import program_guard + +from ..utils.log_utils import get_logger +from .auto_parallel_sharding import _supported_optimizer_type +from .pass_base import PassBase, register_pass + +logger = get_logger(logging.INFO, "MasterGradPass") + + +def get_output_in_varlist(op, var_names) -> List[str]: + grad_names = [] + for output_name in op.output_arg_names: + if output_name in var_names: + grad_names.append(output_name) + return grad_names + + +@register_pass("auto_parallel_master_grad_pass") +class MasterGradPass(PassBase): + """ + Use the high precision gradient to replace the low precision gradient in optimizer to avoid inf/nan values of low precision. + The high precision gradient 'master grad' will be used by communication operator, `update_loss_scaling`, `GradClip` and `optimizer`. + """ + + def __init__(self): + super().__init__() + + def _check_self(self): + return True + + def _check_conflict(self, other_pass): + return True + + def _apply_single_impl(self, main_program, startup_program, context): + self._completer = self.get_attr("completer") + dist_context = self.get_attr("dist_context") + params_grads = self.get_attr("params_grads") + logger.debug(f"Origin main_program: {main_program}") + self._add_master_grad(main_program, params_grads, dist_context) + self._regenerate_optimizer( + main_program, startup_program, params_grads, dist_context + ) + logger.debug(f"After main program: {main_program}") + + def _add_cast_op(self, cur_block, grad_names: List[str], dist_context): + grad_first_ids = OrderedDict() + for idx, op in enumerate(cur_block.ops): + if is_optimize_op(op): + break + elif is_backward_op(op): + var_names = get_output_in_varlist(op, grad_names) + for var_name in var_names: + if var_name not in grad_first_ids: + grad_first_ids[var_name] = idx + # Communication operators such as 'allreduce_sum' use input var as output. + else: + pass + + # insert cast op + for grad_name, idx in reversed(grad_first_ids.items()): + grad_var = cur_block.var(grad_name) + if ( + grad_var.dtype == core.VarDesc.VarType.FP16 + or grad_var.dtype == core.VarDesc.VarType.BF16 + ): + is_fp16 = grad_var.dtype == core.VarDesc.VarType.FP16 + producer_op = cur_block.ops[idx] + producer_op_dist_attr = ( + dist_context.get_op_dist_attr_for_program(producer_op) + ) + assert ( + producer_op_dist_attr is not None + ), f"The op: '{producer_op}' should be distributed" + ref_output_dist_attr = ( + producer_op_dist_attr.get_output_dist_attr(grad_name) + ) + assert ( + ref_output_dist_attr is not None + ), f"The output: '{grad_name}' should be distributed" + ref_mesh = ref_output_dist_attr.process_mesh + ref_dims_mapping = ref_output_dist_attr.dims_mapping + ref_chunk_id = producer_op_dist_attr.chunk_id + grad_half_precision_name = ( + grad_name + '@tmp_fp16' + if is_fp16 + else grad_name + '@tmp_bf16' + ) + grad_half_precision = cur_block.create_var( + name=grad_half_precision_name, + dtype=grad_var.dtype, + shape=grad_var.shape, + persistable=False, + stop_gradient=False, + ) + set_var_dist_attr( + dist_context, + grad_half_precision, + ref_dims_mapping, + ref_mesh, + chunk_id=ref_chunk_id, + ) + producer_op._rename_output(grad_name, grad_half_precision.name) + grad_var.desc.set_dtype(core.VarDesc.VarType.FP32) + cast_op = cur_block._insert_op_without_sync( + idx + 1, + type="cast", + inputs={"X": grad_half_precision}, + outputs={"Out": grad_var}, + attrs={ + "in_dtype": grad_half_precision.dtype, + "out_dtype": grad_var.dtype, + }, + ) + cast_op._set_attr(OP_ROLE_KEY, OpRole.Backward) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cast_op, + ref_mesh, + ref_dims_mapping, + dist_context, + chunk_id=ref_chunk_id, + ) + cur_block._sync_with_cpp() + + def _regenerate_optimizer( + self, + main_program, + startup_program, + params_grads: List[Tuple[Variable, Variable]], + dist_context, + ): + grad_names = [g.name for _, g in params_grads] + # 1. delete the origin optimizer op + # 1.1 delete the var and op associated with the optimizer op in main_program + main_ops = main_program.global_block().ops + main_ops_len = len(main_ops) + first_optimize_idx = main_ops_len + for idx, op in enumerate(main_ops): + # We don't delete the operators for check_nan_inf + if is_optimize_op(op) and is_gradient_clip_op(op): + first_optimize_idx = idx + break + assert ( + first_optimize_idx < main_ops_len + ), "The first optimizer op is not found!" + deleted_temp_var_names = [] + deleted_persist_var_names = [] + reserved_var_names = [] + for idx in range(main_ops_len - 1, first_optimize_idx - 1, -1): + op = main_ops[idx] + inout_arg_names = op.input_arg_names + op.output_arg_names + if op.type in _supported_optimizer_type: + param_names = op.input("Param") + skip_update_names = op.input("SkipUpdate") + for reserved_name in param_names + skip_update_names: + if reserved_name not in reserved_var_names: + reserved_var_names.append(reserved_name) + for input_name in inout_arg_names: + if input_name in grad_names: + continue + var = main_program.global_block().var(input_name) + if ( + var.persistable + and input_name not in deleted_persist_var_names + ): + deleted_persist_var_names.append(input_name) + elif ( + not var.persistable + and input_name not in deleted_temp_var_names + ): + deleted_temp_var_names.append(input_name) + main_program.global_block()._remove_op(idx) + + for var_name in deleted_temp_var_names + deleted_persist_var_names: + if var_name not in reserved_var_names: + main_program.global_block()._remove_var(var_name) + main_program.global_block()._sync_with_cpp() + + # 1.2 delete the var and op in startup_program + for reserved_name in reserved_var_names: + if reserved_name in deleted_persist_var_names: + deleted_persist_var_names.remove(reserved_name) + startup_global_block = startup_program.global_block() + for var_name in deleted_persist_var_names: + if startup_global_block.has_var(var_name): + startup_global_block._remove_var(var_name) + for idx, op in reversed(list(enumerate(startup_global_block.ops))): + inout_arg_names = op.input_arg_names + op.output_arg_names + for var_name in inout_arg_names: + if var_name in deleted_persist_var_names: + startup_program.global_block()._remove_op(idx) + break + + # 2. re-generate new optimizer op + serial_optimizer = copy.deepcopy(dist_context._serial_optimizer) + serial_optimizer._learning_rate = ( + dist_context._serial_optimizer._learning_rate + ) + serial_optimizer._sorted = False + with program_guard(main_program, startup_program): + with main_program.switch_name_generator_guard("opt_"): + _ = serial_optimizer.apply_gradients(params_grads) + self._completer.complete_update_annotation(main_program) + + def _add_master_grad(self, main_program, params_grads, dist_context): + grad_names = [g.name for _, g in params_grads] + for sub_block in main_program.blocks: + self._add_cast_op(sub_block, grad_names, dist_context) diff --git a/test/auto_parallel/amp_o2_pass.py b/test/auto_parallel/amp_o2_pass.py index a770be6d1e4283..501d1f92cae658 100644 --- a/test/auto_parallel/amp_o2_pass.py +++ b/test/auto_parallel/amp_o2_pass.py @@ -39,7 +39,7 @@ def get_cuda_version(): return -1 -def apply_pass(use_amp=False, amp_dtype="bfloat16"): +def apply_pass(use_amp=False, use_master_grad=False, amp_dtype="bfloat16"): strategy = auto.Strategy() strategy.auto_mode = "semi" strategy.reinit = True @@ -54,6 +54,8 @@ def apply_pass(use_amp=False, amp_dtype="bfloat16"): 'elementwise_div', 'reduce_sum', ] + if use_master_grad: + amp.use_master_grad = True return strategy @@ -77,10 +79,12 @@ def init(self, engine): place = paddle.base.CUDAPlace(paddle.distributed.ParallelEnv().dev_id) engine._executor = paddle.static.Executor(place) - def get_engine(self, use_amp=False, amp_dtype="bfloat16"): + def get_engine( + self, use_amp=False, use_master_grad=False, amp_dtype="bfloat16" + ): reset_prog() - strategy = apply_pass(use_amp, amp_dtype) + strategy = apply_pass(use_amp, use_master_grad, amp_dtype) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) model, loss = generate_model("mp") @@ -105,6 +109,23 @@ def check_bf16(self, program): self.assertEqual(num_fp16, 0) self.assertEqual(num_fp32, 10) + def check_fp16(self, program): + num_bf16 = 0 + num_fp16 = 0 + num_fp32 = 0 + + for p in program.all_parameters(): + if p.dtype == core.VarDesc.VarType.FP32: + num_fp32 += 1 + if p.dtype == core.VarDesc.VarType.FP16: + num_fp16 += 1 + if p.dtype == core.VarDesc.VarType.BF16: + num_bf16 += 1 + + self.assertEqual(num_bf16, 0) + self.assertEqual(num_fp16, 26) + self.assertEqual(num_fp32, 10) + def test_param_grad_fuse_overlap(self): # std mp_engine = self.get_engine(use_amp=False) @@ -139,6 +160,39 @@ def test_param_grad_fuse_overlap(self): self.check_bf16(mp_bf16_engine.main_program) + def test_master_grad(self): + # fp16 + mp_fp16_engine = self.get_engine(use_amp=True, amp_dtype="float16") + if not (paddle.amp.is_float16_supported()): + return + + mp_fp16_history = mp_fp16_engine.fit( + self.dataset, + 3, + epochs=1, + steps_per_epoch=self.batch_num, + log_freq=1, + batch_size=self.batch_size, + ) + loss1 = mp_fp16_history.history['loss'][0] + self.check_fp16(mp_fp16_engine.main_program) + # fp16 + mater_grad + mp_fp16_mater_grad_engine = self.get_engine( + use_amp=True, use_master_grad=True, amp_dtype="float16" + ) + mp_fp16_master_grad_history = mp_fp16_mater_grad_engine.fit( + self.dataset, + 3, + epochs=1, + steps_per_epoch=self.batch_num, + log_freq=1, + batch_size=self.batch_size, + ) + loss2 = mp_fp16_master_grad_history.history['loss'][0] + np.testing.assert_allclose(loss1, loss2, atol=1e-3, rtol=1e-2) + + self.check_fp16(mp_fp16_mater_grad_engine.main_program) + if __name__ == "__main__": unittest.main()