@@ -36,53 +36,23 @@ namespace cub = hipcub;
3636
3737namespace phi {
3838
39- template <typename T, int BLOCK_SIZE>
40- __device__ void BlockReverse (
41- const T* idata, T* odata, int src_base, int dst_base, int valid_item) {
42- __shared__ T sh_mem[BLOCK_SIZE];
43- int tx = threadIdx .x ;
44-
45- int offset = tx;
46- T src_data = static_cast <T>(0 );
47- int src_offset = BLOCK_SIZE - offset - 1 ;
48- if (src_offset < valid_item) {
49- src_data = idata[src_base + src_offset];
50- }
51- sh_mem[offset] = src_data;
52-
53- __syncthreads ();
54- int out_index = dst_base - offset;
55- if (offset < valid_item) {
56- int sh_mem_index = BLOCK_SIZE - offset - 1 ;
57- odata[out_index] = sh_mem[sh_mem_index];
58- }
59- }
60-
6139template <typename T>
6240__global__ void MatrixRowReverse (const T* matrix_data,
6341 T* reverse_data,
64- int reverse_size,
65- int outer_size,
66- int inner_size) {
67- int bx = blockIdx .x ;
68- int by = blockIdx .y ;
42+ int64_t grid_size,
43+ int64_t reverse_size) {
6944 int item_per_block = 1024 ;
70-
71- for (int block_offset = 0 ; block_offset < reverse_size;
72- block_offset += item_per_block) {
73- int valid_item = (reverse_size - block_offset > item_per_block)
74- ? item_per_block
75- : reverse_size - block_offset;
76- int src_offset =
77- bx * reverse_size + block_offset + by * (inner_size * reverse_size);
78- int dst_offset = bx * reverse_size + by * (inner_size * reverse_size) +
79- reverse_size - 1 - block_offset;
80- if (reverse_size < item_per_block) {
81- valid_item = reverse_size;
45+ for (int64_t bx = blockIdx .x ; bx < grid_size; bx += gridDim .x ) {
46+ for (int64_t block_offset = 0 ; block_offset < reverse_size;
47+ block_offset += item_per_block) {
48+ int64_t reverse_offset = block_offset + threadIdx .x ;
49+ int64_t src_offset = bx * reverse_size + reverse_offset;
50+ int64_t dst_offset =
51+ bx * reverse_size + (reverse_size - reverse_offset - 1 );
52+ if (reverse_offset < reverse_size) {
53+ reverse_data[dst_offset] = matrix_data[src_offset];
54+ }
8255 }
83-
84- BlockReverse<T, 1024 >(
85- matrix_data, reverse_data, src_offset, dst_offset, valid_item);
8656 }
8757}
8858
@@ -112,24 +82,30 @@ __global__ void MatrixTranspose(T* odata,
11282 size_t width) {
11383 __shared__ T tile[TILE_DIM][TILE_DIM + 1 ];
11484
115- int x = blockIdx . x * TILE_DIM + threadIdx . x ;
116- int y = blockIdx . y * TILE_DIM + threadIdx . y ;
117- for ( int j = 0 ; j < TILE_DIM; j += BLOCK_ROWS) {
118- if (x < width && (y + j) < height) {
119- tile[ threadIdx . y + j][ threadIdx . x ] = idata[(y + j) * width + x];
120- } else {
121- tile[ threadIdx . y + j][ threadIdx . x ] = 0 ;
122- }
123- }
85+ int64_t wblocks = (width + TILE_DIM - 1 ) / TILE_DIM ;
86+ int64_t hblocks = (height + TILE_DIM - 1 ) / TILE_DIM ;
87+
88+ int64_t block_i = blockIdx . x ;
89+ for (; block_i < wblocks * hblocks; block_i += gridDim . x ) {
90+ int64_t block_y = block_i / wblocks;
91+ int64_t block_x = block_i % wblocks ;
92+ int64_t x = block_x * TILE_DIM + threadIdx . x ;
93+ int64_t y = block_y * TILE_DIM + threadIdx . y ;
12494
125- __syncthreads ();
95+ for (int j = 0 ; j < TILE_DIM; j += BLOCK_ROWS) {
96+ if (x < width && (y + j) < height) {
97+ tile[threadIdx .y + j][threadIdx .x ] = idata[(y + j) * width + x];
98+ }
99+ }
100+ __syncthreads ();
126101
127- x = blockIdx . y * TILE_DIM + threadIdx .x ; // transpose block offset
128- y = blockIdx . x * TILE_DIM + threadIdx .y ;
102+ x = block_y * TILE_DIM + threadIdx .x ; // transpose block offset
103+ y = block_x * TILE_DIM + threadIdx .y ;
129104
130- for (int j = 0 ; j < TILE_DIM; j += BLOCK_ROWS) {
131- if (x < height && (y + j) < width) {
132- odata[(y + j) * height + x] = tile[threadIdx .x ][threadIdx .y + j];
105+ for (int j = 0 ; j < TILE_DIM; j += BLOCK_ROWS) {
106+ if (x < height && (y + j) < width) {
107+ odata[(y + j) * height + x] = tile[threadIdx .x ][threadIdx .y + j];
108+ }
133109 }
134110 }
135111}
@@ -172,9 +148,8 @@ struct Identity<T, ComplexSum> {
172148template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD, typename Op>
173149__global__ void BlockScanKernel (T* d_out,
174150 const T* d_in,
175- int inner_size,
176- int outer_size,
177- int scan_size,
151+ int64_t grid_size,
152+ int64_t scan_size,
178153 bool exclusive,
179154 Op op) {
180155 using MT = typename phi::dtype::MPTypeTrait<T>::Type;
@@ -196,38 +171,40 @@ __global__ void BlockScanKernel(T* d_out,
196171 typename BlockScanT::TempStorage scan;
197172 } temp_storage;
198173
199- int bx = blockIdx .x ;
200- BlockPrefixCallbackOp<MT, Op> prefix_op (Identity<MT, Op>::value, op);
201-
202174 // Obtain this block's segment of consecutive keys (blocked across threads)
203- int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
204- for (int block_offset = 0 ; block_offset < scan_size;
205- block_offset += BLOCK_THREADS * ITEMS_PER_THREAD) {
206- int valid_item = (scan_size - block_offset > item_per_block)
207- ? item_per_block
208- : (scan_size - block_offset);
209- if (scan_size < item_per_block) {
210- valid_item = scan_size;
175+ int64_t item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
176+
177+ for (int64_t bx = blockIdx .x ; bx < grid_size; bx += gridDim .x ) {
178+ BlockPrefixCallbackOp<MT, Op> prefix_op (Identity<MT, Op>::value, op);
179+
180+ for (int64_t block_offset = 0 ; block_offset < scan_size;
181+ block_offset += item_per_block) {
182+ int64_t valid_item = (scan_size - block_offset > item_per_block)
183+ ? item_per_block
184+ : (scan_size - block_offset);
185+ if (scan_size < item_per_block) {
186+ valid_item = scan_size;
187+ }
188+
189+ int64_t offset = bx * scan_size + block_offset;
190+
191+ MT thread_keys[ITEMS_PER_THREAD];
192+ BlockLoadT (temp_storage.load )
193+ .Load (d_in + offset, thread_keys, valid_item, 0 );
194+
195+ __syncthreads ();
196+ if (exclusive) {
197+ BlockScanT (temp_storage.scan )
198+ .ExclusiveScan (thread_keys, thread_keys, op, prefix_op);
199+ } else {
200+ BlockScanT (temp_storage.scan )
201+ .InclusiveScan (thread_keys, thread_keys, op, prefix_op);
202+ }
203+ __syncthreads ();
204+
205+ BlockStoreT (temp_storage.store )
206+ .Store (d_out + offset, thread_keys, valid_item);
211207 }
212-
213- int offset = block_offset + bx * scan_size;
214-
215- MT thread_keys[ITEMS_PER_THREAD];
216- BlockLoadT (temp_storage.load )
217- .Load (d_in + offset, thread_keys, valid_item, 0 );
218-
219- __syncthreads ();
220- if (exclusive) {
221- BlockScanT (temp_storage.scan )
222- .ExclusiveScan (thread_keys, thread_keys, op, prefix_op);
223- } else {
224- BlockScanT (temp_storage.scan )
225- .InclusiveScan (thread_keys, thread_keys, op, prefix_op);
226- }
227- __syncthreads ();
228-
229- BlockStoreT (temp_storage.store )
230- .Store (d_out + offset, thread_keys, valid_item);
231208 }
232209}
233210
@@ -347,14 +324,24 @@ void ScanKernel(const Context& dev_ctx,
347324 int scan_size = out_dims[axis];
348325 bool transpose = (axis != out_dims.size () - 1 );
349326
350- int tile_size = 32 ;
351- dim3 blocks (32 , 8 );
352- dim3 transpose_grids ((width + tile_size - 1 ) / tile_size,
353- (height + tile_size - 1 ) / tile_size);
354327 DenseTensor tmp_tensor;
355328 tmp_tensor.Resize (out_dims);
356329 auto * tmp_data = dev_ctx.template Alloc <T>(&tmp_tensor);
357330
331+ auto swap_ptr = [](T*& ptr1, T*& ptr2) {
332+ T* tmp = ptr2;
333+ ptr2 = ptr1;
334+ ptr1 = tmp;
335+ };
336+
337+ int64_t max_grid_x = dev_ctx.GetCUDAMaxGridDimSize ()[0 ];
338+
339+ // Do pre-process transpose
340+ int tile_size = 32 ;
341+ dim3 blocks (32 , 8 );
342+ int64_t transpose_grids = ((width + tile_size - 1 ) / tile_size) *
343+ ((height + tile_size - 1 ) / tile_size);
344+ transpose_grids = std::min (transpose_grids, max_grid_x);
358345 T* next_in_data = out_data;
359346 T* next_out_data = tmp_data;
360347 if (transpose) {
@@ -363,53 +350,42 @@ void ScanKernel(const Context& dev_ctx,
363350 next_in_data = out_data;
364351 next_out_data = tmp_data;
365352 }
366- auto swap_ptr = [](T*& ptr1, T*& ptr2) {
367- T* tmp = ptr2;
368- ptr2 = ptr1;
369- ptr1 = tmp;
370- };
371- int outer_size = height / scan_size;
372- int inner_size = width;
373- // Consider the size of shared memory, here block size is 128
374- dim3 scan_grid (outer_size, inner_size);
375- dim3 reverse_grid = scan_grid;
353+
354+ // Do pre-process reverse
355+ int64_t outer_size = height / scan_size;
356+ int64_t inner_size = width;
357+ int64_t grid_size = outer_size * inner_size;
358+ int64_t scan_grid = std::min (grid_size, max_grid_x);
376359 if (reverse) {
377360 if (transpose) {
378- reverse_grid.x = scan_grid.y ;
379- reverse_grid.y = scan_grid.x ;
380- MatrixRowReverse<T><<<reverse_grid, 1024 , 0 , dev_ctx.stream()>>> (
381- next_in_data, next_out_data, scan_size, outer_size, inner_size);
361+ MatrixRowReverse<T><<<scan_grid, 1024 , 0 , dev_ctx.stream()>>> (
362+ next_in_data, next_out_data, grid_size, scan_size);
382363 if (!transpose) next_in_data = tmp_data;
383364 swap_ptr (next_in_data, next_out_data);
384365 } else {
385- MatrixRowReverse<T><<<reverse_grid , 1024 , 0 , dev_ctx.stream()>>> (
386- in_data, out_data, scan_size, outer_size, inner_size );
366+ MatrixRowReverse<T><<<scan_grid , 1024 , 0 , dev_ctx.stream()>>> (
367+ in_data, out_data, grid_size, scan_size );
387368 }
388369 }
389- int64_t grid_size = outer_size * inner_size;
370+
371+ // Do scan
390372 if (!transpose && !reverse) {
391- BlockScanKernel<T, 128 , 4 , Op><<<grid_size , 128 , 0 , dev_ctx.stream()>>> (
392- out_data, in_data, outer_size, inner_size , scan_size, exclusive, op);
373+ BlockScanKernel<T, 128 , 4 , Op><<<scan_grid , 128 , 0 , dev_ctx.stream()>>> (
374+ out_data, in_data, grid_size , scan_size, exclusive, op);
393375
394376 } else {
395- BlockScanKernel<T, 128 , 4 , Op>
396- <<<grid_size, 128 , 0 , dev_ctx.stream()>>> (next_out_data,
397- next_in_data,
398- outer_size,
399- inner_size,
400- scan_size,
401- exclusive,
402- op);
377+ BlockScanKernel<T, 128 , 4 , Op><<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
378+ next_out_data, next_in_data, grid_size, scan_size, exclusive, op);
403379 }
404380 swap_ptr (next_in_data, next_out_data);
381+
382+ // Do post-process reverse and transpose
405383 if (reverse) {
406- MatrixRowReverse<T><<<reverse_grid , 1024 , 0 , dev_ctx.stream()>>> (
407- next_in_data, next_out_data, scan_size, outer_size, inner_size );
384+ MatrixRowReverse<T><<<scan_grid , 1024 , 0 , dev_ctx.stream()>>> (
385+ next_in_data, next_out_data, grid_size, scan_size );
408386 swap_ptr (next_in_data, next_out_data);
409387 }
410388 if (transpose) {
411- transpose_grids.x = (height + tile_size - 1 ) / tile_size;
412- transpose_grids.y = (width + tile_size - 1 ) / tile_size;
413389 MatrixTranspose<T, 32 , 8 ><<<transpose_grids, blocks, 0 , dev_ctx.stream()>>> (
414390 next_out_data, next_in_data, width, height);
415391 }
0 commit comments