Skip to content

Commit c7cc5ac

Browse files
authored
Unify the block/grid strategy and implementation of ReduceLastDim and ReduceAny (#34436)
1 parent 80f7f7e commit c7cc5ac

File tree

1 file changed

+41
-50
lines changed

1 file changed

+41
-50
lines changed

paddle/fluid/operators/reduce_ops/reduce_op.cu.h

Lines changed: 41 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -360,19 +360,26 @@ struct ReduceConfig {
360360
constexpr int max_num_threads = detail::kMaxThread;
361361

362362
// set block size.
363-
// 1. if reduce_lastdim == true, block is 1-D, no need reduction in block y;
364-
// 2. if reduce_lastdim == false, block is 2-D, if it is necessary,
365-
// it should reduce in block y.
363+
// 1. If reduce_lastdim == true, all the threads whose threadIdx.y are same
364+
// will process the reduction for one output.
365+
// The number of output for one block is blockDim.y;
366+
// 2. If reduce_lastdim == false, different threadIdx.x will process
367+
// different reduction and gets the output separately. If it is
368+
// necessary, it should reduce in block y.
369+
// The number of output for one block is blockDim.x;
370+
int block_x, block_y;
366371
int grid_num, reduce_num_per_thread;
367372
if (reduce_lastdim) {
368-
block_dim->x = detail::GetBlockDim(reduce_num);
369-
block_dim->y = 1;
370-
grid_num = left_num;
371-
reduce_num_per_thread =
372-
detail::AlignUp(reduce_num, block_dim->x * block_dim->y);
373+
block_x = detail::GetBlockDim(reduce_num);
374+
block_y = detail::GetBlockDim(left_num);
375+
block_dim->x = block_x;
376+
block_dim->y =
377+
std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
378+
grid_num = detail::AlignUp(left_num, block_dim->y);
379+
reduce_num_per_thread = detail::AlignUp(reduce_num, block_dim->x);
373380
} else {
374-
int block_x = detail::GetBlockDim(left_num);
375-
int block_y = detail::GetBlockDim(reduce_num);
381+
block_x = detail::GetBlockDim(left_num);
382+
block_y = detail::GetBlockDim(reduce_num);
376383
block_dim->x = std::min(block_x, 32);
377384
block_dim->y =
378385
std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
@@ -467,7 +474,7 @@ struct ReduceConfig {
467474
grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
468475
grid_dim.y = 1;
469476
}
470-
} else if (reduce_type == ReduceType::kReduceAny) {
477+
} else {
471478
SetBlockDimForReduceAny(&block_dim, &grid_dim);
472479
}
473480

@@ -524,18 +531,20 @@ static __device__ T WarpReduce(T val, ReduceOp reducer) {
524531
template <typename T, typename ReduceOp>
525532
static __device__ T BlockXReduce(T val, ReduceOp reducer) {
526533
using detail::kWarpSize;
527-
__shared__ T shared[kWarpSize];
534+
__shared__ T shared[2 * kWarpSize];
528535
int block_dim_x = blockDim.x;
529536
if (blockDim.x > kWarpSize) {
530537
block_dim_x = blockDim.x / kWarpSize;
531538
int lane = threadIdx.x % kWarpSize;
532-
int wid = threadIdx.x / kWarpSize;
539+
int tid = threadIdx.y * blockDim.x + threadIdx.x;
540+
int wid = tid / kWarpSize;
541+
int bid = threadIdx.y;
533542
val = WarpReduce(val, reducer);
534543
if (lane == 0) {
535544
shared[wid] = val;
536545
}
537546
__syncthreads();
538-
val = shared[lane];
547+
val = shared[bid * block_dim_x + lane];
539548
}
540549

541550
unsigned mask = 0u;
@@ -562,29 +571,6 @@ static __device__ T BlockYReduce(T val, ReduceOp reducer) {
562571
return val;
563572
}
564573

565-
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this
566-
// function will be used
567-
// blockId.x -> left_num, threadId.x -> reduce_num
568-
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
569-
__device__ void ReduceLastDim(const Tx* x, Ty* y, ReduceOp reducer,
570-
TransformOp transformer, Ty init,
571-
int reduce_num) {
572-
int idx_x = blockIdx.x * reduce_num;
573-
int idx_y = threadIdx.x;
574-
Ty reduce_var = init;
575-
for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += blockDim.x) {
576-
reduce_var =
577-
reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x + idx_y])));
578-
}
579-
__syncthreads();
580-
581-
reduce_var = BlockXReduce(reduce_var, reducer);
582-
583-
if (threadIdx.x == 0) {
584-
y[blockIdx.x] = reduce_var;
585-
}
586-
}
587-
588574
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
589575
// function will be used
590576
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
@@ -613,27 +599,29 @@ __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
613599
}
614600
}
615601

602+
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
616603
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
617604
// function will be used
618-
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
605+
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
606+
typename ReduceIndexCal, typename LeftIndexCal>
619607
__device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
620608
TransformOp transformer, Ty init, int reduce_num,
621609
int left_num, bool reduce_lastdim,
622-
const IndexCalculator& reduce_index_calculator,
623-
const IndexCalculator& left_index_calculator) {
610+
ReduceIndexCal reduce_index_calculator,
611+
LeftIndexCal left_index_calculator) {
624612
int input_idx, left_idx, stride;
625613
// the last dim gets involved in reduction
626614
if (reduce_lastdim) {
627615
input_idx = blockIdx.y * blockDim.x + threadIdx.x;
628-
left_idx = blockIdx.x;
616+
left_idx = blockIdx.x * blockDim.y + threadIdx.y;
629617
stride = gridDim.y * blockDim.x;
630618
} else {
631619
input_idx = blockIdx.y * blockDim.y + threadIdx.y;
632620
left_idx = blockIdx.x * blockDim.x + threadIdx.x;
633621
stride = gridDim.y * blockDim.y;
634622
}
635623
// calculate the offset, means the addr where each thread really start.
636-
int input_offset = left_index_calculator.Get(left_idx);
624+
int input_offset = left_index_calculator(left_idx);
637625
const Tx* input = x + input_offset;
638626
Ty reduce_var = init;
639627

@@ -646,7 +634,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
646634
#pragma unroll
647635
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
648636
int reduce_idx = input_idx + i * stride;
649-
int idx_x = reduce_index_calculator.Get(reduce_idx);
637+
int idx_x = reduce_index_calculator(reduce_idx);
650638
input_reg[i] = input[idx_x];
651639
}
652640
#pragma unroll
@@ -664,7 +652,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
664652
break;
665653
}
666654
int reduce_idx = input_idx;
667-
int idx_x = reduce_index_calculator.Get(reduce_idx);
655+
int idx_x = reduce_index_calculator(reduce_idx);
668656
input_reg[i] = input[idx_x];
669657
input_idx += stride;
670658
}
@@ -680,16 +668,16 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
680668
}
681669

682670
// 2. reduce in block y
683-
if (blockDim.y > 1) {
671+
if (!reduce_lastdim && blockDim.y > 1) {
684672
reduce_var = BlockYReduce(reduce_var, reducer);
685673
}
686674
__syncthreads();
687675

688676
if (reduce_lastdim) {
689677
// 3. reduce in block x
690678
reduce_var = BlockXReduce(reduce_var, reducer);
691-
if (threadIdx.x == 0) {
692-
y[blockIdx.x + blockIdx.y * gridDim.x] = reduce_var;
679+
if (left_idx < left_num && threadIdx.x == 0) {
680+
y[blockIdx.y * left_num + left_idx] = reduce_var;
693681
}
694682
} else {
695683
if (left_idx < left_num && threadIdx.y == 0) {
@@ -707,8 +695,10 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
707695
const IndexCalculator& reduce_index_calculator,
708696
const IndexCalculator& left_index_calculator) {
709697
if (reduce_type == ReduceType::kReduceLastDim) {
710-
ReduceLastDim<Tx, Ty, ReduceOp, TransformOp>(x, y, reducer, transformer,
711-
init, reduce_num);
698+
ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
699+
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
700+
[&](int idx) { return idx; },
701+
[&](int idx) { return idx * reduce_num; });
712702

713703
// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
714704
} else if (reduce_type == ReduceType::kReduceHigherDim) {
@@ -719,7 +709,8 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
719709
} else {
720710
ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
721711
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
722-
reduce_index_calculator, left_index_calculator);
712+
[&](int idx) { return reduce_index_calculator.Get(idx); },
713+
[&](int idx) { return left_index_calculator.Get(idx); });
723714
}
724715
}
725716

0 commit comments

Comments
 (0)