Skip to content

Commit 6ef8d19

Browse files
authored
[PHI] Fix argsort big tensor bug (#72712)
* [PHI] Fixed argsort big tensor bug * [PHI] Fixed shape mismatch problem.
1 parent fce2670 commit 6ef8d19

File tree

1 file changed

+89
-102
lines changed

1 file changed

+89
-102
lines changed

paddle/phi/kernels/gpu/argsort_kernel.cu

Lines changed: 89 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -96,52 +96,57 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
9696
}
9797
}
9898

99+
#define CUB_ARGSORT_WRAPPER(func, ...) \
100+
{ \
101+
size_t temp_storage_bytes = 0; \
102+
PADDLE_ENFORCE_GPU_SUCCESS( \
103+
func(nullptr, temp_storage_bytes, __VA_ARGS__)); \
104+
DenseTensor temp_storage; \
105+
int64_t temp_size = static_cast<int64_t>(temp_storage_bytes); \
106+
PADDLE_ENFORCE_GT( \
107+
temp_size, \
108+
0, \
109+
common::errors::InvalidArgument( \
110+
"Argsort temp storage size is %d, but should be greater than 0.", \
111+
temp_size)); \
112+
temp_storage.Resize({temp_size}); \
113+
ctx.template Alloc<uint8_t>(&temp_storage); \
114+
PADDLE_ENFORCE_GPU_SUCCESS( \
115+
func(temp_storage.data<uint8_t>(), temp_storage_bytes, __VA_ARGS__)); \
116+
}
117+
118+
#define PREDICATE_CUB_ARGSORT(predicate, if_func, else_func, ...) \
119+
if (predicate) \
120+
CUB_ARGSORT_WRAPPER(if_func, __VA_ARGS__) \
121+
else \
122+
CUB_ARGSORT_WRAPPER(else_func, __VA_ARGS__)
123+
99124
// Sort by flag descending, True: descending. False: Ascending.
100125
// Default is false.
101126
template <typename T, typename IndType>
102127
void ArgFullSort(const phi::GPUContext& ctx,
103128
const DenseTensor* input,
104129
DenseTensor* output,
105130
DenseTensor* indices,
106-
const IndType num_rows,
107-
const IndType num_cols,
131+
const int64_t num_rows,
132+
const int64_t num_cols,
108133
const bool descending) {
109134
auto cu_stream = ctx.stream();
110-
DenseTensor input_indices;
111-
const std::vector<IndType> dims = {num_rows, num_cols};
112-
auto dim = common::make_ddim(dims);
113-
input_indices.Resize(dim);
114-
ctx.template Alloc<IndType>(&input_indices);
115-
size_t temp_storage_bytes = -1;
116-
117135
auto ComputeBlockSize = [](IndType col) {
118136
if (col > 512)
119137
return 1024;
120138
else if (col > 256 && col <= 512)
121139
return 512;
122140
else if (col > 128 && col <= 256)
123141
return 256;
124-
else if (col > 64 && col <= 128)
125-
return 128;
126142
else
127-
return 64;
143+
return 128;
128144
};
145+
const int block_size = ComputeBlockSize(num_cols);
146+
const int64_t maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
129147

130-
int block_size = ComputeBlockSize(num_cols);
131-
int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
132-
// actually, int num_rows < max_grid_size
133-
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
134-
// Init a index array
135-
FillIndex<<<grid_size, block_size, 0, cu_stream>>>(
136-
input_indices.data<IndType>(), num_rows, num_cols);
137-
138-
T* sorted_out_ptr;
139-
IndType* sorted_indices_ptr;
140148
const T* inp = input->data<T>();
141-
T* out = ctx.template Alloc<T>(output);
142-
IndType* ind = ctx.template Alloc<IndType>(indices);
143-
sorted_out_ptr = out;
144-
sorted_indices_ptr = ind;
149+
IndType* sorted_indices_ptr = indices->data<IndType>();
145150

146151
// create iter for counting input
147152
cub::CountingInputIterator<IndType> counting_iter(0);
@@ -151,78 +156,54 @@ void ArgFullSort(const phi::GPUContext& ctx,
151156
cub::CountingInputIterator<IndType>>
152157
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
153158

154-
gpuError_t err;
155-
if (descending) {
156-
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
157-
nullptr,
158-
temp_storage_bytes,
159-
inp,
160-
sorted_out_ptr,
161-
input_indices.data<IndType>(),
162-
sorted_indices_ptr,
163-
num_cols * num_rows,
164-
num_rows,
165-
segment_offsets_t,
166-
segment_offsets_t + 1,
167-
0,
168-
sizeof(T) * 8,
169-
cu_stream);
170-
} else {
171-
err =
172-
cub::DeviceSegmentedRadixSort::SortPairs(nullptr,
173-
temp_storage_bytes,
174-
inp,
175-
sorted_out_ptr,
176-
input_indices.data<IndType>(),
177-
sorted_indices_ptr,
178-
num_cols * num_rows,
179-
num_rows,
180-
segment_offsets_t,
181-
segment_offsets_t + 1,
182-
0,
183-
sizeof(T) * 8,
184-
cu_stream);
185-
}
186-
PADDLE_ENFORCE_GPU_SUCCESS(err);
187-
188-
DenseTensor temp_storage;
189-
int64_t temp_size = temp_storage_bytes;
190-
temp_storage.Resize({temp_size});
191-
ctx.template Alloc<uint8_t>(&temp_storage);
192-
193-
if (descending) {
194-
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
195-
temp_storage.data<uint8_t>(),
196-
temp_storage_bytes,
197-
inp,
198-
sorted_out_ptr,
199-
input_indices.data<IndType>(),
200-
sorted_indices_ptr,
201-
num_cols * num_rows,
202-
num_rows,
203-
segment_offsets_t,
204-
segment_offsets_t + 1,
205-
0,
206-
sizeof(T) * 8,
207-
cu_stream);
208-
} else {
209-
err =
210-
cub::DeviceSegmentedRadixSort::SortPairs(temp_storage.data<uint8_t>(),
211-
temp_storage_bytes,
212-
inp,
213-
sorted_out_ptr,
214-
input_indices.data<IndType>(),
215-
sorted_indices_ptr,
216-
num_cols * num_rows,
217-
num_rows,
218-
segment_offsets_t,
219-
segment_offsets_t + 1,
220-
0,
221-
sizeof(T) * 8,
222-
cu_stream);
223-
}
159+
// num_rows is the total segments to be sorted
160+
constexpr int64_t max_elements = 1 << 30;
161+
const int64_t total_elements = num_cols * num_rows;
162+
const int64_t segment_size = num_cols;
163+
const int64_t element_per_call = std::min(max_elements, total_elements);
164+
// make sure batch size is the multiple of segment_size
165+
const int64_t batch_size = (element_per_call / segment_size) * segment_size;
166+
int64_t offset = 0;
167+
DenseTensor input_indices;
168+
169+
T* sorted_out_ptr = sorted_out_ptr = output->data<T>();
170+
IndType* ind_ptr = nullptr;
224171

225-
PADDLE_ENFORCE_GPU_SUCCESS(err);
172+
while (offset < total_elements) {
173+
const int64_t n_elements = std::min(batch_size, total_elements - offset);
174+
const int64_t n_segments = n_elements / segment_size;
175+
176+
// allocate a temporary storage for input indices, with shape:
177+
// [num_segments = n_elements / segment_size, segment_size]
178+
// will be de-allocated once the sort is done, to save memory and
179+
// avoid repeated allocation and deallocation
180+
if (input_indices.initialized()) {
181+
ind_ptr = input_indices.data<IndType>();
182+
} else {
183+
input_indices.Resize({n_segments, segment_size});
184+
ind_ptr = ctx.template Alloc<IndType>(&input_indices);
185+
}
186+
const int64_t grid_size = std::min(n_segments, maxGridDimX);
187+
// Init a index array
188+
FillIndex<<<grid_size, block_size, 0, cu_stream>>>(
189+
ind_ptr, n_segments, segment_size);
190+
191+
PREDICATE_CUB_ARGSORT(descending,
192+
cub::DeviceSegmentedRadixSort::SortPairsDescending,
193+
cub::DeviceSegmentedRadixSort::SortPairs,
194+
inp + offset,
195+
sorted_out_ptr,
196+
ind_ptr,
197+
sorted_indices_ptr + offset,
198+
n_elements,
199+
n_segments,
200+
segment_offsets_t,
201+
segment_offsets_t + 1,
202+
0,
203+
sizeof(T) * 8,
204+
cu_stream);
205+
offset += n_elements;
206+
}
226207
}
227208

228209
template <typename T, typename Context>
@@ -247,10 +228,10 @@ void ArgsortKernel(const Context& dev_ctx,
247228
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
248229
const T* in_data = input.data<T>();
249230
auto size = input.numel();
250-
T* out_data = dev_ctx.template Alloc<T>(output);
251-
int64_t* ids_data = dev_ctx.template Alloc<int64_t>(indices);
252231

253232
if (rank == 0) {
233+
dev_ctx.template Alloc<T>(output);
234+
dev_ctx.template Alloc<int64_t>(indices);
254235
phi::Copy<Context>(dev_ctx, input, dev_ctx.GetPlace(), false, output);
255236
phi::funcs::set_constant(dev_ctx, indices, static_cast<int64_t>(0));
256237
return;
@@ -261,6 +242,8 @@ void ArgsortKernel(const Context& dev_ctx,
261242
// Compared to the following 'Special case for full sort', ascending sort is
262243
// 34 times faster and descending sort is 31 times faster.
263244
if (size == in_dims[axis]) {
245+
T* out_data = dev_ctx.template Alloc<T>(output);
246+
int64_t* ids_data = dev_ctx.template Alloc<int64_t>(indices);
264247
#ifdef PADDLE_WITH_CUDA
265248
const auto& exec_policy = thrust::cuda::par.on(dev_ctx.stream());
266249
#else
@@ -297,6 +280,8 @@ void ArgsortKernel(const Context& dev_ctx,
297280
const int64_t input_height =
298281
common::product(common::slice_ddim(in_dims, 0, in_dims.size() - 1));
299282
const int64_t input_width = in_dims[in_dims.size() - 1];
283+
dev_ctx.template Alloc<int64_t>(indices);
284+
dev_ctx.template Alloc<T>(output);
300285
ArgFullSort<T, int64_t>(dev_ctx,
301286
&input,
302287
output,
@@ -338,7 +323,6 @@ void ArgsortKernel(const Context& dev_ctx,
338323
// temp indices for sorting
339324
tmp_indices.Resize(trans_dims);
340325
dev_ctx.template Alloc<int64_t>(&tmp_indices);
341-
dev_ctx.template Alloc<int64_t>(indices);
342326

343327
ArgFullSort<T, int64_t>(dev_ctx,
344328
&trans_inp,
@@ -347,10 +331,13 @@ void ArgsortKernel(const Context& dev_ctx,
347331
input_height,
348332
input_width,
349333
descending);
350-
351-
TransposeKernel<int64_t, Context>(dev_ctx, tmp_indices, trans, indices);
334+
// delay output allocation until after transpose, to avoid
335+
// allocating too much memory
336+
dev_ctx.template Alloc<T>(output);
337+
dev_ctx.template Alloc<int64_t>(indices);
352338
// transpose back
353339
TransposeKernel<T, Context>(dev_ctx, tmp_out, trans, output);
340+
TransposeKernel<int64_t, Context>(dev_ctx, tmp_indices, trans, indices);
354341
return;
355342
}
356343
}

0 commit comments

Comments
 (0)