|
| 1 | +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import copy |
| 16 | +import logging |
| 17 | +from collections import OrderedDict |
| 18 | +from typing import List, Tuple |
| 19 | + |
| 20 | +from paddle.base import Variable |
| 21 | +from paddle.distributed.auto_parallel.static.utils import ( |
| 22 | + is_backward_op, |
| 23 | + is_gradient_clip_op, |
| 24 | + is_optimize_op, |
| 25 | + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, |
| 26 | + set_var_dist_attr, |
| 27 | +) |
| 28 | +from paddle.distributed.fleet.meta_optimizers.common import ( |
| 29 | + OP_ROLE_KEY, |
| 30 | + OpRole, |
| 31 | +) |
| 32 | +from paddle.framework import core |
| 33 | +from paddle.static import program_guard |
| 34 | + |
| 35 | +from ..utils.log_utils import get_logger |
| 36 | +from .auto_parallel_sharding import _supported_optimizer_type |
| 37 | +from .pass_base import PassBase, register_pass |
| 38 | + |
| 39 | +logger = get_logger(logging.INFO, "MasterGradPass") |
| 40 | + |
| 41 | + |
| 42 | +def get_output_in_varlist(op, var_names) -> List[str]: |
| 43 | + grad_names = [] |
| 44 | + for output_name in op.output_arg_names: |
| 45 | + if output_name in var_names: |
| 46 | + grad_names.append(output_name) |
| 47 | + return grad_names |
| 48 | + |
| 49 | + |
| 50 | +@register_pass("auto_parallel_master_grad_pass") |
| 51 | +class MasterGradPass(PassBase): |
| 52 | + """ |
| 53 | + Use the high precision gradient to replace the low precision gradient in optimizer to avoid inf/nan values of low precision. |
| 54 | + The high precision gradient 'master grad' will be used by communication operator, `update_loss_scaling`, `GradClip` and `optimizer`. |
| 55 | + """ |
| 56 | + |
| 57 | + def __init__(self): |
| 58 | + super().__init__() |
| 59 | + |
| 60 | + def _check_self(self): |
| 61 | + return True |
| 62 | + |
| 63 | + def _check_conflict(self, other_pass): |
| 64 | + return True |
| 65 | + |
| 66 | + def _apply_single_impl(self, main_program, startup_program, context): |
| 67 | + self._completer = self.get_attr("completer") |
| 68 | + dist_context = self.get_attr("dist_context") |
| 69 | + params_grads = self.get_attr("params_grads") |
| 70 | + logger.debug(f"Origin main_program: {main_program}") |
| 71 | + self._add_master_grad(main_program, params_grads, dist_context) |
| 72 | + self._regenerate_optimizer( |
| 73 | + main_program, startup_program, params_grads, dist_context |
| 74 | + ) |
| 75 | + logger.debug(f"After main program: {main_program}") |
| 76 | + |
| 77 | + def _add_cast_op(self, cur_block, grad_names: List[str], dist_context): |
| 78 | + grad_first_ids = OrderedDict() |
| 79 | + for idx, op in enumerate(cur_block.ops): |
| 80 | + if is_optimize_op(op): |
| 81 | + break |
| 82 | + elif is_backward_op(op): |
| 83 | + var_names = get_output_in_varlist(op, grad_names) |
| 84 | + for var_name in var_names: |
| 85 | + if var_name not in grad_first_ids: |
| 86 | + grad_first_ids[var_name] = idx |
| 87 | + # Communication operators such as 'allreduce_sum' use input var as output. |
| 88 | + else: |
| 89 | + pass |
| 90 | + |
| 91 | + # insert cast op |
| 92 | + for grad_name, idx in reversed(grad_first_ids.items()): |
| 93 | + grad_var = cur_block.var(grad_name) |
| 94 | + if ( |
| 95 | + grad_var.dtype == core.VarDesc.VarType.FP16 |
| 96 | + or grad_var.dtype == core.VarDesc.VarType.BF16 |
| 97 | + ): |
| 98 | + is_fp16 = grad_var.dtype == core.VarDesc.VarType.FP16 |
| 99 | + producer_op = cur_block.ops[idx] |
| 100 | + producer_op_dist_attr = ( |
| 101 | + dist_context.get_op_dist_attr_for_program(producer_op) |
| 102 | + ) |
| 103 | + assert ( |
| 104 | + producer_op_dist_attr is not None |
| 105 | + ), f"The op: '{producer_op}' should be distributed" |
| 106 | + ref_output_dist_attr = ( |
| 107 | + producer_op_dist_attr.get_output_dist_attr(grad_name) |
| 108 | + ) |
| 109 | + assert ( |
| 110 | + ref_output_dist_attr is not None |
| 111 | + ), f"The output: '{grad_name}' should be distributed" |
| 112 | + ref_mesh = ref_output_dist_attr.process_mesh |
| 113 | + ref_dims_mapping = ref_output_dist_attr.dims_mapping |
| 114 | + ref_chunk_id = producer_op_dist_attr.chunk_id |
| 115 | + grad_half_precision_name = ( |
| 116 | + grad_name + '@tmp_fp16' |
| 117 | + if is_fp16 |
| 118 | + else grad_name + '@tmp_bf16' |
| 119 | + ) |
| 120 | + grad_half_precision = cur_block.create_var( |
| 121 | + name=grad_half_precision_name, |
| 122 | + dtype=grad_var.dtype, |
| 123 | + shape=grad_var.shape, |
| 124 | + persistable=False, |
| 125 | + stop_gradient=False, |
| 126 | + ) |
| 127 | + set_var_dist_attr( |
| 128 | + dist_context, |
| 129 | + grad_half_precision, |
| 130 | + ref_dims_mapping, |
| 131 | + ref_mesh, |
| 132 | + chunk_id=ref_chunk_id, |
| 133 | + ) |
| 134 | + producer_op._rename_output(grad_name, grad_half_precision.name) |
| 135 | + grad_var.desc.set_dtype(core.VarDesc.VarType.FP32) |
| 136 | + cast_op = cur_block._insert_op_without_sync( |
| 137 | + idx + 1, |
| 138 | + type="cast", |
| 139 | + inputs={"X": grad_half_precision}, |
| 140 | + outputs={"Out": grad_var}, |
| 141 | + attrs={ |
| 142 | + "in_dtype": grad_half_precision.dtype, |
| 143 | + "out_dtype": grad_var.dtype, |
| 144 | + }, |
| 145 | + ) |
| 146 | + cast_op._set_attr(OP_ROLE_KEY, OpRole.Backward) |
| 147 | + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( |
| 148 | + cast_op, |
| 149 | + ref_mesh, |
| 150 | + ref_dims_mapping, |
| 151 | + dist_context, |
| 152 | + chunk_id=ref_chunk_id, |
| 153 | + ) |
| 154 | + cur_block._sync_with_cpp() |
| 155 | + |
| 156 | + def _regenerate_optimizer( |
| 157 | + self, |
| 158 | + main_program, |
| 159 | + startup_program, |
| 160 | + params_grads: List[Tuple[Variable, Variable]], |
| 161 | + dist_context, |
| 162 | + ): |
| 163 | + grad_names = [g.name for _, g in params_grads] |
| 164 | + # 1. delete the origin optimizer op |
| 165 | + # 1.1 delete the var and op associated with the optimizer op in main_program |
| 166 | + main_ops = main_program.global_block().ops |
| 167 | + main_ops_len = len(main_ops) |
| 168 | + first_optimize_idx = main_ops_len |
| 169 | + for idx, op in enumerate(main_ops): |
| 170 | + # We don't delete the operators for check_nan_inf |
| 171 | + if is_optimize_op(op) and is_gradient_clip_op(op): |
| 172 | + first_optimize_idx = idx |
| 173 | + break |
| 174 | + assert ( |
| 175 | + first_optimize_idx < main_ops_len |
| 176 | + ), "The first optimizer op is not found!" |
| 177 | + deleted_temp_var_names = [] |
| 178 | + deleted_persist_var_names = [] |
| 179 | + reserved_var_names = [] |
| 180 | + for idx in range(main_ops_len - 1, first_optimize_idx - 1, -1): |
| 181 | + op = main_ops[idx] |
| 182 | + inout_arg_names = op.input_arg_names + op.output_arg_names |
| 183 | + if op.type in _supported_optimizer_type: |
| 184 | + param_names = op.input("Param") |
| 185 | + skip_update_names = op.input("SkipUpdate") |
| 186 | + for reserved_name in param_names + skip_update_names: |
| 187 | + if reserved_name not in reserved_var_names: |
| 188 | + reserved_var_names.append(reserved_name) |
| 189 | + for input_name in inout_arg_names: |
| 190 | + if input_name in grad_names: |
| 191 | + continue |
| 192 | + var = main_program.global_block().var(input_name) |
| 193 | + if ( |
| 194 | + var.persistable |
| 195 | + and input_name not in deleted_persist_var_names |
| 196 | + ): |
| 197 | + deleted_persist_var_names.append(input_name) |
| 198 | + elif ( |
| 199 | + not var.persistable |
| 200 | + and input_name not in deleted_temp_var_names |
| 201 | + ): |
| 202 | + deleted_temp_var_names.append(input_name) |
| 203 | + main_program.global_block()._remove_op(idx) |
| 204 | + |
| 205 | + for var_name in deleted_temp_var_names + deleted_persist_var_names: |
| 206 | + if var_name not in reserved_var_names: |
| 207 | + main_program.global_block()._remove_var(var_name) |
| 208 | + main_program.global_block()._sync_with_cpp() |
| 209 | + |
| 210 | + # 1.2 delete the var and op in startup_program |
| 211 | + for reserved_name in reserved_var_names: |
| 212 | + if reserved_name in deleted_persist_var_names: |
| 213 | + deleted_persist_var_names.remove(reserved_name) |
| 214 | + startup_global_block = startup_program.global_block() |
| 215 | + for var_name in deleted_persist_var_names: |
| 216 | + if startup_global_block.has_var(var_name): |
| 217 | + startup_global_block._remove_var(var_name) |
| 218 | + for idx, op in reversed(list(enumerate(startup_global_block.ops))): |
| 219 | + inout_arg_names = op.input_arg_names + op.output_arg_names |
| 220 | + for var_name in inout_arg_names: |
| 221 | + if var_name in deleted_persist_var_names: |
| 222 | + startup_program.global_block()._remove_op(idx) |
| 223 | + break |
| 224 | + |
| 225 | + # 2. re-generate new optimizer op |
| 226 | + serial_optimizer = copy.deepcopy(dist_context._serial_optimizer) |
| 227 | + serial_optimizer._learning_rate = ( |
| 228 | + dist_context._serial_optimizer._learning_rate |
| 229 | + ) |
| 230 | + serial_optimizer._sorted = False |
| 231 | + with program_guard(main_program, startup_program): |
| 232 | + with main_program.switch_name_generator_guard("opt_"): |
| 233 | + _ = serial_optimizer.apply_gradients(params_grads) |
| 234 | + self._completer.complete_update_annotation(main_program) |
| 235 | + |
| 236 | + def _add_master_grad(self, main_program, params_grads, dist_context): |
| 237 | + grad_names = [g.name for _, g in params_grads] |
| 238 | + for sub_block in main_program.blocks: |
| 239 | + self._add_cast_op(sub_block, grad_names, dist_context) |
0 commit comments