-
Notifications
You must be signed in to change notification settings - Fork 5.9k
support tensor index. #34824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support tensor index. #34824
Conversation
|
✅ This PR's description meets the template requirements! |
|
Thanks for your contribution! |
|
|
||
| std::vector<int64_t> output_dim(input_dim.size() + index_dim.size() - 1); | ||
|
|
||
| for (int i = 0; i < static_cast<int>(output_dim.size()); i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use size_t i = 0 directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert the modification of index_select_op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM,动态图现在复用太多Python逻辑,后续需要解决由此引入的性能问题
好的
|
|
||
| if isinstance(item, np.ndarray): | ||
| return True | ||
| if not isinstance(item, (tuple, list)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to support set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert the modification of index_select_op.
| y = y * y | ||
| loss = y.sum() | ||
| loss.backward() | ||
| grad_torch = np.array([[[0., 2.], [4., 6.], [8., 10.]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename this var
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert the modification of index_select_op.
|
|
||
| # Remove Variable to skip bug when counting Ellipsis | ||
| item_remove_var = [ele for ele in item if not isinstance(ele, Variable)] | ||
| item_remove_var = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why need this skip
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
item_remove_var.count(Ellipsis)计算包含Ellipsis的个数,如果对象为Variable或者ndarray,count函数将报错。
| "to be in range of [-%d, %d]. But received Attr(dim) = %d.", | ||
| input_dim.size(), input_dim.size() - 1, dim)); | ||
|
|
||
| PADDLE_ENFORCE_EQ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
放开这个口子的话,我们是否需要一些别的检查?index_dim会不会有不合理的输入?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert the modification of index_select_op.
|
|
||
| def index_tensor(tensor, offsets, strides): | ||
| from . import layers | ||
| from .framework import Variable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not import in beginning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改为通过paddle.xx引用。
|
|
||
| def getitem_list_index(var, list_index): | ||
| from . import layers | ||
| from .framework import Variable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
|
|
||
|
|
||
| def setitem_list_index(var, index_list, value): | ||
| from . import layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
| for i in slice_item: | ||
| if not isinstance(i, (int, bool)): | ||
| raise TypeError("Only support int or bool in index list.") | ||
| if not isinstance(i, (int, bool, list)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it a tuple here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
| return False | ||
|
|
||
| if contain_tensor(item): | ||
| # 1. Call _getitem_impl_ when item contains tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get -> set
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
|
|
||
| def update(self, index): | ||
| if is_list_tuple(index, int) or isinstance( | ||
| index, (paddle.fluid.core.VarBase, paddle.fluid.Variable, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only keep Variable is ok, when in dygraph mode, VarBase is Variable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
| if is_list_tuple(index, int) or isinstance( | ||
| index, (paddle.fluid.core.VarBase, paddle.fluid.Variable, | ||
| np.ndarray)): # Tensor | ||
| if not isinstance(index, (paddle.fluid.core.VarBase, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above, fix all other places
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
| "When index contains a Tensor, its length must be 1, but received {}.". | ||
| format(len(item))) | ||
| elif isinstance(slice_item, np.ndarray): | ||
| # delete |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what delete mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot to delete the comment. Has deleted it.
| def contain_tensor(item): | ||
| if not isinstance(item, tuple): | ||
| item = [item] | ||
|
|
||
| for slice_item in item: | ||
| if isinstance(slice_item, slice): | ||
| if isinstance(slice_item.start, Variable) \ | ||
| or isinstance(slice_item.stop, Variable) \ | ||
| or isinstance(slice_item.step, Variable): | ||
| return True | ||
| else: | ||
| if isinstance(slice_item, Variable): | ||
| return True | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否可以和__getitem__共用一份contain_tensor代码 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
| "only support list/tensor index, but received {}.".format( | ||
| type(index))) | ||
|
|
||
| # if len(self.indexes)>1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
| return reduce(lambda x, y: x * y, shape) | ||
|
|
||
| def get_offset_stride(self, tensor_shape): | ||
| for i in range(len(self.indexes)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use index in self.indexes derectly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
|
|
||
| index_shape = [2, 3, 4, 5, 6] | ||
| index = np.arange(self.numel(index_shape)).reshape(index_shape) | ||
| for i in range(len(inps_shape) - 1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are test cases same in the loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
| index_shape = [3, 3, 2, 1] | ||
| index = np.arange(self.numel(index_shape)).reshape(index_shape) | ||
|
|
||
| for i in range(3): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are test cases same in the loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
| value_np = np.arange( | ||
| self.numel(value_shape), dtype='float32').reshape(value_shape) + 100 | ||
|
|
||
| for i in range(3): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
| value_np = np.arange( | ||
| self.numel(value_shape), dtype='float32').reshape(value_shape) + 100 | ||
|
|
||
| for i in range(4): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
| value_shape = [4] | ||
| value_np = np.arange( | ||
| self.numel(value_shape), dtype='float32').reshape(value_shape) + 100 | ||
| for zz_ in range(3): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
| index2 = np.arange( | ||
| self.numel(index_shape), dtype='int32').reshape(index_shape) + 2 | ||
|
|
||
| for zz_ in range(3): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
| index2 = np.arange( | ||
| self.numel(index_shape), dtype='int32').reshape(index_shape) + 2 | ||
|
|
||
| for zz_ in range(3): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
chenwhql
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM,动态图现在复用太多Python逻辑,后续需要解决由此引入的性能问题
PR types
Function optimization
PR changes
APIs
Describe
1.整体支持:
2.具体功能点支持:
3. 支持 tensor类型索引,示例如下: