Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/jit/translated_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def _preprocess(self, program_desc):
def _append_scale_to_output(self, program):
# 0. scale don't support bool output, we skip append scale for it
for out_desc in self._output_descs:
if out_desc.dtype() == core.VarDesc.VarType.BOOL:
if out_desc.dtype() == paddle.bool:
return

# 1. append scale & save var
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/static/amp/amp_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle import _C_ops
from paddle.base import core
from paddle.base.data_feeder import check_type, check_variable_and_dtype
from paddle.base.framework import Variable, in_dygraph_mode
from paddle.base.layer_helper import LayerHelper
Expand Down Expand Up @@ -130,9 +130,9 @@ def update_loss_scaling(
['float16', 'float32', 'float64', 'uint16'],
'update_loss_scaling',
)
if e.dtype in [core.VarDesc.VarType.FP16, core.VarDesc.VarType.BF16]:
if e.dtype in [paddle.float16, paddle.bfloat16]:
assert (
prev_loss_scaling.dtype == core.VarDesc.VarType.FP32
prev_loss_scaling.dtype == paddle.float32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16 or bfloat16."
else:
assert (
Expand Down
22 changes: 10 additions & 12 deletions python/paddle/static/amp/bf16/amp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np

import paddle
from paddle.base import core, framework, global_scope
from paddle.base.log_helper import get_logger
from paddle.base.wrapped_decorator import signature_safe_contextmanager
Expand Down Expand Up @@ -60,7 +61,7 @@ def _dtype_to_str(dtype):
Args:
dtype (VarType): Variable type.
"""
if dtype == core.VarDesc.VarType.BF16:
if dtype == paddle.bfloat16:
return 'bf16'
else:
return 'fp32'
Expand All @@ -83,7 +84,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
num_cast_ops = 0

for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
if src_dtype == paddle.float32 and op.type in [
'batch_norm',
'fused_bn_add_activation',
'layer_norm',
Expand Down Expand Up @@ -120,10 +121,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype)
if (
src_dtype == core.VarDesc.VarType.FP32
and dest_dtype == core.VarDesc.VarType.BF16
):
if src_dtype == paddle.float32 and dest_dtype == paddle.bfloat16:
for out_name in op.output_names:
if (
op.type
Expand All @@ -135,7 +133,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
out_var = block.var(out_var_name)
if out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
if out_var.dtype == paddle.float32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.BF16)
Expand Down Expand Up @@ -282,7 +280,7 @@ def cast_initializers_to_bf16(

if change_op and are_post_ops_bf16(op_post_ops, keep_fp32_ops):
for out_var in op_out_vars:
if out_var.dtype == core.VarDesc.VarType.FP32:
if out_var.dtype == paddle.float32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
if (
to_bf16_var_names is not None
Expand Down Expand Up @@ -352,7 +350,7 @@ def cast_model_to_bf16(
if in_var is None or in_var.type not in _valid_types:
continue

if in_var.dtype == core.VarDesc.VarType.FP32:
if in_var.dtype == paddle.float32:
in_var.desc.set_dtype(core.VarDesc.VarType.BF16)
to_bf16_var_names.add(in_var_name)

Expand Down Expand Up @@ -386,7 +384,7 @@ def cast_model_to_bf16(
if out_var is None or out_var.type not in _valid_types:
continue

if out_var.dtype == core.VarDesc.VarType.FP32:
if out_var.dtype == paddle.float32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)

_logger.debug(
Expand All @@ -397,7 +395,7 @@ def cast_model_to_bf16(
for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
if (
op.has_attr(attr_name)
and op.attr(attr_name) == core.VarDesc.VarType.FP32
and op.attr(attr_name) == paddle.float32
):
op._set_attr(attr_name, core.VarDesc.VarType.BF16)

Expand Down Expand Up @@ -444,7 +442,7 @@ def cast_model_to_bf16(
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.BF16:
if out_var.dtype == paddle.bfloat16:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
post_ops = find_true_post_op(ops, op, out_var_name)
for post_op in post_ops:
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/static/amp/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def update(self, dtype):
if dtype is None:
self.other_calls = self.other_calls + 1
else:
if dtype == paddle.base.core.VarDesc.VarType.FP32:
if dtype == paddle.float32:
self.fp32_calls = self.fp32_calls + 1
elif dtype == paddle.base.core.VarDesc.VarType.FP16:
elif dtype == paddle.float16:
self.fp16_calls = self.fp16_calls + 1
elif dtype == paddle.base.core.VarDesc.VarType.BF16:
elif dtype == paddle.bfloat16:
self.bf16_calls = self.bf16_calls + 1
else:
self.other_calls = self.other_calls + 1
Expand Down