@@ -360,19 +360,26 @@ struct ReduceConfig {
360360 constexpr int max_num_threads = detail::kMaxThread ;
361361
362362 // set block size.
363- // 1. if reduce_lastdim == true, block is 1-D, no need reduction in block y;
364- // 2. if reduce_lastdim == false, block is 2-D, if it is necessary,
365- // it should reduce in block y.
363+ // 1. If reduce_lastdim == true, all the threads whose threadIdx.y are same
364+ // will process the reduction for one output.
365+ // The number of output for one block is blockDim.y;
366+ // 2. If reduce_lastdim == false, different threadIdx.x will process
367+ // different reduction and gets the output separately. If it is
368+ // necessary, it should reduce in block y.
369+ // The number of output for one block is blockDim.x;
370+ int block_x, block_y;
366371 int grid_num, reduce_num_per_thread;
367372 if (reduce_lastdim) {
368- block_dim->x = detail::GetBlockDim (reduce_num);
369- block_dim->y = 1 ;
370- grid_num = left_num;
371- reduce_num_per_thread =
372- detail::AlignUp (reduce_num, block_dim->x * block_dim->y );
373+ block_x = detail::GetBlockDim (reduce_num);
374+ block_y = detail::GetBlockDim (left_num);
375+ block_dim->x = block_x;
376+ block_dim->y =
377+ std::min (block_y, static_cast <int >(max_num_threads / block_dim->x ));
378+ grid_num = detail::AlignUp (left_num, block_dim->y );
379+ reduce_num_per_thread = detail::AlignUp (reduce_num, block_dim->x );
373380 } else {
374- int block_x = detail::GetBlockDim (left_num);
375- int block_y = detail::GetBlockDim (reduce_num);
381+ block_x = detail::GetBlockDim (left_num);
382+ block_y = detail::GetBlockDim (reduce_num);
376383 block_dim->x = std::min (block_x, 32 );
377384 block_dim->y =
378385 std::min (block_y, static_cast <int >(max_num_threads / block_dim->x ));
@@ -467,7 +474,7 @@ struct ReduceConfig {
467474 grid_dim.x = (left_num + block_dim.x - 1 ) / block_dim.x ;
468475 grid_dim.y = 1 ;
469476 }
470- } else if (reduce_type == ReduceType:: kReduceAny ) {
477+ } else {
471478 SetBlockDimForReduceAny (&block_dim, &grid_dim);
472479 }
473480
@@ -524,18 +531,20 @@ static __device__ T WarpReduce(T val, ReduceOp reducer) {
524531template <typename T, typename ReduceOp>
525532static __device__ T BlockXReduce (T val, ReduceOp reducer) {
526533 using detail::kWarpSize ;
527- __shared__ T shared[kWarpSize ];
534+ __shared__ T shared[2 * kWarpSize ];
528535 int block_dim_x = blockDim.x ;
529536 if (blockDim.x > kWarpSize ) {
530537 block_dim_x = blockDim.x / kWarpSize ;
531538 int lane = threadIdx.x % kWarpSize ;
532- int wid = threadIdx.x / kWarpSize ;
539+ int tid = threadIdx.y * blockDim.x + threadIdx.x ;
540+ int wid = tid / kWarpSize ;
541+ int bid = threadIdx.y ;
533542 val = WarpReduce (val, reducer);
534543 if (lane == 0 ) {
535544 shared[wid] = val;
536545 }
537546 __syncthreads ();
538- val = shared[lane];
547+ val = shared[bid * block_dim_x + lane];
539548 }
540549
541550 unsigned mask = 0u ;
@@ -562,29 +571,6 @@ static __device__ T BlockYReduce(T val, ReduceOp reducer) {
562571 return val;
563572}
564573
565- // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this
566- // function will be used
567- // blockId.x -> left_num, threadId.x -> reduce_num
568- template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
569- __device__ void ReduceLastDim (const Tx* x, Ty* y, ReduceOp reducer,
570- TransformOp transformer, Ty init,
571- int reduce_num) {
572- int idx_x = blockIdx.x * reduce_num;
573- int idx_y = threadIdx.x ;
574- Ty reduce_var = init;
575- for (int idx_y = threadIdx.x ; idx_y < reduce_num; idx_y += blockDim.x ) {
576- reduce_var =
577- reducer (reduce_var, static_cast <Ty>(transformer (x[idx_x + idx_y])));
578- }
579- __syncthreads ();
580-
581- reduce_var = BlockXReduce (reduce_var, reducer);
582-
583- if (threadIdx.x == 0 ) {
584- y[blockIdx.x ] = reduce_var;
585- }
586- }
587-
588574// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
589575// function will be used
590576// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
@@ -613,27 +599,29 @@ __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
613599 }
614600}
615601
602+ // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
616603// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
617604// function will be used
618- template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
605+ template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
606+ typename ReduceIndexCal, typename LeftIndexCal>
619607__device__ void ReduceAny (const Tx* x, Ty* y, ReduceOp reducer,
620608 TransformOp transformer, Ty init, int reduce_num,
621609 int left_num, bool reduce_lastdim,
622- const IndexCalculator& reduce_index_calculator,
623- const IndexCalculator& left_index_calculator) {
610+ ReduceIndexCal reduce_index_calculator,
611+ LeftIndexCal left_index_calculator) {
624612 int input_idx, left_idx, stride;
625613 // the last dim gets involved in reduction
626614 if (reduce_lastdim) {
627615 input_idx = blockIdx.y * blockDim.x + threadIdx.x ;
628- left_idx = blockIdx.x ;
616+ left_idx = blockIdx.x * blockDim. y + threadIdx. y ;
629617 stride = gridDim.y * blockDim.x ;
630618 } else {
631619 input_idx = blockIdx.y * blockDim.y + threadIdx.y ;
632620 left_idx = blockIdx.x * blockDim.x + threadIdx.x ;
633621 stride = gridDim.y * blockDim.y ;
634622 }
635623 // calculate the offset, means the addr where each thread really start.
636- int input_offset = left_index_calculator. Get (left_idx);
624+ int input_offset = left_index_calculator (left_idx);
637625 const Tx* input = x + input_offset;
638626 Ty reduce_var = init;
639627
@@ -646,7 +634,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
646634#pragma unroll
647635 for (int i = 0 ; i < REDUCE_VEC_SIZE; ++i) {
648636 int reduce_idx = input_idx + i * stride;
649- int idx_x = reduce_index_calculator. Get (reduce_idx);
637+ int idx_x = reduce_index_calculator (reduce_idx);
650638 input_reg[i] = input[idx_x];
651639 }
652640#pragma unroll
@@ -664,7 +652,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
664652 break ;
665653 }
666654 int reduce_idx = input_idx;
667- int idx_x = reduce_index_calculator. Get (reduce_idx);
655+ int idx_x = reduce_index_calculator (reduce_idx);
668656 input_reg[i] = input[idx_x];
669657 input_idx += stride;
670658 }
@@ -680,16 +668,16 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
680668 }
681669
682670 // 2. reduce in block y
683- if (blockDim.y > 1 ) {
671+ if (!reduce_lastdim && blockDim.y > 1 ) {
684672 reduce_var = BlockYReduce (reduce_var, reducer);
685673 }
686674 __syncthreads ();
687675
688676 if (reduce_lastdim) {
689677 // 3. reduce in block x
690678 reduce_var = BlockXReduce (reduce_var, reducer);
691- if (threadIdx.x == 0 ) {
692- y[blockIdx.x + blockIdx. y * gridDim. x ] = reduce_var;
679+ if (left_idx < left_num && threadIdx.x == 0 ) {
680+ y[blockIdx.y * left_num + left_idx ] = reduce_var;
693681 }
694682 } else {
695683 if (left_idx < left_num && threadIdx.y == 0 ) {
@@ -707,8 +695,10 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
707695 const IndexCalculator& reduce_index_calculator,
708696 const IndexCalculator& left_index_calculator) {
709697 if (reduce_type == ReduceType::kReduceLastDim ) {
710- ReduceLastDim<Tx, Ty, ReduceOp, TransformOp>(x, y, reducer, transformer,
711- init, reduce_num);
698+ ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
699+ x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
700+ [&](int idx) { return idx; },
701+ [&](int idx) { return idx * reduce_num; });
712702
713703 // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
714704 } else if (reduce_type == ReduceType::kReduceHigherDim ) {
@@ -719,7 +709,8 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
719709 } else {
720710 ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
721711 x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
722- reduce_index_calculator, left_index_calculator);
712+ [&](int idx) { return reduce_index_calculator.Get (idx); },
713+ [&](int idx) { return left_index_calculator.Get (idx); });
723714 }
724715}
725716
0 commit comments