Skip to content

Commit ef06488

Browse files
authored
[API compatibility] add scatter_add_ api (#74632)
* add scatter_add inplace api * change position
1 parent dab96d2 commit ef06488

File tree

4 files changed

+204
-0
lines changed

4 files changed

+204
-0
lines changed

python/paddle/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@
369369
scatter,
370370
scatter_,
371371
scatter_add,
372+
scatter_add_,
372373
scatter_nd,
373374
scatter_nd_add,
374375
scatter_reduce,
@@ -1272,6 +1273,7 @@ def __dir__(self):
12721273
'multigammaln_',
12731274
'nan_to_num',
12741275
'nan_to_num_',
1276+
'scatter_add_',
12751277
'heaviside',
12761278
'tril_indices',
12771279
'index_add',

python/paddle/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@
206206
scatter,
207207
scatter_,
208208
scatter_add,
209+
scatter_add_,
209210
scatter_nd,
210211
scatter_nd_add,
211212
scatter_reduce,
@@ -830,6 +831,7 @@
830831
'bernoulli_',
831832
'exponential_',
832833
'heaviside',
834+
'scatter_add_',
833835
'index_add',
834836
"index_add_",
835837
'index_put',

python/paddle/tensor/manipulation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7346,6 +7346,22 @@ def put_along_axis_(
73467346
)
73477347

73487348

7349+
def scatter_add_(
7350+
input: Tensor,
7351+
dim: int,
7352+
index: Tensor,
7353+
src: Tensor,
7354+
) -> Tensor:
7355+
"""
7356+
Inplace version of ``scatter_add`` API, the output Tensor will be inplaced with input ``input``.
7357+
Please refer to :ref:`api_paddle_scatter_add`.
7358+
"""
7359+
7360+
return put_along_axis_(
7361+
input, index, src, dim, 'add', include_self=True, broadcast=False
7362+
)
7363+
7364+
73497365
def index_add(
73507366
x: Tensor, index: Tensor, axis: int, value: Tensor, name: str | None = None
73517367
) -> Tensor:
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
import unittest
17+
18+
import numpy as np
19+
from op_test import get_places
20+
21+
import paddle
22+
from paddle.framework import core
23+
24+
25+
class TestScatterAddInplaceAPI(unittest.TestCase):
26+
def setUp(self):
27+
np.random.seed(0)
28+
self.shape = [10, 10]
29+
self.index_shape = [10, 10]
30+
self.index_np = np.random.randint(0, 10, (10, 10)).astype('int64')
31+
self.x_np = np.random.random(self.shape).astype(np.float32)
32+
self.place = get_places()
33+
self.axis = 0
34+
self.value_np = np.random.randint(0, 10, (10, 10)).astype(np.float32)
35+
self.value_shape = [10, 10]
36+
37+
def test_inplace_dygraph(self):
38+
def run(place):
39+
paddle.disable_static(place)
40+
x_tensor = paddle.to_tensor(self.x_np)
41+
index_tensor = paddle.to_tensor(self.index_np)
42+
value_tensor = paddle.to_tensor(self.value_np)
43+
44+
x_tensor.scatter_add_(self.axis, index_tensor, value_tensor)
45+
46+
out_ref = copy.deepcopy(self.x_np)
47+
for i in range(10):
48+
for j in range(10):
49+
out_ref[self.index_np[i, j], j] += self.value_np[i, j]
50+
51+
np.testing.assert_allclose(x_tensor.numpy(), out_ref, rtol=0.001)
52+
53+
paddle.enable_static()
54+
55+
for place in self.place:
56+
run(place)
57+
58+
59+
@unittest.skipIf(
60+
not core.is_compiled_with_cuda(),
61+
"core is not compiled with CUDA",
62+
)
63+
class TestScatterAddInplaceAPILargeCase(unittest.TestCase):
64+
def setUp(self):
65+
np.random.seed(0)
66+
self.shape = [64, 102400]
67+
self.index_shape = [64, 102400]
68+
self.index_np = np.random.randint(0, 64, (64, 102400)).astype('int64')
69+
self.x_np = np.random.random(self.shape).astype(np.float32)
70+
self.axis = 1
71+
self.value_np = np.random.randint(0, 50, (64, 102400)).astype(
72+
np.float32
73+
)
74+
self.place = [paddle.CUDAPlace(0)]
75+
76+
def test_inplace_dygraph(self):
77+
def run(place):
78+
paddle.disable_static(place)
79+
x_tensor = paddle.to_tensor(self.x_np)
80+
index_tensor = paddle.to_tensor(self.index_np)
81+
value_tensor = paddle.to_tensor(self.value_np)
82+
83+
x_tensor.scatter_add_(self.axis, index_tensor, value_tensor)
84+
85+
out_ref = copy.deepcopy(self.x_np)
86+
for i in range(64):
87+
for j in range(102400):
88+
out_ref[i, self.index_np[i, j]] += self.value_np[i, j]
89+
90+
np.testing.assert_allclose(x_tensor.numpy(), out_ref, rtol=0.001)
91+
92+
paddle.enable_static()
93+
94+
for place in self.place:
95+
run(place)
96+
97+
98+
class TestScatterAddInplaceAPIOtherCase(unittest.TestCase):
99+
def setUp(self):
100+
np.random.seed(0)
101+
self.shape = [3, 5]
102+
self.index1_shape = [1, 4]
103+
self.index_np1 = np.array([[0, 1, 2, 0]]).astype('int64')
104+
self.index2_shape = [2, 3]
105+
self.index_np2 = np.array([[0, 1, 2], [0, 1, 4]]).astype('int64')
106+
self.x_np = np.zeros((3, 5)).astype(np.float32)
107+
self.value_shape = [2, 5]
108+
self.value = (
109+
np.arange(1, 11).reshape(self.value_shape).astype(np.float32)
110+
)
111+
self.place = get_places()
112+
113+
def test_api_dygraph(self):
114+
def run_inplace(place):
115+
paddle.disable_static(place)
116+
out1 = paddle.to_tensor(self.x_np)
117+
index_tensor1 = paddle.to_tensor(self.index_np1)
118+
value_tensor = paddle.to_tensor(self.value)
119+
out1.scatter_add_(0, index_tensor1, value_tensor)
120+
out_ref = copy.deepcopy(self.x_np)
121+
for i in range(self.index1_shape[0]):
122+
for j in range(self.index1_shape[1]):
123+
out_ref[self.index_np1[i, j], j] += self.value[i, j]
124+
np.testing.assert_allclose(out1.numpy(), out_ref, rtol=0.001)
125+
126+
index_tensor2 = paddle.to_tensor(self.index_np2)
127+
out2 = paddle.to_tensor(self.x_np)
128+
out2.scatter_add_(1, index_tensor2, value_tensor)
129+
out_ref = copy.deepcopy(self.x_np)
130+
for i in range(self.index2_shape[0]):
131+
for j in range(self.index2_shape[1]):
132+
out_ref[i, self.index_np2[i, j]] += self.value[i, j]
133+
np.testing.assert_allclose(out2.numpy(), out_ref, rtol=0.001)
134+
135+
paddle.enable_static()
136+
137+
for place in self.place:
138+
run_inplace(place)
139+
140+
def test_error(self):
141+
tensorx = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]).astype("float32")
142+
indices = paddle.to_tensor([[1, 0, 1], [0, 1, 1]]).astype("int32")
143+
values = paddle.to_tensor([1])
144+
145+
try:
146+
tensorx.scatter_add_(0, indices, values)
147+
except Exception as error:
148+
self.assertIsInstance(error, ValueError)
149+
150+
indices = paddle.to_tensor([1]).astype("int32")
151+
values = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
152+
153+
try:
154+
tensorx.scatter_add_(0, indices, values)
155+
except Exception as error:
156+
self.assertIsInstance(error, ValueError)
157+
158+
indices = paddle.to_tensor(
159+
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
160+
).astype("int32")
161+
# indices too large
162+
try:
163+
tensorx.scatter_add_(0, indices, values)
164+
except Exception as error:
165+
self.assertIsInstance(error, RuntimeError)
166+
167+
indices = paddle.to_tensor([[3, 0, 4], [0, 5, 10]]).astype("int32")
168+
# the element of indices out of range
169+
try:
170+
tensorx.scatter_add_(0, indices, values)
171+
except Exception as error:
172+
self.assertIsInstance(error, RuntimeError)
173+
174+
def test_index_type_error(self):
175+
tensorx = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]).astype("float32")
176+
indices = paddle.to_tensor([[1, 0, 1], [0, 1, 1]]).astype("float32")
177+
values = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
178+
with self.assertRaises(TypeError):
179+
tensorx.scatter_add_(0, indices, values)
180+
181+
182+
if __name__ == "__main__":
183+
paddle.enable_static()
184+
unittest.main()

0 commit comments

Comments
 (0)