Skip to content

Commit c522530

Browse files
authored
fix safe bug of scatter/scatter_nd (#33858)
* fix safe bug of scatter/scatter_nd
1 parent 57aabba commit c522530

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

paddle/fluid/operators/scatter.cu.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ __global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output,
3333
int indices_i = i / slice_size;
3434
int slice_i = i - indices_i * slice_size; // offset inside the slice
3535
IndexT scatter_i = indices[indices_i];
36+
37+
PADDLE_ENFORCE(scatter_i >= 0,
38+
"The index is out of bounds, "
39+
"please check whether the dimensions of index and "
40+
"input meet the requirements. It should "
41+
"be greater than or equal to 0, but received [%d]",
42+
scatter_i);
43+
3644
IndexT out_i = scatter_i * slice_size + slice_i;
3745
*(output + out_i) = static_cast<T>(0);
3846
}
@@ -46,6 +54,14 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
4654
int indices_i = i / slice_size;
4755
int slice_i = i - indices_i * slice_size; // offset inside the slice
4856
IndexT scatter_i = indices[indices_i];
57+
58+
PADDLE_ENFORCE(scatter_i >= 0,
59+
"The index is out of bounds, "
60+
"please check whether the dimensions of index and "
61+
"input meet the requirements. It should "
62+
"be greater than or equal to 0, but received [%d]",
63+
scatter_i);
64+
4965
IndexT out_i = scatter_i * slice_size + slice_i;
5066
if (overwrite) {
5167
*(output + out_i) = *(params + i);
@@ -67,6 +83,15 @@ __global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices,
6783
int64_t temp = slice_size;
6884
for (int64_t j = end_size - 1; j >= 0; --j) {
6985
IndexT index_value = indices[indices_i * end_size + j];
86+
87+
PADDLE_ENFORCE(
88+
index_value >= 0 && index_value < output_dims[j],
89+
"The index is out of bounds, "
90+
"please check whether the dimensions of index and "
91+
"input meet the requirements. It should "
92+
"be less than [%d] and greater or equal to 0, but received [%d]",
93+
output_dims[j], index_value);
94+
7095
gather_i += (index_value * temp);
7196
temp *= output_dims[j];
7297
}

paddle/fluid/operators/scatter.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,15 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
118118

119119
for (int i = 0; i < index_size; ++i) {
120120
IndexT index_ = p_index[i];
121+
122+
PADDLE_ENFORCE_GE(index_, 0,
123+
platform::errors::OutOfRange(
124+
"The index is out of bounds, "
125+
"please check whether the dimensions of index and "
126+
"input meet the requirements. It should "
127+
"be greater than or equal to 0, but received [%d]",
128+
index_));
129+
121130
memcpy(p_output + index_ * slice_size, p_src + i * slice_size, slice_bytes);
122131
}
123132
}
@@ -173,6 +182,15 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
173182
// if not in overwrite mode, need to init output data
174183
for (int i = 0; i < index_size; ++i) {
175184
const IndexT& index_ = p_index[i];
185+
186+
PADDLE_ENFORCE_GE(index_, 0,
187+
platform::errors::OutOfRange(
188+
"The index is out of bounds, "
189+
"please check whether the dimensions of index and "
190+
"input meet the requirements. It should "
191+
"be greater than or equal to 0, but received [%d]",
192+
index_));
193+
176194
elementwise_inner_add<T, IndexT>(ctx, p_src, p_output, result_p_output, src,
177195
output, i, index_, slice_size,
178196
slice_bytes);
@@ -233,6 +251,15 @@ void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update,
233251
IndexT temp = 1;
234252
for (int64_t j = end_size - 1; j >= 0; --j) {
235253
IndexT index_value = p_index[i * end_size + j];
254+
PADDLE_ENFORCE_EQ(
255+
(index_value >= 0 && index_value < output_dims[j]), true,
256+
platform::errors::OutOfRange(
257+
"The index is out of bounds, "
258+
"please check whether the dimensions of index and "
259+
"input meet the requirements. It should "
260+
"be less than [%d] and greater or equal to 0, but received [%d]",
261+
output_dims[j], index_value));
262+
236263
index_ += (index_value * temp);
237264
temp *= output_dims[j];
238265
}

0 commit comments

Comments
 (0)