@@ -218,6 +218,26 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem):
218218 )
219219
220220
221+ def slice_is_same_to_original (start , end , step ):
222+ if start is None and end is None and step is None :
223+ return True
224+
225+ # If there is Variable, we cannot determine whether it is the same to original.
226+ if isinstance (
227+ start , (paddle .base .Variable , paddle .pir .Value , paddle .pir .OpResult )
228+ ):
229+ return False
230+ if isinstance (
231+ end , (paddle .base .Variable , paddle .pir .Value , paddle .pir .OpResult )
232+ ):
233+ return False
234+ if isinstance (
235+ step , (paddle .base .Variable , paddle .pir .Value , paddle .pir .OpResult )
236+ ):
237+ return False
238+ return start == 0 and end == MAX_INTEGER and step == 1
239+
240+
221241def parse_index (x , indices ):
222242 advanced_index = [None ] * 2 * len (x .shape ) # content is (dim, index)
223243 # for set_value / slice / strided_slice OP
@@ -282,18 +302,26 @@ def parse_index(x, indices):
282302 start = slice_item .start
283303 end = slice_item .stop
284304 step = slice_item .step
285- estimated_dim += 1
286- dim += 1
305+
287306 if start is None and end is None and step is None :
307+ estimated_dim += 1
308+ dim += 1
288309 continue
289310
290311 step = 1 if step is None else step
291312 if start is None :
292313 start = 0 if step > 0 else MAX_INTEGER
293314 if end is None :
294315 end = MAX_INTEGER if step > 0 else - 1
295- if x .shape [dim ] != - 1 and end >= x .shape [dim ]:
296- end = MAX_INTEGER
316+
317+ if not (
318+ isinstance (end , (paddle .base .Variable , paddle .pir .Value ))
319+ or isinstance (step , (paddle .base .Variable , paddle .pir .Value ))
320+ ):
321+ if x .shape [dim ] != - 1 and end >= x .shape [dim ]:
322+ end = MAX_INTEGER if step > 0 else - 1
323+ estimated_dim += 1
324+ dim += 1
297325
298326 elif isinstance (slice_item , (list , tuple )):
299327 advanced_index [estimated_dim ] = (
@@ -357,10 +385,7 @@ def parse_index(x, indices):
357385 slice_item
358386 )
359387 )
360- if not (
361- (start is None or end is None or step is None )
362- or (start == 0 and end == MAX_INTEGER and step == 1 )
363- ):
388+ if not slice_is_same_to_original (start , end , step ):
364389 starts .append (start )
365390 ends .append (end )
366391 steps .append (step )
0 commit comments