@@ -420,6 +420,7 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
420420 std::vector<int > *slice_ends,
421421 std::vector<int > *slice_strides,
422422 std::vector<int > *decrease_axis,
423+ std::vector<int > *none_axes,
423424 std::vector<int > *infer_flags) {
424425 // We allow indexing by Integers, Slices, and tuples of those
425426 // types.
@@ -443,10 +444,6 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
443444 }
444445 }
445446
446- PADDLE_ENFORCE_EQ (
447- size <= rank, true ,
448- platform::errors::InvalidArgument (
449- " too many indices (%d) for tensor of dimension %d" , size, rank));
450447 for (int i = 0 , dim = 0 ; i < size; ++i) {
451448 PyObject *slice_item = PyTuple_GetItem (index, i);
452449
@@ -491,14 +488,24 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
491488 dim++;
492489 } else if (slice_item == Py_Ellipsis) {
493490 dim += rank - specified_dims;
491+ } else if (slice_item == Py_None) {
492+ none_axes->push_back (dim);
494493 } else {
495494 PADDLE_THROW (platform::errors::InvalidArgument (
496- " Currently, VarBase.__getitem__() only allows "
497- " indexing by Integers, Slices, Ellipsis, and tuples of "
495+ " Currently, VarBase.__getitem__() only allows indexing "
496+ " by Integers, Slices, Ellipsis, None and tuples of "
498497 " these types, but received %s in %dth slice item" ,
499498 std::string (Py_TYPE (slice_item)->tp_name ), i + 1 ));
500499 }
501500 }
501+
502+ // valid_index is the number of dimensions exclude None index
503+ const int valid_indexs = size - none_axes->size ();
504+ PADDLE_ENFORCE_EQ (valid_indexs <= rank, true ,
505+ platform::errors::InvalidArgument (
506+ " Too many indices (%d) for tensor of dimension %d." ,
507+ valid_indexs, rank));
508+
502509 if (!PyTuple_Check (_index)) Py_DecRef (index);
503510}
504511
@@ -790,9 +797,10 @@ void BindImperative(py::module *m_ptr) {
790797 // copys data to cpu place, which reduces performance.
791798 if (parse_index && value_is_tensor) {
792799 std::vector<int > axes, starts, ends, steps, decrease_axes,
793- infer_flags;
800+ none_axes, infer_flags;
794801 ParseIndexingSlice (self_tensor, index_ptr, &axes, &starts, &ends,
795- &steps, &decrease_axes, &infer_flags);
802+ &steps, &decrease_axes, &none_axes,
803+ &infer_flags);
796804
797805 framework::AttributeMap attrs = {
798806 {" axes" , axes},
@@ -850,27 +858,29 @@ void BindImperative(py::module *m_ptr) {
850858 .def (" _getitem_index_not_tensor" ,
851859 [](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
852860 std::vector<int > slice_axes, slice_starts, slice_ends,
853- slice_strides, decrease_axis, infer_flags;
861+ slice_strides, decrease_axis, none_axes, infer_flags;
854862 auto tensor =
855863 self->MutableVar ()->GetMutable <framework::LoDTensor>();
856864 ParseIndexingSlice (tensor, _index.ptr (), &slice_axes,
857865 &slice_starts, &slice_ends, &slice_strides,
858- &decrease_axis, &infer_flags);
866+ &decrease_axis, &none_axes, & infer_flags);
859867 // release gil and do tracing
860868 py::gil_scoped_release release;
861869 const auto &tracer = imperative::GetCurrentTracer ();
862- if (slice_axes.empty ()) {
863- return self;
864- } else {
870+
871+ auto out = slice_axes.empty ()
872+ ? self
873+ : std::shared_ptr<imperative::VarBase>(
874+ new imperative::VarBase (
875+ tracer->GenerateUniqueName ()));
876+ if (!slice_axes.empty ()) {
865877 imperative::NameVarBaseMap ins = {{" Input" , {self}}};
866878 framework::AttributeMap attrs = {
867879 {" axes" , slice_axes},
868880 {" starts" , slice_starts},
869881 {" ends" , slice_ends},
870882 {" infer_flags" , infer_flags},
871883 {" decrease_axis" , decrease_axis}};
872- auto out = std::shared_ptr<imperative::VarBase>(
873- new imperative::VarBase (tracer->GenerateUniqueName ()));
874884 imperative::NameVarBaseMap outs = {{" Out" , {out}}};
875885 std::string op_type = " slice" ;
876886 for (auto stride : slice_strides) {
@@ -882,8 +892,50 @@ void BindImperative(py::module *m_ptr) {
882892 }
883893 }
884894 tracer->TraceOp (op_type, ins, outs, std::move (attrs));
885- return out;
886895 }
896+ if (!none_axes.empty ()) {
897+ // Deal with cases when all axes are decreased.
898+ // After slice, the shape of out is [1], which should have been
899+ // [], but Paddle doesn't support scalar.
900+ // In order to ensure the correctness of the final shape of out,
901+ // one dimension of out needs to be decreased.
902+ // For example:
903+ // # x.shape: (2,3,4)
904+ // out = x[0, 1, 1, None] # out.shape : (1)
905+ if (static_cast <int >(decrease_axis.size ()) ==
906+ tensor->dims ().size ()) {
907+ none_axes.pop_back ();
908+ }
909+ if (!none_axes.empty ()) {
910+ // Deal with cases that decrease_axes is not empty
911+ // For example:
912+ // # x.shape: (2,3,4)
913+ // out = x[0, 0:2, None] # out.shape : (2, 1, 4)
914+ for (auto &axis : none_axes) {
915+ int len = 0 ;
916+ for (int da : decrease_axis) {
917+ if (da < axis) {
918+ len++;
919+ }
920+ }
921+ axis -= len;
922+ }
923+
924+ imperative::NameVarBaseMap ins = {{" X" , {out}}};
925+ framework::AttributeMap attrs = {{" axes" , none_axes}};
926+ auto new_out = std::shared_ptr<imperative::VarBase>(
927+ new imperative::VarBase (tracer->GenerateUniqueName ()));
928+ auto out_xshape = std::shared_ptr<imperative::VarBase>(
929+ new imperative::VarBase (tracer->GenerateUniqueName ()));
930+ imperative::NameVarBaseMap outs = {{" Out" , {new_out}},
931+ {" XShape" , {out_xshape}}};
932+ tracer->TraceOp (" unsqueeze2" , ins, outs, std::move (attrs));
933+
934+ return new_out;
935+ }
936+ }
937+
938+ return out;
887939 })
888940 .def (
889941 " _getitem_from_offset" ,
0 commit comments