Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions paddle/fluid/pybind/slice_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
}

AdvancedIndex ad = AdvancedIndex(tensor, indices_int64);
const bool is_combined = false;
const bool accumulate = false;

return index_elementwise_get_ad_func(tensor,
Expand All @@ -791,7 +792,8 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
ad.indexed_sizes,
ad.indexed_strides,
slice_offset,
accumulate);
accumulate,
is_combined);
} else {
if (bool_index.shape().size() == 1)
return gather_ad_func(tensor, bool_2_idx);
Expand Down Expand Up @@ -1238,23 +1240,17 @@ static void ApplyGetitem(const int index_size,
&transed_index_int64);

AdvancedIndex ad = AdvancedIndex(*transed_tensor, transed_index_int64);
if (index_size == 1) {
paddle::Tensor flattened_tensor =
flatten_ad_func((*transed_index)[0], 0, -1);
*out = gather_ad_func(*transed_tensor, flattened_tensor);
*out = reshape_ad_func(*out, ad.src_sizes);
} else {
const bool accumulate = true;
*out = index_elementwise_get_ad_func(*self_tensor,
ad.indices,
ad.src_sizes,
ad.src_strides,
ad.indexed_sizes,
ad.indexed_strides,
slice_offset,
accumulate);
}

const bool is_combined = (index_size == 1) ? false : true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_combined表示什么含义?加些注释说明

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_combined用来区分是普通索引还是组合索引,如果仅有一个普通索引反向时会采用性能更好的IndexPutWithSortKernel。新增了注释。

const bool accumulate = true;
*out = index_elementwise_get_ad_func(*self_tensor,
ad.indices,
ad.src_sizes,
ad.src_strides,
ad.indexed_sizes,
ad.indexed_strides,
slice_offset,
accumulate,
is_combined);
return;
} else {
paddle::Tensor transed_advanced_index_tensor;
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2168,6 +2168,7 @@ void IndexElementwiseGetGradInferMeta(
const std::vector<int64_t>& index_strides,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
MetaTensor* x_grad) {
if (x_grad) {
x_grad->share_meta(x);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -788,5 +788,6 @@ void IndexElementwiseGetGradInferMeta(
const std::vector<int64_t>& index_strides,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
MetaTensor* x_grad);
} // namespace phi
1 change: 1 addition & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2599,6 +2599,7 @@ void IndexElementwiseGetInferMeta(const MetaTensor& x,
const std::vector<int64_t>& index_stride,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
MetaTensor* out) {
out->set_dims(common::make_ddim(input_dims));
out->set_dtype(x.dtype());
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ void IndexElementwiseGetInferMeta(const MetaTensor& x,
const std::vector<int64_t>& index_stride,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
MetaTensor* out);

void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ void IndexElementwiseGetGradKernel(const Context& dev_ctx,
const std::vector<int64_t>& index_strides,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
auto dxt = phi::EigenVector<T>::Flatten(*x_grad);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/index_elementwise_get_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ void IndexElementwiseGetKernel(const Context& dev_ctx,
const std::vector<int64_t>& index_stride,
const int64_t slice_offset,
const bool accumulate,
const bool is_combined,
DenseTensor* out) {
const auto& index_type = index[0]->dtype();
PADDLE_ENFORCE_EQ(index_type == phi::DataType::INT64,
Expand Down
117 changes: 117 additions & 0 deletions paddle/phi/kernels/funcs/radix_sort.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/funcs/radix_sort.h"
#include "paddle/phi/common/memory_utils.h"

namespace phi {
namespace funcs {

namespace {

template <typename T>
struct CudaType {
using type = T;
};

template <>
struct CudaType<int64_t> {
using type = long long; // NOLINT
};

#define PADDLE_CUB_WRAPPER(func, ...) \
do { \
size_t temp_storage_bytes = 0; \
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
auto temp_storage = \
phi::memory_utils::Alloc(dev_ctx.GetPlace(), temp_storage_bytes); \
func(temp_storage->ptr(), temp_storage_bytes, __VA_ARGS__); \
} while (0)

} // namespace

template <typename key_t, int value_size>
void RadixSortPairsImpl(const phi::GPUContext& dev_ctx,
const key_t* keys_in,
key_t* keys_out,
const OpaqueTypeRadix<value_size>* values_in,
OpaqueTypeRadix<value_size>* values_out,
int64_t n,
bool descending,
int64_t begin_bit,
int64_t end_bit) {
PADDLE_ENFORCE_LE(
n,
std::numeric_limits<int>::max(),
phi::errors::InvalidArgument(
"CUB sort does not support sorting more than INT_MAX elements"));

using key_t_ = typename CudaType<key_t>::type;

phi::Allocator::AllocationPtr keys_out_owner;
if (keys_out == nullptr) {
keys_out_owner =
phi::memory_utils::Alloc(dev_ctx.GetPlace(), n * sizeof(key_t));
keys_out = reinterpret_cast<key_t*>(keys_out_owner->ptr());
}

const key_t_* keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
key_t_* keys_out_ = reinterpret_cast<key_t_*>(keys_out);

if (descending) {
PADDLE_CUB_WRAPPER(cub::DeviceRadixSort::SortPairsDescending,
keys_in_,
keys_out_,
values_in,
values_out,
static_cast<int>(n),
begin_bit,
end_bit,
dev_ctx.stream());
} else {
PADDLE_CUB_WRAPPER(cub::DeviceRadixSort::SortPairs,
keys_in_,
keys_out_,
values_in,
values_out,
static_cast<int>(n),
begin_bit,
end_bit,
dev_ctx.stream());
}
}

#define INSTANTIATE_SORT_PAIRS(key_t, value_size) \
template void RadixSortPairsImpl<key_t, value_size>( \
const phi::GPUContext&, \
const key_t*, \
key_t*, \
const OpaqueTypeRadix<value_size>*, \
OpaqueTypeRadix<value_size>*, \
int64_t, \
bool, \
int64_t, \
int64_t);

INSTANTIATE_SORT_PAIRS(int32_t, 1)
INSTANTIATE_SORT_PAIRS(int32_t, 2)
INSTANTIATE_SORT_PAIRS(int32_t, 4)
INSTANTIATE_SORT_PAIRS(int64_t, 1)
INSTANTIATE_SORT_PAIRS(int64_t, 2)
INSTANTIATE_SORT_PAIRS(int64_t, 4)
INSTANTIATE_SORT_PAIRS(int32_t, 8)
INSTANTIATE_SORT_PAIRS(int64_t, 8)

} // namespace funcs
} // namespace phi
80 changes: 80 additions & 0 deletions paddle/phi/kernels/funcs/radix_sort.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <cub/cub.cuh>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"

namespace phi {
namespace funcs {

template <int kValueSize>
struct OpaqueTypeRadix {
uint8_t data[kValueSize];
__device__ __host__ OpaqueTypeRadix() = default;
};

template <typename key_t, int kValueSize>
void RadixSortPairsImpl(const phi::GPUContext& dev_ctx,
const key_t* keys_in,
key_t* keys_out,
const OpaqueTypeRadix<kValueSize>* values_in,
OpaqueTypeRadix<kValueSize>* values_out,
int64_t n,
bool descending = false,
int64_t begin_bit = 0,
int64_t end_bit = sizeof(key_t) * 8);

template <typename key_t, typename value_t>
void RadixSortPairs(const phi::GPUContext& dev_ctx,
const key_t* keys_in,
key_t* keys_out,
const value_t* values_in,
value_t* values_out,
int64_t n,
bool descending = false,
int64_t begin_bit = 0,
int64_t end_bit = sizeof(key_t) * 8) {
PADDLE_ENFORCE_EQ(
std::is_trivially_copyable<value_t>::value,
true,
phi::errors::InvalidArgument(
"RadixSortPairs value type must be trivially copyable"));

using opaque_t = OpaqueTypeRadix<sizeof(value_t)>;
PADDLE_ENFORCE_EQ(
sizeof(value_t) <= 8 && (sizeof(value_t) & (sizeof(value_t) - 1)) == 0,
true,
phi::errors::InvalidArgument(
"Unsupported value_t size (must be 1, 2, 4, or 8 bytes)"));
PADDLE_ENFORCE_EQ(
sizeof(value_t),
alignof(value_t),
phi::errors::InvalidArgument("Expected value_t to be size-aligned"));

RadixSortPairsImpl<key_t, sizeof(value_t)>(
dev_ctx,
keys_in,
keys_out,
reinterpret_cast<const opaque_t*>(values_in),
reinterpret_cast<opaque_t*>(values_out),
n,
descending,
begin_bit,
end_bit);
}

} // namespace funcs
} // namespace phi
Loading
Loading