@@ -96,52 +96,57 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
9696 }
9797}
9898
99+ #define CUB_ARGSORT_WRAPPER (func, ...) \
100+ { \
101+ size_t temp_storage_bytes = 0 ; \
102+ PADDLE_ENFORCE_GPU_SUCCESS ( \
103+ func (nullptr , temp_storage_bytes, __VA_ARGS__)); \
104+ DenseTensor temp_storage; \
105+ int64_t temp_size = static_cast <int64_t >(temp_storage_bytes); \
106+ PADDLE_ENFORCE_GT ( \
107+ temp_size, \
108+ 0 , \
109+ common::errors::InvalidArgument ( \
110+ " Argsort temp storage size is %d, but should be greater than 0." , \
111+ temp_size)); \
112+ temp_storage.Resize ({temp_size}); \
113+ ctx.template Alloc <uint8_t >(&temp_storage); \
114+ PADDLE_ENFORCE_GPU_SUCCESS ( \
115+ func (temp_storage.data <uint8_t >(), temp_storage_bytes, __VA_ARGS__)); \
116+ }
117+
118+ #define PREDICATE_CUB_ARGSORT (predicate, if_func, else_func, ...) \
119+ if (predicate) \
120+ CUB_ARGSORT_WRAPPER (if_func, __VA_ARGS__) \
121+ else \
122+ CUB_ARGSORT_WRAPPER (else_func, __VA_ARGS__)
123+
99124// Sort by flag descending, True: descending. False: Ascending.
100125// Default is false.
101126template <typename T, typename IndType>
102127void ArgFullSort (const phi::GPUContext& ctx,
103128 const DenseTensor* input,
104129 DenseTensor* output,
105130 DenseTensor* indices,
106- const IndType num_rows,
107- const IndType num_cols,
131+ const int64_t num_rows,
132+ const int64_t num_cols,
108133 const bool descending) {
109134 auto cu_stream = ctx.stream ();
110- DenseTensor input_indices;
111- const std::vector<IndType> dims = {num_rows, num_cols};
112- auto dim = common::make_ddim (dims);
113- input_indices.Resize (dim);
114- ctx.template Alloc <IndType>(&input_indices);
115- size_t temp_storage_bytes = -1 ;
116-
117135 auto ComputeBlockSize = [](IndType col) {
118136 if (col > 512 )
119137 return 1024 ;
120138 else if (col > 256 && col <= 512 )
121139 return 512 ;
122140 else if (col > 128 && col <= 256 )
123141 return 256 ;
124- else if (col > 64 && col <= 128 )
125- return 128 ;
126142 else
127- return 64 ;
143+ return 128 ;
128144 };
145+ const int block_size = ComputeBlockSize (num_cols);
146+ const int64_t maxGridDimX = ctx.GetCUDAMaxGridDimSize ()[0 ];
129147
130- int block_size = ComputeBlockSize (num_cols);
131- int maxGridDimX = ctx.GetCUDAMaxGridDimSize ()[0 ];
132- // actually, int num_rows < max_grid_size
133- int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
134- // Init a index array
135- FillIndex<<<grid_size, block_size, 0 , cu_stream>>> (
136- input_indices.data <IndType>(), num_rows, num_cols);
137-
138- T* sorted_out_ptr;
139- IndType* sorted_indices_ptr;
140148 const T* inp = input->data <T>();
141- T* out = ctx.template Alloc <T>(output);
142- IndType* ind = ctx.template Alloc <IndType>(indices);
143- sorted_out_ptr = out;
144- sorted_indices_ptr = ind;
149+ IndType* sorted_indices_ptr = indices->data <IndType>();
145150
146151 // create iter for counting input
147152 cub::CountingInputIterator<IndType> counting_iter (0 );
@@ -151,78 +156,54 @@ void ArgFullSort(const phi::GPUContext& ctx,
151156 cub::CountingInputIterator<IndType>>
152157 segment_offsets_t (counting_iter, SegmentOffsetIter (num_cols));
153158
154- gpuError_t err;
155- if (descending) {
156- err = cub::DeviceSegmentedRadixSort::SortPairsDescending (
157- nullptr ,
158- temp_storage_bytes,
159- inp,
160- sorted_out_ptr,
161- input_indices.data <IndType>(),
162- sorted_indices_ptr,
163- num_cols * num_rows,
164- num_rows,
165- segment_offsets_t ,
166- segment_offsets_t + 1 ,
167- 0 ,
168- sizeof (T) * 8 ,
169- cu_stream);
170- } else {
171- err =
172- cub::DeviceSegmentedRadixSort::SortPairs (nullptr ,
173- temp_storage_bytes,
174- inp,
175- sorted_out_ptr,
176- input_indices.data <IndType>(),
177- sorted_indices_ptr,
178- num_cols * num_rows,
179- num_rows,
180- segment_offsets_t ,
181- segment_offsets_t + 1 ,
182- 0 ,
183- sizeof (T) * 8 ,
184- cu_stream);
185- }
186- PADDLE_ENFORCE_GPU_SUCCESS (err);
187-
188- DenseTensor temp_storage;
189- int64_t temp_size = temp_storage_bytes;
190- temp_storage.Resize ({temp_size});
191- ctx.template Alloc <uint8_t >(&temp_storage);
192-
193- if (descending) {
194- err = cub::DeviceSegmentedRadixSort::SortPairsDescending (
195- temp_storage.data <uint8_t >(),
196- temp_storage_bytes,
197- inp,
198- sorted_out_ptr,
199- input_indices.data <IndType>(),
200- sorted_indices_ptr,
201- num_cols * num_rows,
202- num_rows,
203- segment_offsets_t ,
204- segment_offsets_t + 1 ,
205- 0 ,
206- sizeof (T) * 8 ,
207- cu_stream);
208- } else {
209- err =
210- cub::DeviceSegmentedRadixSort::SortPairs (temp_storage.data <uint8_t >(),
211- temp_storage_bytes,
212- inp,
213- sorted_out_ptr,
214- input_indices.data <IndType>(),
215- sorted_indices_ptr,
216- num_cols * num_rows,
217- num_rows,
218- segment_offsets_t ,
219- segment_offsets_t + 1 ,
220- 0 ,
221- sizeof (T) * 8 ,
222- cu_stream);
223- }
159+ // num_rows is the total segments to be sorted
160+ constexpr int64_t max_elements = 1 << 30 ;
161+ const int64_t total_elements = num_cols * num_rows;
162+ const int64_t segment_size = num_cols;
163+ const int64_t element_per_call = std::min (max_elements, total_elements);
164+ // make sure batch size is the multiple of segment_size
165+ const int64_t batch_size = (element_per_call / segment_size) * segment_size;
166+ int64_t offset = 0 ;
167+ DenseTensor input_indices;
168+
169+ T* sorted_out_ptr = sorted_out_ptr = output->data <T>();
170+ IndType* ind_ptr = nullptr ;
224171
225- PADDLE_ENFORCE_GPU_SUCCESS (err);
172+ while (offset < total_elements) {
173+ const int64_t n_elements = std::min (batch_size, total_elements - offset);
174+ const int64_t n_segments = n_elements / segment_size;
175+
176+ // allocate a temporary storage for input indices, with shape:
177+ // [num_segments = n_elements / segment_size, segment_size]
178+ // will be de-allocated once the sort is done, to save memory and
179+ // avoid repeated allocation and deallocation
180+ if (input_indices.initialized ()) {
181+ ind_ptr = input_indices.data <IndType>();
182+ } else {
183+ input_indices.Resize ({n_segments, segment_size});
184+ ind_ptr = ctx.template Alloc <IndType>(&input_indices);
185+ }
186+ const int64_t grid_size = std::min (n_segments, maxGridDimX);
187+ // Init a index array
188+ FillIndex<<<grid_size, block_size, 0 , cu_stream>>> (
189+ ind_ptr, n_segments, segment_size);
190+
191+ PREDICATE_CUB_ARGSORT (descending,
192+ cub::DeviceSegmentedRadixSort::SortPairsDescending,
193+ cub::DeviceSegmentedRadixSort::SortPairs,
194+ inp + offset,
195+ sorted_out_ptr,
196+ ind_ptr,
197+ sorted_indices_ptr + offset,
198+ n_elements,
199+ n_segments,
200+ segment_offsets_t ,
201+ segment_offsets_t + 1 ,
202+ 0 ,
203+ sizeof (T) * 8 ,
204+ cu_stream);
205+ offset += n_elements;
206+ }
226207}
227208
228209template <typename T, typename Context>
@@ -247,10 +228,10 @@ void ArgsortKernel(const Context& dev_ctx,
247228 axis = (axis < 0 ) ? (in_dims.size () + axis) : axis;
248229 const T* in_data = input.data <T>();
249230 auto size = input.numel ();
250- T* out_data = dev_ctx.template Alloc <T>(output);
251- int64_t * ids_data = dev_ctx.template Alloc <int64_t >(indices);
252231
253232 if (rank == 0 ) {
233+ dev_ctx.template Alloc <T>(output);
234+ dev_ctx.template Alloc <int64_t >(indices);
254235 phi::Copy<Context>(dev_ctx, input, dev_ctx.GetPlace (), false , output);
255236 phi::funcs::set_constant (dev_ctx, indices, static_cast <int64_t >(0 ));
256237 return ;
@@ -261,6 +242,8 @@ void ArgsortKernel(const Context& dev_ctx,
261242 // Compared to the following 'Special case for full sort', ascending sort is
262243 // 34 times faster and descending sort is 31 times faster.
263244 if (size == in_dims[axis]) {
245+ T* out_data = dev_ctx.template Alloc <T>(output);
246+ int64_t * ids_data = dev_ctx.template Alloc <int64_t >(indices);
264247#ifdef PADDLE_WITH_CUDA
265248 const auto & exec_policy = thrust::cuda::par.on (dev_ctx.stream ());
266249#else
@@ -297,6 +280,8 @@ void ArgsortKernel(const Context& dev_ctx,
297280 const int64_t input_height =
298281 common::product (common::slice_ddim (in_dims, 0 , in_dims.size () - 1 ));
299282 const int64_t input_width = in_dims[in_dims.size () - 1 ];
283+ dev_ctx.template Alloc <int64_t >(indices);
284+ dev_ctx.template Alloc <T>(output);
300285 ArgFullSort<T, int64_t >(dev_ctx,
301286 &input,
302287 output,
@@ -338,7 +323,6 @@ void ArgsortKernel(const Context& dev_ctx,
338323 // temp indices for sorting
339324 tmp_indices.Resize (trans_dims);
340325 dev_ctx.template Alloc <int64_t >(&tmp_indices);
341- dev_ctx.template Alloc <int64_t >(indices);
342326
343327 ArgFullSort<T, int64_t >(dev_ctx,
344328 &trans_inp,
@@ -347,10 +331,13 @@ void ArgsortKernel(const Context& dev_ctx,
347331 input_height,
348332 input_width,
349333 descending);
350-
351- TransposeKernel<int64_t , Context>(dev_ctx, tmp_indices, trans, indices);
334+ // delay output allocation until after transpose, to avoid
335+ // allocating too much memory
336+ dev_ctx.template Alloc <T>(output);
337+ dev_ctx.template Alloc <int64_t >(indices);
352338 // transpose back
353339 TransposeKernel<T, Context>(dev_ctx, tmp_out, trans, output);
340+ TransposeKernel<int64_t , Context>(dev_ctx, tmp_indices, trans, indices);
354341 return ;
355342 }
356343}
0 commit comments