diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index f47afddde84f0a..87f9ba720abb63 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -218,6 +218,26 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem): ) +def slice_is_same_to_original(start, end, step): + if start is None and end is None and step is None: + return True + + # If there is Variable, we cannot determine whether it is the same to original. + if isinstance( + start, (paddle.base.Variable, paddle.pir.Value, paddle.pir.OpResult) + ): + return False + if isinstance( + end, (paddle.base.Variable, paddle.pir.Value, paddle.pir.OpResult) + ): + return False + if isinstance( + step, (paddle.base.Variable, paddle.pir.Value, paddle.pir.OpResult) + ): + return False + return start == 0 and end == MAX_INTEGER and step == 1 + + def parse_index(x, indices): advanced_index = [None] * 2 * len(x.shape) # content is (dim, index) # for set_value / slice / strided_slice OP @@ -282,9 +302,10 @@ def parse_index(x, indices): start = slice_item.start end = slice_item.stop step = slice_item.step - estimated_dim += 1 - dim += 1 + if start is None and end is None and step is None: + estimated_dim += 1 + dim += 1 continue step = 1 if step is None else step @@ -293,6 +314,16 @@ def parse_index(x, indices): if end is None: end = MAX_INTEGER if step > 0 else -1 + if not ( + is_tensor_array + or isinstance(end, (paddle.base.Variable, paddle.pir.Value)) + or isinstance(step, (paddle.base.Variable, paddle.pir.Value)) + ): + if x.shape[dim] != -1 and end >= x.shape[dim]: + end = MAX_INTEGER if step > 0 else -1 + estimated_dim += 1 + dim += 1 + elif isinstance(slice_item, (list, tuple)): advanced_index[estimated_dim] = ( estimated_dim, @@ -355,7 +386,7 @@ def parse_index(x, indices): slice_item ) ) - if not (start is None or end is None or step is None): + if not slice_is_same_to_original(start, end, step): starts.append(start) ends.append(end) steps.append(step)