3232
3333import numpy as np
3434
35+ import paddle
3536import paddle .version as paddle_version
3637
3738from .. import pir
@@ -230,19 +231,19 @@ def __setattr__(self, name, val):
230231}
231232
232233paddle_type_to_proto_type = {
233- DataType .BOOL : core . VarDesc . VarType . BOOL ,
234- DataType .FLOAT16 : core . VarDesc . VarType . FP16 ,
235- DataType .UINT16 : core . VarDesc . VarType . BF16 ,
236- DataType .BFLOAT16 : core . VarDesc . VarType . BF16 ,
237- DataType .FLOAT32 : core . VarDesc . VarType . FP32 ,
238- DataType .FLOAT64 : core . VarDesc . VarType . FP64 ,
239- DataType .INT8 : core . VarDesc . VarType . INT8 ,
240- DataType .INT16 : core . VarDesc . VarType . INT16 ,
241- DataType .INT32 : core . VarDesc . VarType . INT32 ,
242- DataType .INT64 : core . VarDesc . VarType . INT64 ,
243- DataType .UINT8 : core . VarDesc . VarType . UINT8 ,
244- DataType .COMPLEX64 : core . VarDesc . VarType . COMPLEX64 ,
245- DataType .COMPLEX128 : core . VarDesc . VarType . COMPLEX128 ,
234+ DataType .BOOL : paddle . bool ,
235+ DataType .FLOAT16 : paddle . float16 ,
236+ DataType .UINT16 : paddle . bfloat16 ,
237+ DataType .BFLOAT16 : paddle . bfloat16 ,
238+ DataType .FLOAT32 : paddle . float32 ,
239+ DataType .FLOAT64 : paddle . float64 ,
240+ DataType .INT8 : paddle . int8 ,
241+ DataType .INT16 : paddle . int16 ,
242+ DataType .INT32 : paddle . int32 ,
243+ DataType .INT64 : paddle . int64 ,
244+ DataType .UINT8 : paddle . uint8 ,
245+ DataType .COMPLEX64 : paddle . complex64 ,
246+ DataType .COMPLEX128 : paddle . complex128 ,
246247}
247248
248249
@@ -1245,31 +1246,31 @@ def convert_np_dtype_to_dtype_(np_dtype):
12451246 dtype = np .dtype (np_dtype )
12461247
12471248 if dtype == np .float32 :
1248- return core . VarDesc . VarType . FP32
1249+ return paddle . float32
12491250 elif dtype == np .float64 :
1250- return core . VarDesc . VarType . FP64
1251+ return paddle . float64
12511252 elif dtype == np .float16 :
1252- return core . VarDesc . VarType . FP16
1253+ return paddle . float16
12531254 elif dtype == np .int32 :
1254- return core . VarDesc . VarType . INT32
1255+ return paddle . int32
12551256 elif dtype == np .int16 :
1256- return core . VarDesc . VarType . INT16
1257+ return paddle . int16
12571258 elif dtype == np .int64 :
1258- return core . VarDesc . VarType . INT64
1259+ return paddle . int64
12591260 elif dtype == np .bool_ :
1260- return core . VarDesc . VarType . BOOL
1261+ return paddle . bool
12611262 elif dtype == np .uint16 :
12621263 # since there is still no support for bfloat16 in NumPy,
12631264 # uint16 is used for casting bfloat16
1264- return core . VarDesc . VarType . BF16
1265+ return paddle . bfloat16
12651266 elif dtype == np .uint8 :
1266- return core . VarDesc . VarType . UINT8
1267+ return paddle . uint8
12671268 elif dtype == np .int8 :
1268- return core . VarDesc . VarType . INT8
1269+ return paddle . int8
12691270 elif dtype == np .complex64 :
1270- return core . VarDesc . VarType . COMPLEX64
1271+ return paddle . complex64
12711272 elif dtype == np .complex128 :
1272- return core . VarDesc . VarType . COMPLEX128
1273+ return paddle . complex128
12731274 else :
12741275 raise ValueError ("Not supported numpy dtype %s" % dtype )
12751276
@@ -1288,9 +1289,9 @@ def dtype_is_floating(dtype):
12881289 dtype = convert_np_dtype_to_dtype_ (dtype )
12891290
12901291 return dtype in [
1291- core . VarDesc . VarType . FP16 ,
1292- core . VarDesc . VarType . FP32 ,
1293- core . VarDesc . VarType . FP64 ,
1292+ paddle . float16 ,
1293+ paddle . float32 ,
1294+ paddle . float64 ,
12941295 ]
12951296
12961297
@@ -1327,7 +1328,7 @@ def _create_tensor(
13271328 dtype = convert_np_dtype_to_dtype_ (dtype )
13281329
13291330 eager_tensor = core .eager .Tensor (
1330- dtype if dtype else core . VarDesc . VarType . FP32 ,
1331+ dtype if dtype else paddle . float32 ,
13311332 list (shape ) if shape else [],
13321333 name ,
13331334 type if type else core .VarDesc .VarType .LOD_TENSOR ,
@@ -2785,7 +2786,7 @@ def size(self):
27852786 name = unique_name .generate_with_ignorable_key (
27862787 self .name + "_size"
27872788 ),
2788- dtype = core . VarDesc . VarType . INT64 ,
2789+ dtype = paddle . int64 ,
27892790 )
27902791
27912792 self .block .append_op (
@@ -7570,7 +7571,7 @@ def __init__(self, shape, dtype, **kwargs):
75707571 shape = shape .numpy ()
75717572
75727573 super ().__init__ (
7573- dtype if dtype else core . VarDesc . VarType . FP32 ,
7574+ dtype if dtype else paddle . float32 ,
75747575 list (shape ) if shape else [],
75757576 name ,
75767577 core .VarDesc .VarType .LOD_TENSOR ,
@@ -8154,13 +8155,13 @@ def _get_paddle_place_list(places):
81548155
81558156
81568157def dtype_to_str (in_dtype ):
8157- if in_dtype == core . VarDesc . VarType . FP16 :
8158+ if in_dtype == paddle . float16 :
81588159 return "fp16"
8159- elif in_dtype == core . VarDesc . VarType . BF16 :
8160+ elif in_dtype == paddle . bfloat16 :
81608161 return "bf16"
8161- elif in_dtype == core . VarDesc . VarType . FP32 :
8162+ elif in_dtype == paddle . float32 :
81628163 return "fp32"
8163- elif in_dtype == core . VarDesc . VarType . FP64 :
8164+ elif in_dtype == paddle . float64 :
81648165 return "fp64"
81658166 else :
81668167 return None
0 commit comments