Skip to content

Commit 4458b3a

Browse files
authored
[Cleanup][A-12] clean some VarType for test (#61565)
1 parent 9e193aa commit 4458b3a

4 files changed

Lines changed: 17 additions & 19 deletions

File tree

python/paddle/jit/translated_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def _preprocess(self, program_desc):
503503
def _append_scale_to_output(self, program):
504504
# 0. scale don't support bool output, we skip append scale for it
505505
for out_desc in self._output_descs:
506-
if out_desc.dtype() == core.VarDesc.VarType.BOOL:
506+
if out_desc.dtype() == paddle.bool:
507507
return
508508

509509
# 1. append scale & save var

python/paddle/static/amp/amp_nn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import paddle
1516
from paddle import _C_ops
16-
from paddle.base import core
1717
from paddle.base.data_feeder import check_type, check_variable_and_dtype
1818
from paddle.base.framework import Variable, in_dygraph_mode
1919
from paddle.base.layer_helper import LayerHelper
@@ -130,9 +130,9 @@ def update_loss_scaling(
130130
['float16', 'float32', 'float64', 'uint16'],
131131
'update_loss_scaling',
132132
)
133-
if e.dtype in [core.VarDesc.VarType.FP16, core.VarDesc.VarType.BF16]:
133+
if e.dtype in [paddle.float16, paddle.bfloat16]:
134134
assert (
135-
prev_loss_scaling.dtype == core.VarDesc.VarType.FP32
135+
prev_loss_scaling.dtype == paddle.float32
136136
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16 or bfloat16."
137137
else:
138138
assert (

python/paddle/static/amp/bf16/amp_utils.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import numpy as np
2121

22+
import paddle
2223
from paddle.base import core, framework, global_scope
2324
from paddle.base.log_helper import get_logger
2425
from paddle.base.wrapped_decorator import signature_safe_contextmanager
@@ -60,7 +61,7 @@ def _dtype_to_str(dtype):
6061
Args:
6162
dtype (VarType): Variable type.
6263
"""
63-
if dtype == core.VarDesc.VarType.BF16:
64+
if dtype == paddle.bfloat16:
6465
return 'bf16'
6566
else:
6667
return 'fp32'
@@ -83,7 +84,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
8384
num_cast_ops = 0
8485

8586
for in_name in op.input_names:
86-
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
87+
if src_dtype == paddle.float32 and op.type in [
8788
'batch_norm',
8889
'fused_bn_add_activation',
8990
'layer_norm',
@@ -120,10 +121,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
120121
else:
121122
if op.has_attr('in_dtype'):
122123
op._set_attr('in_dtype', dest_dtype)
123-
if (
124-
src_dtype == core.VarDesc.VarType.FP32
125-
and dest_dtype == core.VarDesc.VarType.BF16
126-
):
124+
if src_dtype == paddle.float32 and dest_dtype == paddle.bfloat16:
127125
for out_name in op.output_names:
128126
if (
129127
op.type
@@ -135,7 +133,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
135133
out_var = block.var(out_var_name)
136134
if out_var.type not in _valid_types:
137135
continue
138-
if out_var.dtype == core.VarDesc.VarType.FP32:
136+
if out_var.dtype == paddle.float32:
139137
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
140138
if op.has_attr('out_dtype'):
141139
op._set_attr('out_dtype', core.VarDesc.VarType.BF16)
@@ -282,7 +280,7 @@ def cast_initializers_to_bf16(
282280

283281
if change_op and are_post_ops_bf16(op_post_ops, keep_fp32_ops):
284282
for out_var in op_out_vars:
285-
if out_var.dtype == core.VarDesc.VarType.FP32:
283+
if out_var.dtype == paddle.float32:
286284
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
287285
if (
288286
to_bf16_var_names is not None
@@ -352,7 +350,7 @@ def cast_model_to_bf16(
352350
if in_var is None or in_var.type not in _valid_types:
353351
continue
354352

355-
if in_var.dtype == core.VarDesc.VarType.FP32:
353+
if in_var.dtype == paddle.float32:
356354
in_var.desc.set_dtype(core.VarDesc.VarType.BF16)
357355
to_bf16_var_names.add(in_var_name)
358356

@@ -386,7 +384,7 @@ def cast_model_to_bf16(
386384
if out_var is None or out_var.type not in _valid_types:
387385
continue
388386

389-
if out_var.dtype == core.VarDesc.VarType.FP32:
387+
if out_var.dtype == paddle.float32:
390388
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
391389

392390
_logger.debug(
@@ -397,7 +395,7 @@ def cast_model_to_bf16(
397395
for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
398396
if (
399397
op.has_attr(attr_name)
400-
and op.attr(attr_name) == core.VarDesc.VarType.FP32
398+
and op.attr(attr_name) == paddle.float32
401399
):
402400
op._set_attr(attr_name, core.VarDesc.VarType.BF16)
403401

@@ -444,7 +442,7 @@ def cast_model_to_bf16(
444442
out_var = block.vars.get(out_var_name)
445443
if out_var is None or out_var.type not in _valid_types:
446444
continue
447-
if out_var.dtype == core.VarDesc.VarType.BF16:
445+
if out_var.dtype == paddle.bfloat16:
448446
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
449447
post_ops = find_true_post_op(ops, op, out_var_name)
450448
for post_op in post_ops:

python/paddle/static/amp/debugging.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ def update(self, dtype):
3535
if dtype is None:
3636
self.other_calls = self.other_calls + 1
3737
else:
38-
if dtype == paddle.base.core.VarDesc.VarType.FP32:
38+
if dtype == paddle.float32:
3939
self.fp32_calls = self.fp32_calls + 1
40-
elif dtype == paddle.base.core.VarDesc.VarType.FP16:
40+
elif dtype == paddle.float16:
4141
self.fp16_calls = self.fp16_calls + 1
42-
elif dtype == paddle.base.core.VarDesc.VarType.BF16:
42+
elif dtype == paddle.bfloat16:
4343
self.bf16_calls = self.bf16_calls + 1
4444
else:
4545
self.other_calls = self.other_calls + 1

0 commit comments

Comments
 (0)