Skip to content

Commit 2499332

Browse files
authored
【PIR API adaptor No.71、186、203、215】 Migrate paddle.fill_diagonal_tensor,roll,shard_index,strided_slice into pir (#59091)
1 parent d7f0fb8 commit 2499332

File tree

4 files changed

+184
-122
lines changed

4 files changed

+184
-122
lines changed

python/paddle/tensor/manipulation.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
660660
[[-1]
661661
[ 1]]
662662
"""
663-
if in_dynamic_mode():
663+
if in_dynamic_or_pir_mode():
664664
return _C_ops.shard_index(
665665
input, index_num, nshards, shard_id, ignore_value
666666
)
@@ -1021,11 +1021,11 @@ def _fill_diagonal_tensor_impl(x, y, offset=0, dim1=0, dim2=1, inplace=False):
10211021
if len(y.shape) == 1:
10221022
y = y.reshape([1, -1])
10231023

1024-
if inplace:
1025-
return _C_ops.fill_diagonal_tensor_(x, y, offset, dim1, dim2)
1026-
1027-
if in_dynamic_mode():
1028-
return _C_ops.fill_diagonal_tensor(x, y, offset, dim1, dim2)
1024+
if in_dynamic_or_pir_mode():
1025+
if inplace:
1026+
return _C_ops.fill_diagonal_tensor_(x, y, offset, dim1, dim2)
1027+
else:
1028+
return _C_ops.fill_diagonal_tensor(x, y, offset, dim1, dim2)
10291029
else:
10301030
check_variable_and_dtype(
10311031
x,
@@ -1843,7 +1843,7 @@ def roll(x, shifts, axis=None, name=None):
18431843
else:
18441844
axis = []
18451845

1846-
if in_dynamic_mode():
1846+
if in_dynamic_or_pir_mode():
18471847
return _C_ops.roll(x, shifts, axis)
18481848
else:
18491849
check_variable_and_dtype(
@@ -4411,7 +4411,7 @@ def strided_slice(x, axes, starts, ends, strides, name=None):
44114411
>>> sliced_2 = paddle.strided_slice(x, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides_2)
44124412
>>> # sliced_2 is x[:, 1:3:1, 0:2:1, 2:4:2].
44134413
"""
4414-
if in_dynamic_mode():
4414+
if in_dynamic_or_pir_mode():
44154415
return _C_ops.strided_slice(x, axes, starts, ends, strides)
44164416
else:
44174417
helper = LayerHelper('strided_slice', **locals())

test/legacy_test/test_fill_diagonal_tensor_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ def init_kernel_type(self):
103103
self.dtype = np.float64
104104

105105
def test_check_output(self):
106-
self.check_output()
106+
self.check_output(check_pir=True)
107107

108108
def test_check_grad(self):
109-
self.check_grad(['X'], 'Out')
109+
self.check_grad(['X'], 'Out', check_pir=True)
110110

111111

112112
class TensorFillDiagTensor_Test2(TensorFillDiagTensor_Test):
@@ -193,11 +193,11 @@ def init_input_output(self):
193193

194194
def test_check_output(self):
195195
place = core.CUDAPlace(0)
196-
self.check_output_with_place(place)
196+
self.check_output_with_place(place, check_pir=True)
197197

198198
def test_check_grad(self):
199199
place = core.CUDAPlace(0)
200-
self.check_grad_with_place(place, ['X'], 'Out')
200+
self.check_grad_with_place(place, ['X'], 'Out', check_pir=True)
201201

202202

203203
if __name__ == '__main__':

test/legacy_test/test_roll_op.py

Lines changed: 73 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
import paddle
2121
from paddle import base
22-
from paddle.base import Program, core, program_guard
22+
from paddle.base import core
23+
from paddle.pir_utils import test_with_pir_api
2324

2425

2526
class TestRollOp(OpTest):
@@ -48,10 +49,10 @@ def init_dtype_type(self):
4849
self.axis = [0, -2]
4950

5051
def test_check_output(self):
51-
self.check_output(check_prim=True)
52+
self.check_output(check_prim=True, check_pir=True)
5253

5354
def test_check_grad_normal(self):
54-
self.check_grad(['X'], 'Out', check_prim=True)
55+
self.check_grad(['X'], 'Out', check_prim=True, check_pir=True)
5556

5657

5758
class TestRollOpCase2(TestRollOp):
@@ -108,10 +109,14 @@ def init_dtype_type(self):
108109
self.place = core.CUDAPlace(0)
109110

110111
def test_check_output(self):
111-
self.check_output_with_place(self.place, check_prim=True)
112+
self.check_output_with_place(
113+
self.place, check_prim=True, check_pir=True
114+
)
112115

113116
def test_check_grad_normal(self):
114-
self.check_grad_with_place(self.place, ['X'], 'Out', check_prim=True)
117+
self.check_grad_with_place(
118+
self.place, ['X'], 'Out', check_prim=True, check_pir=True
119+
)
115120

116121

117122
@unittest.skipIf(
@@ -128,10 +133,14 @@ def init_dtype_type(self):
128133
self.place = core.CUDAPlace(0)
129134

130135
def test_check_output(self):
131-
self.check_output_with_place(self.place, check_prim=True)
136+
self.check_output_with_place(
137+
self.place, check_prim=True, check_pir=True
138+
)
132139

133140
def test_check_grad_normal(self):
134-
self.check_grad_with_place(self.place, ['X'], 'Out', check_prim=True)
141+
self.check_grad_with_place(
142+
self.place, ['X'], 'Out', check_prim=True, check_pir=True
143+
)
135144

136145

137146
@unittest.skipIf(
@@ -148,10 +157,14 @@ def init_dtype_type(self):
148157
self.place = core.CUDAPlace(0)
149158

150159
def test_check_output(self):
151-
self.check_output_with_place(self.place, check_prim=True)
160+
self.check_output_with_place(
161+
self.place, check_prim=True, check_pir=True
162+
)
152163

153164
def test_check_grad_normal(self):
154-
self.check_grad_with_place(self.place, ['X'], 'Out', check_prim=True)
165+
self.check_grad_with_place(
166+
self.place, ['X'], 'Out', check_prim=True, check_pir=True
167+
)
155168

156169

157170
class TestRollAPI(unittest.TestCase):
@@ -160,37 +173,53 @@ def input_data(self):
160173
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
161174
)
162175

163-
def test_roll_op_api(self):
164-
self.input_data()
165-
176+
@test_with_pir_api
177+
def test_roll_op_api_case1(self):
166178
paddle.enable_static()
167-
# case 1:
168-
with program_guard(Program(), Program()):
179+
with paddle.static.program_guard(
180+
paddle.static.Program(), paddle.static.Program()
181+
):
169182
x = paddle.static.data(name='x', shape=[-1, 3], dtype='float32')
170-
x.desc.set_need_check_feed(False)
183+
data_x = np.array(
184+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
185+
).astype('float32')
171186
z = paddle.roll(x, shifts=1)
172-
exe = base.Executor(base.CPUPlace())
187+
exe = paddle.static.Executor(paddle.CPUPlace())
173188
(res,) = exe.run(
174-
feed={'x': self.data_x}, fetch_list=[z.name], return_numpy=False
189+
paddle.static.default_main_program(),
190+
feed={'x': data_x},
191+
fetch_list=[z],
192+
return_numpy=False,
175193
)
176194
expect_out = np.array(
177195
[[9.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]
178196
)
179-
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
197+
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
198+
paddle.disable_static()
180199

181-
# case 2:
182-
with program_guard(Program(), Program()):
200+
@test_with_pir_api
201+
def test_roll_op_api_case2(self):
202+
paddle.enable_static()
203+
with paddle.static.program_guard(
204+
paddle.static.Program(), paddle.static.Program()
205+
):
183206
x = paddle.static.data(name='x', shape=[-1, 3], dtype='float32')
184-
x.desc.set_need_check_feed(False)
207+
data_x = np.array(
208+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
209+
).astype('float32')
185210
z = paddle.roll(x, shifts=1, axis=0)
186-
exe = base.Executor(base.CPUPlace())
211+
exe = paddle.static.Executor(paddle.CPUPlace())
187212
(res,) = exe.run(
188-
feed={'x': self.data_x}, fetch_list=[z.name], return_numpy=False
213+
paddle.static.default_main_program(),
214+
feed={'x': data_x},
215+
fetch_list=[z],
216+
return_numpy=False,
217+
)
218+
expect_out = np.array(
219+
[[7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
189220
)
190-
expect_out = np.array(
191-
[[7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
192-
)
193221
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
222+
paddle.disable_static()
194223

195224
def test_dygraph_api(self):
196225
self.input_data()
@@ -214,22 +243,27 @@ def test_dygraph_api(self):
214243
)
215244
np.testing.assert_allclose(expect_out, np_z, rtol=1e-05)
216245

246+
@test_with_pir_api
217247
def test_roll_op_false(self):
218-
self.input_data()
219-
220248
def test_axis_out_range():
221-
with program_guard(Program(), Program()):
249+
paddle.enable_static()
250+
with paddle.static.program_guard(
251+
paddle.static.Program(), paddle.static.Program()
252+
):
222253
x = paddle.static.data(name='x', shape=[-1, 3], dtype='float32')
223-
x.desc.set_need_check_feed(False)
254+
data_x = np.array(
255+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
256+
).astype('float32')
224257
z = paddle.roll(x, shifts=1, axis=10)
225258
exe = base.Executor(base.CPUPlace())
226259
(res,) = exe.run(
227-
feed={'x': self.data_x},
228-
fetch_list=[z.name],
260+
feed={'x': data_x},
261+
fetch_list=[z],
229262
return_numpy=False,
230263
)
231264

232265
self.assertRaises(ValueError, test_axis_out_range)
266+
paddle.disable_static()
233267

234268
def test_shifts_as_tensor_dygraph(self):
235269
with base.dygraph.guard():
@@ -241,23 +275,28 @@ def test_shifts_as_tensor_dygraph(self):
241275
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
242276
np.testing.assert_allclose(out, expected_out, rtol=1e-05)
243277

278+
@test_with_pir_api
244279
def test_shifts_as_tensor_static(self):
245-
with program_guard(Program(), Program()):
280+
paddle.enable_static()
281+
with paddle.static.program_guard(
282+
paddle.static.Program(), paddle.static.Program()
283+
):
246284
x = paddle.arange(9).reshape([3, 3]).astype('float32')
247285
shape = paddle.shape(x)
248286
shifts = shape // 2
249287
axes = [0, 1]
250288
out = paddle.roll(x, shifts=shifts, axis=axes)
251289
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
252290

253-
exe = base.Executor(base.CPUPlace())
291+
exe = paddle.static.Executor(paddle.CPUPlace())
254292
[out_np] = exe.run(fetch_list=[out])
255293
np.testing.assert_allclose(out_np, expected_out, rtol=1e-05)
256294

257295
if paddle.is_compiled_with_cuda():
258296
exe = base.Executor(base.CPUPlace())
259297
[out_np] = exe.run(fetch_list=[out])
260298
np.testing.assert_allclose(out_np, expected_out, rtol=1e-05)
299+
paddle.disable_static()
261300

262301

263302
if __name__ == "__main__":

0 commit comments

Comments
 (0)