diff --git a/paddle/phi/kernels/gpu/argsort_kernel.cu b/paddle/phi/kernels/gpu/argsort_kernel.cu index f4b73076eeb2e7..dc10070d484f5a 100644 --- a/paddle/phi/kernels/gpu/argsort_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_kernel.cu @@ -96,6 +96,31 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) { } } +#define CUB_ARGSORT_WRAPPER(func, ...) \ + { \ + size_t temp_storage_bytes = 0; \ + PADDLE_ENFORCE_GPU_SUCCESS( \ + func(nullptr, temp_storage_bytes, __VA_ARGS__)); \ + DenseTensor temp_storage; \ + int64_t temp_size = static_cast(temp_storage_bytes); \ + PADDLE_ENFORCE_GT( \ + temp_size, \ + 0, \ + common::errors::InvalidArgument( \ + "Argsort temp storage size is %d, but should be greater than 0.", \ + temp_size)); \ + temp_storage.Resize({temp_size}); \ + ctx.template Alloc(&temp_storage); \ + PADDLE_ENFORCE_GPU_SUCCESS( \ + func(temp_storage.data(), temp_storage_bytes, __VA_ARGS__)); \ + } + +#define PREDICATE_CUB_ARGSORT(predicate, if_func, else_func, ...) \ + if (predicate) \ + CUB_ARGSORT_WRAPPER(if_func, __VA_ARGS__) \ + else \ + CUB_ARGSORT_WRAPPER(else_func, __VA_ARGS__) + // Sort by flag descending, True: descending. False: Ascending. // Default is false. template @@ -103,17 +128,10 @@ void ArgFullSort(const phi::GPUContext& ctx, const DenseTensor* input, DenseTensor* output, DenseTensor* indices, - const IndType num_rows, - const IndType num_cols, + const int64_t num_rows, + const int64_t num_cols, const bool descending) { auto cu_stream = ctx.stream(); - DenseTensor input_indices; - const std::vector dims = {num_rows, num_cols}; - auto dim = common::make_ddim(dims); - input_indices.Resize(dim); - ctx.template Alloc(&input_indices); - size_t temp_storage_bytes = -1; - auto ComputeBlockSize = [](IndType col) { if (col > 512) return 1024; @@ -121,27 +139,14 @@ void ArgFullSort(const phi::GPUContext& ctx, return 512; else if (col > 128 && col <= 256) return 256; - else if (col > 64 && col <= 128) - return 128; else - return 64; + return 128; }; + const int block_size = ComputeBlockSize(num_cols); + const int64_t maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; - int block_size = ComputeBlockSize(num_cols); - int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; - // actually, int num_rows < max_grid_size - int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX; - // Init a index array - FillIndex<<>>( - input_indices.data(), num_rows, num_cols); - - T* sorted_out_ptr; - IndType* sorted_indices_ptr; const T* inp = input->data(); - T* out = ctx.template Alloc(output); - IndType* ind = ctx.template Alloc(indices); - sorted_out_ptr = out; - sorted_indices_ptr = ind; + IndType* sorted_indices_ptr = indices->data(); // create iter for counting input cub::CountingInputIterator counting_iter(0); @@ -151,78 +156,54 @@ void ArgFullSort(const phi::GPUContext& ctx, cub::CountingInputIterator> segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); - gpuError_t err; - if (descending) { - err = cub::DeviceSegmentedRadixSort::SortPairsDescending( - nullptr, - temp_storage_bytes, - inp, - sorted_out_ptr, - input_indices.data(), - sorted_indices_ptr, - num_cols * num_rows, - num_rows, - segment_offsets_t, - segment_offsets_t + 1, - 0, - sizeof(T) * 8, - cu_stream); - } else { - err = - cub::DeviceSegmentedRadixSort::SortPairs(nullptr, - temp_storage_bytes, - inp, - sorted_out_ptr, - input_indices.data(), - sorted_indices_ptr, - num_cols * num_rows, - num_rows, - segment_offsets_t, - segment_offsets_t + 1, - 0, - sizeof(T) * 8, - cu_stream); - } - PADDLE_ENFORCE_GPU_SUCCESS(err); - - DenseTensor temp_storage; - int64_t temp_size = temp_storage_bytes; - temp_storage.Resize({temp_size}); - ctx.template Alloc(&temp_storage); - - if (descending) { - err = cub::DeviceSegmentedRadixSort::SortPairsDescending( - temp_storage.data(), - temp_storage_bytes, - inp, - sorted_out_ptr, - input_indices.data(), - sorted_indices_ptr, - num_cols * num_rows, - num_rows, - segment_offsets_t, - segment_offsets_t + 1, - 0, - sizeof(T) * 8, - cu_stream); - } else { - err = - cub::DeviceSegmentedRadixSort::SortPairs(temp_storage.data(), - temp_storage_bytes, - inp, - sorted_out_ptr, - input_indices.data(), - sorted_indices_ptr, - num_cols * num_rows, - num_rows, - segment_offsets_t, - segment_offsets_t + 1, - 0, - sizeof(T) * 8, - cu_stream); - } + // num_rows is the total segments to be sorted + constexpr int64_t max_elements = 1 << 30; + const int64_t total_elements = num_cols * num_rows; + const int64_t segment_size = num_cols; + const int64_t element_per_call = std::min(max_elements, total_elements); + // make sure batch size is the multiple of segment_size + const int64_t batch_size = (element_per_call / segment_size) * segment_size; + int64_t offset = 0; + DenseTensor input_indices; + + T* sorted_out_ptr = sorted_out_ptr = output->data(); + IndType* ind_ptr = nullptr; - PADDLE_ENFORCE_GPU_SUCCESS(err); + while (offset < total_elements) { + const int64_t n_elements = std::min(batch_size, total_elements - offset); + const int64_t n_segments = n_elements / segment_size; + + // allocate a temporary storage for input indices, with shape: + // [num_segments = n_elements / segment_size, segment_size] + // will be de-allocated once the sort is done, to save memory and + // avoid repeated allocation and deallocation + if (input_indices.initialized()) { + ind_ptr = input_indices.data(); + } else { + input_indices.Resize({n_segments, segment_size}); + ind_ptr = ctx.template Alloc(&input_indices); + } + const int64_t grid_size = std::min(n_segments, maxGridDimX); + // Init a index array + FillIndex<<>>( + ind_ptr, n_segments, segment_size); + + PREDICATE_CUB_ARGSORT(descending, + cub::DeviceSegmentedRadixSort::SortPairsDescending, + cub::DeviceSegmentedRadixSort::SortPairs, + inp + offset, + sorted_out_ptr, + ind_ptr, + sorted_indices_ptr + offset, + n_elements, + n_segments, + segment_offsets_t, + segment_offsets_t + 1, + 0, + sizeof(T) * 8, + cu_stream); + offset += n_elements; + } } template @@ -247,10 +228,10 @@ void ArgsortKernel(const Context& dev_ctx, axis = (axis < 0) ? (in_dims.size() + axis) : axis; const T* in_data = input.data(); auto size = input.numel(); - T* out_data = dev_ctx.template Alloc(output); - int64_t* ids_data = dev_ctx.template Alloc(indices); if (rank == 0) { + dev_ctx.template Alloc(output); + dev_ctx.template Alloc(indices); phi::Copy(dev_ctx, input, dev_ctx.GetPlace(), false, output); phi::funcs::set_constant(dev_ctx, indices, static_cast(0)); return; @@ -261,6 +242,8 @@ void ArgsortKernel(const Context& dev_ctx, // Compared to the following 'Special case for full sort', ascending sort is // 34 times faster and descending sort is 31 times faster. if (size == in_dims[axis]) { + T* out_data = dev_ctx.template Alloc(output); + int64_t* ids_data = dev_ctx.template Alloc(indices); #ifdef PADDLE_WITH_CUDA const auto& exec_policy = thrust::cuda::par.on(dev_ctx.stream()); #else @@ -297,6 +280,8 @@ void ArgsortKernel(const Context& dev_ctx, const int64_t input_height = common::product(common::slice_ddim(in_dims, 0, in_dims.size() - 1)); const int64_t input_width = in_dims[in_dims.size() - 1]; + dev_ctx.template Alloc(indices); + dev_ctx.template Alloc(output); ArgFullSort(dev_ctx, &input, output, @@ -338,7 +323,6 @@ void ArgsortKernel(const Context& dev_ctx, // temp indices for sorting tmp_indices.Resize(trans_dims); dev_ctx.template Alloc(&tmp_indices); - dev_ctx.template Alloc(indices); ArgFullSort(dev_ctx, &trans_inp, @@ -347,10 +331,13 @@ void ArgsortKernel(const Context& dev_ctx, input_height, input_width, descending); - - TransposeKernel(dev_ctx, tmp_indices, trans, indices); + // delay output allocation until after transpose, to avoid + // allocating too much memory + dev_ctx.template Alloc(output); + dev_ctx.template Alloc(indices); // transpose back TransposeKernel(dev_ctx, tmp_out, trans, output); + TransposeKernel(dev_ctx, tmp_indices, trans, indices); return; } }