Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,9 @@ def contain_tensor(item):
or isinstance(slice_item.step, Variable):
return True
else:
if isinstance(slice_item,
Variable) and Variable.dtype != paddle.bool:
if isinstance(
slice_item,
(Variable, np.ndarray)) and Variable.dtype != paddle.bool:
return True
return False

Expand Down
9 changes: 9 additions & 0 deletions python/paddle/fluid/tests/unittests/test_var_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,11 @@ def _test_list_index(self):
[0., 0., 42., 42., 42., 0.]])
self.assertTrue(np.array_equal(res, exp))

# case3:
row = np.array([0, 1, 2])
col = np.array([2, 1, 3])
self.assertTrue(np.array_equal(array[row, col], x[row, col].numpy()))

def test_slice(self):
with fluid.dygraph.guard():
self._test_slice()
Expand All @@ -834,6 +839,10 @@ def test_slice(self):
with self.assertRaises(IndexError):
y = var[0 - self.shape[0] - 1]

with self.assertRaises(IndexError):
mask = np.array([1, 0, 1, 0], dtype=bool)
var[paddle.to_tensor([0, 1]), mask]

def test_var_base_to_np(self):
with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array)
Expand Down
28 changes: 21 additions & 7 deletions python/paddle/fluid/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class SliceInfo:
def __init__(self):
self.pre_shape = None
self.indexes = []
self.dtype = None

def update(self, index):
if is_list_tuple(index, int) or isinstance(index, (
Expand All @@ -75,6 +76,14 @@ def update(self, index):
if not isinstance(index, paddle.fluid.Variable):
index = paddle.assign(index)

if self.dtype is None:
self.dtype = index.dtype
else:
if index.dtype != self.dtype:
raise IndexError(
"Data type of Tensor/List index should be same. The current data type is {}, but the previous data type is {}.".
format(index.dtype, self.dtype))

self.indexes.append(index)

if self.pre_shape is None:
Expand Down Expand Up @@ -214,6 +223,16 @@ def replace_ellipsis(var, item):
return item


def replace_ndarray(item):
new_item = []
for slice_item in item:
if isinstance(slice_item, np.ndarray):
new_item.append(paddle.assign(slice_item))
else:
new_item.append(slice_item)
return new_item


def replace_none(item):
new_item = []
none_axes = []
Expand Down Expand Up @@ -278,6 +297,7 @@ def _getitem_impl_(var, item):
reverse_axes = []

use_strided_slice = False
item = replace_ndarray(item)
item, none_axes = replace_none(item)
item = replace_ellipsis(var, item)
slice_info = SliceInfo()
Expand Down Expand Up @@ -361,9 +381,6 @@ def _getitem_impl_(var, item):
idx = assign(np.array(slice_item).astype("int32"))
return index_select(var, index=idx, axis=0)

elif isinstance(slice_item, np.ndarray):
slice_info.update(slice_item)
continue
elif isinstance(slice_item, (Variable)):
if len(item) == 1:

Expand Down Expand Up @@ -499,6 +516,7 @@ def _setitem_impl_(var, item, value):
ends = []
steps = []

item = replace_ndarray(item)
item, none_axes = replace_none(item)
item = replace_ellipsis(var, item)
slice_info = SliceInfo()
Expand Down Expand Up @@ -556,10 +574,6 @@ def _setitem_impl_(var, item, value):
idx_tensor = assign(slice_item)
return set_value_for_bool_tensor(var, idx_tensor, value)

elif isinstance(slice_item, np.ndarray):
slice_info.update(slice_item)
continue

elif isinstance(slice_item, Variable):
if slice_item.dtype == core.VarDesc.VarType.BOOL:
if len(item) != 1:
Expand Down