@@ -340,9 +340,9 @@ def _cast_block(self, block):
340340 out_var = block .var (out_var_name )
341341 in_var = block ._find_var_recursive (in_var_name )
342342 for in_var_name in op .input_arg_names :
343- assert (
344- in_var . dtype == block .var (in_var_name ). dtype
345- ), f" { in_var } , { block . var ( in_var_name ) } , { op } "
343+ assert in_var . dtype == block . var ( in_var_name ). dtype , (
344+ f" { in_var } , { block .var (in_var_name )} , { op } "
345+ )
346346 out_var .desc .set_dtype (in_var .dtype )
347347 elif int (op .attr ('op_role' )) == 257 :
348348 pass
@@ -545,9 +545,9 @@ def _keep_fp32_output(op, out_name):
545545 cast_name , in_var_dist_attr
546546 )
547547 else :
548- assert (
549- in_var . dtype == dst_dtype
550- ), f"op [ { op . type } ] expect input [ { in_name } ] to be dtype [ { dst_dtype } ] BUT got [ { in_var . dtype } ]. { op } "
548+ assert in_var . dtype == dst_dtype , (
549+ f"op [ { op . type } ] expect input [ { in_name } ] to be dtype [ { dst_dtype } ] BUT got [ { in_var . dtype } ]. { op } "
550+ )
551551
552552 for out_name in op .output_names :
553553 if src_dtype == paddle .float32 and _keep_fp32_output (op , out_name ):
@@ -1158,13 +1158,13 @@ def _update_loss_scaling(self, grads, found_inf):
11581158 e , "x" , ['float16' , 'float32' , 'float64' ], 'update_loss_scaling'
11591159 )
11601160 if e .dtype == paddle .float16 :
1161- assert (
1162- self . _loss_scaling . dtype == paddle . float32
1163- ), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
1161+ assert self . _loss_scaling . dtype == paddle . float32 , (
1162+ "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
1163+ )
11641164 else :
1165- assert (
1166- self . _loss_scaling . dtype == e . dtype
1167- ), "The dtype of prev_loss_scaling should be equal to the dtype of x."
1165+ assert self . _loss_scaling . dtype == e . dtype , (
1166+ "The dtype of prev_loss_scaling should be equal to the dtype of x."
1167+ )
11681168
11691169 inputs = {
11701170 'X' : grads ,
0 commit comments