Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
elif isinstance(shape, paddle.pir.Value):
pass
else:
TypeError("Shape only supports OpResult, or list, or tuple.")
TypeError("Shape only supports Value, or list, or tuple.")

if out is None:
out = _C_ops.full(shape, value, dtype, place)
Expand Down
180 changes: 112 additions & 68 deletions test/legacy_test/test_zero_dim_no_backward_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ class TestNoBackwardAPIStatic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.exe = paddle.static.Executor()
self.shape = []
Copy link
Member

@SigureMo SigureMo Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我们不应该在setUp中加载任何数据,因为在这时还没确定是哪种 IR 模式运行

这个默认的 self.shape 是针对什么情况的?如果所有 case 都有 init_data 的话,就不存在需要 self.shape 默认值的情况

另外,既然现在是在具体 test case 里调用的,就不要写成带有 side effect 的函数了,直接作为返回值返回即可,能写成纯函数就写成纯函数,函数名改名为 create_dynamic_shape

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


def init_data(self):
self.shape = [
paddle.full([], 2, 'int32'),
paddle.full([], 3, 'int32'),
Expand Down Expand Up @@ -312,6 +315,7 @@ def test_arange(self):
np.testing.assert_array_equal(res, [1.0, 2.0, 3.0, 4.0, 5.0])

def test_normal(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测还有别的问题,下个pr看

self.init_data()
mean = paddle.full([], 0.0)
std = paddle.full([], 0.0)
out1 = paddle.normal(mean, std)
Expand All @@ -325,25 +329,35 @@ def test_normal(self):
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))

@test_with_pir_api
def test_rand(self):
out1 = paddle.rand([])
out2 = paddle.rand(self.shape)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
self.init_data()
out1 = paddle.rand([])
out2 = paddle.rand(self.shape)

res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))
res = paddle.static.Executor().run(
main_program, fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))

@test_with_pir_api
def test_randn(self):
out1 = paddle.randn([])
out2 = paddle.randn(self.shape)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
self.init_data()
out1 = paddle.randn([])
out2 = paddle.randn(self.shape)

res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))
res = paddle.static.Executor().run(
main_program, fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))

@test_with_pir_api
def test_randint(self):
Expand Down Expand Up @@ -381,76 +395,106 @@ def test_randint_like(self):
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())

@test_with_pir_api
def test_standard_normal(self):
out1 = paddle.standard_normal([])
out2 = paddle.standard_normal(self.shape)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
self.init_data()
out1 = paddle.standard_normal([])
out2 = paddle.standard_normal(self.shape)

res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))
res = paddle.static.Executor().run(
main_program, fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))

@test_with_pir_api
def test_uniform(self):
out1 = paddle.uniform([])
out2 = paddle.uniform(self.shape)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
self.init_data()
out1 = paddle.uniform([])
out2 = paddle.uniform(self.shape)

res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))
res = paddle.static.Executor().run(
main_program, fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))

@test_with_pir_api
def test_empty_and_empty_like(self):
out1 = paddle.empty([])
out2 = paddle.empty_like(out1)
out3 = paddle.empty(self.shape)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
self.init_data()
out1 = paddle.empty([])
out2 = paddle.empty_like(out1)
out3 = paddle.empty(self.shape)

res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
res = paddle.static.Executor().run(
main_program, fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))

@test_with_pir_api
def test_full_and_full_like(self):
out1 = paddle.full([], 0.5)
out2 = paddle.full_like(out1, 0.5)
out3 = paddle.full(self.shape, 0.5)
out4 = paddle.full(self.shape, paddle.full([], 0.5))

res = self.exe.run(
paddle.static.default_main_program(),
fetch_list=[out1, out2, out3, out4],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
self.assertEqual(res[3].shape, (2, 3, 4))
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
self.init_data()
out1 = paddle.full([], 0.5)
out2 = paddle.full_like(out1, 0.5)
out3 = paddle.full(self.shape, 0.5)
out4 = paddle.full(self.shape, paddle.full([], 0.5))

res = paddle.static.Executor().run(
main_program,
fetch_list=[out1, out2, out3, out4],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
self.assertEqual(res[3].shape, (2, 3, 4))

@test_with_pir_api
def test_ones_and_ones_like(self):
out1 = paddle.ones([])
out2 = paddle.ones_like(out1)
out3 = paddle.ones(self.shape)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
self.init_data()
out1 = paddle.ones([])
out2 = paddle.ones_like(out1)
out3 = paddle.ones(self.shape)

res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
res = paddle.static.Executor().run(
main_program, fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))

@test_with_pir_api
def test_zeros_and_zeros_like(self):
out1 = paddle.zeros([])
out2 = paddle.zeros_like(out1)
out3 = paddle.zeros(self.shape)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
self.init_data()
out1 = paddle.zeros([])
out2 = paddle.zeros_like(out1)
out3 = paddle.zeros(self.shape)

res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
res = paddle.static.Executor().run(
main_program, fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))

@test_with_pir_api
def test_embedding(self):
Expand Down