-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【AutoParallel】Add master grad in AMP-O2 of AutoParallel #59987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
ac603bb
add master_grad in auto-parallel
heavyrain-lzy bf1831c
merge develop
heavyrain-lzy 9d05ba4
Merge remote-tracking branch 'upstream/develop' into add_master_grad
heavyrain-lzy db1ce6d
reset third_party
heavyrain-lzy 4dd8e9f
fix coverage
heavyrain-lzy 7e84af8
Merge remote-tracking branch 'upstream/develop' into add_master_grad
heavyrain-lzy 4aac829
Merge remote-tracking branch 'upstream/develop' into add_master_grad
heavyrain-lzy 0548062
support bf16 master_grad
heavyrain-lzy 8c7e198
Merge remote-tracking branch 'upstream/develop' into add_master_grad
heavyrain-lzy de188bc
Merge remote-tracking branch 'upstream/develop' into add_master_grad
heavyrain-lzy ad871a6
fix bug in master_grad
heavyrain-lzy 6eb307f
Merge remote-tracking branch 'upstream/develop' into add_master_grad
heavyrain-lzy 1d5a936
change code according to review
heavyrain-lzy 1d0567f
Merge remote-tracking branch 'upstream/develop' into add_master_grad
heavyrain-lzy 4c96caf
Merge remote-tracking branch 'upstream/develop' into add_master_grad
heavyrain-lzy 21a48b0
change the way to find optimizer op
heavyrain-lzy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
239 changes: 239 additions & 0 deletions
239
python/paddle/distributed/passes/auto_parallel_master_grad.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,239 @@ | ||
| # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import copy | ||
| import logging | ||
| from collections import OrderedDict | ||
| from typing import List, Tuple | ||
|
|
||
| from paddle.base import Variable | ||
| from paddle.distributed.auto_parallel.static.utils import ( | ||
| is_backward_op, | ||
| is_gradient_clip_op, | ||
| is_optimize_op, | ||
| naive_set_dist_op_attr_for_program_by_mesh_and_mapping, | ||
| set_var_dist_attr, | ||
| ) | ||
| from paddle.distributed.fleet.meta_optimizers.common import ( | ||
| OP_ROLE_KEY, | ||
| OpRole, | ||
| ) | ||
| from paddle.framework import core | ||
| from paddle.static import program_guard | ||
|
|
||
| from ..utils.log_utils import get_logger | ||
| from .auto_parallel_sharding import _supported_optimizer_type | ||
| from .pass_base import PassBase, register_pass | ||
|
|
||
| logger = get_logger(logging.INFO, "MasterGradPass") | ||
|
|
||
|
|
||
| def get_output_in_varlist(op, var_names) -> List[str]: | ||
| grad_names = [] | ||
| for output_name in op.output_arg_names: | ||
| if output_name in var_names: | ||
| grad_names.append(output_name) | ||
| return grad_names | ||
|
|
||
|
|
||
| @register_pass("auto_parallel_master_grad_pass") | ||
| class MasterGradPass(PassBase): | ||
| """ | ||
| Use the high precision gradient to replace the low precision gradient in optimizer to avoid inf/nan values of low precision. | ||
| The high precision gradient 'master grad' will be used by communication operator, `update_loss_scaling`, `GradClip` and `optimizer`. | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def _check_self(self): | ||
| return True | ||
|
|
||
| def _check_conflict(self, other_pass): | ||
| return True | ||
|
|
||
| def _apply_single_impl(self, main_program, startup_program, context): | ||
| self._completer = self.get_attr("completer") | ||
| dist_context = self.get_attr("dist_context") | ||
| params_grads = self.get_attr("params_grads") | ||
| logger.debug(f"Origin main_program: {main_program}") | ||
| self._add_master_grad(main_program, params_grads, dist_context) | ||
| self._regenerate_optimizer( | ||
| main_program, startup_program, params_grads, dist_context | ||
| ) | ||
| logger.debug(f"After main program: {main_program}") | ||
|
|
||
| def _add_cast_op(self, cur_block, grad_names: List[str], dist_context): | ||
| grad_first_ids = OrderedDict() | ||
| for idx, op in enumerate(cur_block.ops): | ||
| if is_optimize_op(op): | ||
| break | ||
| elif is_backward_op(op): | ||
| var_names = get_output_in_varlist(op, grad_names) | ||
| for var_name in var_names: | ||
| if var_name not in grad_first_ids: | ||
| grad_first_ids[var_name] = idx | ||
heavyrain-lzy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Communication operators such as 'allreduce_sum' use input var as output. | ||
| else: | ||
| pass | ||
|
|
||
| # insert cast op | ||
| for grad_name, idx in reversed(grad_first_ids.items()): | ||
| grad_var = cur_block.var(grad_name) | ||
| if ( | ||
| grad_var.dtype == core.VarDesc.VarType.FP16 | ||
| or grad_var.dtype == core.VarDesc.VarType.BF16 | ||
| ): | ||
| is_fp16 = grad_var.dtype == core.VarDesc.VarType.FP16 | ||
| producer_op = cur_block.ops[idx] | ||
| producer_op_dist_attr = ( | ||
| dist_context.get_op_dist_attr_for_program(producer_op) | ||
| ) | ||
| assert ( | ||
| producer_op_dist_attr is not None | ||
| ), f"The op: '{producer_op}' should be distributed" | ||
| ref_output_dist_attr = ( | ||
| producer_op_dist_attr.get_output_dist_attr(grad_name) | ||
| ) | ||
| assert ( | ||
| ref_output_dist_attr is not None | ||
| ), f"The output: '{grad_name}' should be distributed" | ||
| ref_mesh = ref_output_dist_attr.process_mesh | ||
| ref_dims_mapping = ref_output_dist_attr.dims_mapping | ||
| ref_chunk_id = producer_op_dist_attr.chunk_id | ||
| grad_half_precision_name = ( | ||
| grad_name + '@tmp_fp16' | ||
| if is_fp16 | ||
| else grad_name + '@tmp_bf16' | ||
| ) | ||
| grad_half_precision = cur_block.create_var( | ||
| name=grad_half_precision_name, | ||
| dtype=grad_var.dtype, | ||
| shape=grad_var.shape, | ||
| persistable=False, | ||
| stop_gradient=False, | ||
| ) | ||
| set_var_dist_attr( | ||
| dist_context, | ||
| grad_half_precision, | ||
| ref_dims_mapping, | ||
| ref_mesh, | ||
| chunk_id=ref_chunk_id, | ||
| ) | ||
| producer_op._rename_output(grad_name, grad_half_precision.name) | ||
| grad_var.desc.set_dtype(core.VarDesc.VarType.FP32) | ||
| cast_op = cur_block._insert_op_without_sync( | ||
| idx + 1, | ||
| type="cast", | ||
| inputs={"X": grad_half_precision}, | ||
| outputs={"Out": grad_var}, | ||
| attrs={ | ||
| "in_dtype": grad_half_precision.dtype, | ||
| "out_dtype": grad_var.dtype, | ||
| }, | ||
| ) | ||
| cast_op._set_attr(OP_ROLE_KEY, OpRole.Backward) | ||
| naive_set_dist_op_attr_for_program_by_mesh_and_mapping( | ||
| cast_op, | ||
| ref_mesh, | ||
| ref_dims_mapping, | ||
| dist_context, | ||
| chunk_id=ref_chunk_id, | ||
| ) | ||
| cur_block._sync_with_cpp() | ||
|
|
||
| def _regenerate_optimizer( | ||
| self, | ||
| main_program, | ||
| startup_program, | ||
| params_grads: List[Tuple[Variable, Variable]], | ||
| dist_context, | ||
| ): | ||
| grad_names = [g.name for _, g in params_grads] | ||
| # 1. delete the origin optimizer op | ||
| # 1.1 delete the var and op associated with the optimizer op in main_program | ||
| main_ops = main_program.global_block().ops | ||
| main_ops_len = len(main_ops) | ||
| first_optimize_idx = main_ops_len | ||
| for idx, op in enumerate(main_ops): | ||
| # We don't delete the operators for check_nan_inf | ||
| if is_optimize_op(op) and is_gradient_clip_op(op): | ||
| first_optimize_idx = idx | ||
| break | ||
| assert ( | ||
| first_optimize_idx < main_ops_len | ||
| ), "The first optimizer op is not found!" | ||
| deleted_temp_var_names = [] | ||
| deleted_persist_var_names = [] | ||
| reserved_var_names = [] | ||
| for idx in range(main_ops_len - 1, first_optimize_idx - 1, -1): | ||
| op = main_ops[idx] | ||
| inout_arg_names = op.input_arg_names + op.output_arg_names | ||
| if op.type in _supported_optimizer_type: | ||
| param_names = op.input("Param") | ||
| skip_update_names = op.input("SkipUpdate") | ||
| for reserved_name in param_names + skip_update_names: | ||
| if reserved_name not in reserved_var_names: | ||
| reserved_var_names.append(reserved_name) | ||
| for input_name in inout_arg_names: | ||
| if input_name in grad_names: | ||
| continue | ||
| var = main_program.global_block().var(input_name) | ||
| if ( | ||
| var.persistable | ||
| and input_name not in deleted_persist_var_names | ||
| ): | ||
| deleted_persist_var_names.append(input_name) | ||
| elif ( | ||
| not var.persistable | ||
| and input_name not in deleted_temp_var_names | ||
| ): | ||
| deleted_temp_var_names.append(input_name) | ||
| main_program.global_block()._remove_op(idx) | ||
|
|
||
| for var_name in deleted_temp_var_names + deleted_persist_var_names: | ||
| if var_name not in reserved_var_names: | ||
| main_program.global_block()._remove_var(var_name) | ||
| main_program.global_block()._sync_with_cpp() | ||
|
|
||
| # 1.2 delete the var and op in startup_program | ||
| for reserved_name in reserved_var_names: | ||
| if reserved_name in deleted_persist_var_names: | ||
| deleted_persist_var_names.remove(reserved_name) | ||
| startup_global_block = startup_program.global_block() | ||
| for var_name in deleted_persist_var_names: | ||
| if startup_global_block.has_var(var_name): | ||
| startup_global_block._remove_var(var_name) | ||
| for idx, op in reversed(list(enumerate(startup_global_block.ops))): | ||
| inout_arg_names = op.input_arg_names + op.output_arg_names | ||
| for var_name in inout_arg_names: | ||
| if var_name in deleted_persist_var_names: | ||
| startup_program.global_block()._remove_op(idx) | ||
| break | ||
|
|
||
| # 2. re-generate new optimizer op | ||
| serial_optimizer = copy.deepcopy(dist_context._serial_optimizer) | ||
| serial_optimizer._learning_rate = ( | ||
| dist_context._serial_optimizer._learning_rate | ||
| ) | ||
| serial_optimizer._sorted = False | ||
| with program_guard(main_program, startup_program): | ||
| with main_program.switch_name_generator_guard("opt_"): | ||
| _ = serial_optimizer.apply_gradients(params_grads) | ||
| self._completer.complete_update_annotation(main_program) | ||
|
|
||
| def _add_master_grad(self, main_program, params_grads, dist_context): | ||
| grad_names = [g.name for _, g in params_grads] | ||
| for sub_block in main_program.blocks: | ||
| self._add_cast_op(sub_block, grad_names, dist_context) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.