Skip to content

Commit d963050

Browse files
authored
【PIR API adaptor No.45-47】Migrate some ops into pir (#58682)
1 parent cbafa02 commit d963050

File tree

6 files changed

+74
-39
lines changed

6 files changed

+74
-39
lines changed

python/paddle/nn/functional/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def base_softmax_with_cross_entropy(
271271
)
272272
if input_dims - 1 == label_dims:
273273
label = paddle.unsqueeze(label, axis=axis)
274-
if in_dynamic_mode():
274+
if in_dynamic_or_pir_mode():
275275
softmax, loss = _C_ops.cross_entropy_with_softmax(
276276
logits,
277277
label,

python/paddle/tensor/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1467,7 +1467,7 @@ def cross(x, y, axis=9, name=None):
14671467
[0., 0., 0.],
14681468
[0., 0., 0.]])
14691469
"""
1470-
if in_dynamic_mode():
1470+
if in_dynamic_or_pir_mode():
14711471
axis = K_DEFAULT_DIM if axis is None else axis
14721472
return _C_ops.cross(x, y, axis)
14731473
else:

python/paddle/tensor/manipulation.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -781,10 +781,16 @@ def crop(x, shape=None, offsets=None, name=None):
781781
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'crop_tensor'
782782
)
783783
check_type(
784-
shape, 'shape', (list, tuple, Variable, type(None)), 'crop_tensor'
784+
shape,
785+
'shape',
786+
(list, tuple, Variable, type(None), paddle.pir.OpResult),
787+
'crop_tensor',
785788
)
786789
check_type(
787-
offsets, 'offsets', (list, tuple, Variable, type(None)), 'crop_tensor'
790+
offsets,
791+
'offsets',
792+
(list, tuple, Variable, type(None), paddle.pir.OpResult),
793+
'crop_tensor',
788794
)
789795

790796
if offsets is None:
@@ -793,7 +799,7 @@ def crop(x, shape=None, offsets=None, name=None):
793799
if shape is None:
794800
shape = x.shape
795801

796-
if in_dynamic_mode():
802+
if in_dynamic_or_pir_mode():
797803
return _C_ops.crop(x, shape, offsets)
798804

799805
out = helper.create_variable_for_type_inference(x.dtype)

test/legacy_test/test_crop_tensor_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ def initTestCase(self):
8181
self.offsets = [1, 2]
8282

8383
def test_check_output(self):
84-
self.check_output()
84+
self.check_output(check_pir=True)
8585

8686
def test_check_grad_normal(self):
87-
self.check_grad(['X'], 'Out')
87+
self.check_grad(['X'], 'Out', check_pir=True)
8888

8989

9090
class TestCase1(TestCropTensorOp):
@@ -182,10 +182,10 @@ def initTestCase(self):
182182
self.shape_attr = [0, 0]
183183

184184
def test_check_output(self):
185-
self.check_output()
185+
self.check_output(check_pir=True)
186186

187187
def test_check_grad_normal(self):
188-
self.check_grad(["X"], "Out")
188+
self.check_grad(["X"], "Out", check_pir=True)
189189

190190

191191
class TestCropTensorOpTensorAttrCase1(TestCropTensorOpTensorAttr):

test/legacy_test/test_cross_op.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
import paddle
2121
from paddle import base
22-
from paddle.base import Program, core, program_guard
22+
from paddle.base import core
23+
from paddle.pir_utils import test_with_pir_api
2324

2425

2526
class TestCrossOp(OpTest):
@@ -47,10 +48,10 @@ def init_output(self):
4748
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}
4849

4950
def test_check_output(self):
50-
self.check_output()
51+
self.check_output(check_pir=True)
5152

5253
def test_check_grad_normal(self):
53-
self.check_grad(['X', 'Y'], 'Out')
54+
self.check_grad(['X', 'Y'], 'Out', check_pir=True)
5455

5556

5657
class TestCrossOpCase1(TestCrossOp):
@@ -116,13 +117,15 @@ def test_check_output(self):
116117
if core.is_compiled_with_cuda():
117118
place = core.CUDAPlace(0)
118119
if core.is_bfloat16_supported(place):
119-
self.check_output_with_place(place)
120+
self.check_output_with_place(place, check_pir=True)
120121

121122
def test_check_grad_normal(self):
122123
if core.is_compiled_with_cuda():
123124
place = core.CUDAPlace(0)
124125
if core.is_bfloat16_supported(place):
125-
self.check_grad_with_place(place, ['X', 'Y'], 'Out')
126+
self.check_grad_with_place(
127+
place, ['X', 'Y'], 'Out', check_pir=True
128+
)
126129

127130

128131
class TestCrossAPI(unittest.TestCase):
@@ -134,43 +137,56 @@ def input_data(self):
134137
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
135138
).astype('float32')
136139

140+
@test_with_pir_api
137141
def test_cross_api(self):
138142
self.input_data()
139143

144+
main = paddle.static.Program()
145+
startup = paddle.static.Program()
140146
# case 1:
141-
with program_guard(Program(), Program()):
147+
with paddle.static.program_guard(main, startup):
142148
x = paddle.static.data(name='x', shape=[-1, 3], dtype="float32")
143149
y = paddle.static.data(name='y', shape=[-1, 3], dtype="float32")
144150
z = paddle.cross(x, y, axis=1)
145151
exe = base.Executor(base.CPUPlace())
146152
(res,) = exe.run(
153+
main,
147154
feed={'x': self.data_x, 'y': self.data_y},
148-
fetch_list=[z.name],
155+
fetch_list=[z],
149156
return_numpy=False,
150157
)
151158
expect_out = np.array(
152159
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
153160
)
154161
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
155162

163+
main = paddle.static.Program()
164+
startup = paddle.static.Program()
156165
# case 2:
157-
with program_guard(Program(), Program()):
166+
with paddle.static.program_guard(main, startup):
158167
x = paddle.static.data(name='x', shape=[-1, 3], dtype="float32")
159168
y = paddle.static.data(name='y', shape=[-1, 3], dtype="float32")
160169
z = paddle.cross(x, y)
161170
exe = base.Executor(base.CPUPlace())
162171
(res,) = exe.run(
172+
main,
163173
feed={'x': self.data_x, 'y': self.data_y},
164-
fetch_list=[z.name],
174+
fetch_list=[z],
165175
return_numpy=False,
166176
)
167177
expect_out = np.array(
168178
[[-1.0, -1.0, -1.0], [2.0, 2.0, 2.0], [-1.0, -1.0, -1.0]]
169179
)
170180
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
171181

172-
# case 3:
173-
with program_guard(Program(), Program()):
182+
def test_cross_api1(self):
183+
self.input_data()
184+
185+
main = paddle.static.Program()
186+
startup = paddle.static.Program()
187+
188+
# case 1:
189+
with paddle.static.program_guard(main, startup):
174190
x = paddle.static.data(name="x", shape=[-1, 3], dtype="float32")
175191
y = paddle.static.data(name='y', shape=[-1, 3], dtype='float32')
176192

test/legacy_test/test_softmax_with_cross_entropy_op.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -153,27 +153,30 @@ def setUp(self):
153153

154154
def test_check_output(self):
155155
if self.python_api is not None:
156-
self.check_output()
157-
self.check_output()
156+
self.check_output(check_pir=True)
157+
self.check_output(check_pir=True)
158158

159159
def test_check_grad(self):
160160
if core.is_compiled_with_rocm():
161161
if self.python_api is not None:
162162
self.check_grad(
163-
["Logits"],
164-
"Loss",
165-
max_relative_error=5e-1,
163+
["Logits"], "Loss", max_relative_error=5e-1, check_pir=False
166164
)
167165
# HIP will have accuracy fail when using float32 in CPU place
168-
self.check_grad(["Logits"], "Loss", max_relative_error=5e-1)
166+
self.check_grad(
167+
["Logits"], "Loss", max_relative_error=5e-1, check_pir=False
168+
)
169169
else:
170170
if self.python_api is not None:
171171
self.check_grad(
172172
["Logits"],
173173
"Loss",
174174
numeric_grad_delta=0.001,
175+
check_pir=False,
175176
)
176-
self.check_grad(["Logits"], "Loss", numeric_grad_delta=0.001)
177+
self.check_grad(
178+
["Logits"], "Loss", numeric_grad_delta=0.001, check_pir=False
179+
)
177180

178181

179182
class TestSoftmaxWithCrossEntropyOpInt32(TestSoftmaxWithCrossEntropyOp):
@@ -509,13 +512,15 @@ def setUp(self):
509512

510513
def test_check_output(self):
511514
if self.python_api is not None:
512-
self.check_output()
513-
self.check_output()
515+
self.check_output(check_pir=True)
516+
self.check_output(check_pir=True)
514517

515518
def test_check_grad(self):
516519
if self.python_api is not None:
517-
self.check_grad(["Logits"], "Loss")
518-
self.check_grad(["Logits"], "Loss", max_relative_error=0.1)
520+
self.check_grad(["Logits"], "Loss", check_pir=False)
521+
self.check_grad(
522+
["Logits"], "Loss", max_relative_error=0.1, check_pir=False
523+
)
519524

520525

521526
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16(
@@ -534,8 +539,12 @@ def initParams(self):
534539

535540
def test_check_grad(self):
536541
if self.python_api is not None:
537-
self.check_grad(["Logits"], "Loss", max_relative_error=0.1)
538-
self.check_grad(["Logits"], "Loss", max_relative_error=0.1)
542+
self.check_grad(
543+
["Logits"], "Loss", max_relative_error=0.1, check_pir=False
544+
)
545+
self.check_grad(
546+
["Logits"], "Loss", max_relative_error=0.1, check_pir=False
547+
)
539548

540549

541550
class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp):
@@ -557,19 +566,23 @@ def initParams(self):
557566

558567
def test_check_output(self):
559568
if self.python_api is not None:
560-
self.check_output()
561-
self.check_output()
569+
self.check_output(check_pir=True)
570+
self.check_output(check_pir=True)
562571

563572
def test_check_grad(self):
564573
if core.is_compiled_with_rocm():
565574
# HIP will have accuracy fail when using float32 in CPU place
566575
if self.python_api is not None:
567-
self.check_grad(["Logits"], "Loss", max_relative_error=0.1)
568-
self.check_grad(["Logits"], "Loss", max_relative_error=0.1)
576+
self.check_grad(
577+
["Logits"], "Loss", max_relative_error=0.1, check_pir=False
578+
)
579+
self.check_grad(
580+
["Logits"], "Loss", max_relative_error=0.1, check_pir=False
581+
)
569582
else:
570583
if self.python_api is not None:
571-
self.check_grad(["Logits"], "Loss")
572-
self.check_grad(["Logits"], "Loss")
584+
self.check_grad(["Logits"], "Loss", check_pir=False)
585+
self.check_grad(["Logits"], "Loss", check_pir=False)
573586

574587

575588
class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp):

0 commit comments

Comments
 (0)