2626)
2727from paddle .distributed .fleet .meta_optimizers .common import OP_ROLE_KEY , OpRole
2828from paddle .framework import core
29- from paddle .static .amp .bf16 .amp_utils import (
30- AutoMixedPrecisionListsBF16 ,
31- _is_in_fp32_varnames ,
32- )
3329from paddle .static .amp .fp16_utils import (
3430 AutoMixedPrecisionLists ,
3531 _is_in_black_varnames ,
@@ -88,48 +84,26 @@ def __init__(
8884 black_varnames = None ,
8985 dtype = "float16" ,
9086 ):
91- self ._amp_list = None
92- if dtype == "float16" :
93- self ._amp_list = AutoMixedPrecisionLists (
94- set (white_list ), set (black_list ), set (black_varnames )
95- )
96- elif dtype == "bfloat16" :
97- self ._amp_list = AutoMixedPrecisionListsBF16 (
98- set (white_list ), set (black_list ), set (black_varnames )
99- )
100-
101- assert self ._amp_list is not None
87+ self ._amp_list = AutoMixedPrecisionLists (
88+ set (white_list ), set (black_list ), set (black_varnames ), dtype = dtype
89+ )
10290 self ._dtype = dtype
103- self ._is_float16 = dtype == "float16"
10491
10592 @property
10693 def white_list (self ):
107- if self ._is_float16 :
108- return self ._amp_list .white_list
109- else :
110- return self ._amp_list .bf16_list
94+ return self ._amp_list .white_list
11195
11296 @property
11397 def black_list (self ):
114- if self ._is_float16 :
115- return self ._amp_list .black_list
116- else :
117- return self ._amp_list .fp32_list
98+ return self ._amp_list .black_list
11899
119100 @property
120101 def gray_list (self ):
121102 return self ._amp_list .gray_list
122103
123104 @property
124105 def black_varnames (self ):
125- if self ._is_float16 :
126- return self ._amp_list .black_varnames
127- else :
128- return self ._amp_list .fp32_varnames
129-
130- @property
131- def is_fp16 (self ):
132- return self ._is_float16
106+ return self ._amp_list .black_varnames
133107
134108 @property
135109 def dtype (self ):
@@ -140,36 +114,17 @@ def amp_list(self):
140114 return self ._amp_list
141115
142116 def _is_in_black_fp32_varnames (self , op ):
143- if self ._is_float16 :
144- return _is_in_black_varnames (op , self ._amp_list )
145- else :
146- return _is_in_fp32_varnames (op , self ._amp_list )
117+ return _is_in_black_varnames (op , self ._amp_list )
147118
148119 def _op_keep_fp32_input (self , op , in_name ):
149120 if not op .amp_options .enable :
150121 return True
151- if self ._is_float16 :
152- return _keep_fp32_input (op , in_name )
153- else :
154- if op .type in ['batch_norm' , 'layer_norm' ]:
155- return in_name != 'X'
156- if op .type == 'fused_bn_add_activation' :
157- return in_name not in {'X' , 'Z' }
158- return False
122+ return _keep_fp32_input (op , in_name )
159123
160124 def _op_keep_fp32_output (self , op , out_name ):
161125 if not op .amp_options .enable :
162126 return True
163- if self ._is_float16 :
164- return _keep_fp32_output (op , out_name )
165- else :
166- if op .type in [
167- 'batch_norm' ,
168- 'fused_bn_add_activation' ,
169- 'layer_norm' ,
170- ]:
171- return out_name != 'Y'
172- return False
127+ return _keep_fp32_output (op , out_name )
173128
174129
175130class AMPState :
@@ -324,12 +279,12 @@ def _cast_block(self, block):
324279 self .dist_context ,
325280 )
326281 elif self ._is_fp16_op (op .desc .original_id ()) is True :
327- if self . amp_dtype == "bfloat16" :
328- if (
329- op .has_attr ('dtype' )
330- and op .attr ('dtype' ) == core .VarDesc .VarType .FP32
331- ):
332- op ._set_attr ('dtype' , core . VarDesc . VarType . BF16 )
282+ # deal with op with attribute 'dtype', such as 'fill_constant'
283+ if (
284+ op .has_attr ('dtype' )
285+ and op .attr ('dtype' ) == core .VarDesc .VarType .FP32
286+ ):
287+ op ._set_attr ('dtype' , _str_to_dtype ( self . amp_dtype ) )
333288 num_cast_ops = self ._insert_cast_op_forward (
334289 block ,
335290 op ,
@@ -362,16 +317,13 @@ def _cast_block(self, block):
362317 self .dist_context ,
363318 appended_grad_times ,
364319 )
365- elif (
366- self ._is_fp16_op (op .desc .original_id ()) is True
367- ): # fp16/bf16
368- if self .amp_dtype == "bfloat16" :
369- if (
370- op .has_attr ('dtype' )
371- and op .attr ('dtype' )
372- == core .VarDesc .VarType .FP32
373- ):
374- op ._set_attr ('dtype' , core .VarDesc .VarType .BF16 )
320+ elif self ._is_fp16_op (op .desc .original_id ()) is True :
321+ # deal with op with attribute 'dtype', such as 'fill_constant'
322+ if (
323+ op .has_attr ('dtype' )
324+ and op .attr ('dtype' ) == core .VarDesc .VarType .FP32
325+ ):
326+ op ._set_attr ('dtype' , _str_to_dtype (self .amp_dtype ))
375327 num_cast_ops = self ._insert_cast_op_backward (
376328 block ,
377329 op ,
@@ -522,6 +474,7 @@ def _insert_cast_op_forward(
522474 op ._set_attr (
523475 'out_dtype' , _str_to_dtype (self .amp_dtype )
524476 )
477+
525478 return num_cast_ops
526479
527480 def _insert_cast_op_backward (
@@ -699,6 +652,12 @@ def _keep_fp32_output(op, out_name):
699652 else :
700653 assert out_var .dtype == dst_dtype
701654
655+ if (
656+ op .has_attr ('dtype' )
657+ and op .attr ('dtype' ) == core .VarDesc .VarType .FP32
658+ ):
659+ op ._set_attr ('dtype' , _str_to_dtype (self .amp_dtype ))
660+
702661 return num_cast_ops
703662
704663
0 commit comments