-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Support getitem by None index in dynamic mode #34338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4eba782
48414a7
db60dc2
0fe6bcb
429eb8c
9f1a225
741e0cc
ca9fd0d
7df1f8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -420,6 +420,7 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, | |
| std::vector<int> *slice_ends, | ||
| std::vector<int> *slice_strides, | ||
| std::vector<int> *decrease_axis, | ||
| std::vector<int> *none_axes, | ||
| std::vector<int> *infer_flags) { | ||
| // We allow indexing by Integers, Slices, and tuples of those | ||
| // types. | ||
|
|
@@ -443,10 +444,6 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, | |
| } | ||
| } | ||
|
|
||
| PADDLE_ENFORCE_EQ( | ||
| size <= rank, true, | ||
| platform::errors::InvalidArgument( | ||
| "too many indices (%d) for tensor of dimension %d", size, rank)); | ||
| for (int i = 0, dim = 0; i < size; ++i) { | ||
| PyObject *slice_item = PyTuple_GetItem(index, i); | ||
|
|
||
|
|
@@ -491,14 +488,24 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, | |
| dim++; | ||
| } else if (slice_item == Py_Ellipsis) { | ||
| dim += rank - specified_dims; | ||
| } else if (slice_item == Py_None) { | ||
| none_axes->push_back(dim); | ||
| } else { | ||
| PADDLE_THROW(platform::errors::InvalidArgument( | ||
| "Currently, VarBase.__getitem__() only allows " | ||
| "indexing by Integers, Slices, Ellipsis, and tuples of " | ||
| "Currently, VarBase.__getitem__() only allows indexing" | ||
| "by Integers, Slices, Ellipsis, None and tuples of " | ||
| "these types, but received %s in %dth slice item", | ||
| std::string(Py_TYPE(slice_item)->tp_name), i + 1)); | ||
| } | ||
| } | ||
|
|
||
| // valid_index is the number of dimensions exclude None index | ||
| const int valid_indexs = size - none_axes->size(); | ||
| PADDLE_ENFORCE_EQ(valid_indexs <= rank, true, | ||
| platform::errors::InvalidArgument( | ||
| "Too many indices (%d) for tensor of dimension %d.", | ||
| valid_indexs, rank)); | ||
|
|
||
| if (!PyTuple_Check(_index)) Py_DecRef(index); | ||
| } | ||
|
|
||
|
|
@@ -790,9 +797,10 @@ void BindImperative(py::module *m_ptr) { | |
| // copys data to cpu place, which reduces performance. | ||
| if (parse_index && value_is_tensor) { | ||
| std::vector<int> axes, starts, ends, steps, decrease_axes, | ||
| infer_flags; | ||
| none_axes, infer_flags; | ||
| ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends, | ||
| &steps, &decrease_axes, &infer_flags); | ||
| &steps, &decrease_axes, &none_axes, | ||
| &infer_flags); | ||
|
|
||
| framework::AttributeMap attrs = { | ||
| {"axes", axes}, | ||
|
|
@@ -850,27 +858,29 @@ void BindImperative(py::module *m_ptr) { | |
| .def("_getitem_index_not_tensor", | ||
| [](std::shared_ptr<imperative::VarBase> &self, py::handle _index) { | ||
| std::vector<int> slice_axes, slice_starts, slice_ends, | ||
| slice_strides, decrease_axis, infer_flags; | ||
| slice_strides, decrease_axis, none_axes, infer_flags; | ||
| auto tensor = | ||
| self->MutableVar()->GetMutable<framework::LoDTensor>(); | ||
| ParseIndexingSlice(tensor, _index.ptr(), &slice_axes, | ||
| &slice_starts, &slice_ends, &slice_strides, | ||
| &decrease_axis, &infer_flags); | ||
| &decrease_axis, &none_axes, &infer_flags); | ||
| // release gil and do tracing | ||
| py::gil_scoped_release release; | ||
| const auto &tracer = imperative::GetCurrentTracer(); | ||
| if (slice_axes.empty()) { | ||
| return self; | ||
| } else { | ||
|
|
||
| auto out = slice_axes.empty() | ||
| ? self | ||
| : std::shared_ptr<imperative::VarBase>( | ||
| new imperative::VarBase( | ||
| tracer->GenerateUniqueName())); | ||
| if (!slice_axes.empty()) { | ||
| imperative::NameVarBaseMap ins = {{"Input", {self}}}; | ||
| framework::AttributeMap attrs = { | ||
| {"axes", slice_axes}, | ||
| {"starts", slice_starts}, | ||
| {"ends", slice_ends}, | ||
| {"infer_flags", infer_flags}, | ||
| {"decrease_axis", decrease_axis}}; | ||
| auto out = std::shared_ptr<imperative::VarBase>( | ||
| new imperative::VarBase(tracer->GenerateUniqueName())); | ||
| imperative::NameVarBaseMap outs = {{"Out", {out}}}; | ||
| std::string op_type = "slice"; | ||
| for (auto stride : slice_strides) { | ||
|
|
@@ -882,8 +892,50 @@ void BindImperative(py::module *m_ptr) { | |
| } | ||
| } | ||
| tracer->TraceOp(op_type, ins, outs, std::move(attrs)); | ||
| return out; | ||
| } | ||
| if (!none_axes.empty()) { | ||
| // Deal with cases when all axes are decreased. | ||
| // After slice, the shape of out is [1], which should have been | ||
| // [], but Paddle doesn't support scalar. | ||
| // In order to ensure the correctness of the final shape of out, | ||
| // one dimension of out needs to be decreased. | ||
| // For example: | ||
| // # x.shape: (2,3,4) | ||
| // out = x[0, 1, 1, None] # out.shape : (1) | ||
| if (static_cast<int>(decrease_axis.size()) == | ||
| tensor->dims().size()) { | ||
| none_axes.pop_back(); | ||
| } | ||
| if (!none_axes.empty()) { | ||
| // Deal with cases that decrease_axes is not empty | ||
| // For example: | ||
| // # x.shape: (2,3,4) | ||
| // out = x[0, 0:2, None] # out.shape : (2, 1, 4) | ||
| for (auto &axis : none_axes) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 统一下遍历的写法?上面是
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是对vector内变量进行更新,不需要下标信息,所以使用了变量引用的方式更新 |
||
| int len = 0; | ||
| for (int da : decrease_axis) { | ||
| if (da < axis) { | ||
| len++; | ||
| } | ||
| } | ||
| axis -= len; | ||
| } | ||
|
|
||
| imperative::NameVarBaseMap ins = {{"X", {out}}}; | ||
| framework::AttributeMap attrs = {{"axes", none_axes}}; | ||
| auto new_out = std::shared_ptr<imperative::VarBase>( | ||
| new imperative::VarBase(tracer->GenerateUniqueName())); | ||
| auto out_xshape = std::shared_ptr<imperative::VarBase>( | ||
| new imperative::VarBase(tracer->GenerateUniqueName())); | ||
| imperative::NameVarBaseMap outs = {{"Out", {new_out}}, | ||
| {"XShape", {out_xshape}}}; | ||
| tracer->TraceOp("unsqueeze2", ins, outs, std::move(attrs)); | ||
|
|
||
| return new_out; | ||
| } | ||
| } | ||
|
|
||
| return out; | ||
| }) | ||
| .def( | ||
| "_getitem_from_offset", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -689,6 +689,40 @@ def assert_getitem_ellipsis_index(var_tensor, var_np): | |
| assert_getitem_ellipsis_index(var_fp32, np_fp32_value) | ||
| assert_getitem_ellipsis_index(var_int, np_int_value) | ||
|
|
||
| def _test_none_index(self): | ||
| shape = (8, 64, 5, 256) | ||
| np_value = np.random.random(shape).astype('float32') | ||
| var_tensor = paddle.to_tensor(np_value) | ||
|
|
||
| var = [ | ||
| var_tensor[1, 0, None].numpy(), | ||
| var_tensor[None, ..., 1, 0].numpy(), | ||
| var_tensor[:, :, :, None].numpy(), | ||
| var_tensor[1, ..., 1, None].numpy(), | ||
| var_tensor[2, ..., None, None].numpy(), | ||
| var_tensor[None, 2, 0, ...].numpy(), | ||
| var_tensor[None, 2, None, 1].numpy(), | ||
| var_tensor[None].numpy(), | ||
| var_tensor[0, 0, None, 0, 0, None].numpy(), | ||
| var_tensor[0, 1:10:2, None, None, ...].numpy(), | ||
| ] | ||
|
|
||
| self.assertTrue(np.array_equal(var[0], np_value[1, 0, None])) | ||
| self.assertTrue(np.array_equal(var[1], np_value[None, ..., 1, 0])) | ||
| self.assertTrue(np.array_equal(var[2], np_value[:, :, :, None])) | ||
| self.assertTrue(np.array_equal(var[3], np_value[1, ..., 1, None])) | ||
| self.assertTrue(np.array_equal(var[4], np_value[2, ..., None, None])) | ||
| self.assertTrue(np.array_equal(var[5], np_value[None, 2, 0, ...])) | ||
| self.assertTrue(np.array_equal(var[6], np_value[None, 2, None, 1])) | ||
| self.assertTrue(np.array_equal(var[7], np_value[None])) | ||
| self.assertTrue( | ||
| np.array_equal(var[8], np_value[0, 0, None, 0, 0, None])) | ||
|
|
||
| # TODO(zyfncg) there is a bug of dimensions when slice step > 1 and | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个bug触发的时候报错是怎样的,用户可以get到这里有bug,暂时不能使用吗?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这是一个功能上与numpy不一致的问题,以例子说明: 当slice的step>1时,Paddle得到的结果中保留了长度为1的维度,在Numpy中则是去掉了这些维度,这个问题计划在后续的PR中进行修复 |
||
| # indexs has int type | ||
| # self.assertTrue( | ||
| # np.array_equal(var[9], np_value[0, 1:10:2, None, None, ...])) | ||
|
|
||
| def _test_for_var(self): | ||
| np_value = np.random.random((30, 100, 100)).astype('float32') | ||
| w = fluid.dygraph.to_variable(np_value) | ||
|
|
@@ -702,6 +736,7 @@ def test_slice(self): | |
| self._test_slice_for_tensor_attr() | ||
| self._test_for_var() | ||
| self._test_for_getitem_ellipsis_index() | ||
| self._test_none_index() | ||
|
|
||
| var = fluid.dygraph.to_variable(self.array) | ||
| self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :])) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是
x[0, 1, 1, None, None]会怎么样There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x[0, 1, 1, None, None]的结果为[[element]],二维数组,与Numpy结果一致这里去掉一个None的原因在于Paddle中目前没有0维向量,当只有一个数值时,Paddle的结果仍为1维向量,如果再使用None索引升维,结果会比正常的维度多出一维,所以需要去掉一个None来保持结果维度数量的正确。