Skip to content

Commit 715b5a5

Browse files
committed
update slice-check
1 parent 85e6ece commit 715b5a5

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

paddle/fluid/pybind/slice_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,6 +1240,11 @@ static void ApplyGetitem(const int index_size,
12401240
&transed_index_int64);
12411241

12421242
AdvancedIndex ad = AdvancedIndex(*transed_tensor, transed_index_int64);
1243+
// is_combined:
1244+
// Distinguishes between regular indexing (single index) and combined
1245+
// indexing (multiple indices). When false (single index case), enables
1246+
// optimized backward pass using IndexPutWithSortKernel for better
1247+
// performance.
12431248
const bool is_combined = (index_size == 1) ? false : true;
12441249
const bool accumulate = true;
12451250
*out = index_elementwise_get_ad_func(*self_tensor,

paddle/phi/ops/yaml/op_compat.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2032,9 +2032,10 @@
20322032
- op : index_elementwise_get
20332033
backward : index_elementwise_get_grad, index_elementwise_get_double_grad
20342034
inputs :
2035-
{x : x, index : index, input_dims : input_dims, input_strides : input_strides, index_dims : index_dims, index_stride : index_stride, slice_offset : slice_offset, accumulate : accumulate, is_combined : is_combined}
2035+
{x : x, index : index, input_dims : input_dims, input_strides : input_strides, index_dims : index_dims, index_stride : index_stride}
20362036
outputs :
20372037
out : Out
2038+
attrs : {slice_offset : slice_offset, accumulate : accumulate, is_combined : is_combined}
20382039

20392040
- op : index_sample
20402041
inputs :

0 commit comments

Comments
 (0)