Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 89 additions & 102 deletions paddle/phi/kernels/gpu/argsort_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,52 +96,57 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
}
}

#define CUB_ARGSORT_WRAPPER(func, ...) \
{ \
size_t temp_storage_bytes = 0; \
PADDLE_ENFORCE_GPU_SUCCESS( \
func(nullptr, temp_storage_bytes, __VA_ARGS__)); \
DenseTensor temp_storage; \
int64_t temp_size = static_cast<int64_t>(temp_storage_bytes); \
PADDLE_ENFORCE_GT( \
temp_size, \
0, \
common::errors::InvalidArgument( \
"Argsort temp storage size is %d, but should be greater than 0.", \
temp_size)); \
temp_storage.Resize({temp_size}); \
ctx.template Alloc<uint8_t>(&temp_storage); \
PADDLE_ENFORCE_GPU_SUCCESS( \
func(temp_storage.data<uint8_t>(), temp_storage_bytes, __VA_ARGS__)); \
}

#define PREDICATE_CUB_ARGSORT(predicate, if_func, else_func, ...) \
if (predicate) \
CUB_ARGSORT_WRAPPER(if_func, __VA_ARGS__) \
else \
CUB_ARGSORT_WRAPPER(else_func, __VA_ARGS__)

// Sort by flag descending, True: descending. False: Ascending.
// Default is false.
template <typename T, typename IndType>
void ArgFullSort(const phi::GPUContext& ctx,
const DenseTensor* input,
DenseTensor* output,
DenseTensor* indices,
const IndType num_rows,
const IndType num_cols,
const int64_t num_rows,
const int64_t num_cols,
const bool descending) {
auto cu_stream = ctx.stream();
DenseTensor input_indices;
const std::vector<IndType> dims = {num_rows, num_cols};
auto dim = common::make_ddim(dims);
input_indices.Resize(dim);
ctx.template Alloc<IndType>(&input_indices);
size_t temp_storage_bytes = -1;

auto ComputeBlockSize = [](IndType col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
return 128;
};
const int block_size = ComputeBlockSize(num_cols);
const int64_t maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];

int block_size = ComputeBlockSize(num_cols);
int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
// actually, int num_rows < max_grid_size
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
// Init a index array
FillIndex<<<grid_size, block_size, 0, cu_stream>>>(
input_indices.data<IndType>(), num_rows, num_cols);

T* sorted_out_ptr;
IndType* sorted_indices_ptr;
const T* inp = input->data<T>();
T* out = ctx.template Alloc<T>(output);
IndType* ind = ctx.template Alloc<IndType>(indices);
sorted_out_ptr = out;
sorted_indices_ptr = ind;
IndType* sorted_indices_ptr = indices->data<IndType>();

// create iter for counting input
cub::CountingInputIterator<IndType> counting_iter(0);
Expand All @@ -151,78 +156,54 @@ void ArgFullSort(const phi::GPUContext& ctx,
cub::CountingInputIterator<IndType>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));

gpuError_t err;
if (descending) {
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr,
temp_storage_bytes,
inp,
sorted_out_ptr,
input_indices.data<IndType>(),
sorted_indices_ptr,
num_cols * num_rows,
num_rows,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
cu_stream);
} else {
err =
cub::DeviceSegmentedRadixSort::SortPairs(nullptr,
temp_storage_bytes,
inp,
sorted_out_ptr,
input_indices.data<IndType>(),
sorted_indices_ptr,
num_cols * num_rows,
num_rows,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
cu_stream);
}
PADDLE_ENFORCE_GPU_SUCCESS(err);

DenseTensor temp_storage;
int64_t temp_size = temp_storage_bytes;
temp_storage.Resize({temp_size});
ctx.template Alloc<uint8_t>(&temp_storage);

if (descending) {
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
temp_storage.data<uint8_t>(),
temp_storage_bytes,
inp,
sorted_out_ptr,
input_indices.data<IndType>(),
sorted_indices_ptr,
num_cols * num_rows,
num_rows,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
cu_stream);
} else {
err =
cub::DeviceSegmentedRadixSort::SortPairs(temp_storage.data<uint8_t>(),
temp_storage_bytes,
inp,
sorted_out_ptr,
input_indices.data<IndType>(),
sorted_indices_ptr,
num_cols * num_rows,
num_rows,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
cu_stream);
}
// num_rows is the total segments to be sorted
constexpr int64_t max_elements = 1 << 30;
const int64_t total_elements = num_cols * num_rows;
const int64_t segment_size = num_cols;
const int64_t element_per_call = std::min(max_elements, total_elements);
// make sure batch size is the multiple of segment_size
const int64_t batch_size = (element_per_call / segment_size) * segment_size;
int64_t offset = 0;
DenseTensor input_indices;

T* sorted_out_ptr = sorted_out_ptr = output->data<T>();
IndType* ind_ptr = nullptr;

PADDLE_ENFORCE_GPU_SUCCESS(err);
while (offset < total_elements) {
const int64_t n_elements = std::min(batch_size, total_elements - offset);
const int64_t n_segments = n_elements / segment_size;

// allocate a temporary storage for input indices, with shape:
// [num_segments = n_elements / segment_size, segment_size]
// will be de-allocated once the sort is done, to save memory and
// avoid repeated allocation and deallocation
if (input_indices.initialized()) {
ind_ptr = input_indices.data<IndType>();
} else {
input_indices.Resize({n_segments, segment_size});
ind_ptr = ctx.template Alloc<IndType>(&input_indices);
}
const int64_t grid_size = std::min(n_segments, maxGridDimX);
// Init a index array
FillIndex<<<grid_size, block_size, 0, cu_stream>>>(
ind_ptr, n_segments, segment_size);

PREDICATE_CUB_ARGSORT(descending,
cub::DeviceSegmentedRadixSort::SortPairsDescending,
cub::DeviceSegmentedRadixSort::SortPairs,
inp + offset,
sorted_out_ptr,
ind_ptr,
sorted_indices_ptr + offset,
n_elements,
n_segments,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
cu_stream);
offset += n_elements;
}
}

template <typename T, typename Context>
Expand All @@ -247,10 +228,10 @@ void ArgsortKernel(const Context& dev_ctx,
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const T* in_data = input.data<T>();
auto size = input.numel();
T* out_data = dev_ctx.template Alloc<T>(output);
int64_t* ids_data = dev_ctx.template Alloc<int64_t>(indices);

if (rank == 0) {
dev_ctx.template Alloc<T>(output);
dev_ctx.template Alloc<int64_t>(indices);
phi::Copy<Context>(dev_ctx, input, dev_ctx.GetPlace(), false, output);
phi::funcs::set_constant(dev_ctx, indices, static_cast<int64_t>(0));
return;
Expand All @@ -261,6 +242,8 @@ void ArgsortKernel(const Context& dev_ctx,
// Compared to the following 'Special case for full sort', ascending sort is
// 34 times faster and descending sort is 31 times faster.
if (size == in_dims[axis]) {
T* out_data = dev_ctx.template Alloc<T>(output);
int64_t* ids_data = dev_ctx.template Alloc<int64_t>(indices);
#ifdef PADDLE_WITH_CUDA
const auto& exec_policy = thrust::cuda::par.on(dev_ctx.stream());
#else
Expand Down Expand Up @@ -297,6 +280,8 @@ void ArgsortKernel(const Context& dev_ctx,
const int64_t input_height =
common::product(common::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
dev_ctx.template Alloc<int64_t>(indices);
dev_ctx.template Alloc<T>(output);
ArgFullSort<T, int64_t>(dev_ctx,
&input,
output,
Expand Down Expand Up @@ -338,7 +323,6 @@ void ArgsortKernel(const Context& dev_ctx,
// temp indices for sorting
tmp_indices.Resize(trans_dims);
dev_ctx.template Alloc<int64_t>(&tmp_indices);
dev_ctx.template Alloc<int64_t>(indices);

ArgFullSort<T, int64_t>(dev_ctx,
&trans_inp,
Expand All @@ -347,10 +331,13 @@ void ArgsortKernel(const Context& dev_ctx,
input_height,
input_width,
descending);

TransposeKernel<int64_t, Context>(dev_ctx, tmp_indices, trans, indices);
// delay output allocation until after transpose, to avoid
// allocating too much memory
dev_ctx.template Alloc<T>(output);
dev_ctx.template Alloc<int64_t>(indices);
// transpose back
TransposeKernel<T, Context>(dev_ctx, tmp_out, trans, output);
TransposeKernel<int64_t, Context>(dev_ctx, tmp_indices, trans, indices);
return;
}
}
Expand Down