Skip to content

Commit a0bbc99

Browse files
authored
Support getitem by None index in dynamic mode (#34338)
* Support getitem by ellipsis index in dynamic mode * change some code style * Support getitem by none index in dynamic mode * modify a comments style and remove useless code
1 parent df27c26 commit a0bbc99

File tree

2 files changed

+103
-16
lines changed

2 files changed

+103
-16
lines changed

paddle/fluid/pybind/imperative.cc

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
420420
std::vector<int> *slice_ends,
421421
std::vector<int> *slice_strides,
422422
std::vector<int> *decrease_axis,
423+
std::vector<int> *none_axes,
423424
std::vector<int> *infer_flags) {
424425
// We allow indexing by Integers, Slices, and tuples of those
425426
// types.
@@ -443,10 +444,6 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
443444
}
444445
}
445446

446-
PADDLE_ENFORCE_EQ(
447-
size <= rank, true,
448-
platform::errors::InvalidArgument(
449-
"too many indices (%d) for tensor of dimension %d", size, rank));
450447
for (int i = 0, dim = 0; i < size; ++i) {
451448
PyObject *slice_item = PyTuple_GetItem(index, i);
452449

@@ -491,14 +488,24 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
491488
dim++;
492489
} else if (slice_item == Py_Ellipsis) {
493490
dim += rank - specified_dims;
491+
} else if (slice_item == Py_None) {
492+
none_axes->push_back(dim);
494493
} else {
495494
PADDLE_THROW(platform::errors::InvalidArgument(
496-
"Currently, VarBase.__getitem__() only allows "
497-
"indexing by Integers, Slices, Ellipsis, and tuples of "
495+
"Currently, VarBase.__getitem__() only allows indexing"
496+
"by Integers, Slices, Ellipsis, None and tuples of "
498497
"these types, but received %s in %dth slice item",
499498
std::string(Py_TYPE(slice_item)->tp_name), i + 1));
500499
}
501500
}
501+
502+
// valid_index is the number of dimensions exclude None index
503+
const int valid_indexs = size - none_axes->size();
504+
PADDLE_ENFORCE_EQ(valid_indexs <= rank, true,
505+
platform::errors::InvalidArgument(
506+
"Too many indices (%d) for tensor of dimension %d.",
507+
valid_indexs, rank));
508+
502509
if (!PyTuple_Check(_index)) Py_DecRef(index);
503510
}
504511

@@ -790,9 +797,10 @@ void BindImperative(py::module *m_ptr) {
790797
// copys data to cpu place, which reduces performance.
791798
if (parse_index && value_is_tensor) {
792799
std::vector<int> axes, starts, ends, steps, decrease_axes,
793-
infer_flags;
800+
none_axes, infer_flags;
794801
ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends,
795-
&steps, &decrease_axes, &infer_flags);
802+
&steps, &decrease_axes, &none_axes,
803+
&infer_flags);
796804

797805
framework::AttributeMap attrs = {
798806
{"axes", axes},
@@ -850,27 +858,29 @@ void BindImperative(py::module *m_ptr) {
850858
.def("_getitem_index_not_tensor",
851859
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
852860
std::vector<int> slice_axes, slice_starts, slice_ends,
853-
slice_strides, decrease_axis, infer_flags;
861+
slice_strides, decrease_axis, none_axes, infer_flags;
854862
auto tensor =
855863
self->MutableVar()->GetMutable<framework::LoDTensor>();
856864
ParseIndexingSlice(tensor, _index.ptr(), &slice_axes,
857865
&slice_starts, &slice_ends, &slice_strides,
858-
&decrease_axis, &infer_flags);
866+
&decrease_axis, &none_axes, &infer_flags);
859867
// release gil and do tracing
860868
py::gil_scoped_release release;
861869
const auto &tracer = imperative::GetCurrentTracer();
862-
if (slice_axes.empty()) {
863-
return self;
864-
} else {
870+
871+
auto out = slice_axes.empty()
872+
? self
873+
: std::shared_ptr<imperative::VarBase>(
874+
new imperative::VarBase(
875+
tracer->GenerateUniqueName()));
876+
if (!slice_axes.empty()) {
865877
imperative::NameVarBaseMap ins = {{"Input", {self}}};
866878
framework::AttributeMap attrs = {
867879
{"axes", slice_axes},
868880
{"starts", slice_starts},
869881
{"ends", slice_ends},
870882
{"infer_flags", infer_flags},
871883
{"decrease_axis", decrease_axis}};
872-
auto out = std::shared_ptr<imperative::VarBase>(
873-
new imperative::VarBase(tracer->GenerateUniqueName()));
874884
imperative::NameVarBaseMap outs = {{"Out", {out}}};
875885
std::string op_type = "slice";
876886
for (auto stride : slice_strides) {
@@ -882,8 +892,50 @@ void BindImperative(py::module *m_ptr) {
882892
}
883893
}
884894
tracer->TraceOp(op_type, ins, outs, std::move(attrs));
885-
return out;
886895
}
896+
if (!none_axes.empty()) {
897+
// Deal with cases when all axes are decreased.
898+
// After slice, the shape of out is [1], which should have been
899+
// [], but Paddle doesn't support scalar.
900+
// In order to ensure the correctness of the final shape of out,
901+
// one dimension of out needs to be decreased.
902+
// For example:
903+
// # x.shape: (2,3,4)
904+
// out = x[0, 1, 1, None] # out.shape : (1)
905+
if (static_cast<int>(decrease_axis.size()) ==
906+
tensor->dims().size()) {
907+
none_axes.pop_back();
908+
}
909+
if (!none_axes.empty()) {
910+
// Deal with cases that decrease_axes is not empty
911+
// For example:
912+
// # x.shape: (2,3,4)
913+
// out = x[0, 0:2, None] # out.shape : (2, 1, 4)
914+
for (auto &axis : none_axes) {
915+
int len = 0;
916+
for (int da : decrease_axis) {
917+
if (da < axis) {
918+
len++;
919+
}
920+
}
921+
axis -= len;
922+
}
923+
924+
imperative::NameVarBaseMap ins = {{"X", {out}}};
925+
framework::AttributeMap attrs = {{"axes", none_axes}};
926+
auto new_out = std::shared_ptr<imperative::VarBase>(
927+
new imperative::VarBase(tracer->GenerateUniqueName()));
928+
auto out_xshape = std::shared_ptr<imperative::VarBase>(
929+
new imperative::VarBase(tracer->GenerateUniqueName()));
930+
imperative::NameVarBaseMap outs = {{"Out", {new_out}},
931+
{"XShape", {out_xshape}}};
932+
tracer->TraceOp("unsqueeze2", ins, outs, std::move(attrs));
933+
934+
return new_out;
935+
}
936+
}
937+
938+
return out;
887939
})
888940
.def(
889941
"_getitem_from_offset",

python/paddle/fluid/tests/unittests/test_var_base.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,40 @@ def assert_getitem_ellipsis_index(var_tensor, var_np):
689689
assert_getitem_ellipsis_index(var_fp32, np_fp32_value)
690690
assert_getitem_ellipsis_index(var_int, np_int_value)
691691

692+
def _test_none_index(self):
693+
shape = (8, 64, 5, 256)
694+
np_value = np.random.random(shape).astype('float32')
695+
var_tensor = paddle.to_tensor(np_value)
696+
697+
var = [
698+
var_tensor[1, 0, None].numpy(),
699+
var_tensor[None, ..., 1, 0].numpy(),
700+
var_tensor[:, :, :, None].numpy(),
701+
var_tensor[1, ..., 1, None].numpy(),
702+
var_tensor[2, ..., None, None].numpy(),
703+
var_tensor[None, 2, 0, ...].numpy(),
704+
var_tensor[None, 2, None, 1].numpy(),
705+
var_tensor[None].numpy(),
706+
var_tensor[0, 0, None, 0, 0, None].numpy(),
707+
var_tensor[0, 1:10:2, None, None, ...].numpy(),
708+
]
709+
710+
self.assertTrue(np.array_equal(var[0], np_value[1, 0, None]))
711+
self.assertTrue(np.array_equal(var[1], np_value[None, ..., 1, 0]))
712+
self.assertTrue(np.array_equal(var[2], np_value[:, :, :, None]))
713+
self.assertTrue(np.array_equal(var[3], np_value[1, ..., 1, None]))
714+
self.assertTrue(np.array_equal(var[4], np_value[2, ..., None, None]))
715+
self.assertTrue(np.array_equal(var[5], np_value[None, 2, 0, ...]))
716+
self.assertTrue(np.array_equal(var[6], np_value[None, 2, None, 1]))
717+
self.assertTrue(np.array_equal(var[7], np_value[None]))
718+
self.assertTrue(
719+
np.array_equal(var[8], np_value[0, 0, None, 0, 0, None]))
720+
721+
# TODO(zyfncg) there is a bug of dimensions when slice step > 1 and
722+
# indexs has int type
723+
# self.assertTrue(
724+
# np.array_equal(var[9], np_value[0, 1:10:2, None, None, ...]))
725+
692726
def _test_for_var(self):
693727
np_value = np.random.random((30, 100, 100)).astype('float32')
694728
w = fluid.dygraph.to_variable(np_value)
@@ -702,6 +736,7 @@ def test_slice(self):
702736
self._test_slice_for_tensor_attr()
703737
self._test_for_var()
704738
self._test_for_getitem_ellipsis_index()
739+
self._test_none_index()
705740

706741
var = fluid.dygraph.to_variable(self.array)
707742
self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :]))

0 commit comments

Comments
 (0)