Skip to content

Commit f15a1ca

Browse files
committed
[CINN] Add unittest for gather_nd op
1 parent 8b79000 commit f15a1ca

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy
18+
import utils
19+
20+
import paddle
21+
22+
23+
class TestGatherNd(unittest.TestCase):
24+
# Note that GatherNd is also used in index_put, so we can test it by using index_put.
25+
def eval(self, dy_compute, inputs, input_spec=None):
26+
dy_out = dy_compute(*inputs)
27+
28+
static_compute = utils.apply_to_static(
29+
dy_compute, use_cinn=True, input_spec=None
30+
)
31+
st_out = static_compute(*inputs)
32+
33+
for a, b in zip(
34+
paddle.utils.flatten(dy_out), paddle.utils.flatten(st_out)
35+
):
36+
numpy.testing.assert_allclose(a, b, atol=1e-6, rtol=1e-6)
37+
38+
@staticmethod
39+
def get_input(x_shape, indices_shape, value_shape, has_negative_index=True):
40+
n_indices = indices_shape[0]
41+
index_dim_size = indices_shape[1] if len(indices_shape) > 1 else 1
42+
43+
x_pd = paddle.randn(x_shape)
44+
x_pd.stop_gradient = False
45+
46+
indices_pd = tuple(
47+
[
48+
paddle.randint(
49+
-x_shape[i] if has_negative_index else 0,
50+
x_shape[i],
51+
[n_indices],
52+
)
53+
for i in range(max(index_dim_size, 1))
54+
]
55+
)
56+
value_pd = paddle.randn(value_shape)
57+
value_pd.stop_gradient = False
58+
59+
dout_pd = paddle.randn(x_shape)
60+
dout_pd.stop_gradient = False
61+
return x_pd, indices_pd, value_pd, dout_pd
62+
63+
@staticmethod
64+
def get_input_spec(indice_dim):
65+
return [
66+
paddle.static.InputSpec(shape=[-1, -1], dtype="float32"),
67+
tuple(
68+
paddle.static.InputSpec(shape=[-1], dtype="int64")
69+
for _ in range(indice_dim)
70+
),
71+
paddle.static.InputSpec(shape=[-1, -1], dtype="float32"),
72+
paddle.static.InputSpec(shape=[-1, -1], dtype="float32"),
73+
]
74+
75+
@staticmethod
76+
def index_put_grad(x, indices, v, dy):
77+
y = paddle.index_put(x, indices, v, True)
78+
return paddle.grad(y, [x, v], dy)
79+
80+
def test_index_put_grad_non_negative_index(self):
81+
x_pd, indices_pd, value_pd, dout_pd = self.get_input(
82+
[12, 13, 14], [88, 2], [88, 14], False
83+
)
84+
85+
self.eval(
86+
TestGatherNd.index_put_grad,
87+
[x_pd, indices_pd, value_pd, dout_pd],
88+
input_spec=self.get_input_spec(2),
89+
)
90+
91+
def test_index_put_grad_negative_index_1(self):
92+
x_pd, indices_pd, value_pd, dout_pd = self.get_input(
93+
[12, 13, 14], [88, 1], [88, 13, 14]
94+
)
95+
96+
self.eval(
97+
TestGatherNd.index_put_grad,
98+
[x_pd, indices_pd, value_pd, dout_pd],
99+
input_spec=self.get_input_spec(1),
100+
)
101+
102+
def test_index_put_grad_negative_index_2(self):
103+
x_pd, indices_pd, value_pd, dout_pd = self.get_input(
104+
[16, 16], [20, 2], [20]
105+
)
106+
107+
self.eval(
108+
TestGatherNd.index_put_grad,
109+
[x_pd, indices_pd, value_pd, dout_pd],
110+
input_spec=self.get_input_spec(2),
111+
)
112+
113+
def test_gather_nd_fusion(self):
114+
x_pd = paddle.randn([256, 128])
115+
y_pd = paddle.randn_like(x_pd)
116+
z_pd = paddle.randn([100])
117+
indices_pd = paddle.randint(-128, 128, [100, 2])
118+
119+
def func(x, y, z, indices):
120+
return paddle.gather_nd(x * y, indices) + z
121+
122+
self.eval(func, [x_pd, y_pd, z_pd, indices_pd])
123+
124+
125+
if __name__ == "__main__":
126+
unittest.main()

0 commit comments

Comments
 (0)