diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index dfe41a7e006e96..40455c8fe55c2b 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1601,19 +1601,10 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self, &trans_dim, &out_is_view); - bool has_bool_index = false; - for (auto& index : transed_index) { - if (index.dtype() == phi::DataType::BOOL) { - has_bool_index = true; - } - } const int index_size = PyTuple_GET_SIZE(index_ptr); - const bool is_combined_bool = has_bool_index && index_size > 1; - ApplyGetitem(index_size, pos_of_new_dim, rank_of_new_dim, - is_combined_bool, &transed_index, &tensor, &self->tensor, diff --git a/paddle/fluid/pybind/slice_utils.h b/paddle/fluid/pybind/slice_utils.h index 07bea224531ff0..9253111dd5324b 100644 --- a/paddle/fluid/pybind/slice_utils.h +++ b/paddle/fluid/pybind/slice_utils.h @@ -17,6 +17,7 @@ #include #include +#include #include "paddle/fluid/eager/api/all.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/utils.h" @@ -30,6 +31,7 @@ #include "paddle/phi/kernels/funcs/common_infer_shape_functions.h" #include "paddle/phi/kernels/funcs/slice_utils.h" #include "paddle/phi/kernels/funcs/strided_slice.h" +#include "paddle/utils/pybind.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -519,21 +521,31 @@ static void ParseIndex(const paddle::Tensor& tensor, estimated_dim++; } } else { + *has_advanced_index = true; if (slice_tensor.dtype() == phi::DataType::BOOL) { - PADDLE_ENFORCE_EQ(slice_tensor.shape()[0], - dim_len, - common::errors::OutOfRange( - "The shape of boolean index %d did not match" - "indexed tensor %d along axis %d.", - slice_tensor.shape()[0], - dim_len, - current_dim)); + // bool tensor consumes (rank of index tensor) dimensions of input + // tensor + for (int i = 0; i < slice_tensor.shape().size(); i++) { + PADDLE_ENFORCE_EQ(slice_tensor.shape()[i], + dim_len, + common::errors::OutOfRange( + "The shape of boolean index %d did not match" + "indexed tensor %d along axis %d.", + slice_tensor.shape()[0], + dim_len, + current_dim)); + (*advanced_index_dim)[estimated_dim] = estimated_dim; + estimated_dim++; + current_dim++; + dim_len = shape[current_dim]; + } + } else { + // int tensor consumes only one dimension of input tensor + (*advanced_index_dim)[estimated_dim] = estimated_dim; + estimated_dim++; + current_dim++; } - *has_advanced_index = true; advanced_index->push_back(std::move(slice_tensor)); - (*advanced_index_dim)[estimated_dim] = estimated_dim; - estimated_dim++; - current_dim++; } } else { @@ -648,17 +660,14 @@ static paddle::Tensor dealWithAdvancedIndex( int* rank_of_new_dim, std::vector* trans_dim, bool* out_is_view) { + *rank_of_new_dim = 0; int p = 0; - bool int_tensor_only = true; for (size_t i = 0; i < advanced_index_dim->size(); ++i) { auto index_dim = (*advanced_index_dim)[i]; if (index_dim != -1) { - // size of advanced_index is same to number of non -1 element in - // advanced_index_dim + // sum of each advanced_index_tensor's rank equals to number of non -1 + // element in advanced_index_dim auto index = (*advanced_index)[p++]; - if (index.dtype() == phi::DataType::BOOL) { - int_tensor_only = false; - } if (index_dim == 0) { // case 1: advanced indices at axis 0, the new dim will be at first. @@ -671,11 +680,23 @@ static paddle::Tensor dealWithAdvancedIndex( } else { *pos_of_new_dim = std::min(index_dim, *pos_of_new_dim); } - *rank_of_new_dim = - std::max(*rank_of_new_dim, static_cast(index.shape().size())); - trans_dim->push_back(index_dim); - transed_index->push_back(std::move(index)); + if (index.dtype() == phi::DataType::BOOL) { + *rank_of_new_dim = std::max(*rank_of_new_dim, 1); + i--; + for (int j = 0; j < index.shape().size(); j++) { + i++; + index_dim = (*advanced_index_dim)[i]; + trans_dim->push_back(index_dim); + } + transed_index->push_back(std::move(index)); + } else { + *rank_of_new_dim = + std::max(*rank_of_new_dim, static_cast(index.shape().size())); + + trans_dim->push_back(index_dim); + transed_index->push_back(std::move(index)); + } } } @@ -695,8 +716,7 @@ static paddle::Tensor dealWithAdvancedIndex( transed_tensor = tensor; } else { *out_is_view = true; - if (FLAGS_use_stride_kernel && *pos_of_new_dim != 0 && - (is_for_setitem || int_tensor_only)) { + if (FLAGS_use_stride_kernel && *pos_of_new_dim != 0) { transed_tensor = tensor; } else { transed_tensor = transpose_ad_func(tensor, *trans_dim); @@ -731,9 +751,10 @@ static std::vector PrepareIndices( } static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor, + const paddle::Tensor& self_tensor, const paddle::Tensor& bool_index, const int64_t slice_offset, - const bool is_combined_bool) { + const int64_t pos_of_new_dim) { PADDLE_ENFORCE(bool_index.shape().size() <= tensor.shape().size(), common::errors::InvalidArgument( "The dims of bool index doesn't match indexed array, " @@ -743,22 +764,37 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor, bool_index.shape().size())); auto tensor_shape = tensor.shape(); size_t i = 0; - while (i < bool_index.shape().size()) { - PADDLE_ENFORCE_EQ( - bool_index.shape()[i], - tensor_shape[i], - common::errors::OutOfRange( - "The dimension of bool index doesn't match indexed array along " - "dimension %d, the target dimension is %d, but received %d", - i, - tensor_shape[i], - bool_index.shape()[i])); - i++; + if (FLAGS_use_stride_kernel) { + while (i < bool_index.shape().size()) { + PADDLE_ENFORCE_EQ( + bool_index.shape()[i], + tensor_shape[i + pos_of_new_dim], + common::errors::OutOfRange( + "The dimension of bool index doesn't match indexed array along " + "dimension %d, the target dimension is %d, but received %d", + i, + tensor_shape[i + pos_of_new_dim], + bool_index.shape()[i])); + i++; + } + } else { + while (i < bool_index.shape().size()) { + PADDLE_ENFORCE_EQ( + bool_index.shape()[i], + tensor_shape[i], + common::errors::OutOfRange( + "The dimension of bool index doesn't match indexed array along " + "dimension %d, the target dimension is %d, but received %d", + i, + tensor_shape[i], + bool_index.shape()[i])); + i++; + } } const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, tensor, bool_index)) { - ConvertAllInputsToDistTensor(mesh, tensor, bool_index); + if (InputsContainDistTensor(&mesh, tensor, self_tensor, bool_index)) { + ConvertAllInputsToDistTensor(mesh, tensor, self_tensor, bool_index); } if (bool_index.shape().size() == tensor_shape.size()) { @@ -766,11 +802,14 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor, } auto bool_2_idx = nonzero_ad_func(bool_index); - if (FLAGS_use_stride_kernel && !is_combined_bool) { + if (FLAGS_use_stride_kernel) { std::vector indices = PrepareIndices(tensor, bool_2_idx, bool_index); + for (int i = 0; i < pos_of_new_dim; ++i) { + indices.insert(indices.begin(), paddle::Tensor()); + } while (indices.size() < static_cast(tensor.dims().size())) { - indices.emplace_back(); + indices.emplace_back(paddle::Tensor()); } std::vector indices_int64; @@ -784,7 +823,7 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor, AdvancedIndex ad = AdvancedIndex(tensor, indices_int64); const bool accumulate = false; - return index_elementwise_get_ad_func(tensor, + return index_elementwise_get_ad_func(self_tensor, ad.indices, ad.src_sizes, ad.src_strides, @@ -1172,7 +1211,6 @@ static void ApplySetitem(const std::vector trans_dim, static void ApplyGetitem(const int index_size, const int pos_of_new_dim, const int rank_of_new_dim, - const bool is_combined_bool, std::vector* transed_index, paddle::Tensor* tensor, paddle::Tensor* self_tensor, @@ -1201,9 +1239,18 @@ static void ApplyGetitem(const int index_size, if (transed_index->size() == 1 && (*transed_index)[0].dtype() == phi::DataType::BOOL) { // get value for bool tensor - int64_t slice_offset = 0; - *out = getValueForBoolTensor( - *transed_tensor, (*transed_index)[0], slice_offset, is_combined_bool); + const int64_t slice_offset = + reinterpret_cast(transed_tensor->data()) - + reinterpret_cast(self_tensor->data()); + *out = getValueForBoolTensor(*transed_tensor, + (*self_tensor), + (*transed_index)[0], + slice_offset, + pos_of_new_dim); + if (!FLAGS_use_stride_kernel) { + handle_transpose(*out); + } + return; } else { // get value for int tensor ParseBoolAndBroadcastIndices(transed_index); @@ -1215,7 +1262,7 @@ static void ApplyGetitem(const int index_size, } } - if (FLAGS_use_stride_kernel && !is_combined_bool && !has_empty_index) { + if (FLAGS_use_stride_kernel && !has_empty_index) { const phi::distributed::ProcessMesh* mesh = nullptr; if (InputsContainDistTensor( &mesh, *self_tensor, *transed_tensor, *transed_index)) { @@ -1223,6 +1270,7 @@ static void ApplyGetitem(const int index_size, mesh, *self_tensor, *transed_tensor, *transed_index); } + *transed_index = expandTensors(*transed_index); *transed_index = expand_outplace(*transed_index); std::vector transed_index_int64; @@ -1277,7 +1325,6 @@ static void ApplyGetitem(const int index_size, return; } } - handle_transpose(*out); } diff --git a/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc b/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc index 1c15f04f7dd411..9885dbec8ae781 100644 --- a/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc @@ -28,6 +28,7 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad, bool, int64_t, int16_t, + int8_t, int, phi::dtype::float16, phi::dtype::bfloat16, diff --git a/paddle/phi/kernels/strided_slice_grad_kernel.cc b/paddle/phi/kernels/strided_slice_grad_kernel.cc index 8c5c90783133c9..807fef9359d4e1 100644 --- a/paddle/phi/kernels/strided_slice_grad_kernel.cc +++ b/paddle/phi/kernels/strided_slice_grad_kernel.cc @@ -49,6 +49,8 @@ PD_REGISTER_KERNEL(strided_slice_grad, phi::StridedSliceGradKernel, bool, int, + int8_t, + int16_t, int64_t, float, double, @@ -62,6 +64,8 @@ PD_REGISTER_KERNEL(strided_slice_grad, phi::StridedSliceGradKernel, bool, int, + int8_t, + int16_t, int64_t, float, double, diff --git a/test/indexing/test_getitem_appendix.py b/test/indexing/test_getitem_appendix.py index fdd012613065ac..45490300b75c7d 100644 --- a/test/indexing/test_getitem_appendix.py +++ b/test/indexing/test_getitem_appendix.py @@ -230,6 +230,16 @@ def test_combined(self): # case 6: # [[[4 , 5 ],[10, 11],[16, 17],[22, 23]]] self.accuracy_check(x[[True, False], :, -1], y[[True, False], :, -1]) + # case 7: + # [[0, 3, 4, 5], [24, 26, 28, 29]] + index_np = np.array([[True, False], [False, True], [True, True]]) + index_paddle = paddle.to_tensor(index_np) + self.accuracy_check(x[:, 0, index_np], y[:, 0, index_paddle]) + # case 8: + # [[[[0, 1]], [[2, 3]], [[24, 25]], [[26, 27]]]] + index_np = np.array([[0], [1]]) + index_paddle = paddle.to_tensor(index_np) + self.accuracy_check(x[:, 0, index_np], y[:, 0, index_paddle]) class Test0DTensorIndexing(unittest.TestCase):