Skip to content

Commit a225636

Browse files
authored
[static getitem]Support index is list bool for getitem in static mode (#33298)
1 parent 11b5776 commit a225636

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

python/paddle/fluid/tests/unittests/test_variable.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,34 @@ def _test_slice_index_ellipsis(self, place):
245245
with self.assertRaises(TypeError):
246246
res = x[[1.2, 0]]
247247

248+
def _test_slice_index_list_bool(self, place):
249+
data = np.random.rand(2, 3).astype("float32")
250+
prog = paddle.static.Program()
251+
with paddle.static.program_guard(prog):
252+
x = paddle.assign(data)
253+
idx0 = [True, False]
254+
idx1 = [False, True]
255+
idx2 = [False, False]
256+
idx3 = [True, True]
257+
258+
out0 = x[idx0]
259+
out1 = x[idx1]
260+
out2 = x[idx2]
261+
out3 = x[idx3]
262+
263+
exe = paddle.static.Executor(place)
264+
result = exe.run(prog, fetch_list=[out0, out1, out2, out3])
265+
266+
expected = [data[idx0], data[idx1], data[idx2], data[idx3]]
267+
268+
self.assertTrue((result[0] == expected[0]).all())
269+
self.assertTrue((result[1] == expected[1]).all())
270+
self.assertTrue((result[2] == expected[2]).all())
271+
self.assertTrue((result[3] == expected[3]).all())
272+
273+
with self.assertRaises(TypeError):
274+
res = x[[True, 0]]
275+
248276
def test_slice(self):
249277
places = [fluid.CPUPlace()]
250278
if core.is_compiled_with_cuda():
@@ -255,6 +283,7 @@ def test_slice(self):
255283
self._test_slice_index_tensor(place)
256284
self._test_slice_index_list(place)
257285
self._test_slice_index_ellipsis(place)
286+
self._test_slice_index_list_bool(place)
258287

259288
def _tostring(self):
260289
b = default_main_program().current_block()

python/paddle/fluid/variable_index.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,19 +140,36 @@ def _getitem_impl_(var, item):
140140
end = MAX_INTEGER if end is None else end
141141

142142
elif isinstance(slice_item, list):
143+
is_bool_list = False
143144
for i in slice_item:
144-
if not isinstance(i, int):
145-
raise TypeError("Only support int value in list")
145+
if not isinstance(i, (int, bool)):
146+
raise TypeError("Only support int or bool in index list.")
147+
148+
if isinstance(i, bool):
149+
is_bool_list = True
150+
break
146151

147152
if len(item) != 1:
148153
raise IndexError(
149154
"When index contains a list, its length must be 1, but received {}".
150155
format(len(item)))
151156

157+
if is_bool_list:
158+
new_slice_item = []
159+
for idx, ele in enumerate(slice_item):
160+
if not isinstance(ele, bool):
161+
raise TypeError(
162+
"Mixed bool index with other types is not supported."
163+
)
164+
165+
if ele is True:
166+
new_slice_item.append(idx)
167+
slice_item = new_slice_item
168+
152169
from .layers import assign
153170
from ..tensor import index_select
154171

155-
idx = assign(np.array(slice_item))
172+
idx = assign(np.array(slice_item).astype("int32"))
156173
return index_select(var, index=idx, axis=0)
157174

158175
elif isinstance(slice_item, Variable):

0 commit comments

Comments
 (0)