Skip to content

Conversation

@hbwx24
Copy link
Contributor

@hbwx24 hbwx24 commented Aug 11, 2021

PR types

Function optimization

PR changes

APIs

Describe

1.整体支持:

image

image

2.具体功能点支持:

image
image
image
image
image

3. 支持 tensor类型索引,示例如下:

array = np.arange(4*3*2).reshape([4, 3, 2])
value = np.arange(12*3).reshape([3, 2, 3, 2])
index = [[0, 0], [3, 1]]

index_t = paddle.to_tensor(index)
index_np = np.array(index)
tt = paddle.to_tensor(array)

plist = paddle.index_select(tt, index_t, axis=0)

nplist = array[index_np]

print(np.array_equal(plist.numpy(), nplist))

@paddle-bot-old
Copy link

paddle-bot-old bot commented Aug 11, 2021

✅ This PR's description meets the template requirements!
Please wait for other CI results.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.


std::vector<int64_t> output_dim(input_dim.size() + index_dim.size() - 1);

for (int i = 0; i < static_cast<int>(output_dim.size()); i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use size_t i = 0 directly?

Copy link
Contributor Author

@hbwx24 hbwx24 Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert the modification of index_select_op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM,动态图现在复用太多Python逻辑,后续需要解决由此引入的性能问题

好的


if isinstance(item, np.ndarray):
return True
if not isinstance(item, (tuple, list)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to support set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert the modification of index_select_op.

y = y * y
loss = y.sum()
loss.backward()
grad_torch = np.array([[[0., 2.], [4., 6.], [8., 10.]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename this var

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert the modification of index_select_op.


# Remove Variable to skip bug when counting Ellipsis
item_remove_var = [ele for ele in item if not isinstance(ele, Variable)]
item_remove_var = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need this skip

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

item_remove_var.count(Ellipsis)计算包含Ellipsis的个数,如果对象为Variable或者ndarray,count函数将报错。

"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_dim.size(), input_dim.size() - 1, dim));

PADDLE_ENFORCE_EQ(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放开这个口子的话,我们是否需要一些别的检查?index_dim会不会有不合理的输入?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert the modification of index_select_op.


def index_tensor(tensor, offsets, strides):
from . import layers
from .framework import Variable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not import in beginning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为通过paddle.xx引用。


def getitem_list_index(var, list_index):
from . import layers
from .framework import Variable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.



def setitem_list_index(var, index_list, value):
from . import layers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

for i in slice_item:
if not isinstance(i, (int, bool)):
raise TypeError("Only support int or bool in index list.")
if not isinstance(i, (int, bool, list)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it a tuple here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

return False

if contain_tensor(item):
# 1. Call _getitem_impl_ when item contains tensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get -> set

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.


def update(self, index):
if is_list_tuple(index, int) or isinstance(
index, (paddle.fluid.core.VarBase, paddle.fluid.Variable,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only keep Variable is ok, when in dygraph mode, VarBase is Variable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

if is_list_tuple(index, int) or isinstance(
index, (paddle.fluid.core.VarBase, paddle.fluid.Variable,
np.ndarray)): # Tensor
if not isinstance(index, (paddle.fluid.core.VarBase,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above, fix all other places

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

"When index contains a Tensor, its length must be 1, but received {}.".
format(len(item)))
elif isinstance(slice_item, np.ndarray):
# delete
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what delete mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to delete the comment. Has deleted it.

Comment on lines 590 to 603
def contain_tensor(item):
if not isinstance(item, tuple):
item = [item]

for slice_item in item:
if isinstance(slice_item, slice):
if isinstance(slice_item.start, Variable) \
or isinstance(slice_item.stop, Variable) \
or isinstance(slice_item.step, Variable):
return True
else:
if isinstance(slice_item, Variable):
return True
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否可以和__getitem__共用一份contain_tensor代码 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

"only support list/tensor index, but received {}.".format(
type(index)))

# if len(self.indexes)>1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

return reduce(lambda x, y: x * y, shape)

def get_offset_stride(self, tensor_shape):
for i in range(len(self.indexes)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use index in self.indexes derectly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.


index_shape = [2, 3, 4, 5, 6]
index = np.arange(self.numel(index_shape)).reshape(index_shape)
for i in range(len(inps_shape) - 1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are test cases same in the loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

index_shape = [3, 3, 2, 1]
index = np.arange(self.numel(index_shape)).reshape(index_shape)

for i in range(3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are test cases same in the loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

value_np = np.arange(
self.numel(value_shape), dtype='float32').reshape(value_shape) + 100

for i in range(3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

value_np = np.arange(
self.numel(value_shape), dtype='float32').reshape(value_shape) + 100

for i in range(4):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

value_shape = [4]
value_np = np.arange(
self.numel(value_shape), dtype='float32').reshape(value_shape) + 100
for zz_ in range(3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

index2 = np.arange(
self.numel(index_shape), dtype='int32').reshape(index_shape) + 2

for zz_ in range(3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

index2 = np.arange(
self.numel(index_shape), dtype='int32').reshape(index_shape) + 2

for zz_ in range(3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM,动态图现在复用太多Python逻辑,后续需要解决由此引入的性能问题

@hbwx24 hbwx24 merged commit e7df47e into PaddlePaddle:develop Aug 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants