Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2332,12 +2332,14 @@ void ScatterInferMeta(const MetaTensor& x,
const auto& index_dims = index.dims();

if (index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(index_dims[1],
1,
common::errors::InvalidArgument(
"The last dim of the index should be 1 when the "
"index is a 2D tensor, but we get %d.",
index_dims[1]));
if (index_dims[1] != 0) {
PADDLE_ENFORCE_EQ(index_dims[1],
1,
common::errors::InvalidArgument(
"The last dim of the index should be 1 when the "
"index is a 2D tensor, but we get %d.",
index_dims[1]));
}
} else {
PADDLE_ENFORCE_EQ(index_dims.size() == 1 || index_dims.size() == 0,
true,
Expand Down
15 changes: 14 additions & 1 deletion paddle/phi/kernels/cpu/scatter_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/phi/kernels/funcs/scatter.h"

namespace phi {

template <typename T, typename Context>
Expand All @@ -30,6 +30,19 @@ void ScatterGradKernel(const Context &dev_ctx,
bool overwrite UNUSED,
DenseTensor *x_grad,
DenseTensor *updates_grad) {
if (out_grad.numel() == 0) {
if (x_grad) {
dev_ctx.template Alloc<T>(x_grad);
}
if (updates_grad) {
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(updates_grad->dims())),
0,
updates_grad);
}
return;
}
const auto &index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/cpu/scatter_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ void ScatterKernel(const Context &dev_ctx,
const DenseTensor &updates,
bool overwrite,
DenseTensor *out) {
if (index.numel() == 0) {
dev_ctx.template Alloc<T>(out);
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
return;
}
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
// In place output: Out = X, Out[Ids] = Updates
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
// Apply ScatterUpdate: Out[index] = Updates[:]
Expand Down
15 changes: 14 additions & 1 deletion paddle/phi/kernels/cpu/scatter_nd_add_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/gather.h"

namespace phi {

template <typename T, typename Context>
Expand All @@ -28,6 +28,19 @@ void ScatterNdAddGradKernel(const Context &dev_ctx,
const DenseTensor &out_grad,
DenseTensor *x_grad,
DenseTensor *updates_grad) {
if (out_grad.numel() == 0) {
if (x_grad) {
dev_ctx.template Alloc<T>(x_grad);
}
if (updates_grad) {
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(updates_grad->dims())),
0,
updates_grad);
}
return;
}
if (x_grad) {
Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
}
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/cpu/scatter_nd_add_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ void ScatterNdAddKernel(const Context &dev_ctx,
const DenseTensor &index,
const DenseTensor &updates,
DenseTensor *out) {
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
// In place output: Out = X
Copy(dev_ctx, x, dev_ctx.GetPlace(), true, out);
const auto &index_type = index.dtype();
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/kernels/gpu/scatter_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"

Expand All @@ -31,6 +32,19 @@ void ScatterGradKernel(const Context &dev_ctx,
bool overwrite,
DenseTensor *x_grad,
DenseTensor *updates_grad) {
if (out_grad.numel() == 0) {
if (x_grad) {
dev_ctx.template Alloc<T>(x_grad);
}
if (updates_grad) {
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(updates_grad->dims())),
0,
updates_grad);
}
return;
}
auto index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/gpu/scatter_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ void ScatterKernel(const Context &dev_ctx,
const DenseTensor &updates,
bool overwrite,
DenseTensor *out) {
if (index.numel() == 0) {
dev_ctx.template Alloc<T>(out);
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
return;
}
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
// use template class to support int32_t and int64_t
auto index_type = index.dtype();
Expand Down
15 changes: 14 additions & 1 deletion paddle/phi/kernels/gpu/scatter_nd_add_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"

namespace phi {

template <typename T, typename Context>
Expand All @@ -29,6 +29,19 @@ void ScatterNdAddGradKernel(const Context &dev_ctx,
const DenseTensor &out_grad,
DenseTensor *x_grad,
DenseTensor *updates_grad) {
if (out_grad.numel() == 0) {
if (x_grad) {
dev_ctx.template Alloc<T>(x_grad);
}
if (updates_grad) {
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(updates_grad->dims())),
0,
updates_grad);
}
return;
}
if (x_grad) {
Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
}
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ void ScatterNdAddKernel(const Context &dev_ctx,
const DenseTensor &index,
const DenseTensor &updates,
DenseTensor *out) {
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
const auto &index_type = index.dtype();
bool index_type_match =
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/kernels/xpu/scatter_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"

namespace phi {

Expand All @@ -27,6 +28,19 @@ void ScatterGradKernel(const Context &dev_ctx,
bool overwrite,
DenseTensor *x_grad,
DenseTensor *updates_grad) {
if (out_grad.numel() == 0) {
if (x_grad) {
dev_ctx.template Alloc<T>(x_grad);
}
if (updates_grad) {
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(updates_grad->dims())),
0,
updates_grad);
}
return;
}
using XPUType = typename XPUTypeTrait<T>::Type;

const auto &index_type = index.dtype();
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/xpu/scatter_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ void ScatterKernel(const Context &dev_ctx,
const DenseTensor &updates,
bool overwrite,
DenseTensor *out) {
if (index.numel() == 0) {
dev_ctx.template Alloc<T>(out);
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
return;
}
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
using XPUTypeT = typename XPUTypeTrait<T>::Type;
out->Resize(x.dims());
auto *x_data = reinterpret_cast<const XPUTypeT *>(x.data<T>());
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"

namespace phi {
template <typename T, typename Context>
Expand All @@ -25,6 +26,19 @@ void ScatterNdAddGradKernel(const Context &dev_ctx,
const DenseTensor &out_grad,
DenseTensor *x_grad,
DenseTensor *updates_grad) {
if (out_grad.numel() == 0) {
if (x_grad) {
dev_ctx.template Alloc<T>(x_grad);
}
if (updates_grad) {
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(updates_grad->dims())),
0,
updates_grad);
}
return;
}
using XPUType = typename XPUTypeTrait<T>::Type;
int ret = 0;
const T *out_grad_data = out_grad.data<T>();
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ void ScatterNdAddKernel(const Context &dev_ctx,
const DenseTensor &index,
const DenseTensor &updates,
DenseTensor *out) {
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
const T *x_ptr = x.data<T>();
const T *updates_ptr = updates.data<T>();

Expand Down
3 changes: 3 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4374,6 +4374,9 @@ def scatter_nd_add(
f"x and updates must have same data type but x.dtype={convert_dtype(x.dtype)}, updates.dtype={convert_dtype(updates.dtype)}"
)

if in_dynamic_mode():
if index.size == 0:
return x.clone() + updates
if in_dynamic_or_pir_mode():
return _C_ops.scatter_nd_add(x, index, updates)
else:
Expand Down
68 changes: 63 additions & 5 deletions test/legacy_test/test_scatter_nd_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy as np
from op_test import OpTest, convert_float_to_uint16
from op_test import OpTest, convert_float_to_uint16, get_places
from utils import static_guard

import paddle
Expand Down Expand Up @@ -183,12 +183,12 @@ def setUp(self):
def _set_dtype(self):
self.dtype = np.float64

def test_check_output(self):
def _test_check_output(self):
self.check_output(
check_cinn=True, check_pir=True, check_symbol_infer=False
)

def test_check_grad(self):
def _test_check_grad(self):
self.check_grad(
['X', 'Updates'],
'Out',
Expand Down Expand Up @@ -220,12 +220,12 @@ class TestScatterNdAddWithEmptyIndexBF16(TestScatterNdAddWithEmptyIndex):
def _set_dtype(self):
self.dtype = np.uint16

def test_check_output(self):
def _test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_pir=True)

def test_check_grad(self):
def _test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
Expand Down Expand Up @@ -581,6 +581,64 @@ def test_dygraph_1(self):
output = paddle.scatter_nd_add(x, index, updates)


class TestScatterNd_ZeroSize(unittest.TestCase):
def test_dygraph(self):
for place in get_places():
with base.dygraph.guard(place):
index_data = np.random.random([0, 1])
index = paddle.to_tensor(index_data)
index.stop_gradient = False
updates = paddle.rand(shape=[4], dtype='float32')
updates.stop_gradient = False
shape = [4]
output = paddle.scatter_nd(index, updates, shape)
np.testing.assert_allclose(output.numpy(), updates.numpy())
output.sum().backward()
np.testing.assert_allclose(updates.grad.numpy(), np.ones([4]))


class TestScatterNdAdd_ZeroSize(unittest.TestCase):
def test_dygraph(self):
for place in get_places():
with base.dygraph.guard(place):
# x 0-size
x = paddle.randn([0, 2, 3])
x.stop_gradient = False
index_data = np.random.random([2, 3])
index = paddle.to_tensor(index_data)
updates = paddle.rand(shape=[2], dtype='float32')
updates.stop_gradient = False
output = paddle.scatter_nd_add(x, index, updates)
np.testing.assert_allclose(output.numpy(), x.numpy())
output.sum().backward()
np.testing.assert_allclose(x.grad.numpy(), np.zeros(x.shape))
np.testing.assert_allclose(
updates.grad.numpy(), np.zeros(updates.shape)
)


class TestScatterNdAdd_ZeroSize2(unittest.TestCase):
def test_dygraph(self):
for place in get_places():
with base.dygraph.guard(place):
# index 0-size
x = paddle.randn([1, 2])
x.stop_gradient = False
index_data = np.random.random([0, 3])
index = paddle.to_tensor(index_data)
updates = paddle.rand(shape=[1, 2], dtype='float32')
updates.stop_gradient = False
output = paddle.scatter_nd_add(x, index, updates)
np.testing.assert_allclose(
output.numpy(), (x + updates).numpy()
)
output.sum().backward()
np.testing.assert_allclose(x.grad.numpy(), np.ones(x.shape))
np.testing.assert_allclose(
updates.grad.numpy(), np.ones(updates.shape)
)


if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Loading
Loading