Skip to content

Commit c66c543

Browse files
authored
【PIR API adaptor No.146】inner (#58931)
1 parent 97b3447 commit c66c543

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

python/paddle/tensor/math.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2478,7 +2478,7 @@ def inner(x, y, name=None):
24782478
24792479
24802480
"""
2481-
if x.size == 1 or y.size == 1:
2481+
if in_dynamic_mode() and (x.size == 1 or y.size == 1):
24822482
return multiply(x, y)
24832483
else:
24842484
xshape = x.shape
@@ -2488,8 +2488,10 @@ def inner(x, y, name=None):
24882488
nx = x.reshape((-1, xshape[-1]))
24892489
ny = y.reshape((-1, yshape[-1]))
24902490

2491-
if in_dynamic_mode():
2492-
return _C_ops.matmul(nx, ny.T, False, False).reshape(dstshape)
2491+
if in_dynamic_or_pir_mode():
2492+
return _C_ops.matmul(
2493+
nx, paddle.transpose(ny, [1, 0]), False, False
2494+
).reshape(dstshape)
24932495
else:
24942496

24952497
def __check_input(x, y):
@@ -2513,7 +2515,6 @@ def __check_input(x, y):
25132515
)
25142516

25152517
__check_input(nx, ny)
2516-
25172518
helper = LayerHelper('inner', **locals())
25182519
out = helper.create_variable_for_type_inference(dtype=nx.dtype)
25192520
helper.append_op(

test/legacy_test/test_inner.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
import numpy as np
1818

1919
import paddle
20-
from paddle.static import Program, program_guard
20+
from paddle.pir_utils import test_with_pir_api
2121

2222

2323
class TestMultiplyApi(unittest.TestCase):
2424
def _run_static_graph_case(self, x_data, y_data):
25-
with program_guard(Program(), Program()):
25+
with paddle.static.program_guard(
26+
paddle.static.Program(), paddle.static.Program()
27+
):
2628
paddle.enable_static()
2729
x = paddle.static.data(
2830
name='x', shape=x_data.shape, dtype=x_data.dtype
@@ -53,45 +55,52 @@ def _run_dynamic_graph_case(self, x_data, y_data):
5355
res = paddle.inner(x, y)
5456
return res.numpy()
5557

56-
def test_multiply(self):
57-
np.random.seed(7)
58-
58+
@test_with_pir_api
59+
def test_multiply_static_case1(self):
5960
# test static computation graph: 3-d array
6061
x_data = np.random.rand(2, 10, 10).astype(np.float64)
6162
y_data = np.random.rand(2, 5, 10).astype(np.float64)
6263
res = self._run_static_graph_case(x_data, y_data)
6364
np.testing.assert_allclose(res, np.inner(x_data, y_data), rtol=1e-05)
6465

66+
@test_with_pir_api
67+
def test_multiply_static_case2(self):
6568
# test static computation graph: 2-d array
6669
x_data = np.random.rand(200, 5).astype(np.float64)
6770
y_data = np.random.rand(50, 5).astype(np.float64)
6871
res = self._run_static_graph_case(x_data, y_data)
6972
np.testing.assert_allclose(res, np.inner(x_data, y_data), rtol=1e-05)
7073

74+
@test_with_pir_api
75+
def test_multiply_static_case3(self):
7176
# test static computation graph: 1-d array
7277
x_data = np.random.rand(50).astype(np.float64)
7378
y_data = np.random.rand(50).astype(np.float64)
7479
res = self._run_static_graph_case(x_data, y_data)
7580
np.testing.assert_allclose(res, np.inner(x_data, y_data), rtol=1e-05)
7681

82+
def test_multiply_dynamic_case1(self):
7783
# test dynamic computation graph: 3-d array
7884
x_data = np.random.rand(5, 10, 10).astype(np.float64)
7985
y_data = np.random.rand(2, 10).astype(np.float64)
8086
res = self._run_dynamic_graph_case(x_data, y_data)
8187
np.testing.assert_allclose(res, np.inner(x_data, y_data), rtol=1e-05)
8288

89+
def test_multiply_dynamic_case2(self):
8390
# test dynamic computation graph: 2-d array
8491
x_data = np.random.rand(20, 50).astype(np.float64)
8592
y_data = np.random.rand(50).astype(np.float64)
8693
res = self._run_dynamic_graph_case(x_data, y_data)
8794
np.testing.assert_allclose(res, np.inner(x_data, y_data), rtol=1e-05)
8895

96+
def test_multiply_dynamic_case3(self):
8997
# test dynamic computation graph: Scalar
9098
x_data = np.random.rand(20, 10).astype(np.float32)
9199
y_data = np.random.rand(1).astype(np.float32).item()
92100
res = self._run_dynamic_graph_case(x_data, y_data)
93101
np.testing.assert_allclose(res, np.inner(x_data, y_data), rtol=1e-05)
94102

103+
def test_multiply_dynamic_case4(self):
95104
# test dynamic computation graph: 2-d array Complex
96105
x_data = np.random.rand(20, 50).astype(
97106
np.float64
@@ -102,6 +111,7 @@ def test_multiply(self):
102111
res = self._run_dynamic_graph_case(x_data, y_data)
103112
np.testing.assert_allclose(res, np.inner(x_data, y_data), rtol=1e-05)
104113

114+
def test_multiply_dynamic_case5(self):
105115
# test dynamic computation graph: 3-d array Complex
106116
x_data = np.random.rand(5, 10, 10).astype(
107117
np.float64
@@ -114,41 +124,49 @@ def test_multiply(self):
114124

115125

116126
class TestMultiplyError(unittest.TestCase):
117-
def test_errors(self):
127+
def test_errors_static_case1(self):
118128
# test static computation graph: dtype can not be int8
119129
paddle.enable_static()
120-
with program_guard(Program(), Program()):
130+
with paddle.static.program_guard(
131+
paddle.static.Program(), paddle.static.Program()
132+
):
121133
x = paddle.static.data(name='x', shape=[100], dtype=np.int8)
122134
y = paddle.static.data(name='y', shape=[100], dtype=np.int8)
123135
self.assertRaises(TypeError, paddle.inner, x, y)
124136

137+
def test_errors_static_case2(self):
125138
# test static computation graph: inputs must be broadcastable
126-
with program_guard(Program(), Program()):
139+
paddle.enable_static()
140+
with paddle.static.program_guard(
141+
paddle.static.Program(), paddle.static.Program()
142+
):
127143
x = paddle.static.data(name='x', shape=[20, 50], dtype=np.float64)
128144
y = paddle.static.data(name='y', shape=[20], dtype=np.float64)
129145
self.assertRaises(ValueError, paddle.inner, x, y)
130146

131-
np.random.seed(7)
132-
147+
def test_errors_dynamic_case1(self):
133148
# test dynamic computation graph: inputs must be broadcastable
134149
x_data = np.random.rand(20, 5)
135150
y_data = np.random.rand(10, 2)
136151
x = paddle.to_tensor(x_data)
137152
y = paddle.to_tensor(y_data)
138153
self.assertRaises(ValueError, paddle.inner, x, y)
139154

155+
def test_errors_dynamic_case2(self):
140156
# test dynamic computation graph: dtype must be Tensor type
141157
x_data = np.random.randn(200).astype(np.float64)
142158
y_data = np.random.randn(200).astype(np.float64)
143159
y = paddle.to_tensor(y_data)
144160
self.assertRaises(TypeError, paddle.inner, x_data, y)
145161

162+
def test_errors_dynamic_case3(self):
146163
# test dynamic computation graph: dtype must be Tensor type
147164
x_data = np.random.randn(200).astype(np.float64)
148165
y_data = np.random.randn(200).astype(np.float64)
149166
x = paddle.to_tensor(x_data)
150167
self.assertRaises(TypeError, paddle.inner, x, y_data)
151168

169+
def test_errors_dynamic_case4(self):
152170
# test dynamic computation graph: dtype must be Tensor type
153171
x_data = np.random.randn(200).astype(np.float32)
154172
y_data = np.random.randn(200).astype(np.float32)

0 commit comments

Comments
 (0)