Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,26 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem):
)


def slice_is_same_to_original(start, end, step):
if start is None and end is None and step is None:
return True

# If there is Variable, we cannot determine whether it is the same to original.
if isinstance(
start, (paddle.base.Variable, paddle.pir.Value, paddle.pir.OpResult)
):
return False
if isinstance(
end, (paddle.base.Variable, paddle.pir.Value, paddle.pir.OpResult)
):
return False
if isinstance(
step, (paddle.base.Variable, paddle.pir.Value, paddle.pir.OpResult)
):
return False
return start == 0 and end == MAX_INTEGER and step == 1


def parse_index(x, indices):
advanced_index = [None] * 2 * len(x.shape) # content is (dim, index)
# for set_value / slice / strided_slice OP
Expand Down Expand Up @@ -282,9 +302,10 @@ def parse_index(x, indices):
start = slice_item.start
end = slice_item.stop
step = slice_item.step
estimated_dim += 1
dim += 1

if start is None and end is None and step is None:
estimated_dim += 1
dim += 1
continue

step = 1 if step is None else step
Expand All @@ -293,6 +314,16 @@ def parse_index(x, indices):
if end is None:
end = MAX_INTEGER if step > 0 else -1

if not (
is_tensor_array
or isinstance(end, (paddle.base.Variable, paddle.pir.Value))
or isinstance(step, (paddle.base.Variable, paddle.pir.Value))
):
if x.shape[dim] != -1 and end >= x.shape[dim]:
end = MAX_INTEGER if step > 0 else -1
estimated_dim += 1
dim += 1

elif isinstance(slice_item, (list, tuple)):
advanced_index[estimated_dim] = (
estimated_dim,
Expand Down Expand Up @@ -355,7 +386,7 @@ def parse_index(x, indices):
slice_item
)
)
if not (start is None or end is None or step is None):
if not slice_is_same_to_original(start, end, step):
starts.append(start)
ends.append(end)
steps.append(step)
Expand Down