Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/paddle/tensor/manipulation.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该文件下的 _fill_diagonal_tensor_impl 貌似没有适配?该函数是 fill_diagonal_tensor 的实现。需要注意的是:
fill_diagonal_tensor: 静态图和动态图都可以运行
fill_diagonal_tensor_ : 只有动态图能运行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到

Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
[[-1]
[ 1]]
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.shard_index(
input, index_num, nshards, shard_id, ignore_value
)
Expand Down Expand Up @@ -980,7 +980,7 @@ def fill_diagonal_(x, value, offset=0, wrap=False, name=None):
>>> print(x.tolist())
[[1.0, 2.0, 2.0], [2.0, 1.0, 2.0], [2.0, 2.0, 1.0], [2.0, 2.0, 2.0]]
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
if len(x.shape) == 2:
return _C_ops.fill_diagonal_(x, value, offset, wrap)
return _C_ops.fill_diagonal_(x, value, offset, True)
Expand Down Expand Up @@ -1780,7 +1780,7 @@ def roll(x, shifts, axis=None, name=None):
else:
axis = []

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.roll(x, shifts, axis)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -4326,7 +4326,7 @@ def strided_slice(x, axes, starts, ends, strides, name=None):
>>> sliced_2 = paddle.strided_slice(x, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides_2)
>>> # sliced_2 is x[:, 1:3:1, 0:2:1, 2:4:2].
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.strided_slice(x, axes, starts, ends, strides)
else:
helper = LayerHelper('strided_slice', **locals())
Expand Down
107 changes: 73 additions & 34 deletions test/legacy_test/test_roll_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import paddle
from paddle import base
from paddle.base import Program, core, program_guard
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


class TestRollOp(OpTest):
Expand Down Expand Up @@ -48,10 +49,10 @@ def init_dtype_type(self):
self.axis = [0, -2]

def test_check_output(self):
self.check_output(check_prim=True)
self.check_output(check_prim=True, check_pir=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', check_prim=True)
self.check_grad(['X'], 'Out', check_prim=True, check_pir=True)


class TestRollOpCase2(TestRollOp):
Expand Down Expand Up @@ -108,10 +109,14 @@ def init_dtype_type(self):
self.place = core.CUDAPlace(0)

def test_check_output(self):
self.check_output_with_place(self.place, check_prim=True)
self.check_output_with_place(
self.place, check_prim=True, check_pir=True
)

def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X'], 'Out', check_prim=True)
self.check_grad_with_place(
self.place, ['X'], 'Out', check_prim=True, check_pir=True
)


@unittest.skipIf(
Expand All @@ -128,10 +133,14 @@ def init_dtype_type(self):
self.place = core.CUDAPlace(0)

def test_check_output(self):
self.check_output_with_place(self.place, check_prim=True)
self.check_output_with_place(
self.place, check_prim=True, check_pir=True
)

def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X'], 'Out', check_prim=True)
self.check_grad_with_place(
self.place, ['X'], 'Out', check_prim=True, check_pir=True
)


@unittest.skipIf(
Expand All @@ -148,10 +157,14 @@ def init_dtype_type(self):
self.place = core.CUDAPlace(0)

def test_check_output(self):
self.check_output_with_place(self.place, check_prim=True)
self.check_output_with_place(
self.place, check_prim=True, check_pir=True
)

def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X'], 'Out', check_prim=True)
self.check_grad_with_place(
self.place, ['X'], 'Out', check_prim=True, check_pir=True
)


class TestRollAPI(unittest.TestCase):
Expand All @@ -160,37 +173,53 @@ def input_data(self):
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
)

def test_roll_op_api(self):
self.input_data()

@test_with_pir_api
def test_roll_op_api_case1(self):
paddle.enable_static()
# case 1:
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='x', shape=[-1, 3], dtype='float32')
x.desc.set_need_check_feed(False)
data_x = np.array(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
).astype('float32')
z = paddle.roll(x, shifts=1)
exe = base.Executor(base.CPUPlace())
exe = paddle.static.Executor(paddle.CPUPlace())
(res,) = exe.run(
feed={'x': self.data_x}, fetch_list=[z.name], return_numpy=False
paddle.static.default_main_program(),
feed={'x': data_x},
fetch_list=[z],
return_numpy=False,
)
expect_out = np.array(
[[9.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]
)
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
paddle.disable_static()

# case 2:
with program_guard(Program(), Program()):
@test_with_pir_api
def test_roll_op_api_case2(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='x', shape=[-1, 3], dtype='float32')
x.desc.set_need_check_feed(False)
data_x = np.array(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
).astype('float32')
z = paddle.roll(x, shifts=1, axis=0)
exe = base.Executor(base.CPUPlace())
exe = paddle.static.Executor(paddle.CPUPlace())
(res,) = exe.run(
feed={'x': self.data_x}, fetch_list=[z.name], return_numpy=False
paddle.static.default_main_program(),
feed={'x': data_x},
fetch_list=[z],
return_numpy=False,
)
expect_out = np.array(
[[7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
)
expect_out = np.array(
[[7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
)
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
paddle.disable_static()

def test_dygraph_api(self):
self.input_data()
Expand All @@ -214,22 +243,27 @@ def test_dygraph_api(self):
)
np.testing.assert_allclose(expect_out, np_z, rtol=1e-05)

@test_with_pir_api
def test_roll_op_false(self):
self.input_data()

def test_axis_out_range():
with program_guard(Program(), Program()):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='x', shape=[-1, 3], dtype='float32')
x.desc.set_need_check_feed(False)
data_x = np.array(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
).astype('float32')
z = paddle.roll(x, shifts=1, axis=10)
exe = base.Executor(base.CPUPlace())
(res,) = exe.run(
feed={'x': self.data_x},
fetch_list=[z.name],
feed={'x': data_x},
fetch_list=[z],
return_numpy=False,
)

self.assertRaises(ValueError, test_axis_out_range)
paddle.disable_static()

def test_shifts_as_tensor_dygraph(self):
with base.dygraph.guard():
Expand All @@ -241,23 +275,28 @@ def test_shifts_as_tensor_dygraph(self):
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
np.testing.assert_allclose(out, expected_out, rtol=1e-05)

@test_with_pir_api
def test_shifts_as_tensor_static(self):
with program_guard(Program(), Program()):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.arange(9).reshape([3, 3]).astype('float32')
shape = paddle.shape(x)
shifts = shape // 2
axes = [0, 1]
out = paddle.roll(x, shifts=shifts, axis=axes)
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])

exe = base.Executor(base.CPUPlace())
exe = paddle.static.Executor(paddle.CPUPlace())
[out_np] = exe.run(fetch_list=[out])
np.testing.assert_allclose(out_np, expected_out, rtol=1e-05)

if paddle.is_compiled_with_cuda():
exe = base.Executor(base.CPUPlace())
[out_np] = exe.run(fetch_list=[out])
np.testing.assert_allclose(out_np, expected_out, rtol=1e-05)
paddle.disable_static()


if __name__ == "__main__":
Expand Down
Loading