1919
2020import numpy as np
2121
22+ import paddle
2223from paddle .base import core , framework , global_scope
2324from paddle .base .log_helper import get_logger
2425from paddle .base .wrapped_decorator import signature_safe_contextmanager
@@ -60,7 +61,7 @@ def _dtype_to_str(dtype):
6061 Args:
6162 dtype (VarType): Variable type.
6263 """
63- if dtype == core . VarDesc . VarType . BF16 :
64+ if dtype == paddle . bfloat16 :
6465 return 'bf16'
6566 else :
6667 return 'fp32'
@@ -83,7 +84,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
8384 num_cast_ops = 0
8485
8586 for in_name in op .input_names :
86- if src_dtype == core . VarDesc . VarType . FP32 and op .type in [
87+ if src_dtype == paddle . float32 and op .type in [
8788 'batch_norm' ,
8889 'fused_bn_add_activation' ,
8990 'layer_norm' ,
@@ -120,10 +121,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
120121 else :
121122 if op .has_attr ('in_dtype' ):
122123 op ._set_attr ('in_dtype' , dest_dtype )
123- if (
124- src_dtype == core .VarDesc .VarType .FP32
125- and dest_dtype == core .VarDesc .VarType .BF16
126- ):
124+ if src_dtype == paddle .float32 and dest_dtype == paddle .bfloat16 :
127125 for out_name in op .output_names :
128126 if (
129127 op .type
@@ -135,7 +133,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
135133 out_var = block .var (out_var_name )
136134 if out_var .type not in _valid_types :
137135 continue
138- if out_var .dtype == core . VarDesc . VarType . FP32 :
136+ if out_var .dtype == paddle . float32 :
139137 out_var .desc .set_dtype (core .VarDesc .VarType .BF16 )
140138 if op .has_attr ('out_dtype' ):
141139 op ._set_attr ('out_dtype' , core .VarDesc .VarType .BF16 )
@@ -282,7 +280,7 @@ def cast_initializers_to_bf16(
282280
283281 if change_op and are_post_ops_bf16 (op_post_ops , keep_fp32_ops ):
284282 for out_var in op_out_vars :
285- if out_var .dtype == core . VarDesc . VarType . FP32 :
283+ if out_var .dtype == paddle . float32 :
286284 out_var .desc .set_dtype (core .VarDesc .VarType .BF16 )
287285 if (
288286 to_bf16_var_names is not None
@@ -352,7 +350,7 @@ def cast_model_to_bf16(
352350 if in_var is None or in_var .type not in _valid_types :
353351 continue
354352
355- if in_var .dtype == core . VarDesc . VarType . FP32 :
353+ if in_var .dtype == paddle . float32 :
356354 in_var .desc .set_dtype (core .VarDesc .VarType .BF16 )
357355 to_bf16_var_names .add (in_var_name )
358356
@@ -386,7 +384,7 @@ def cast_model_to_bf16(
386384 if out_var is None or out_var .type not in _valid_types :
387385 continue
388386
389- if out_var .dtype == core . VarDesc . VarType . FP32 :
387+ if out_var .dtype == paddle . float32 :
390388 out_var .desc .set_dtype (core .VarDesc .VarType .BF16 )
391389
392390 _logger .debug (
@@ -397,7 +395,7 @@ def cast_model_to_bf16(
397395 for attr_name in ['in_dtype' , 'out_dtype' , 'dtype' ]:
398396 if (
399397 op .has_attr (attr_name )
400- and op .attr (attr_name ) == core . VarDesc . VarType . FP32
398+ and op .attr (attr_name ) == paddle . float32
401399 ):
402400 op ._set_attr (attr_name , core .VarDesc .VarType .BF16 )
403401
@@ -444,7 +442,7 @@ def cast_model_to_bf16(
444442 out_var = block .vars .get (out_var_name )
445443 if out_var is None or out_var .type not in _valid_types :
446444 continue
447- if out_var .dtype == core . VarDesc . VarType . BF16 :
445+ if out_var .dtype == paddle . bfloat16 :
448446 out_var .desc .set_dtype (core .VarDesc .VarType .FP32 )
449447 post_ops = find_true_post_op (ops , op , out_var_name )
450448 for post_op in post_ops :
0 commit comments