Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/jit/translated_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def _preprocess(self, program_desc):
def _append_scale_to_output(self, program):
# 0. scale don't support bool output, we skip append scale for it
for out_desc in self._output_descs:
if out_desc.dtype() == core.VarDesc.VarType.BOOL:
if out_desc.dtype() == paddle.bool:
return

# 1. append scale & save var
Expand Down
5 changes: 1 addition & 4 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1922,10 +1922,7 @@ def _is_dtype_fp16_or_bf16(self, dtype):
dtype, (core.VarDesc.VarType, core.DataType)
), "The dtype should be an instance of core.VarDesc.VarType or core.DataType."
if isinstance(dtype, core.VarDesc.VarType):
return (
dtype == core.VarDesc.VarType.FP16
or dtype == core.VarDesc.VarType.BF16
)
return dtype == paddle.float16 or dtype == paddle.bfloat16
else:
return (
dtype == core.DataType.FLOAT16 or dtype == core.DataType.UINT16
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/static/amp/amp_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle import _C_ops
from paddle.base import core
from paddle.base.data_feeder import check_type, check_variable_and_dtype
Expand Down Expand Up @@ -132,7 +133,7 @@ def update_loss_scaling(
)
if e.dtype in [core.VarDesc.VarType.FP16, core.VarDesc.VarType.BF16]:
assert (
prev_loss_scaling.dtype == core.VarDesc.VarType.FP32
prev_loss_scaling.dtype == paddle.float32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16 or bfloat16."
else:
assert (
Expand Down
32 changes: 12 additions & 20 deletions python/paddle/static/amp/bf16/amp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np

import paddle
from paddle.base import core, framework, global_scope
from paddle.base.log_helper import get_logger
from paddle.base.wrapped_decorator import signature_safe_contextmanager
Expand Down Expand Up @@ -60,7 +61,7 @@ def _dtype_to_str(dtype):
Args:
dtype (VarType): Variable type.
"""
if dtype == core.VarDesc.VarType.BF16:
if dtype == paddle.bfloat16:
return 'bf16'
else:
return 'fp32'
Expand All @@ -83,7 +84,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
num_cast_ops = 0

for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
if src_dtype == paddle.float32 and op.type in [
'batch_norm',
'fused_bn_add_activation',
'layer_norm',
Expand Down Expand Up @@ -120,10 +121,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype)
if (
src_dtype == core.VarDesc.VarType.FP32
and dest_dtype == core.VarDesc.VarType.BF16
):
if src_dtype == paddle.float32 and dest_dtype == paddle.bfloat16:
for out_name in op.output_names:
if (
op.type
Expand All @@ -135,7 +133,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
out_var = block.var(out_var_name)
if out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
if out_var.dtype == paddle.float32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.BF16)
Expand Down Expand Up @@ -282,17 +280,14 @@ def cast_initializers_to_bf16(

if change_op and are_post_ops_bf16(op_post_ops, keep_fp32_ops):
for out_var in op_out_vars:
if out_var.dtype == core.VarDesc.VarType.FP32:
if out_var.dtype == paddle.float32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
if (
to_bf16_var_names is not None
and out_var.name in to_bf16_var_names
):
to_bf16_var_names.remove(out_var.name)
if (
op.has_attr('dtype')
and op.attr('dtype') == core.VarDesc.VarType.FP32
):
if op.has_attr('dtype') and op.attr('dtype') == paddle.float32:
op._set_attr('dtype', core.VarDesc.VarType.BF16)


Expand Down Expand Up @@ -352,7 +347,7 @@ def cast_model_to_bf16(
if in_var is None or in_var.type not in _valid_types:
continue

if in_var.dtype == core.VarDesc.VarType.FP32:
if in_var.dtype == paddle.float32:
in_var.desc.set_dtype(core.VarDesc.VarType.BF16)
to_bf16_var_names.add(in_var_name)

Expand Down Expand Up @@ -386,7 +381,7 @@ def cast_model_to_bf16(
if out_var is None or out_var.type not in _valid_types:
continue

if out_var.dtype == core.VarDesc.VarType.FP32:
if out_var.dtype == paddle.float32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)

_logger.debug(
Expand All @@ -397,7 +392,7 @@ def cast_model_to_bf16(
for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
if (
op.has_attr(attr_name)
and op.attr(attr_name) == core.VarDesc.VarType.FP32
and op.attr(attr_name) == paddle.float32
):
op._set_attr(attr_name, core.VarDesc.VarType.BF16)

Expand Down Expand Up @@ -444,7 +439,7 @@ def cast_model_to_bf16(
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.BF16:
if out_var.dtype == paddle.bfloat16:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
post_ops = find_true_post_op(ops, op, out_var_name)
for post_op in post_ops:
Expand Down Expand Up @@ -589,10 +584,7 @@ def rewrite_program_bf16(main_prog, amp_lists=None):
core.VarDesc.VarType.FP32,
)
elif op in bf16_op_set:
if (
op.has_attr('dtype')
and op.attr('dtype') == core.VarDesc.VarType.FP32
):
if op.has_attr('dtype') and op.attr('dtype') == paddle.float32:
op._set_attr('dtype', core.VarDesc.VarType.BF16)

num_cast_ops = _insert_cast_op(
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/static/amp/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def update(self, dtype):
if dtype is None:
self.other_calls = self.other_calls + 1
else:
if dtype == paddle.base.core.VarDesc.VarType.FP32:
if dtype == paddle.float32:
self.fp32_calls = self.fp32_calls + 1
elif dtype == paddle.base.core.VarDesc.VarType.FP16:
elif dtype == paddle.float16:
self.fp16_calls = self.fp16_calls + 1
elif dtype == paddle.base.core.VarDesc.VarType.BF16:
elif dtype == paddle.bfloat16:
self.bf16_calls = self.bf16_calls + 1
else:
self.other_calls = self.other_calls + 1
Expand Down