diff --git a/paddle/phi/kernels/cpu/take_along_axis_kernel.cc b/paddle/phi/kernels/cpu/take_along_axis_kernel.cc index 2621fe9dad0e7d..45af898c38d809 100644 --- a/paddle/phi/kernels/cpu/take_along_axis_kernel.cc +++ b/paddle/phi/kernels/cpu/take_along_axis_kernel.cc @@ -18,6 +18,7 @@ #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/gather_scatter_functor.h" namespace phi { @@ -28,6 +29,16 @@ void TakeAlongAxisKernel(const Context& dev_ctx, const DenseTensor& index, int axis, DenseTensor* out) { + if (index.numel() == 0) { + dev_ctx.template Alloc(out); + return; + } + if (x.numel() == 0) { + phi::Full( + dev_ctx, common::vectorize(out->dims()), static_cast(0), out); + return; + } + out->Resize(index.dims()); dev_ctx.template Alloc(out); diff --git a/paddle/phi/kernels/gpu/take_along_axis_kernel.cu b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu index d4bbeb6320abe3..b2d123bbe26576 100644 --- a/paddle/phi/kernels/gpu/take_along_axis_kernel.cu +++ b/paddle/phi/kernels/gpu/take_along_axis_kernel.cu @@ -18,6 +18,7 @@ #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/gather_scatter_functor.h" namespace phi { @@ -28,6 +29,16 @@ void TakeAlongAxisKernel(const Context& dev_ctx, const DenseTensor& index, int axis, DenseTensor* out) { + if (index.numel() == 0) { + dev_ctx.template Alloc(out); + return; + } + if (x.numel() == 0) { + phi::Full( + dev_ctx, common::vectorize(out->dims()), static_cast(0), out); + return; + } + out->Resize(index.dims()); dev_ctx.template Alloc(out); diff --git a/paddle/phi/kernels/xpu/take_along_axis_kernel.cc b/paddle/phi/kernels/xpu/take_along_axis_kernel.cc index 31dcfb16d179f1..6ab824e644d07c 100644 --- a/paddle/phi/kernels/xpu/take_along_axis_kernel.cc +++ b/paddle/phi/kernels/xpu/take_along_axis_kernel.cc @@ -19,6 +19,7 @@ #include "paddle/common/layout.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" namespace phi { @@ -28,6 +29,16 @@ void TakeAlongAxisKernel(const Context& dev_ctx, const DenseTensor& index, int axis, DenseTensor* out) { + if (index.numel() == 0) { + dev_ctx.template Alloc(out); + return; + } + if (x.numel() == 0) { + phi::Full( + dev_ctx, common::vectorize(out->dims()), static_cast(0), out); + return; + } + out->Resize(index.dims()); dev_ctx.template Alloc(out); diff --git a/test/legacy_test/test_take_along_axis_op.py b/test/legacy_test/test_take_along_axis_op.py index fba92803d382ed..f14649e497cbd6 100644 --- a/test/legacy_test/test_take_along_axis_op.py +++ b/test/legacy_test/test_take_along_axis_op.py @@ -26,6 +26,58 @@ paddle.enable_static() +class TestTakeAlongAxis0Size(OpTest): + def setUp(self): + self.python_api = paddle.take_along_axis + self.op_type = "take_along_axis" + self.dtype = "float64" + self.check_pir = True + + x = np.zeros((2, 0, 5)).astype(self.dtype) + indices = np.zeros((2, 3, 5)).astype("int64") + + self.inputs = {'Input': x, 'Index': indices} + self.attrs = {'Axis': 1} + + output = np.zeros((2, 3, 5)).astype(self.dtype) + self.outputs = {'Result': output} + + def test_check_output(self): + self.check_output(check_pir=self.check_pir) + + def test_check_grad(self): + self.check_grad(['Input'], 'Result', check_pir=self.check_pir) + + +class TestTakeAlongAxis0Size2(OpTest): + def setUp(self): + self.python_api = paddle.take_along_axis + self.op_type = "take_along_axis" + self.dtype = "float64" + self.check_pir = True + + x = np.random.rand(2, 3, 5).astype(self.dtype) + indices = np.zeros((2, 0, 5)).astype("int64") + + self.inputs = {'Input': x, 'Index': indices} + self.attrs = {'Axis': 1} + + output = np.zeros((2, 0, 5)).astype(self.dtype) + self.outputs = {'Result': output} + + def test_check_output(self): + self.check_output(check_pir=self.check_pir) + + def test_check_grad(self): + self.grad = np.zeros_like(self.outputs['Result']).astype(self.dtype) + self.check_grad( + ['Input'], + 'Result', + user_defined_grads=[self.grad], + check_pir=self.check_pir, + ) + + class TestTakeAlongAxisOp(OpTest): def setUp(self): self.init_data()