diff --git a/paddle/fluid/pybind/slice_utils.h b/paddle/fluid/pybind/slice_utils.h index 918d2eeae4272a..f4eef3af16bcf1 100644 --- a/paddle/fluid/pybind/slice_utils.h +++ b/paddle/fluid/pybind/slice_utils.h @@ -25,6 +25,7 @@ #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/scope_guard.h" +#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/utils.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/compat/convert_utils.h" @@ -483,13 +484,15 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor, i++; } - auto bool_2_idx = nonzero_ad_func(bool_index); - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, tensor, bool_2_idx)) { - ConvertAllInputsToDistTensor(mesh, tensor, bool_2_idx); + if (InputsContainDistTensor(&mesh, tensor, bool_index)) { + ConvertAllInputsToDistTensor(mesh, tensor, bool_index); } + if (bool_index.shape().size() == tensor_shape.size()) { + return masked_select_ad_func(tensor, bool_index); + } + auto bool_2_idx = nonzero_ad_func(bool_index); return gather_nd_ad_func(tensor, bool_2_idx); } @@ -504,10 +507,30 @@ static void ParseBoolAndBroadcastIndices( } } if (advanced_index->size() > 1) { - // Here advanced_index has been checked ContainDistTensor - // and transed in dealWithAdvancedIndex - auto broadcasted_index = broadcast_tensors_ad_func(*advanced_index); - advanced_index->assign(broadcasted_index.begin(), broadcasted_index.end()); + bool need_broadcast = false; + common::DDim common_shape = common::make_ddim((*advanced_index)[0].shape()); + for (size_t i = 1; i < advanced_index->size(); ++i) { + common::DDim current_shape = + common::make_ddim((*advanced_index)[i].shape()); + if (current_shape != common_shape) { + need_broadcast = true; + common_shape = operators::details::BroadcastTwoDims( + current_shape, common_shape, -1); + } + } + + if (need_broadcast) { + // Here advanced_index has been checked ContainDistTensor + // and transed in dealWithAdvancedIndex + auto common_shape_vec = common::vectorize(common_shape); + for (size_t i = 0; i < advanced_index->size(); ++i) { + auto current_shape = (*advanced_index)[i].shape(); + if (current_shape != common_shape_vec) { + (*advanced_index)[i] = + expand_ad_func((*advanced_index)[i], common_shape_vec); + } + } + } } } diff --git a/paddle/phi/kernels/cpu/masked_select_grad_kernel.cc b/paddle/phi/kernels/cpu/masked_select_grad_kernel.cc index fa120de4b79521..49b1de9446c3e4 100644 --- a/paddle/phi/kernels/cpu/masked_select_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/masked_select_grad_kernel.cc @@ -100,7 +100,15 @@ PD_REGISTER_KERNEL(masked_select_grad, CPU, ALL_LAYOUT, phi::MaskedSelectGradKernel, + bool, float, double, int, - int64_t) {} + int8_t, + int64_t, + int16_t, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/masked_select_kernel.cc b/paddle/phi/kernels/cpu/masked_select_kernel.cc index 8e9e3bbebecd4d..7c7c134248bd4a 100644 --- a/paddle/phi/kernels/cpu/masked_select_kernel.cc +++ b/paddle/phi/kernels/cpu/masked_select_kernel.cc @@ -87,9 +87,17 @@ PD_REGISTER_KERNEL(masked_select, CPU, ALL_LAYOUT, phi::MaskedSelectKernel, + bool, float, double, int, - int64_t) { + int8_t, + int64_t, + int16_t, + uint8_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) { kernel->InputAt(1).SetDataType(phi::DataType::BOOL); } diff --git a/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu b/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu index 4bf5949f084fe5..0e717ecc13ff8d 100644 --- a/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu @@ -108,9 +108,15 @@ PD_REGISTER_KERNEL(masked_select_grad, GPU, ALL_LAYOUT, phi::MaskedSelectGradKernel, + bool, float, double, int, + int8_t, int64_t, + int16_t, + uint8_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/masked_select_kernel.cu b/paddle/phi/kernels/gpu/masked_select_kernel.cu index 9739f9799a4ec1..0bf8a8789d0a18 100644 --- a/paddle/phi/kernels/gpu/masked_select_kernel.cu +++ b/paddle/phi/kernels/gpu/masked_select_kernel.cu @@ -94,11 +94,17 @@ PD_REGISTER_KERNEL(masked_select, GPU, ALL_LAYOUT, phi::MaskedSelectKernel, + bool, float, double, int, + int8_t, int64_t, + int16_t, + uint8_t, phi::dtype::float16, - phi::dtype::bfloat16) { + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) { kernel->InputAt(1).SetDataType(phi::DataType::BOOL); } diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index f47afddde84f0a..0df9ebc5513dac 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -134,7 +134,8 @@ def get_value_for_bool_tensor(var, item): ) ) i += 1 - + if len(item.shape) == len(var.shape): + return paddle.masked_select(var, item) bool_2_idx = paddle.nonzero(item) return paddle.gather_nd(var, bool_2_idx)