diff --git a/python/paddle/jit/translated_layer.py b/python/paddle/jit/translated_layer.py index 53d3ff9a718c84..2805fcae597bb7 100644 --- a/python/paddle/jit/translated_layer.py +++ b/python/paddle/jit/translated_layer.py @@ -503,7 +503,7 @@ def _preprocess(self, program_desc): def _append_scale_to_output(self, program): # 0. scale don't support bool output, we skip append scale for it for out_desc in self._output_descs: - if out_desc.dtype() == core.VarDesc.VarType.BOOL: + if out_desc.dtype() == paddle.bool: return # 1. append scale & save var diff --git a/python/paddle/static/amp/amp_nn.py b/python/paddle/static/amp/amp_nn.py index 304ad8d7d6c405..eba7c9a1924769 100644 --- a/python/paddle/static/amp/amp_nn.py +++ b/python/paddle/static/amp/amp_nn.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle from paddle import _C_ops -from paddle.base import core from paddle.base.data_feeder import check_type, check_variable_and_dtype from paddle.base.framework import Variable, in_dygraph_mode from paddle.base.layer_helper import LayerHelper @@ -130,9 +130,9 @@ def update_loss_scaling( ['float16', 'float32', 'float64', 'uint16'], 'update_loss_scaling', ) - if e.dtype in [core.VarDesc.VarType.FP16, core.VarDesc.VarType.BF16]: + if e.dtype in [paddle.float16, paddle.bfloat16]: assert ( - prev_loss_scaling.dtype == core.VarDesc.VarType.FP32 + prev_loss_scaling.dtype == paddle.float32 ), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16 or bfloat16." else: assert ( diff --git a/python/paddle/static/amp/bf16/amp_utils.py b/python/paddle/static/amp/bf16/amp_utils.py index 8a73b2c1f5514a..468c6256cddc21 100644 --- a/python/paddle/static/amp/bf16/amp_utils.py +++ b/python/paddle/static/amp/bf16/amp_utils.py @@ -19,6 +19,7 @@ import numpy as np +import paddle from paddle.base import core, framework, global_scope from paddle.base.log_helper import get_logger from paddle.base.wrapped_decorator import signature_safe_contextmanager @@ -60,7 +61,7 @@ def _dtype_to_str(dtype): Args: dtype (VarType): Variable type. """ - if dtype == core.VarDesc.VarType.BF16: + if dtype == paddle.bfloat16: return 'bf16' else: return 'fp32' @@ -83,7 +84,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): num_cast_ops = 0 for in_name in op.input_names: - if src_dtype == core.VarDesc.VarType.FP32 and op.type in [ + if src_dtype == paddle.float32 and op.type in [ 'batch_norm', 'fused_bn_add_activation', 'layer_norm', @@ -120,10 +121,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): else: if op.has_attr('in_dtype'): op._set_attr('in_dtype', dest_dtype) - if ( - src_dtype == core.VarDesc.VarType.FP32 - and dest_dtype == core.VarDesc.VarType.BF16 - ): + if src_dtype == paddle.float32 and dest_dtype == paddle.bfloat16: for out_name in op.output_names: if ( op.type @@ -135,7 +133,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): out_var = block.var(out_var_name) if out_var.type not in _valid_types: continue - if out_var.dtype == core.VarDesc.VarType.FP32: + if out_var.dtype == paddle.float32: out_var.desc.set_dtype(core.VarDesc.VarType.BF16) if op.has_attr('out_dtype'): op._set_attr('out_dtype', core.VarDesc.VarType.BF16) @@ -282,7 +280,7 @@ def cast_initializers_to_bf16( if change_op and are_post_ops_bf16(op_post_ops, keep_fp32_ops): for out_var in op_out_vars: - if out_var.dtype == core.VarDesc.VarType.FP32: + if out_var.dtype == paddle.float32: out_var.desc.set_dtype(core.VarDesc.VarType.BF16) if ( to_bf16_var_names is not None @@ -352,7 +350,7 @@ def cast_model_to_bf16( if in_var is None or in_var.type not in _valid_types: continue - if in_var.dtype == core.VarDesc.VarType.FP32: + if in_var.dtype == paddle.float32: in_var.desc.set_dtype(core.VarDesc.VarType.BF16) to_bf16_var_names.add(in_var_name) @@ -386,7 +384,7 @@ def cast_model_to_bf16( if out_var is None or out_var.type not in _valid_types: continue - if out_var.dtype == core.VarDesc.VarType.FP32: + if out_var.dtype == paddle.float32: out_var.desc.set_dtype(core.VarDesc.VarType.BF16) _logger.debug( @@ -397,7 +395,7 @@ def cast_model_to_bf16( for attr_name in ['in_dtype', 'out_dtype', 'dtype']: if ( op.has_attr(attr_name) - and op.attr(attr_name) == core.VarDesc.VarType.FP32 + and op.attr(attr_name) == paddle.float32 ): op._set_attr(attr_name, core.VarDesc.VarType.BF16) @@ -444,7 +442,7 @@ def cast_model_to_bf16( out_var = block.vars.get(out_var_name) if out_var is None or out_var.type not in _valid_types: continue - if out_var.dtype == core.VarDesc.VarType.BF16: + if out_var.dtype == paddle.bfloat16: out_var.desc.set_dtype(core.VarDesc.VarType.FP32) post_ops = find_true_post_op(ops, op, out_var_name) for post_op in post_ops: diff --git a/python/paddle/static/amp/debugging.py b/python/paddle/static/amp/debugging.py index 9dcef221fbfc03..954a958d939dbf 100644 --- a/python/paddle/static/amp/debugging.py +++ b/python/paddle/static/amp/debugging.py @@ -35,11 +35,11 @@ def update(self, dtype): if dtype is None: self.other_calls = self.other_calls + 1 else: - if dtype == paddle.base.core.VarDesc.VarType.FP32: + if dtype == paddle.float32: self.fp32_calls = self.fp32_calls + 1 - elif dtype == paddle.base.core.VarDesc.VarType.FP16: + elif dtype == paddle.float16: self.fp16_calls = self.fp16_calls + 1 - elif dtype == paddle.base.core.VarDesc.VarType.BF16: + elif dtype == paddle.bfloat16: self.bf16_calls = self.bf16_calls + 1 else: self.other_calls = self.other_calls + 1