Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@
scatter_,
scatter_nd,
scatter_nd_add,
scatter_reduce,
select_scatter,
shard_index,
slice,
Expand Down Expand Up @@ -1171,6 +1172,7 @@
'renorm',
'renorm_',
'take_along_axis',
'scatter_reduce',
'put_along_axis',
'select_scatter',
'multigammaln',
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
scatter_,
scatter_nd,
scatter_nd_add,
scatter_reduce,
select_scatter,
shard_index,
slice,
Expand Down Expand Up @@ -796,6 +797,7 @@
'moveaxis',
'repeat_interleave',
'take_along_axis',
'scatter_reduce',
'put_along_axis',
'select_scatter',
'put_along_axis_',
Expand Down
61 changes: 61 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6786,6 +6786,67 @@ def take_along_axis(
return result


def scatter_reduce(
input: Tensor,
dim: int,
index: Tensor,
src: Tensor,
reduce: Literal['sum', 'prod', 'mean', 'amin', 'amax'],
include_self: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch这个是指定关键字用法,我们也按这样放到*后面去

) -> Tensor:
"""
Scatter the values of the source tensor to the target tensor according to the given indices, and perform a reduction operation along the designated axis.

Args:
input (Tensor) : The Input Tensor. Supported data types are bfloat16, float16, float32, float64,
int32, int64, uint8.
dim (int) : The axis to scatter 1d slices along.
index (Tensor) : Indices to scatter along each 1d slice of input. This must match the dimension of input,
Supported data type are int32 and int64.
src (Tensor) : The value element(s) to scatter. The data types should be same as input.
reduce (str): The reduce operation, support 'sum', 'prod', 'mean', 'amin', 'amax'.
include_self (bool, optional): whether to reduce with the elements of input, default is 'True'.

Returns:
Tensor, The indexed element, same dtype with input

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([[10, 20, 30], [40, 50, 60]])
>>> indices = paddle.zeros((2,3)).astype("int32")
>>> values = paddle.to_tensor([[1, 2, 3],[4, 5, 6]]).astype(x.dtype)
>>> result = paddle.scatter_reduce(x, 0, indices, values, "sum", True)
>>> print(result)
Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
[[15, 27, 39],
[40, 50, 60]])

>>> result = paddle.scatter_reduce(x, 0, indices, values, "prod", True)
>>> print(result)
Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
[[40 , 200, 540],
[40 , 50 , 60 ]])

>>> result = paddle.scatter_reduce(x, 0, indices, values, "mean", True)
>>> print(result)
Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
[[5 , 9 , 13],
[40, 50, 60]])

"""

if reduce == 'sum':
reduce = 'add'
if reduce == 'prod':
reduce = 'multiply'
return put_along_axis(
input, index, src, dim, reduce, include_self, broadcast=False
)


def put_along_axis(
arr: Tensor,
indices: Tensor,
Expand Down
Loading
Loading