Skip to content

Commit 166ced5

Browse files
committed
fix bug of indexing with ellipsis
1 parent 228eb89 commit 166ced5

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

paddle/fluid/pybind/imperative.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,13 +549,20 @@ static void ParseIndexingSlice(
549549
// specified_dims is the number of dimensions which indexed by Interger,
550550
// Slices.
551551
int specified_dims = 0;
552+
int ell_count = 0;
552553
for (int dim = 0; dim < size; ++dim) {
553554
PyObject *slice_item = PyTuple_GetItem(index, dim);
554555
if (PyCheckInteger(slice_item) || PySlice_Check(slice_item)) {
555556
specified_dims++;
557+
} else if (slice_item == Py_Ellipsis) {
558+
ell_count++;
556559
}
557560
}
558561

562+
PADDLE_ENFORCE_LE(ell_count, 1,
563+
platform::errors::InvalidArgument(
564+
"An index can only have a single ellipsis ('...')"));
565+
559566
for (int i = 0, dim = 0; i < size; ++i) {
560567
PyObject *slice_item = PyTuple_GetItem(index, i);
561568

@@ -660,7 +667,7 @@ static void ParseIndexingSlice(
660667
}
661668

662669
// valid_index is the number of dimensions exclude None index
663-
const int valid_indexs = size - none_axes->size();
670+
const int valid_indexs = size - none_axes->size() - ell_count;
664671
PADDLE_ENFORCE_EQ(valid_indexs <= rank, true,
665672
platform::errors::InvalidArgument(
666673
"Too many indices (%d) for tensor of dimension %d.",

python/paddle/fluid/tests/unittests/test_var_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,11 @@ def assert_getitem_ellipsis_index(var_tensor, var_np):
702702
assert_getitem_ellipsis_index(var_fp32, np_fp32_value)
703703
assert_getitem_ellipsis_index(var_int, np_int_value)
704704

705+
# test 1 dim tensor
706+
var_one_dim = paddle.to_tensor([1, 2, 3, 4])
707+
self.assertTrue(
708+
np.array_equal(var_one_dim[..., 0].numpy(), np.array([1])))
709+
705710
def _test_none_index(self):
706711
shape = (8, 64, 5, 256)
707712
np_value = np.random.random(shape).astype('float32')

python/paddle/fluid/tests/unittests/test_variable.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,19 +226,22 @@ def _test_slice_index_ellipsis(self, place):
226226
prog = paddle.static.Program()
227227
with paddle.static.program_guard(prog):
228228
x = paddle.assign(data)
229+
y = paddle.assign([1, 2, 3, 4])
229230
out1 = x[0:, ..., 1:]
230231
out2 = x[0:, ...]
231232
out3 = x[..., 1:]
232233
out4 = x[...]
233234
out5 = x[[1, 0], [0, 0]]
234235
out6 = x[([1, 0], [0, 0])]
236+
out7 = y[..., 0]
235237

236238
exe = paddle.static.Executor(place)
237-
result = exe.run(prog, fetch_list=[out1, out2, out3, out4, out5, out6])
239+
result = exe.run(prog,
240+
fetch_list=[out1, out2, out3, out4, out5, out6, out7])
238241

239242
expected = [
240243
data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...],
241-
data[[1, 0], [0, 0]], data[([1, 0], [0, 0])]
244+
data[[1, 0], [0, 0]], data[([1, 0], [0, 0])], np.array([1])
242245
]
243246

244247
self.assertTrue((result[0] == expected[0]).all())
@@ -247,6 +250,7 @@ def _test_slice_index_ellipsis(self, place):
247250
self.assertTrue((result[3] == expected[3]).all())
248251
self.assertTrue((result[4] == expected[4]).all())
249252
self.assertTrue((result[5] == expected[5]).all())
253+
self.assertTrue((result[6] == expected[6]).all())
250254

251255
with self.assertRaises(IndexError):
252256
res = x[[1.2, 0]]

0 commit comments

Comments
 (0)