Skip to content

Commit ec94322

Browse files
committed
support tensor index.
1 parent d3d174f commit ec94322

File tree

4 files changed

+146
-15
lines changed

4 files changed

+146
-15
lines changed

paddle/fluid/operators/index_select_op.cc

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,32 @@ class IndexSelectOp : public framework::OperatorWithKernel {
4646
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
4747
input_dim.size(), input_dim.size() - 1, dim));
4848

49-
PADDLE_ENFORCE_EQ(
50-
index_dim.size() == 1 || (index_dim.size() == 2 && index_dim[1] == 1),
51-
true, platform::errors::InvalidArgument(
52-
"The 'shape' of Input(Index) must be 1-D tensor. "
53-
"But received: the 'shape' of Input(Index) is [%s], "
54-
"the dimension of Input(Index) is [%d].",
55-
index_dim, index_dim.size()));
56-
57-
auto output_dim = framework::vectorize(input_dim);
49+
// PADDLE_ENFORCE_EQ(
50+
// index_dim.size() == 1 || (index_dim.size() == 2 && index_dim[1] ==
51+
// 1),
52+
// true, platform::errors::InvalidArgument(
53+
// "The 'shape' of Input(Index) must be 1-D tensor. "
54+
// "But received: the 'shape' of Input(Index) is [%s], "
55+
// "the dimension of Input(Index) is [%d].",
56+
// index_dim, index_dim.size()));
57+
5858
if (dim < 0) {
5959
dim += input_dim.size();
6060
}
61-
output_dim[dim] = index_dim[0];
61+
// auto output_dim = framework::vectorize(input_dim);
62+
// output_dim[dim] = index_dim[0];
63+
std::vector<int64_t> output_dim(input_dim.size() + index_dim.size() - 1);
64+
65+
for (int i = 0; i < static_cast<int>(output_dim.size()); i++) {
66+
if (i < dim) {
67+
output_dim[i] = input_dim[i];
68+
} else if (i < dim + index_dim.size()) {
69+
output_dim[i] = index_dim[i - dim];
70+
} else {
71+
output_dim[i] = input_dim[i - index_dim.size() + 1];
72+
}
73+
}
74+
6275
ctx->SetOutputDim("Out", framework::make_ddim(output_dim));
6376
auto type = ctx->GetInputsVarType("X")[0];
6477
if (type == framework::proto::VarType::LOD_TENSOR) {

paddle/fluid/operators/index_select_op.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ class IndexSelectCUDAKernel : public framework::OpKernel<T> {
7777
dim = dim >= 0 ? dim : dim + input_dim.size();
7878
auto stride_dim = framework::stride(input_dim);
7979
int64_t stride = stride_dim[dim];
80-
int64_t size = output_dim[dim];
80+
int64_t size = 1;
81+
for (int i = 0; i < index->dims().size(); i++) {
82+
size *= output_dim[dim + i];
83+
}
8184
int64_t delta = input_dim[dim] - size;
8285

8386
const auto& index_type = index->type();
@@ -143,7 +146,12 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
143146
dim = dim >= 0 ? dim : dim + input_dim.size();
144147
auto stride_dim = framework::stride(input_dim);
145148
int64_t stride = stride_dim[dim];
146-
int64_t size = input_dim[dim];
149+
150+
int64_t size = 1;
151+
for (int i = 0; i < index->dims().size(); i++) {
152+
size *= input_dim[dim + i];
153+
}
154+
147155
int64_t delta = output_dim[dim] - size;
148156

149157
const auto& index_type = index->type();

paddle/fluid/operators/index_select_op.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,12 @@ void IndexSelectInner(const framework::ExecutionContext& context,
4646
for (auto i = 0; i < dim; i++) {
4747
outer_nums *= input_dim[i];
4848
}
49-
50-
auto index_size = index.dims()[0];
49+
int index_size = 1;
50+
auto index_dim = index.dims();
51+
for (int i = 0; i < index_dim.size(); i++) {
52+
index_size *= index.dims()[i];
53+
}
54+
// auto index_size = index.dims()[0];
5155

5256
std::vector<T> input_vec;
5357
std::vector<IndexT> index_vec;
@@ -179,7 +183,18 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
179183
outer_nums *= input_dim[i];
180184
}
181185

182-
auto index_size = index->dims()[0];
186+
// auto index_size = index->dims()[0];cmake .. -DPY_VERSION=3.7 -DWITH_GPU=ON
187+
// -DWITH_TESTING=ON -DCMAKE_BUILD_TYPE=Release
188+
int index_size = index->numel();
189+
190+
for (int i = 0; i < index_size; i++) {
191+
PADDLE_ENFORCE_LE(
192+
index_data[i], input_dim[dim],
193+
platform::errors::InvalidArgument(
194+
"Element of index should be less than %d, but received %d.",
195+
input_dim[dim], index_data[dim]));
196+
}
197+
183198
VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums
184199
<< "; slice_size: " << slice_size << "; input_width: " << input_width
185200
<< "; output_width: " << output_width

python/paddle/fluid/tests/unittests/test_index_select_op.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from op_test import OpTest
2222
import paddle.fluid as fluid
2323
from paddle.fluid import Program, program_guard
24+
from functools import reduce
2425

2526

2627
class TestIndexSelectOp(OpTest):
@@ -131,5 +132,99 @@ def test_dygraph_api(self):
131132
self.assertTrue(np.allclose(expect_out, np_z))
132133

133134

135+
class TestTensorIndex(unittest.TestCase):
136+
def numel(self, shape):
137+
return reduce(lambda x, y: x * y, shape)
138+
139+
def test_dygraph(self):
140+
paddle.disable_static()
141+
142+
inps_shape = [7, 6, 5, 4, 3]
143+
array = np.arange(self.numel(inps_shape)).reshape(inps_shape)
144+
145+
index_shape = [7, 6, 5, 4, 3]
146+
index = np.arange(self.numel(index_shape)).reshape(index_shape)
147+
for i in range(len(inps_shape) - 1):
148+
149+
pt = paddle.to_tensor(array)
150+
index_mod = index % (array.shape[0])
151+
pindex = paddle.to_tensor(index_mod)
152+
getitem_np = array[index_mod]
153+
getitem_pp = pt[pindex]
154+
self.assertTrue(np.array_equal(getitem_np, getitem_pp.numpy()))
155+
array = array[0]
156+
index = index[0]
157+
158+
def test_static_graph(self):
159+
paddle.enable_static()
160+
inps_shape = [1, 2, 3, 4, 5]
161+
array = np.arange(
162+
self.numel(inps_shape), dtype='float32').reshape(inps_shape)
163+
164+
index_shape = [7, 6, 5, 4, 3]
165+
index = np.arange(self.numel(index_shape)).reshape(index_shape)
166+
167+
for i in range(len(inps_shape) - 1):
168+
index_mod = index % (array.shape[0])
169+
program = paddle.static.Program()
170+
171+
with paddle.static.program_guard(program):
172+
x = paddle.static.data(
173+
name='x', shape=array.shape, dtype='float32')
174+
if i % 2 == 0:
175+
index_dtype = 'int32'
176+
index_mod = index_mod.astype('int32')
177+
else:
178+
index_dtype = 'int64'
179+
index_mod = index_mod.astype('int64')
180+
181+
index_p = paddle.static.data(
182+
name='index', shape=index.shape, dtype=index_dtype)
183+
184+
y = x[index_p]
185+
186+
place = paddle.fluid.CPUPlace(
187+
) if not paddle.fluid.core.is_compiled_with_cuda(
188+
) else paddle.fluid.CUDAPlace(0)
189+
190+
prog = paddle.static.default_main_program()
191+
exe = paddle.static.Executor(place)
192+
193+
exe.run(paddle.static.default_startup_program())
194+
fetch_list = [y.name]
195+
getitem_pp = exe.run(
196+
prog,
197+
feed={x.name: array,
198+
index_p.name: index_mod},
199+
fetch_list=fetch_list)
200+
201+
getitem_np = array[index_mod]
202+
self.assertTrue(np.array_equal(getitem_np, getitem_pp[0]))
203+
204+
array = array[0]
205+
index = index[0]
206+
207+
def test_backward(self):
208+
paddle.disable_static()
209+
array = np.arange(4 * 3 * 2, dtype='float32').reshape([4, 3, 2])
210+
index = [[1, 2], [0, 3]]
211+
212+
index_p = paddle.to_tensor(index)
213+
pt = paddle.to_tensor(array, stop_gradient=False)
214+
215+
y = pt[index_p]
216+
y = y * y
217+
loss = y.sum()
218+
loss.backward()
219+
grad_torch = np.array([[[0., 2.], [4., 6.], [8., 10.]],
220+
[[12., 14.], [16., 18.], [20., 22.]],
221+
[[24., 26.], [28., 30.],
222+
[32., 34.]], [[36., 38.], [40., 42.],
223+
[44., 46.]]])
224+
self.assertTrue(
225+
np.array_equal(pt.grad.numpy(), grad_torch),
226+
msg='grad of index_select_op:\n{}'.format(pt.grad.numpy()))
227+
228+
134229
if __name__ == '__main__':
135230
unittest.main()

0 commit comments

Comments
 (0)