Skip to content

Commit 504a590

Browse files
【AutoParallel】Unify the fp16 and bf16 in auto-parallel (#60514)
* unify the fp16 and bf16 * change white_list in AMP * add dtype support * fix bug in dtype
1 parent 620e371 commit 504a590

File tree

3 files changed

+34
-78
lines changed

3 files changed

+34
-78
lines changed

python/paddle/amp/amp_lists.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
'max_pool2d_with_index',
2323
'mul',
2424
'fused_gemm_epilogue',
25+
"fused_rotary_position_embedding",
26+
"flash_attn",
2527
}
2628

2729
# The set of ops that support fp16, and bf16 was unsupported.

python/paddle/distributed/passes/auto_parallel_amp.py

Lines changed: 29 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@
2626
)
2727
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
2828
from paddle.framework import core
29-
from paddle.static.amp.bf16.amp_utils import (
30-
AutoMixedPrecisionListsBF16,
31-
_is_in_fp32_varnames,
32-
)
3329
from paddle.static.amp.fp16_utils import (
3430
AutoMixedPrecisionLists,
3531
_is_in_black_varnames,
@@ -88,48 +84,26 @@ def __init__(
8884
black_varnames=None,
8985
dtype="float16",
9086
):
91-
self._amp_list = None
92-
if dtype == "float16":
93-
self._amp_list = AutoMixedPrecisionLists(
94-
set(white_list), set(black_list), set(black_varnames)
95-
)
96-
elif dtype == "bfloat16":
97-
self._amp_list = AutoMixedPrecisionListsBF16(
98-
set(white_list), set(black_list), set(black_varnames)
99-
)
100-
101-
assert self._amp_list is not None
87+
self._amp_list = AutoMixedPrecisionLists(
88+
set(white_list), set(black_list), set(black_varnames), dtype=dtype
89+
)
10290
self._dtype = dtype
103-
self._is_float16 = dtype == "float16"
10491

10592
@property
10693
def white_list(self):
107-
if self._is_float16:
108-
return self._amp_list.white_list
109-
else:
110-
return self._amp_list.bf16_list
94+
return self._amp_list.white_list
11195

11296
@property
11397
def black_list(self):
114-
if self._is_float16:
115-
return self._amp_list.black_list
116-
else:
117-
return self._amp_list.fp32_list
98+
return self._amp_list.black_list
11899

119100
@property
120101
def gray_list(self):
121102
return self._amp_list.gray_list
122103

123104
@property
124105
def black_varnames(self):
125-
if self._is_float16:
126-
return self._amp_list.black_varnames
127-
else:
128-
return self._amp_list.fp32_varnames
129-
130-
@property
131-
def is_fp16(self):
132-
return self._is_float16
106+
return self._amp_list.black_varnames
133107

134108
@property
135109
def dtype(self):
@@ -140,36 +114,17 @@ def amp_list(self):
140114
return self._amp_list
141115

142116
def _is_in_black_fp32_varnames(self, op):
143-
if self._is_float16:
144-
return _is_in_black_varnames(op, self._amp_list)
145-
else:
146-
return _is_in_fp32_varnames(op, self._amp_list)
117+
return _is_in_black_varnames(op, self._amp_list)
147118

148119
def _op_keep_fp32_input(self, op, in_name):
149120
if not op.amp_options.enable:
150121
return True
151-
if self._is_float16:
152-
return _keep_fp32_input(op, in_name)
153-
else:
154-
if op.type in ['batch_norm', 'layer_norm']:
155-
return in_name != 'X'
156-
if op.type == 'fused_bn_add_activation':
157-
return in_name not in {'X', 'Z'}
158-
return False
122+
return _keep_fp32_input(op, in_name)
159123

160124
def _op_keep_fp32_output(self, op, out_name):
161125
if not op.amp_options.enable:
162126
return True
163-
if self._is_float16:
164-
return _keep_fp32_output(op, out_name)
165-
else:
166-
if op.type in [
167-
'batch_norm',
168-
'fused_bn_add_activation',
169-
'layer_norm',
170-
]:
171-
return out_name != 'Y'
172-
return False
127+
return _keep_fp32_output(op, out_name)
173128

174129

175130
class AMPState:
@@ -324,12 +279,12 @@ def _cast_block(self, block):
324279
self.dist_context,
325280
)
326281
elif self._is_fp16_op(op.desc.original_id()) is True:
327-
if self.amp_dtype == "bfloat16":
328-
if (
329-
op.has_attr('dtype')
330-
and op.attr('dtype') == core.VarDesc.VarType.FP32
331-
):
332-
op._set_attr('dtype', core.VarDesc.VarType.BF16)
282+
# deal with op with attribute 'dtype', such as 'fill_constant'
283+
if (
284+
op.has_attr('dtype')
285+
and op.attr('dtype') == core.VarDesc.VarType.FP32
286+
):
287+
op._set_attr('dtype', _str_to_dtype(self.amp_dtype))
333288
num_cast_ops = self._insert_cast_op_forward(
334289
block,
335290
op,
@@ -362,16 +317,13 @@ def _cast_block(self, block):
362317
self.dist_context,
363318
appended_grad_times,
364319
)
365-
elif (
366-
self._is_fp16_op(op.desc.original_id()) is True
367-
): # fp16/bf16
368-
if self.amp_dtype == "bfloat16":
369-
if (
370-
op.has_attr('dtype')
371-
and op.attr('dtype')
372-
== core.VarDesc.VarType.FP32
373-
):
374-
op._set_attr('dtype', core.VarDesc.VarType.BF16)
320+
elif self._is_fp16_op(op.desc.original_id()) is True:
321+
# deal with op with attribute 'dtype', such as 'fill_constant'
322+
if (
323+
op.has_attr('dtype')
324+
and op.attr('dtype') == core.VarDesc.VarType.FP32
325+
):
326+
op._set_attr('dtype', _str_to_dtype(self.amp_dtype))
375327
num_cast_ops = self._insert_cast_op_backward(
376328
block,
377329
op,
@@ -522,6 +474,7 @@ def _insert_cast_op_forward(
522474
op._set_attr(
523475
'out_dtype', _str_to_dtype(self.amp_dtype)
524476
)
477+
525478
return num_cast_ops
526479

527480
def _insert_cast_op_backward(
@@ -699,6 +652,12 @@ def _keep_fp32_output(op, out_name):
699652
else:
700653
assert out_var.dtype == dst_dtype
701654

655+
if (
656+
op.has_attr('dtype')
657+
and op.attr('dtype') == core.VarDesc.VarType.FP32
658+
):
659+
op._set_attr('dtype', _str_to_dtype(self.amp_dtype))
660+
702661
return num_cast_ops
703662

704663

python/paddle/distributed/passes/auto_parallel_fp16.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from collections import defaultdict
1717

1818
import paddle
19+
import paddle.static.amp.fp16_utils as amp_utils
1920
from paddle.common_ops_import import check_type, check_variable_and_dtype
2021
from paddle.distributed.auto_parallel.static.dist_attribute import (
2122
OperatorDistAttr,
@@ -831,19 +832,12 @@ def _apply_single_impl(self, main_program, startup_program, context):
831832
if self.use_optimizer_fp16 is None:
832833
self.use_optimizer_fp16 = self.get_attr("level", None) == "o3"
833834

835+
AMPList = amp_utils.AutoMixedPrecisionLists
834836
# swith enviroment for fp16 / bf16.
835837
if self.target_dtype == "float16":
836-
import paddle.static.amp.fp16_utils as amp_utils
837-
838-
AMPList = amp_utils.AutoMixedPrecisionLists
839838
__target_dtype = core.VarDesc.VarType.FP16
840-
841839
elif self.target_dtype == "bfloat16":
842-
from paddle.static.amp.bf16 import amp_utils
843-
844-
AMPList = amp_utils.AutoMixedPrecisionListsBF16
845840
__target_dtype = core.VarDesc.VarType.BF16
846-
847841
else:
848842
raise NotImplementedError(
849843
f"target dtype [{self.target_dtype}] is for amp o2 not supported yet."
@@ -856,6 +850,7 @@ def _apply_single_impl(self, main_program, startup_program, context):
856850
set(self.get_attr("custom_white_list")),
857851
set(self.get_attr("custom_black_list")),
858852
None,
853+
dtype=self.target_dtype,
859854
)
860855

861856
# NOTE don't not change input data dtype, since it is controled by dataloader

0 commit comments

Comments
 (0)