@@ -132,6 +132,8 @@ def product(dim):
132132 tensor_to_check_dtype = np .float16
133133 # set delta as np.float16, will automatic convert to float32, float64
134134 delta = np .array (delta ).astype (np .float16 )
135+ elif tensor_to_check_dtype == core .VarDesc .VarType .BF16 :
136+ tensor_to_check_dtype = np .float32
135137 else :
136138 raise ValueError ("Not supported data type " + str (
137139 tensor_to_check_dtype ))
@@ -140,9 +142,10 @@ def get_output():
140142 sum = []
141143 op .run (scope , place )
142144 for output_name in output_names :
143- sum .append (
144- np .array (scope .find_var (output_name ).get_tensor ()).astype (
145- tensor_to_check_dtype ).mean ())
145+ output_numpy = np .array (scope .find_var (output_name ).get_tensor ())
146+ if tensor_to_check ._dtype () == core .VarDesc .VarType .BF16 :
147+ output_numpy = convert_uint16_to_float (output_numpy )
148+ sum .append (output_numpy .astype (tensor_to_check_dtype ).mean ())
146149 return tensor_to_check_dtype (np .array (sum ).sum () / len (output_names ))
147150
148151 gradient_flat = np .zeros (shape = (tensor_size , ), dtype = tensor_to_check_dtype )
@@ -152,6 +155,11 @@ def __get_elem__(tensor, i):
152155 numpy_tensor = np .array (tensor ).astype (np .float16 )
153156 numpy_tensor = numpy_tensor .flatten ()
154157 return numpy_tensor [i ]
158+ elif tensor_to_check ._dtype () == core .VarDesc .VarType .BF16 :
159+ numpy_tensor = np .array (tensor ).astype (np .uint16 )
160+ numpy_tensor = numpy_tensor .flatten ()
161+ return struct .unpack ('<f' , struct .pack ('<I' , numpy_tensor [i ]
162+ << 16 ))[0 ]
155163 elif tensor_to_check_dtype == np .float32 :
156164 return tensor ._get_float_element (i )
157165 elif tensor_to_check_dtype == np .float64 :
@@ -168,6 +176,13 @@ def __set_elem__(tensor, i, e):
168176 numpy_tensor [i ] = e
169177 numpy_tensor = numpy_tensor .reshape (shape )
170178 tensor .set (numpy_tensor , place )
179+ elif tensor_to_check ._dtype () == core .VarDesc .VarType .BF16 :
180+ numpy_tensor = np .array (tensor ).astype (np .uint16 )
181+ shape = numpy_tensor .shape
182+ numpy_tensor = numpy_tensor .flatten ()
183+ numpy_tensor [i ] = np .uint16 (copy_bits_from_float_to_uint16 (e ))
184+ numpy_tensor = numpy_tensor .reshape (shape )
185+ tensor .set (numpy_tensor , place )
171186 elif tensor_to_check_dtype == np .float32 :
172187 tensor ._set_float_element (i , e )
173188 elif tensor_to_check_dtype == np .float64 :
@@ -1353,6 +1368,8 @@ def _assert_is_close(self, numeric_grads, analytic_grads, names,
13531368 abs_a [abs_a < 1e-10 ] = 1e-3
13541369 abs_a [np .logical_and (abs_a > 1e-10 , abs_a <= 1e-8 )] *= 1e4
13551370 abs_a [np .logical_and (abs_a > 1e-8 , abs_a <= 1e-6 )] *= 1e2
1371+ elif self .is_bfloat16_op ():
1372+ abs_a [abs_a < 1e-2 ] = 1
13561373 else :
13571374 abs_a [abs_a < 1e-3 ] = 1
13581375
@@ -1500,6 +1517,13 @@ def check_grad_with_place(self,
15001517 dygraph_grad = self ._get_dygraph_grad (
15011518 inputs_to_check , place , output_names , user_defined_grad_outputs ,
15021519 no_grad_set )
1520+ fp32_grads = []
1521+ for grad in dygraph_grad :
1522+ if grad .dtype == np .uint16 :
1523+ grad = convert_uint16_to_float (grad )
1524+ max_relative_error = 0.03
1525+ fp32_grads .append (grad )
1526+ dygraph_grad = fp32_grads
15031527 self ._assert_is_close (numeric_grads , dygraph_grad , inputs_to_check ,
15041528 max_relative_error ,
15051529 "Gradient Check On %s" % str (place ))
@@ -1544,6 +1568,21 @@ def _get_dygraph_grad(self,
15441568 outputs = outputs ,
15451569 attrs = attrs_outputs if hasattr (self , "attrs" ) else None )
15461570
1571+ if self .dtype == np .uint16 :
1572+ cast_inputs = self ._find_var_in_dygraph (outputs ,
1573+ output_names [0 ])
1574+ cast_outputs = block .create_var (
1575+ dtype = "float32" , shape = cast_inputs [0 ].shape )
1576+ cast_op = block .append_op (
1577+ inputs = {"X" : cast_inputs },
1578+ outputs = {"Out" : cast_outputs },
1579+ type = "cast" ,
1580+ attrs = {
1581+ "in_dtype" : core .VarDesc .VarType .BF16 ,
1582+ "out_dtype" : core .VarDesc .VarType .FP32
1583+ })
1584+ outputs = {output_names [0 ]: cast_outputs }
1585+
15471586 outputs_valid = {}
15481587 for output_name in output_names :
15491588 outputs_valid [output_name ] = self ._find_var_in_dygraph (
@@ -1659,6 +1698,21 @@ def _get_gradient(self,
16591698 feed_dict = self .feed_var (inputs , place )
16601699
16611700 if user_defined_grad_outputs is None :
1701+ if self .dtype == np .uint16 :
1702+ cast_inputs = list (map (block .var , output_names ))
1703+ cast_outputs = block .create_var (
1704+ dtype = "float32" , shape = cast_inputs [0 ].shape )
1705+ cast_op = block .append_op (
1706+ inputs = {"X" : cast_inputs },
1707+ outputs = {"Out" : cast_outputs },
1708+ type = "cast" ,
1709+ attrs = {
1710+ "in_dtype" : core .VarDesc .VarType .BF16 ,
1711+ "out_dtype" : core .VarDesc .VarType .FP32
1712+ })
1713+ cast_op .desc .infer_var_type (block .desc )
1714+ cast_op .desc .infer_shape (block .desc )
1715+ output_names = [cast_outputs .name ]
16621716 loss = append_loss_ops (block , output_names )
16631717 param_grad_list = append_backward (
16641718 loss = loss ,
0 commit comments