Skip to content

Commit 21ef318

Browse files
hbwx24AnnaTrainingG
authored andcommitted
support numpy.ndarray index. (PaddlePaddle#35748)
* support numpy.ndarray index. * polish code.
1 parent d414e59 commit 21ef318

File tree

3 files changed

+33
-9
lines changed

3 files changed

+33
-9
lines changed

python/paddle/fluid/dygraph/varbase_patch_methods.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,9 @@ def contain_tensor(item):
553553
or isinstance(slice_item.step, Variable):
554554
return True
555555
else:
556-
if isinstance(slice_item,
557-
Variable) and Variable.dtype != paddle.bool:
556+
if isinstance(
557+
slice_item,
558+
(Variable, np.ndarray)) and Variable.dtype != paddle.bool:
558559
return True
559560
return False
560561

python/paddle/fluid/tests/unittests/test_var_base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,11 @@ def _test_list_index(self):
813813
[0., 0., 42., 42., 42., 0.]])
814814
self.assertTrue(np.array_equal(res, exp))
815815

816+
# case3:
817+
row = np.array([0, 1, 2])
818+
col = np.array([2, 1, 3])
819+
self.assertTrue(np.array_equal(array[row, col], x[row, col].numpy()))
820+
816821
def test_slice(self):
817822
with fluid.dygraph.guard():
818823
self._test_slice()
@@ -834,6 +839,10 @@ def test_slice(self):
834839
with self.assertRaises(IndexError):
835840
y = var[0 - self.shape[0] - 1]
836841

842+
with self.assertRaises(IndexError):
843+
mask = np.array([1, 0, 1, 0], dtype=bool)
844+
var[paddle.to_tensor([0, 1]), mask]
845+
837846
def test_var_base_to_np(self):
838847
with fluid.dygraph.guard():
839848
var = fluid.dygraph.to_variable(self.array)

python/paddle/fluid/variable_index.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class SliceInfo:
6767
def __init__(self):
6868
self.pre_shape = None
6969
self.indexes = []
70+
self.dtype = None
7071

7172
def update(self, index):
7273
if is_list_tuple(index, int) or isinstance(index, (
@@ -75,6 +76,14 @@ def update(self, index):
7576
if not isinstance(index, paddle.fluid.Variable):
7677
index = paddle.assign(index)
7778

79+
if self.dtype is None:
80+
self.dtype = index.dtype
81+
else:
82+
if index.dtype != self.dtype:
83+
raise IndexError(
84+
"Data type of Tensor/List index should be same. The current data type is {}, but the previous data type is {}.".
85+
format(index.dtype, self.dtype))
86+
7887
self.indexes.append(index)
7988

8089
if self.pre_shape is None:
@@ -214,6 +223,16 @@ def replace_ellipsis(var, item):
214223
return item
215224

216225

226+
def replace_ndarray(item):
227+
new_item = []
228+
for slice_item in item:
229+
if isinstance(slice_item, np.ndarray):
230+
new_item.append(paddle.assign(slice_item))
231+
else:
232+
new_item.append(slice_item)
233+
return new_item
234+
235+
217236
def replace_none(item):
218237
new_item = []
219238
none_axes = []
@@ -278,6 +297,7 @@ def _getitem_impl_(var, item):
278297
reverse_axes = []
279298

280299
use_strided_slice = False
300+
item = replace_ndarray(item)
281301
item, none_axes = replace_none(item)
282302
item = replace_ellipsis(var, item)
283303
slice_info = SliceInfo()
@@ -361,9 +381,6 @@ def _getitem_impl_(var, item):
361381
idx = assign(np.array(slice_item).astype("int32"))
362382
return index_select(var, index=idx, axis=0)
363383

364-
elif isinstance(slice_item, np.ndarray):
365-
slice_info.update(slice_item)
366-
continue
367384
elif isinstance(slice_item, (Variable)):
368385
if len(item) == 1:
369386

@@ -499,6 +516,7 @@ def _setitem_impl_(var, item, value):
499516
ends = []
500517
steps = []
501518

519+
item = replace_ndarray(item)
502520
item, none_axes = replace_none(item)
503521
item = replace_ellipsis(var, item)
504522
slice_info = SliceInfo()
@@ -556,10 +574,6 @@ def _setitem_impl_(var, item, value):
556574
idx_tensor = assign(slice_item)
557575
return set_value_for_bool_tensor(var, idx_tensor, value)
558576

559-
elif isinstance(slice_item, np.ndarray):
560-
slice_info.update(slice_item)
561-
continue
562-
563577
elif isinstance(slice_item, Variable):
564578
if slice_item.dtype == core.VarDesc.VarType.BOOL:
565579
if len(item) != 1:

0 commit comments

Comments
 (0)