Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_fp16_guard", False)
set_field_default_config(AMP, "use_bf16_guard", False)
set_field_default_config(AMP, "use_master_grad", False)

#########################################
# sharding configuration
Expand Down
22 changes: 19 additions & 3 deletions python/paddle/distributed/auto_parallel/static/parallelizer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,15 @@ def _generate_optimizer(
# but optimizer will be called repeatedly in re-launch, so optimizer need to be copied.
# 2. lr_scheduler cannot be deepcopy, cause 'deepcopy' will lead to difference of learning_rate between executor and engine.
learning_rate = optimizer._learning_rate
optimizer = copy.deepcopy(optimizer)
new_optimizer = copy.deepcopy(optimizer)
new_optimizer._learning_rate = learning_rate
new_optimizer._sorted = False
self._dist_context._serial_optimizer = optimizer
self._dist_context._serial_optimizer._learning_rate = learning_rate
optimizer._sorted = False

with program_guard(main_program, startup_program):
with main_program.switch_name_generator_guard("opt_"):
optimizer_ops = optimizer.apply_gradients(params_grads)
optimizer_ops = new_optimizer.apply_gradients(params_grads)
self._completer.complete_update_annotation(main_program)
return optimizer_ops

Expand Down Expand Up @@ -380,6 +381,21 @@ def _apply_post_optimization(
[main_program], [startup_program], self._pass_context
)

# apply master grad pass
if self._strategy.amp.enable:
amp_config = copy.deepcopy(self._strategy.amp.to_dict())
config = {}
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["completer"] = self._completer
if amp_config['level'] == "o2" and amp_config["use_master_grad"]:
master_grad_pass = new_pass(
"auto_parallel_master_grad_pass", config
)
master_grad_pass.apply(
[main_program], [startup_program], self._pass_context
)

# data parallel optimization
if self._strategy.dp_optimization.enable:
config = copy.deepcopy(self._strategy.dp_optimization.to_dict())
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .auto_parallel_gradient_merge import * # noqa: F403
from .auto_parallel_sharding import * # noqa: F403
from .auto_parallel_amp import * # noqa: F403
from .auto_parallel_master_grad import * # noqa: F403
from .auto_parallel_fp16 import * # noqa: F403
from .auto_parallel_recompute import * # noqa: F403
from .auto_parallel_quantization import * # noqa: F403
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/passes/auto_parallel_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def set_auto_cast_attr(cast_op, block):
), f"in_var {in_name} or out_var {out_name} is None of cast op"
if is_forward_op(cast_op):
cast_op._set_attr('in_dtype', in_var.dtype)
cast_op._set_attr('out_dtype', out_var.dtype)
out_var.desc.set_dtype(paddle.dtype(cast_op.attr('out_dtype')))
elif is_backward_op(cast_op):
in_var_fw = block._find_var_recursive(in_name[: in_name.find("@")])
out_var_fw = block._find_var_recursive(out_name[: out_name.find("@")])
Expand Down
13 changes: 7 additions & 6 deletions python/paddle/distributed/passes/auto_parallel_gradient_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def _append_gradient_merge_backward_op(
for out_name in op.desc.output_arg_names():
if out_name in grad_to_params_grads:
param = grad_to_params_grads[out_name][0]
grad = grad_to_params_grads[out_name][1]
assert param is not None
ref_dist_attr = dist_context.get_tensor_dist_attr_for_program(
param
Expand All @@ -188,8 +189,8 @@ def _append_gradient_merge_backward_op(
# Add persistable gradient variables in main_program
gradient_merge_var = main_block.create_var(
name=param.name + "@GRAD@MERGE",
shape=param.shape,
dtype=param.dtype,
shape=grad.shape,
dtype=grad.dtype,
persistable=True,
)
ref_process_mesh = ref_dist_attr.process_mesh
Expand All @@ -205,17 +206,17 @@ def _append_gradient_merge_backward_op(
# Add persistable gradient variables in startup_program
startup_gradient_merge_var = startup_block.create_var(
name=param.name + "@GRAD@MERGE",
shape=param.shape,
dtype=param.dtype,
shape=grad.shape,
dtype=grad.dtype,
persistable=True,
)
# Initial persistable gradient variables in startup_program
startup_block.append_op(
type="fill_constant",
outputs={"Out": startup_gradient_merge_var},
attrs={
"shape": param.shape,
"dtype": param.dtype,
"shape": grad.shape,
"dtype": grad.dtype,
"value": float(0),
},
)
Expand Down
239 changes: 239 additions & 0 deletions python/paddle/distributed/passes/auto_parallel_master_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
from collections import OrderedDict
from typing import List, Tuple

from paddle.base import Variable
from paddle.distributed.auto_parallel.static.utils import (
is_backward_op,
is_gradient_clip_op,
is_optimize_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
)
from paddle.distributed.fleet.meta_optimizers.common import (
OP_ROLE_KEY,
OpRole,
)
from paddle.framework import core
from paddle.static import program_guard

from ..utils.log_utils import get_logger
from .auto_parallel_sharding import _supported_optimizer_type
from .pass_base import PassBase, register_pass

logger = get_logger(logging.INFO, "MasterGradPass")


def get_output_in_varlist(op, var_names) -> List[str]:
grad_names = []
for output_name in op.output_arg_names:
if output_name in var_names:
grad_names.append(output_name)
return grad_names


@register_pass("auto_parallel_master_grad_pass")
class MasterGradPass(PassBase):
"""
Use the high precision gradient to replace the low precision gradient in optimizer to avoid inf/nan values of low precision.
The high precision gradient 'master grad' will be used by communication operator, `update_loss_scaling`, `GradClip` and `optimizer`.
"""

def __init__(self):
super().__init__()

def _check_self(self):
return True

def _check_conflict(self, other_pass):
return True

def _apply_single_impl(self, main_program, startup_program, context):
self._completer = self.get_attr("completer")
dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
logger.debug(f"Origin main_program: {main_program}")
self._add_master_grad(main_program, params_grads, dist_context)
self._regenerate_optimizer(
main_program, startup_program, params_grads, dist_context
)
logger.debug(f"After main program: {main_program}")

def _add_cast_op(self, cur_block, grad_names: List[str], dist_context):
grad_first_ids = OrderedDict()
for idx, op in enumerate(cur_block.ops):
if is_optimize_op(op):
break
elif is_backward_op(op):
var_names = get_output_in_varlist(op, grad_names)
for var_name in var_names:
if var_name not in grad_first_ids:
grad_first_ids[var_name] = idx
# Communication operators such as 'allreduce_sum' use input var as output.
else:
pass

# insert cast op
for grad_name, idx in reversed(grad_first_ids.items()):
grad_var = cur_block.var(grad_name)
if (
grad_var.dtype == core.VarDesc.VarType.FP16
or grad_var.dtype == core.VarDesc.VarType.BF16
):
is_fp16 = grad_var.dtype == core.VarDesc.VarType.FP16
producer_op = cur_block.ops[idx]
producer_op_dist_attr = (
dist_context.get_op_dist_attr_for_program(producer_op)
)
assert (
producer_op_dist_attr is not None
), f"The op: '{producer_op}' should be distributed"
ref_output_dist_attr = (
producer_op_dist_attr.get_output_dist_attr(grad_name)
)
assert (
ref_output_dist_attr is not None
), f"The output: '{grad_name}' should be distributed"
ref_mesh = ref_output_dist_attr.process_mesh
ref_dims_mapping = ref_output_dist_attr.dims_mapping
ref_chunk_id = producer_op_dist_attr.chunk_id
grad_half_precision_name = (
grad_name + '@tmp_fp16'
if is_fp16
else grad_name + '@tmp_bf16'
)
grad_half_precision = cur_block.create_var(
name=grad_half_precision_name,
dtype=grad_var.dtype,
shape=grad_var.shape,
persistable=False,
stop_gradient=False,
)
set_var_dist_attr(
dist_context,
grad_half_precision,
ref_dims_mapping,
ref_mesh,
chunk_id=ref_chunk_id,
)
producer_op._rename_output(grad_name, grad_half_precision.name)
grad_var.desc.set_dtype(core.VarDesc.VarType.FP32)
cast_op = cur_block._insert_op_without_sync(
idx + 1,
type="cast",
inputs={"X": grad_half_precision},
outputs={"Out": grad_var},
attrs={
"in_dtype": grad_half_precision.dtype,
"out_dtype": grad_var.dtype,
},
)
cast_op._set_attr(OP_ROLE_KEY, OpRole.Backward)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op,
ref_mesh,
ref_dims_mapping,
dist_context,
chunk_id=ref_chunk_id,
)
cur_block._sync_with_cpp()

def _regenerate_optimizer(
self,
main_program,
startup_program,
params_grads: List[Tuple[Variable, Variable]],
dist_context,
):
grad_names = [g.name for _, g in params_grads]
# 1. delete the origin optimizer op
# 1.1 delete the var and op associated with the optimizer op in main_program
main_ops = main_program.global_block().ops
main_ops_len = len(main_ops)
first_optimize_idx = main_ops_len
for idx, op in enumerate(main_ops):
# We don't delete the operators for check_nan_inf
if is_optimize_op(op) and is_gradient_clip_op(op):
first_optimize_idx = idx
break
assert (
first_optimize_idx < main_ops_len
), "The first optimizer op is not found!"
deleted_temp_var_names = []
deleted_persist_var_names = []
reserved_var_names = []
for idx in range(main_ops_len - 1, first_optimize_idx - 1, -1):
op = main_ops[idx]
inout_arg_names = op.input_arg_names + op.output_arg_names
if op.type in _supported_optimizer_type:
param_names = op.input("Param")
skip_update_names = op.input("SkipUpdate")
for reserved_name in param_names + skip_update_names:
if reserved_name not in reserved_var_names:
reserved_var_names.append(reserved_name)
for input_name in inout_arg_names:
if input_name in grad_names:
continue
var = main_program.global_block().var(input_name)
if (
var.persistable
and input_name not in deleted_persist_var_names
):
deleted_persist_var_names.append(input_name)
elif (
not var.persistable
and input_name not in deleted_temp_var_names
):
deleted_temp_var_names.append(input_name)
main_program.global_block()._remove_op(idx)

for var_name in deleted_temp_var_names + deleted_persist_var_names:
if var_name not in reserved_var_names:
main_program.global_block()._remove_var(var_name)
main_program.global_block()._sync_with_cpp()

# 1.2 delete the var and op in startup_program
for reserved_name in reserved_var_names:
if reserved_name in deleted_persist_var_names:
deleted_persist_var_names.remove(reserved_name)
startup_global_block = startup_program.global_block()
for var_name in deleted_persist_var_names:
if startup_global_block.has_var(var_name):
startup_global_block._remove_var(var_name)
for idx, op in reversed(list(enumerate(startup_global_block.ops))):
inout_arg_names = op.input_arg_names + op.output_arg_names
for var_name in inout_arg_names:
if var_name in deleted_persist_var_names:
startup_program.global_block()._remove_op(idx)
break

# 2. re-generate new optimizer op
serial_optimizer = copy.deepcopy(dist_context._serial_optimizer)
serial_optimizer._learning_rate = (
dist_context._serial_optimizer._learning_rate
)
serial_optimizer._sorted = False
with program_guard(main_program, startup_program):
with main_program.switch_name_generator_guard("opt_"):
_ = serial_optimizer.apply_gradients(params_grads)
self._completer.complete_update_annotation(main_program)

def _add_master_grad(self, main_program, params_grads, dist_context):
grad_names = [g.name for _, g in params_grads]
for sub_block in main_program.blocks:
self._add_cast_op(sub_block, grad_names, dist_context)
Loading