Skip to content

Commit 63889be

Browse files
update UT
1 parent 1938ad4 commit 63889be

File tree

4 files changed

+110
-0
lines changed

4 files changed

+110
-0
lines changed

test/legacy_test/test_empty.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,17 @@ def test_Tensor_new_empty(self):
252252
if isinstance(dtype, paddle.dtype):
253253
self.assertEqual(x.dtype, dtype)
254254

255+
x = paddle.empty(
256+
[2],
257+
).new_empty(
258+
*shape,
259+
dtype=dtype,
260+
requires_grad=requires_grad,
261+
device=device,
262+
pin_memory=pin_memory,
263+
)
264+
self.assertEqual(x.shape, shape)
265+
255266
def new_empty(
256267
x, shape, dtype, requires_grad, device, pin_memory
257268
):
@@ -283,6 +294,30 @@ def new_empty(
283294
if isinstance(dtype, paddle.dtype):
284295
self.assertEqual(x.dtype, dtype)
285296

297+
def new_empty_size_arg(
298+
x, shape, dtype, requires_grad, device, pin_memory
299+
):
300+
return x.new_empty(
301+
*shape,
302+
dtype=dtype,
303+
requires_grad=requires_grad,
304+
device=device,
305+
pin_memory=pin_memory,
306+
)
307+
308+
st_f = paddle.jit.to_static(
309+
new_empty_size_arg, full_graph=True, backend=None
310+
)
311+
x = st_f(
312+
paddle.randn([1]),
313+
shape,
314+
dtype=dtype,
315+
requires_grad=requires_grad,
316+
device=device,
317+
pin_memory=pin_memory,
318+
)
319+
self.assertEqual(x.shape, shape)
320+
286321

287322
class TestCreationOut(unittest.TestCase):
288323
def setUp(self):

test/legacy_test/test_math_op_patch_pir.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,11 @@ def test_new_xxx(self):
788788
(output_x,) = exe.run(main_program, fetch_list=[x_new])
789789
self.assertEqual(output_x.shape, (2, 3))
790790

791+
x_new = x.new_zeros(2, 3)
792+
self.assertEqual(x_new.shape, [2, 3])
793+
(output_x,) = exe.run(main_program, fetch_list=[x_new])
794+
self.assertEqual(output_x.shape, (2, 3))
795+
791796
# test mT with dynamic shape
792797
with paddle.pir_utils.IrGuard():
793798
main_program, exe, program_guard = new_program()

test/legacy_test/test_ones.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,17 @@ def test_Tensor_new_ones(self):
246246
if isinstance(dtype, paddle.dtype):
247247
self.assertEqual(x.dtype, dtype)
248248

249+
x = paddle.ones(
250+
[2],
251+
).new_ones(
252+
*shape,
253+
dtype=dtype,
254+
requires_grad=requires_grad,
255+
device=device,
256+
pin_memory=pin_memory,
257+
)
258+
self.assertEqual(x.shape, shape)
259+
249260
def new_ones(
250261
x, shape, dtype, requires_grad, device, pin_memory
251262
):
@@ -277,6 +288,30 @@ def new_ones(
277288
if isinstance(dtype, paddle.dtype):
278289
self.assertEqual(x.dtype, dtype)
279290

291+
def new_ones_size_arg(
292+
x, shape, dtype, requires_grad, device, pin_memory
293+
):
294+
return x.new_ones(
295+
*shape,
296+
dtype=dtype,
297+
requires_grad=requires_grad,
298+
device=device,
299+
pin_memory=pin_memory,
300+
)
301+
302+
st_f = paddle.jit.to_static(
303+
new_ones_size_arg, full_graph=True, backend=None
304+
)
305+
x = st_f(
306+
paddle.randn([1]),
307+
shape,
308+
dtype=dtype,
309+
requires_grad=requires_grad,
310+
device=device,
311+
pin_memory=pin_memory,
312+
)
313+
self.assertEqual(x.shape, shape)
314+
280315

281316
class TestCreationOut(unittest.TestCase):
282317
def setUp(self):

test/legacy_test/test_zeros.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,17 @@ def test_Tensor_new_zeros(self):
246246
if isinstance(dtype, paddle.dtype):
247247
self.assertEqual(x.dtype, dtype)
248248

249+
x = paddle.zeros(
250+
[2],
251+
).new_zeros(
252+
*shape,
253+
dtype=dtype,
254+
requires_grad=requires_grad,
255+
device=device,
256+
pin_memory=pin_memory,
257+
)
258+
self.assertEqual(x.shape, shape)
259+
249260
def new_zeros(
250261
x, shape, dtype, requires_grad, device, pin_memory
251262
):
@@ -277,6 +288,30 @@ def new_zeros(
277288
if isinstance(dtype, paddle.dtype):
278289
self.assertEqual(x.dtype, dtype)
279290

291+
def new_zeros_size_arg(
292+
x, shape, dtype, requires_grad, device, pin_memory
293+
):
294+
return x.new_zeros(
295+
*shape,
296+
dtype=dtype,
297+
requires_grad=requires_grad,
298+
device=device,
299+
pin_memory=pin_memory,
300+
)
301+
302+
st_f = paddle.jit.to_static(
303+
new_zeros_size_arg, full_graph=True, backend=None
304+
)
305+
x = st_f(
306+
paddle.randn([1]),
307+
shape,
308+
dtype=dtype,
309+
requires_grad=requires_grad,
310+
device=device,
311+
pin_memory=pin_memory,
312+
)
313+
self.assertEqual(x.shape, shape)
314+
280315

281316
class TestCreationOut(unittest.TestCase):
282317
def setUp(self):

0 commit comments

Comments
 (0)