Skip to content

Commit f84fbdd

Browse files
【AutoParallel】Add master grad in AMP-O2 of AutoParallel (#59987)
* add master_grad in auto-parallel * reset third_party * fix coverage * support bf16 master_grad * fix bug in master_grad * change code according to review * change the way to find optimizer op
1 parent 8ed3d18 commit f84fbdd

File tree

7 files changed

+325
-13
lines changed

7 files changed

+325
-13
lines changed

python/paddle/distributed/auto_parallel/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def set_field_default_config(category, field, default_value):
7878
set_field_default_config(AMP, "custom_black_varnames", [])
7979
set_field_default_config(AMP, "use_fp16_guard", False)
8080
set_field_default_config(AMP, "use_bf16_guard", False)
81+
set_field_default_config(AMP, "use_master_grad", False)
8182

8283
#########################################
8384
# sharding configuration

python/paddle/distributed/auto_parallel/static/parallelizer_v2.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,15 @@ def _generate_optimizer(
252252
# but optimizer will be called repeatedly in re-launch, so optimizer need to be copied.
253253
# 2. lr_scheduler cannot be deepcopy, cause 'deepcopy' will lead to difference of learning_rate between executor and engine.
254254
learning_rate = optimizer._learning_rate
255-
optimizer = copy.deepcopy(optimizer)
255+
new_optimizer = copy.deepcopy(optimizer)
256+
new_optimizer._learning_rate = learning_rate
257+
new_optimizer._sorted = False
256258
self._dist_context._serial_optimizer = optimizer
257259
self._dist_context._serial_optimizer._learning_rate = learning_rate
258-
optimizer._sorted = False
259260

260261
with program_guard(main_program, startup_program):
261262
with main_program.switch_name_generator_guard("opt_"):
262-
optimizer_ops = optimizer.apply_gradients(params_grads)
263+
optimizer_ops = new_optimizer.apply_gradients(params_grads)
263264
self._completer.complete_update_annotation(main_program)
264265
return optimizer_ops
265266

@@ -380,6 +381,21 @@ def _apply_post_optimization(
380381
[main_program], [startup_program], self._pass_context
381382
)
382383

384+
# apply master grad pass
385+
if self._strategy.amp.enable:
386+
amp_config = copy.deepcopy(self._strategy.amp.to_dict())
387+
config = {}
388+
config["dist_context"] = self._dist_context
389+
config["params_grads"] = params_grads
390+
config["completer"] = self._completer
391+
if amp_config['level'] == "o2" and amp_config["use_master_grad"]:
392+
master_grad_pass = new_pass(
393+
"auto_parallel_master_grad_pass", config
394+
)
395+
master_grad_pass.apply(
396+
[main_program], [startup_program], self._pass_context
397+
)
398+
383399
# data parallel optimization
384400
if self._strategy.dp_optimization.enable:
385401
config = copy.deepcopy(self._strategy.dp_optimization.to_dict())

python/paddle/distributed/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .auto_parallel_gradient_merge import * # noqa: F403
1818
from .auto_parallel_sharding import * # noqa: F403
1919
from .auto_parallel_amp import * # noqa: F403
20+
from .auto_parallel_master_grad import * # noqa: F403
2021
from .auto_parallel_fp16 import * # noqa: F403
2122
from .auto_parallel_recompute import * # noqa: F403
2223
from .auto_parallel_quantization import * # noqa: F403

python/paddle/distributed/passes/auto_parallel_fp16.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def set_auto_cast_attr(cast_op, block):
7979
), f"in_var {in_name} or out_var {out_name} is None of cast op"
8080
if is_forward_op(cast_op):
8181
cast_op._set_attr('in_dtype', in_var.dtype)
82-
cast_op._set_attr('out_dtype', out_var.dtype)
82+
out_var.desc.set_dtype(paddle.dtype(cast_op.attr('out_dtype')))
8383
elif is_backward_op(cast_op):
8484
in_var_fw = block._find_var_recursive(in_name[: in_name.find("@")])
8585
out_var_fw = block._find_var_recursive(out_name[: out_name.find("@")])

python/paddle/distributed/passes/auto_parallel_gradient_merge.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def _append_gradient_merge_backward_op(
178178
for out_name in op.desc.output_arg_names():
179179
if out_name in grad_to_params_grads:
180180
param = grad_to_params_grads[out_name][0]
181+
grad = grad_to_params_grads[out_name][1]
181182
assert param is not None
182183
ref_dist_attr = dist_context.get_tensor_dist_attr_for_program(
183184
param
@@ -188,8 +189,8 @@ def _append_gradient_merge_backward_op(
188189
# Add persistable gradient variables in main_program
189190
gradient_merge_var = main_block.create_var(
190191
name=param.name + "@GRAD@MERGE",
191-
shape=param.shape,
192-
dtype=param.dtype,
192+
shape=grad.shape,
193+
dtype=grad.dtype,
193194
persistable=True,
194195
)
195196
ref_process_mesh = ref_dist_attr.process_mesh
@@ -205,17 +206,17 @@ def _append_gradient_merge_backward_op(
205206
# Add persistable gradient variables in startup_program
206207
startup_gradient_merge_var = startup_block.create_var(
207208
name=param.name + "@GRAD@MERGE",
208-
shape=param.shape,
209-
dtype=param.dtype,
209+
shape=grad.shape,
210+
dtype=grad.dtype,
210211
persistable=True,
211212
)
212213
# Initial persistable gradient variables in startup_program
213214
startup_block.append_op(
214215
type="fill_constant",
215216
outputs={"Out": startup_gradient_merge_var},
216217
attrs={
217-
"shape": param.shape,
218-
"dtype": param.dtype,
218+
"shape": grad.shape,
219+
"dtype": grad.dtype,
219220
"value": float(0),
220221
},
221222
)
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
import logging
17+
from collections import OrderedDict
18+
from typing import List, Tuple
19+
20+
from paddle.base import Variable
21+
from paddle.distributed.auto_parallel.static.utils import (
22+
is_backward_op,
23+
is_gradient_clip_op,
24+
is_optimize_op,
25+
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
26+
set_var_dist_attr,
27+
)
28+
from paddle.distributed.fleet.meta_optimizers.common import (
29+
OP_ROLE_KEY,
30+
OpRole,
31+
)
32+
from paddle.framework import core
33+
from paddle.static import program_guard
34+
35+
from ..utils.log_utils import get_logger
36+
from .auto_parallel_sharding import _supported_optimizer_type
37+
from .pass_base import PassBase, register_pass
38+
39+
logger = get_logger(logging.INFO, "MasterGradPass")
40+
41+
42+
def get_output_in_varlist(op, var_names) -> List[str]:
43+
grad_names = []
44+
for output_name in op.output_arg_names:
45+
if output_name in var_names:
46+
grad_names.append(output_name)
47+
return grad_names
48+
49+
50+
@register_pass("auto_parallel_master_grad_pass")
51+
class MasterGradPass(PassBase):
52+
"""
53+
Use the high precision gradient to replace the low precision gradient in optimizer to avoid inf/nan values of low precision.
54+
The high precision gradient 'master grad' will be used by communication operator, `update_loss_scaling`, `GradClip` and `optimizer`.
55+
"""
56+
57+
def __init__(self):
58+
super().__init__()
59+
60+
def _check_self(self):
61+
return True
62+
63+
def _check_conflict(self, other_pass):
64+
return True
65+
66+
def _apply_single_impl(self, main_program, startup_program, context):
67+
self._completer = self.get_attr("completer")
68+
dist_context = self.get_attr("dist_context")
69+
params_grads = self.get_attr("params_grads")
70+
logger.debug(f"Origin main_program: {main_program}")
71+
self._add_master_grad(main_program, params_grads, dist_context)
72+
self._regenerate_optimizer(
73+
main_program, startup_program, params_grads, dist_context
74+
)
75+
logger.debug(f"After main program: {main_program}")
76+
77+
def _add_cast_op(self, cur_block, grad_names: List[str], dist_context):
78+
grad_first_ids = OrderedDict()
79+
for idx, op in enumerate(cur_block.ops):
80+
if is_optimize_op(op):
81+
break
82+
elif is_backward_op(op):
83+
var_names = get_output_in_varlist(op, grad_names)
84+
for var_name in var_names:
85+
if var_name not in grad_first_ids:
86+
grad_first_ids[var_name] = idx
87+
# Communication operators such as 'allreduce_sum' use input var as output.
88+
else:
89+
pass
90+
91+
# insert cast op
92+
for grad_name, idx in reversed(grad_first_ids.items()):
93+
grad_var = cur_block.var(grad_name)
94+
if (
95+
grad_var.dtype == core.VarDesc.VarType.FP16
96+
or grad_var.dtype == core.VarDesc.VarType.BF16
97+
):
98+
is_fp16 = grad_var.dtype == core.VarDesc.VarType.FP16
99+
producer_op = cur_block.ops[idx]
100+
producer_op_dist_attr = (
101+
dist_context.get_op_dist_attr_for_program(producer_op)
102+
)
103+
assert (
104+
producer_op_dist_attr is not None
105+
), f"The op: '{producer_op}' should be distributed"
106+
ref_output_dist_attr = (
107+
producer_op_dist_attr.get_output_dist_attr(grad_name)
108+
)
109+
assert (
110+
ref_output_dist_attr is not None
111+
), f"The output: '{grad_name}' should be distributed"
112+
ref_mesh = ref_output_dist_attr.process_mesh
113+
ref_dims_mapping = ref_output_dist_attr.dims_mapping
114+
ref_chunk_id = producer_op_dist_attr.chunk_id
115+
grad_half_precision_name = (
116+
grad_name + '@tmp_fp16'
117+
if is_fp16
118+
else grad_name + '@tmp_bf16'
119+
)
120+
grad_half_precision = cur_block.create_var(
121+
name=grad_half_precision_name,
122+
dtype=grad_var.dtype,
123+
shape=grad_var.shape,
124+
persistable=False,
125+
stop_gradient=False,
126+
)
127+
set_var_dist_attr(
128+
dist_context,
129+
grad_half_precision,
130+
ref_dims_mapping,
131+
ref_mesh,
132+
chunk_id=ref_chunk_id,
133+
)
134+
producer_op._rename_output(grad_name, grad_half_precision.name)
135+
grad_var.desc.set_dtype(core.VarDesc.VarType.FP32)
136+
cast_op = cur_block._insert_op_without_sync(
137+
idx + 1,
138+
type="cast",
139+
inputs={"X": grad_half_precision},
140+
outputs={"Out": grad_var},
141+
attrs={
142+
"in_dtype": grad_half_precision.dtype,
143+
"out_dtype": grad_var.dtype,
144+
},
145+
)
146+
cast_op._set_attr(OP_ROLE_KEY, OpRole.Backward)
147+
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
148+
cast_op,
149+
ref_mesh,
150+
ref_dims_mapping,
151+
dist_context,
152+
chunk_id=ref_chunk_id,
153+
)
154+
cur_block._sync_with_cpp()
155+
156+
def _regenerate_optimizer(
157+
self,
158+
main_program,
159+
startup_program,
160+
params_grads: List[Tuple[Variable, Variable]],
161+
dist_context,
162+
):
163+
grad_names = [g.name for _, g in params_grads]
164+
# 1. delete the origin optimizer op
165+
# 1.1 delete the var and op associated with the optimizer op in main_program
166+
main_ops = main_program.global_block().ops
167+
main_ops_len = len(main_ops)
168+
first_optimize_idx = main_ops_len
169+
for idx, op in enumerate(main_ops):
170+
# We don't delete the operators for check_nan_inf
171+
if is_optimize_op(op) and is_gradient_clip_op(op):
172+
first_optimize_idx = idx
173+
break
174+
assert (
175+
first_optimize_idx < main_ops_len
176+
), "The first optimizer op is not found!"
177+
deleted_temp_var_names = []
178+
deleted_persist_var_names = []
179+
reserved_var_names = []
180+
for idx in range(main_ops_len - 1, first_optimize_idx - 1, -1):
181+
op = main_ops[idx]
182+
inout_arg_names = op.input_arg_names + op.output_arg_names
183+
if op.type in _supported_optimizer_type:
184+
param_names = op.input("Param")
185+
skip_update_names = op.input("SkipUpdate")
186+
for reserved_name in param_names + skip_update_names:
187+
if reserved_name not in reserved_var_names:
188+
reserved_var_names.append(reserved_name)
189+
for input_name in inout_arg_names:
190+
if input_name in grad_names:
191+
continue
192+
var = main_program.global_block().var(input_name)
193+
if (
194+
var.persistable
195+
and input_name not in deleted_persist_var_names
196+
):
197+
deleted_persist_var_names.append(input_name)
198+
elif (
199+
not var.persistable
200+
and input_name not in deleted_temp_var_names
201+
):
202+
deleted_temp_var_names.append(input_name)
203+
main_program.global_block()._remove_op(idx)
204+
205+
for var_name in deleted_temp_var_names + deleted_persist_var_names:
206+
if var_name not in reserved_var_names:
207+
main_program.global_block()._remove_var(var_name)
208+
main_program.global_block()._sync_with_cpp()
209+
210+
# 1.2 delete the var and op in startup_program
211+
for reserved_name in reserved_var_names:
212+
if reserved_name in deleted_persist_var_names:
213+
deleted_persist_var_names.remove(reserved_name)
214+
startup_global_block = startup_program.global_block()
215+
for var_name in deleted_persist_var_names:
216+
if startup_global_block.has_var(var_name):
217+
startup_global_block._remove_var(var_name)
218+
for idx, op in reversed(list(enumerate(startup_global_block.ops))):
219+
inout_arg_names = op.input_arg_names + op.output_arg_names
220+
for var_name in inout_arg_names:
221+
if var_name in deleted_persist_var_names:
222+
startup_program.global_block()._remove_op(idx)
223+
break
224+
225+
# 2. re-generate new optimizer op
226+
serial_optimizer = copy.deepcopy(dist_context._serial_optimizer)
227+
serial_optimizer._learning_rate = (
228+
dist_context._serial_optimizer._learning_rate
229+
)
230+
serial_optimizer._sorted = False
231+
with program_guard(main_program, startup_program):
232+
with main_program.switch_name_generator_guard("opt_"):
233+
_ = serial_optimizer.apply_gradients(params_grads)
234+
self._completer.complete_update_annotation(main_program)
235+
236+
def _add_master_grad(self, main_program, params_grads, dist_context):
237+
grad_names = [g.name for _, g in params_grads]
238+
for sub_block in main_program.blocks:
239+
self._add_cast_op(sub_block, grad_names, dist_context)

0 commit comments

Comments
 (0)