Skip to content

Commit 0eaff6f

Browse files
committed
Fix
1 parent 29d0104 commit 0eaff6f

16 files changed

+251
-14
lines changed

paddle/phi/infermeta/ternary.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2332,12 +2332,14 @@ void ScatterInferMeta(const MetaTensor& x,
23322332
const auto& index_dims = index.dims();
23332333

23342334
if (index_dims.size() == 2) {
2335-
PADDLE_ENFORCE_EQ(index_dims[1],
2336-
1,
2337-
common::errors::InvalidArgument(
2338-
"The last dim of the index should be 1 when the "
2339-
"index is a 2D tensor, but we get %d.",
2340-
index_dims[1]));
2335+
if (index_dims[1] != 0) {
2336+
PADDLE_ENFORCE_EQ(index_dims[1],
2337+
1,
2338+
common::errors::InvalidArgument(
2339+
"The last dim of the index should be 1 when the "
2340+
"index is a 2D tensor, but we get %d.",
2341+
index_dims[1]));
2342+
}
23412343
} else {
23422344
PADDLE_ENFORCE_EQ(index_dims.size() == 1 || index_dims.size() == 0,
23432345
true,

paddle/phi/kernels/cpu/scatter_grad_kernel.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#include "paddle/phi/core/tensor_utils.h"
20+
#include "paddle/phi/kernels/full_kernel.h"
2021
#include "paddle/phi/kernels/funcs/gather.h"
2122
#include "paddle/phi/kernels/funcs/scatter.h"
22-
2323
namespace phi {
2424

2525
template <typename T, typename Context>
@@ -30,6 +30,19 @@ void ScatterGradKernel(const Context &dev_ctx,
3030
bool overwrite UNUSED,
3131
DenseTensor *x_grad,
3232
DenseTensor *updates_grad) {
33+
if (out_grad.numel() == 0) {
34+
if (x_grad) {
35+
dev_ctx.template Alloc<T>(x_grad);
36+
}
37+
if (updates_grad) {
38+
phi::Full<T, Context>(
39+
dev_ctx,
40+
phi::IntArray(common::vectorize(updates_grad->dims())),
41+
0,
42+
updates_grad);
43+
}
44+
return;
45+
}
3346
const auto &index_type = index.dtype();
3447
bool index_type_match =
3548
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;

paddle/phi/kernels/cpu/scatter_kernel.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ void ScatterKernel(const Context &dev_ctx,
2828
const DenseTensor &updates,
2929
bool overwrite,
3030
DenseTensor *out) {
31+
if (index.numel() == 0) {
32+
dev_ctx.template Alloc<T>(out);
33+
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
34+
return;
35+
}
36+
if (out && out->numel() == 0) {
37+
dev_ctx.template Alloc<T>(out);
38+
return;
39+
}
3140
// In place output: Out = X, Out[Ids] = Updates
3241
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
3342
// Apply ScatterUpdate: Out[index] = Updates[:]

paddle/phi/kernels/cpu/scatter_nd_add_grad_kernel.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#include "paddle/phi/core/tensor_utils.h"
20+
#include "paddle/phi/kernels/full_kernel.h"
2021
#include "paddle/phi/kernels/funcs/gather.h"
21-
2222
namespace phi {
2323

2424
template <typename T, typename Context>
@@ -28,6 +28,19 @@ void ScatterNdAddGradKernel(const Context &dev_ctx,
2828
const DenseTensor &out_grad,
2929
DenseTensor *x_grad,
3030
DenseTensor *updates_grad) {
31+
if (out_grad.numel() == 0) {
32+
if (x_grad) {
33+
dev_ctx.template Alloc<T>(x_grad);
34+
}
35+
if (updates_grad) {
36+
phi::Full<T, Context>(
37+
dev_ctx,
38+
phi::IntArray(common::vectorize(updates_grad->dims())),
39+
0,
40+
updates_grad);
41+
}
42+
return;
43+
}
3144
if (x_grad) {
3245
Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
3346
}

paddle/phi/kernels/cpu/scatter_nd_add_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ void ScatterNdAddKernel(const Context &dev_ctx,
2727
const DenseTensor &index,
2828
const DenseTensor &updates,
2929
DenseTensor *out) {
30+
if (out && out->numel() == 0) {
31+
dev_ctx.template Alloc<T>(out);
32+
return;
33+
}
3034
// In place output: Out = X
3135
Copy(dev_ctx, x, dev_ctx.GetPlace(), true, out);
3236
const auto &index_type = index.dtype();

paddle/phi/kernels/gpu/scatter_grad_kernel.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "paddle/phi/common/bfloat16.h"
1919
#include "paddle/phi/core/kernel_registry.h"
2020
#include "paddle/phi/core/tensor_utils.h"
21+
#include "paddle/phi/kernels/full_kernel.h"
2122
#include "paddle/phi/kernels/funcs/gather.cu.h"
2223
#include "paddle/phi/kernels/funcs/scatter.cu.h"
2324

@@ -31,6 +32,19 @@ void ScatterGradKernel(const Context &dev_ctx,
3132
bool overwrite,
3233
DenseTensor *x_grad,
3334
DenseTensor *updates_grad) {
35+
if (out_grad.numel() == 0) {
36+
if (x_grad) {
37+
dev_ctx.template Alloc<T>(x_grad);
38+
}
39+
if (updates_grad) {
40+
phi::Full<T, Context>(
41+
dev_ctx,
42+
phi::IntArray(common::vectorize(updates_grad->dims())),
43+
0,
44+
updates_grad);
45+
}
46+
return;
47+
}
3448
auto index_type = index.dtype();
3549
bool index_type_match =
3650
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;

paddle/phi/kernels/gpu/scatter_kernel.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ void ScatterKernel(const Context &dev_ctx,
2929
const DenseTensor &updates,
3030
bool overwrite,
3131
DenseTensor *out) {
32+
if (index.numel() == 0) {
33+
dev_ctx.template Alloc<T>(out);
34+
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
35+
return;
36+
}
37+
if (out && out->numel() == 0) {
38+
dev_ctx.template Alloc<T>(out);
39+
return;
40+
}
3241
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
3342
// use template class to support int32_t and int64_t
3443
auto index_type = index.dtype();

paddle/phi/kernels/gpu/scatter_nd_add_grad_kernel.cu

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
#include "paddle/phi/common/bfloat16.h"
1919
#include "paddle/phi/core/kernel_registry.h"
2020
#include "paddle/phi/core/tensor_utils.h"
21+
#include "paddle/phi/kernels/full_kernel.h"
2122
#include "paddle/phi/kernels/funcs/gather.cu.h"
22-
2323
namespace phi {
2424

2525
template <typename T, typename Context>
@@ -29,6 +29,19 @@ void ScatterNdAddGradKernel(const Context &dev_ctx,
2929
const DenseTensor &out_grad,
3030
DenseTensor *x_grad,
3131
DenseTensor *updates_grad) {
32+
if (out_grad.numel() == 0) {
33+
if (x_grad) {
34+
dev_ctx.template Alloc<T>(x_grad);
35+
}
36+
if (updates_grad) {
37+
phi::Full<T, Context>(
38+
dev_ctx,
39+
phi::IntArray(common::vectorize(updates_grad->dims())),
40+
0,
41+
updates_grad);
42+
}
43+
return;
44+
}
3245
if (x_grad) {
3346
Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
3447
}

paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ void ScatterNdAddKernel(const Context &dev_ctx,
2828
const DenseTensor &index,
2929
const DenseTensor &updates,
3030
DenseTensor *out) {
31+
if (out && out->numel() == 0) {
32+
dev_ctx.template Alloc<T>(out);
33+
return;
34+
}
3135
Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
3236
const auto &index_type = index.dtype();
3337
bool index_type_match =

paddle/phi/kernels/xpu/scatter_grad_kernel.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920

2021
namespace phi {
2122

@@ -27,6 +28,19 @@ void ScatterGradKernel(const Context &dev_ctx,
2728
bool overwrite,
2829
DenseTensor *x_grad,
2930
DenseTensor *updates_grad) {
31+
if (out_grad.numel() == 0) {
32+
if (x_grad) {
33+
dev_ctx.template Alloc<T>(x_grad);
34+
}
35+
if (updates_grad) {
36+
phi::Full<T, Context>(
37+
dev_ctx,
38+
phi::IntArray(common::vectorize(updates_grad->dims())),
39+
0,
40+
updates_grad);
41+
}
42+
return;
43+
}
3044
using XPUType = typename XPUTypeTrait<T>::Type;
3145

3246
const auto &index_type = index.dtype();

0 commit comments

Comments
 (0)