Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions python/paddle/amp/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,9 @@ def _unscale(self, optimizer):
for param in group['params']:
if param._grad_ivar() is not None:
param_grads.append(param._grad_ivar())
if (
param._grad_ivar().dtype
== core.VarDesc.VarType.FP16
):
if param._grad_ivar().dtype == paddle.float16:
param_grads_fp16.append(param._grad_ivar())
elif (
param._grad_ivar().dtype
== core.VarDesc.VarType.BF16
):
elif param._grad_ivar().dtype == paddle.bfloat16:
param_grads_bf16.append(param._grad_ivar())
else:
param_grads_fp32.append(param._grad_ivar())
Expand All @@ -328,17 +322,17 @@ def _unscale(self, optimizer):
param_grads_fp16 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.FP16
if param.dtype == paddle.float16
]
param_grads_bf16 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.BF16
if param.dtype == paddle.bfloat16
]
param_grads_fp32 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.FP32
if param.dtype == paddle.float32
]
self._found_inf = self._temp_found_inf_value_false
if len(param_grads_fp16):
Expand Down
7 changes: 4 additions & 3 deletions python/paddle/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from google.protobuf import text_format

import paddle
from paddle.base.proto import data_feed_pb2

from ..utils import deprecated
Expand Down Expand Up @@ -271,11 +272,11 @@ def set_use_var(self, var_list):
if var.lod_level == 0:
slot_var.is_dense = True
slot_var.shape.extend(var.shape)
if var.dtype == core.VarDesc.VarType.FP32:
if var.dtype == paddle.float32:
slot_var.type = "float"
elif var.dtype == core.VarDesc.VarType.INT64:
elif var.dtype == paddle.int64:
slot_var.type = "uint64"
elif var.dtype == core.VarDesc.VarType.INT32:
elif var.dtype == paddle.int32:
slot_var.type = "uint32"
else:
raise ValueError(
Expand Down
9 changes: 5 additions & 4 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import numpy as np

import paddle
import paddle.version as paddle_version

from .. import pir
Expand Down Expand Up @@ -8154,13 +8155,13 @@ def _get_paddle_place_list(places):


def dtype_to_str(in_dtype):
if in_dtype == core.VarDesc.VarType.FP16:
if in_dtype == paddle.float16:
return "fp16"
elif in_dtype == core.VarDesc.VarType.BF16:
elif in_dtype == paddle.bfloat16:
return "bf16"
elif in_dtype == core.VarDesc.VarType.FP32:
elif in_dtype == paddle.float32:
return "fp32"
elif in_dtype == core.VarDesc.VarType.FP64:
elif in_dtype == paddle.float64:
return "fp64"
else:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np

import paddle
from paddle.base.core import VarDesc
from paddle.utils.flops import flops

from ..cluster import DeviceType, LinkType, get_default_cluster
Expand Down Expand Up @@ -968,11 +967,11 @@ def calc_time_by_cost_model(op, cluster=None):
), "Only GPU device is supported currently."

gflops = 0.0
if dtype == VarDesc.VarType.FP64:
if dtype == paddle.float64:
gflops = device.dp_gflops
elif dtype == VarDesc.VarType.FP32:
elif dtype == paddle.float32:
gflops = device.sp_gflops
elif dtype == VarDesc.VarType.FP16 or dtype == VarDesc.VarType.BF16:
elif dtype == paddle.float16 or dtype == paddle.bfloat16:
gflops = device.hp_gflops
else:
raise ValueError(
Expand Down