Skip to content

Commit fca8abb

Browse files
committed
Fix VarType
1 parent 0e3f93c commit fca8abb

File tree

3 files changed

+22
-24
lines changed

3 files changed

+22
-24
lines changed

test/legacy_test/test_variable.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,32 +34,31 @@ def setUp(self):
3434
np.random.seed(2022)
3535

3636
def test_np_dtype_convert(self):
37-
DT = core.VarDesc.VarType
3837
convert = convert_np_dtype_to_dtype_
39-
self.assertEqual(DT.FP32, convert(np.float32))
40-
self.assertEqual(DT.FP16, convert("float16"))
41-
self.assertEqual(DT.FP64, convert("float64"))
42-
self.assertEqual(DT.INT32, convert("int32"))
43-
self.assertEqual(DT.INT16, convert("int16"))
44-
self.assertEqual(DT.INT64, convert("int64"))
45-
self.assertEqual(DT.BOOL, convert("bool"))
46-
self.assertEqual(DT.INT8, convert("int8"))
47-
self.assertEqual(DT.UINT8, convert("uint8"))
38+
self.assertEqual(paddle.float32, convert(np.float32))
39+
self.assertEqual(paddle.float16, convert("float16"))
40+
self.assertEqual(paddle.float64, convert("float64"))
41+
self.assertEqual(paddle.int32, convert("int32"))
42+
self.assertEqual(paddle.int16, convert("int16"))
43+
self.assertEqual(paddle.int64, convert("int64"))
44+
self.assertEqual(paddle.bool, convert("bool"))
45+
self.assertEqual(paddle.int8, convert("int8"))
46+
self.assertEqual(paddle.uint8, convert("uint8"))
4847

4948
def test_var(self):
5049
b = default_main_program().current_block()
5150
w = b.create_var(
5251
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w"
5352
)
5453
self.assertNotEqual(str(w), "")
55-
self.assertEqual(core.VarDesc.VarType.FP64, w.dtype)
54+
self.assertEqual(paddle.float64, w.dtype)
5655
self.assertEqual((784, 100), w.shape)
5756
self.assertEqual("fc.w", w.name)
5857
self.assertEqual("fc.w@GRAD", w.grad_name)
5958
self.assertEqual(0, w.lod_level)
6059

6160
w = b.create_var(name='fc.w')
62-
self.assertEqual(core.VarDesc.VarType.FP64, w.dtype)
61+
self.assertEqual(paddle.float64, w.dtype)
6362
self.assertEqual((784, 100), w.shape)
6463
self.assertEqual("fc.w", w.name)
6564
self.assertEqual("fc.w@GRAD", w.grad_name)
@@ -440,7 +439,7 @@ def test_variable_in_dygraph_mode(self):
440439

441440
self.assertTrue(var.name.startswith('_generated_var_'))
442441
self.assertEqual(var.shape, (1, 1))
443-
self.assertEqual(var.dtype, base.core.VarDesc.VarType.FP64)
442+
self.assertEqual(var.dtype, paddle.float64)
444443
self.assertEqual(var.type, base.core.VarDesc.VarType.LOD_TENSOR)
445444

446445
def test_create_selected_rows(self):

test/sequence/test_sequence_pad_op.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from op_test import OpTest
2020

2121
import paddle
22-
from paddle.base import core
2322

2423

2524
class TestSequencePadOp(OpTest):
@@ -191,7 +190,7 @@ def test_length_dtype(self):
191190
x=x, pad_value=pad_value
192191
)
193192
# check if the dtype of length is int64 in compile time
194-
self.assertEqual(length.dtype, core.VarDesc.VarType.INT64)
193+
self.assertEqual(length.dtype, paddle.int64)
195194

196195

197196
if __name__ == '__main__':

test/xpu/test_gaussian_random_op_xpu.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424

2525
import paddle
2626
from paddle import base
27+
from paddle.base import core
2728

2829
paddle.enable_static()
29-
from paddle.base import core
3030
from paddle.tensor import random
3131

3232
typeid_dict = {
@@ -285,22 +285,22 @@ def test_default_dtype(self):
285285
def test_default_fp16():
286286
paddle.framework.set_default_dtype('float16')
287287
out = paddle.tensor.random.gaussian([2, 3])
288-
self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP16)
288+
self.assertEqual(out.dtype, paddle.float16)
289289

290290
def test_default_bf16():
291291
paddle.framework.set_default_dtype('bfloat16')
292292
out = paddle.tensor.random.gaussian([2, 3])
293-
self.assertEqual(out.dtype, base.core.VarDesc.VarType.BF16)
293+
self.assertEqual(out.dtype, paddle.bfloat16)
294294

295295
def test_default_fp32():
296296
paddle.framework.set_default_dtype('float32')
297297
out = paddle.tensor.random.gaussian([2, 3])
298-
self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP32)
298+
self.assertEqual(out.dtype, paddle.float32)
299299

300300
def test_default_fp64():
301301
paddle.framework.set_default_dtype('float64')
302302
out = paddle.tensor.random.gaussian([2, 3])
303-
self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP64)
303+
self.assertEqual(out.dtype, paddle.float64)
304304

305305
test_default_fp64()
306306
test_default_fp32()
@@ -317,22 +317,22 @@ def test_default_dtype(self):
317317
def test_default_fp16():
318318
paddle.framework.set_default_dtype('float16')
319319
out = paddle.tensor.random.standard_normal([2, 3])
320-
self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP16)
320+
self.assertEqual(out.dtype, paddle.float16)
321321

322322
def test_default_bf16():
323323
paddle.framework.set_default_dtype('bfloat16')
324324
out = paddle.tensor.random.standard_normal([2, 3])
325-
self.assertEqual(out.dtype, base.core.VarDesc.VarType.BF16)
325+
self.assertEqual(out.dtype, paddle.bfloat16)
326326

327327
def test_default_fp32():
328328
paddle.framework.set_default_dtype('float32')
329329
out = paddle.tensor.random.standard_normal([2, 3])
330-
self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP32)
330+
self.assertEqual(out.dtype, paddle.float32)
331331

332332
def test_default_fp64():
333333
paddle.framework.set_default_dtype('float64')
334334
out = paddle.tensor.random.standard_normal([2, 3])
335-
self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP64)
335+
self.assertEqual(out.dtype, paddle.float64)
336336

337337
test_default_fp64()
338338
test_default_fp32()

0 commit comments

Comments
 (0)