Skip to content

Commit ca59629

Browse files
committed
Fix VarType
1 parent 062b99e commit ca59629

5 files changed

Lines changed: 62 additions & 65 deletions

File tree

python/paddle/amp/grad_scaler.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import numpy as np
2020

21+
import paddle
2122
from paddle import _C_ops, _legacy_C_ops
2223
from paddle.base import core
2324
from paddle.base.data_feeder import check_type
@@ -296,15 +297,9 @@ def _unscale(self, optimizer):
296297
for param in group['params']:
297298
if param._grad_ivar() is not None:
298299
param_grads.append(param._grad_ivar())
299-
if (
300-
param._grad_ivar().dtype
301-
== core.VarDesc.VarType.FP16
302-
):
300+
if param._grad_ivar().dtype == paddle.float16:
303301
param_grads_fp16.append(param._grad_ivar())
304-
elif (
305-
param._grad_ivar().dtype
306-
== core.VarDesc.VarType.BF16
307-
):
302+
elif param._grad_ivar().dtype == paddle.bfloat16:
308303
param_grads_bf16.append(param._grad_ivar())
309304
else:
310305
param_grads_fp32.append(param._grad_ivar())
@@ -328,17 +323,17 @@ def _unscale(self, optimizer):
328323
param_grads_fp16 = [
329324
param
330325
for param in param_grads
331-
if param.dtype == core.VarDesc.VarType.FP16
326+
if param.dtype == paddle.float16
332327
]
333328
param_grads_bf16 = [
334329
param
335330
for param in param_grads
336-
if param.dtype == core.VarDesc.VarType.BF16
331+
if param.dtype == paddle.bfloat16
337332
]
338333
param_grads_fp32 = [
339334
param
340335
for param in param_grads
341-
if param.dtype == core.VarDesc.VarType.FP32
336+
if param.dtype == paddle.float32
342337
]
343338
self._found_inf = self._temp_found_inf_value_false
344339
if len(param_grads_fp16):

python/paddle/base/dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from google.protobuf import text_format
1717

18+
import paddle
1819
from paddle.base.proto import data_feed_pb2
1920

2021
from ..utils import deprecated
@@ -271,11 +272,11 @@ def set_use_var(self, var_list):
271272
if var.lod_level == 0:
272273
slot_var.is_dense = True
273274
slot_var.shape.extend(var.shape)
274-
if var.dtype == core.VarDesc.VarType.FP32:
275+
if var.dtype == paddle.float32:
275276
slot_var.type = "float"
276-
elif var.dtype == core.VarDesc.VarType.INT64:
277+
elif var.dtype == paddle.int64:
277278
slot_var.type = "uint64"
278-
elif var.dtype == core.VarDesc.VarType.INT32:
279+
elif var.dtype == paddle.int32:
279280
slot_var.type = "uint32"
280281
else:
281282
raise ValueError(

python/paddle/base/dygraph/math_op_patch.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,19 @@
1414

1515
import numpy as np
1616

17+
import paddle
1718
from paddle import _C_ops, _legacy_C_ops
1819

1920
from .. import core, framework
2021
from ..framework import convert_np_dtype_to_dtype_
2122

2223
_supported_int_dtype_ = [
23-
core.VarDesc.VarType.UINT8,
24-
core.VarDesc.VarType.INT8,
25-
core.VarDesc.VarType.INT16,
26-
core.VarDesc.VarType.INT32,
27-
core.VarDesc.VarType.INT64,
28-
core.VarDesc.VarType.BOOL,
24+
paddle.uint8,
25+
paddle.int8,
26+
paddle.int16,
27+
paddle.int32,
28+
paddle.int64,
29+
paddle.bool,
2930
]
3031

3132
# NOTE(chenweihang): We currently do not fully support the type promotion
@@ -50,8 +51,8 @@
5051
]
5152

5253
_complex_dtypes = [
53-
core.VarDesc.VarType.COMPLEX64,
54-
core.VarDesc.VarType.COMPLEX128,
54+
paddle.complex64,
55+
paddle.complex128,
5556
]
5657

5758
_already_patch_eager_tensor = False
@@ -107,7 +108,7 @@ def _float_(var):
107108
), "only one element variable can be converted to float."
108109
tensor = var.value().get_tensor()
109110
assert tensor._is_initialized(), "variable's tensor is not initialized"
110-
if var.dtype == core.VarDesc.VarType.BF16:
111+
if var.dtype == paddle.bfloat16:
111112
var = var.astype('float32')
112113
return float(np.array(var))
113114

@@ -116,7 +117,7 @@ def _long_(var):
116117
assert numel == 1, "only one element variable can be converted to long."
117118
tensor = var.value().get_tensor()
118119
assert tensor._is_initialized(), "variable's tensor is not initialized"
119-
if var.dtype == core.VarDesc.VarType.BF16:
120+
if var.dtype == paddle.bfloat16:
120121
var = var.astype('float32')
121122
return int(np.array(var))
122123

@@ -125,7 +126,7 @@ def _int_(var):
125126
assert numel == 1, "only one element variable can be converted to int."
126127
tensor = var.value().get_tensor()
127128
assert tensor._is_initialized(), "variable's tensor is not initialized"
128-
if var.dtype == core.VarDesc.VarType.BF16:
129+
if var.dtype == paddle.bfloat16:
129130
var = var.astype('float32')
130131
return int(np.array(var))
131132

@@ -145,7 +146,7 @@ def _index_(var):
145146
), "only one element variable can be converted to python index."
146147
tensor = var.value().get_tensor()
147148
assert tensor._is_initialized(), "variable's tensor is not initialized"
148-
if var.dtype == core.VarDesc.VarType.BF16:
149+
if var.dtype == paddle.bfloat16:
149150
var = var.astype('float32')
150151
return int(np.array(var))
151152

python/paddle/base/framework.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import numpy as np
3434

35+
import paddle
3536
import paddle.version as paddle_version
3637

3738
from .. import pir
@@ -230,19 +231,19 @@ def __setattr__(self, name, val):
230231
}
231232

232233
paddle_type_to_proto_type = {
233-
DataType.BOOL: core.VarDesc.VarType.BOOL,
234-
DataType.FLOAT16: core.VarDesc.VarType.FP16,
235-
DataType.UINT16: core.VarDesc.VarType.BF16,
236-
DataType.BFLOAT16: core.VarDesc.VarType.BF16,
237-
DataType.FLOAT32: core.VarDesc.VarType.FP32,
238-
DataType.FLOAT64: core.VarDesc.VarType.FP64,
239-
DataType.INT8: core.VarDesc.VarType.INT8,
240-
DataType.INT16: core.VarDesc.VarType.INT16,
241-
DataType.INT32: core.VarDesc.VarType.INT32,
242-
DataType.INT64: core.VarDesc.VarType.INT64,
243-
DataType.UINT8: core.VarDesc.VarType.UINT8,
244-
DataType.COMPLEX64: core.VarDesc.VarType.COMPLEX64,
245-
DataType.COMPLEX128: core.VarDesc.VarType.COMPLEX128,
234+
DataType.BOOL: paddle.bool,
235+
DataType.FLOAT16: paddle.float16,
236+
DataType.UINT16: paddle.bfloat16,
237+
DataType.BFLOAT16: paddle.bfloat16,
238+
DataType.FLOAT32: paddle.float32,
239+
DataType.FLOAT64: paddle.float64,
240+
DataType.INT8: paddle.int8,
241+
DataType.INT16: paddle.int16,
242+
DataType.INT32: paddle.int32,
243+
DataType.INT64: paddle.int64,
244+
DataType.UINT8: paddle.uint8,
245+
DataType.COMPLEX64: paddle.complex64,
246+
DataType.COMPLEX128: paddle.complex128,
246247
}
247248

248249

@@ -1245,31 +1246,31 @@ def convert_np_dtype_to_dtype_(np_dtype):
12451246
dtype = np.dtype(np_dtype)
12461247

12471248
if dtype == np.float32:
1248-
return core.VarDesc.VarType.FP32
1249+
return paddle.float32
12491250
elif dtype == np.float64:
1250-
return core.VarDesc.VarType.FP64
1251+
return paddle.float64
12511252
elif dtype == np.float16:
1252-
return core.VarDesc.VarType.FP16
1253+
return paddle.float16
12531254
elif dtype == np.int32:
1254-
return core.VarDesc.VarType.INT32
1255+
return paddle.int32
12551256
elif dtype == np.int16:
1256-
return core.VarDesc.VarType.INT16
1257+
return paddle.int16
12571258
elif dtype == np.int64:
1258-
return core.VarDesc.VarType.INT64
1259+
return paddle.int64
12591260
elif dtype == np.bool_:
1260-
return core.VarDesc.VarType.BOOL
1261+
return paddle.bool
12611262
elif dtype == np.uint16:
12621263
# since there is still no support for bfloat16 in NumPy,
12631264
# uint16 is used for casting bfloat16
1264-
return core.VarDesc.VarType.BF16
1265+
return paddle.bfloat16
12651266
elif dtype == np.uint8:
1266-
return core.VarDesc.VarType.UINT8
1267+
return paddle.uint8
12671268
elif dtype == np.int8:
1268-
return core.VarDesc.VarType.INT8
1269+
return paddle.int8
12691270
elif dtype == np.complex64:
1270-
return core.VarDesc.VarType.COMPLEX64
1271+
return paddle.complex64
12711272
elif dtype == np.complex128:
1272-
return core.VarDesc.VarType.COMPLEX128
1273+
return paddle.complex128
12731274
else:
12741275
raise ValueError("Not supported numpy dtype %s" % dtype)
12751276

@@ -1288,9 +1289,9 @@ def dtype_is_floating(dtype):
12881289
dtype = convert_np_dtype_to_dtype_(dtype)
12891290

12901291
return dtype in [
1291-
core.VarDesc.VarType.FP16,
1292-
core.VarDesc.VarType.FP32,
1293-
core.VarDesc.VarType.FP64,
1292+
paddle.float16,
1293+
paddle.float32,
1294+
paddle.float64,
12941295
]
12951296

12961297

@@ -1327,7 +1328,7 @@ def _create_tensor(
13271328
dtype = convert_np_dtype_to_dtype_(dtype)
13281329

13291330
eager_tensor = core.eager.Tensor(
1330-
dtype if dtype else core.VarDesc.VarType.FP32,
1331+
dtype if dtype else paddle.float32,
13311332
list(shape) if shape else [],
13321333
name,
13331334
type if type else core.VarDesc.VarType.LOD_TENSOR,
@@ -2785,7 +2786,7 @@ def size(self):
27852786
name=unique_name.generate_with_ignorable_key(
27862787
self.name + "_size"
27872788
),
2788-
dtype=core.VarDesc.VarType.INT64,
2789+
dtype=paddle.int64,
27892790
)
27902791

27912792
self.block.append_op(
@@ -7570,7 +7571,7 @@ def __init__(self, shape, dtype, **kwargs):
75707571
shape = shape.numpy()
75717572

75727573
super().__init__(
7573-
dtype if dtype else core.VarDesc.VarType.FP32,
7574+
dtype if dtype else paddle.float32,
75747575
list(shape) if shape else [],
75757576
name,
75767577
core.VarDesc.VarType.LOD_TENSOR,
@@ -8154,13 +8155,13 @@ def _get_paddle_place_list(places):
81548155

81558156

81568157
def dtype_to_str(in_dtype):
8157-
if in_dtype == core.VarDesc.VarType.FP16:
8158+
if in_dtype == paddle.float16:
81588159
return "fp16"
8159-
elif in_dtype == core.VarDesc.VarType.BF16:
8160+
elif in_dtype == paddle.bfloat16:
81608161
return "bf16"
8161-
elif in_dtype == core.VarDesc.VarType.FP32:
8162+
elif in_dtype == paddle.float32:
81628163
return "fp32"
8163-
elif in_dtype == core.VarDesc.VarType.FP64:
8164+
elif in_dtype == paddle.float64:
81648165
return "fp64"
81658166
else:
81668167
return None

python/paddle/distributed/auto_parallel/static/cost/base_cost.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import numpy as np
1818

1919
import paddle
20-
from paddle.base.core import VarDesc
2120
from paddle.utils.flops import flops
2221

2322
from ..cluster import DeviceType, LinkType, get_default_cluster
@@ -968,11 +967,11 @@ def calc_time_by_cost_model(op, cluster=None):
968967
), "Only GPU device is supported currently."
969968

970969
gflops = 0.0
971-
if dtype == VarDesc.VarType.FP64:
970+
if dtype == paddle.float64:
972971
gflops = device.dp_gflops
973-
elif dtype == VarDesc.VarType.FP32:
972+
elif dtype == paddle.float32:
974973
gflops = device.sp_gflops
975-
elif dtype == VarDesc.VarType.FP16 or dtype == VarDesc.VarType.BF16:
974+
elif dtype == paddle.float16 or dtype == paddle.bfloat16:
976975
gflops = device.hp_gflops
977976
else:
978977
raise ValueError(

0 commit comments

Comments
 (0)