Skip to content

Commit a44e8a5

Browse files
committed
fix code
1 parent 8066878 commit a44e8a5

1 file changed

Lines changed: 33 additions & 8 deletions

File tree

python/paddle/base/variable_index.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,26 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem):
218218
)
219219

220220

221+
def slice_is_same_to_original(start, end, step):
222+
if start is None and end is None and step is None:
223+
return True
224+
225+
# If there is Variable, we cannot determine whether it is the same to original.
226+
if isinstance(
227+
start, (paddle.base.Variable, paddle.pir.Value, paddle.pir.OpResult)
228+
):
229+
return False
230+
if isinstance(
231+
end, (paddle.base.Variable, paddle.pir.Value, paddle.pir.OpResult)
232+
):
233+
return False
234+
if isinstance(
235+
step, (paddle.base.Variable, paddle.pir.Value, paddle.pir.OpResult)
236+
):
237+
return False
238+
return start == 0 and end == MAX_INTEGER and step == 1
239+
240+
221241
def parse_index(x, indices):
222242
advanced_index = [None] * 2 * len(x.shape) # content is (dim, index)
223243
# for set_value / slice / strided_slice OP
@@ -282,18 +302,26 @@ def parse_index(x, indices):
282302
start = slice_item.start
283303
end = slice_item.stop
284304
step = slice_item.step
285-
estimated_dim += 1
286-
dim += 1
305+
287306
if start is None and end is None and step is None:
307+
estimated_dim += 1
308+
dim += 1
288309
continue
289310

290311
step = 1 if step is None else step
291312
if start is None:
292313
start = 0 if step > 0 else MAX_INTEGER
293314
if end is None:
294315
end = MAX_INTEGER if step > 0 else -1
295-
if x.shape[dim] != -1 and end >= x.shape[dim]:
296-
end = MAX_INTEGER
316+
317+
if not (
318+
isinstance(end, (paddle.base.Variable, paddle.pir.Value))
319+
or isinstance(step, (paddle.base.Variable, paddle.pir.Value))
320+
):
321+
if x.shape[dim] != -1 and end >= x.shape[dim]:
322+
end = MAX_INTEGER if step > 0 else -1
323+
estimated_dim += 1
324+
dim += 1
297325

298326
elif isinstance(slice_item, (list, tuple)):
299327
advanced_index[estimated_dim] = (
@@ -357,10 +385,7 @@ def parse_index(x, indices):
357385
slice_item
358386
)
359387
)
360-
if not (
361-
(start is None or end is None or step is None)
362-
or (start == 0 and end == MAX_INTEGER and step == 1)
363-
):
388+
if not slice_is_same_to_original(start, end, step):
364389
starts.append(start)
365390
ends.append(end)
366391
steps.append(step)

0 commit comments

Comments
 (0)