Skip to content

Commit e43ce03

Browse files
committed
fix
1 parent 5dc7b79 commit e43ce03

File tree

1 file changed

+43
-74
lines changed

1 file changed

+43
-74
lines changed

paddle/fluid/operators/reduce_ops/reduce_op.cu.h

Lines changed: 43 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
8787

8888
#ifdef __HIPCC__
8989
constexpr int kMaxThread = 256;
90+
constexpr int warp_size = 64;
9091
#else
9192
constexpr int kMaxThread = 128;
93+
constexpr int warp_size = 32;
9294
#endif
9395

9496
// get blockDim for reduceLastDim and reduceAny
@@ -393,63 +395,31 @@ struct ReduceConfig {
393395
dim3 grid;
394396
};
395397

396-
// version 1
397-
// template <typename T, typename ReduceOp>
398-
// __device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
399-
// constexpr int warp_size = 32;
400-
// unsigned mask = 0u;
401-
// CREATE_SHFL_MASK(mask, true);
402-
// for (int stride = warp_size / 2; stride > 0; stride >>= 1) {
403-
// T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
404-
// val = reducer(val, temp);
405-
// }
406-
// return val;
407-
// }
408-
409-
// template <typename T, typename ReduceOp>
410-
// __device__ __forceinline__ T BlockReduce(T val, ReduceOp reducer) {
411-
// __shared__ T shared[32];
412-
// int block_dim_x = blockDim.x;
413-
// if (blockDim.x > warpSize) {
414-
// block_dim_x = blockDim.x / warpSize;
415-
// int lane = threadIdx.x % warpSize;
416-
// int wid = threadIdx.x / warpSize;
417-
// val = WarpReduce(val, reducer);
418-
// if (lane == 0) {
419-
// shared[wid] = val;
420-
// }
421-
// __syncthreads();
422-
// if (wid == 0) {
423-
// val = shared[lane];
424-
// }
425-
// }
426-
// __syncthreads();
427-
428-
// unsigned mask = 0u;
429-
// CREATE_SHFL_MASK(mask, true);
430-
// for (int stride = 1; stride < block_dim_x; stride <<= 1) {
431-
// T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
432-
// val = reducer(val, temp);
433-
// }
434-
// return val;
435-
// }
436-
437-
// version 2
398+
template <typename T, typename ReduceOp>
399+
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
400+
unsigned mask = 0u;
401+
CREATE_SHFL_MASK(mask, true);
402+
for (int stride = detail::warp_size / 2; stride > 0; stride >>= 1) {
403+
T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
404+
val = reducer(val, temp);
405+
}
406+
return val;
407+
}
408+
438409
template <typename T, typename ReduceOp>
439410
__device__ __forceinline__ T BlockReduce(T val, ReduceOp reducer) {
440-
__shared__ T shared[detail::kMaxThread];
411+
__shared__ T shared[detail::warp_size];
441412
int block_dim_x = blockDim.x;
442413
if (blockDim.x > warpSize) {
443-
block_dim_x = warpSize;
444-
shared[threadIdx.x] = val;
445-
for (int stride = blockDim.x / 2; stride >= warpSize; stride >>= 1) {
446-
__syncthreads();
447-
if (threadIdx.x < stride) {
448-
T temp = shared[threadIdx.x + stride];
449-
val = reducer(val, temp);
450-
shared[threadIdx.x] = val;
451-
}
414+
block_dim_x = blockDim.x / warpSize;
415+
int lane = threadIdx.x % warpSize;
416+
int wid = threadIdx.x / warpSize;
417+
val = WarpReduce(val, reducer);
418+
if (lane == 0) {
419+
shared[wid] = val;
452420
}
421+
__syncthreads();
422+
val = shared[lane];
453423
}
454424
__syncthreads();
455425

@@ -616,37 +586,36 @@ __global__ void ReduceKernelFunction(
616586
left_strides);
617587
}
618588

619-
template <typename Tx, typename Ty, typename ReduceOp, int kRank,
620-
int kReduceRank>
589+
template <typename Tx, typename Ty, typename ReduceOp, int Rank, int ReduceRank>
621590
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
622591
const ReduceOp& reducer, Ty init,
623592
gpuStream_t stream, ReduceConfig<Ty> config) {
624593
using TransformOp = typename ReduceOp::Transformer;
625594

626-
ReduceKernelFunction<Tx, Ty, ReduceOp, TransformOp, kRank,
627-
kReduceRank><<<config.grid, config.block, 0, stream>>>(
595+
ReduceKernelFunction<Tx, Ty, ReduceOp, TransformOp, Rank,
596+
ReduceRank><<<config.grid, config.block, 0, stream>>>(
628597
x_data, config.output_data, reducer, TransformOp(config.reduce_num), init,
629598
config.reduce_num, config.left_num, config.blocking_size,
630-
config.reduce_type, detail::VectorToArray<int, kRank>(config.x_strides),
631-
detail::VectorToArray<int, kReduceRank>(config.reduce_dim),
632-
detail::VectorToArray<int, kReduceRank>(config.reduce_strides),
633-
detail::VectorToArray<int, kRank - kReduceRank>(config.left_dim),
634-
detail::VectorToArray<int, kRank - kReduceRank>(config.left_strides));
599+
config.reduce_type, detail::VectorToArray<int, Rank>(config.x_strides),
600+
detail::VectorToArray<int, ReduceRank>(config.reduce_dim),
601+
detail::VectorToArray<int, ReduceRank>(config.reduce_strides),
602+
detail::VectorToArray<int, Rank - ReduceRank>(config.left_dim),
603+
detail::VectorToArray<int, Rank - ReduceRank>(config.left_strides));
635604

636605
if (config.should_reduce_again) {
637606
dim3 block(config.block.x, 1, 1);
638607
dim3 grid(config.grid.x, 1, config.grid.z);
639608

640-
ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, kRank,
641-
kReduceRank><<<grid, block, 0, stream>>>(
609+
ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<Ty>, Rank,
610+
ReduceRank><<<grid, block, 0, stream>>>(
642611
config.output_data, y_data, reducer,
643612
detail::IdentityFunctor<Ty>(config.grid.y), init, config.grid.y,
644613
config.left_num, config.grid.y, ReduceType::kReduceHigherDim,
645-
detail::VectorToArray<int, kRank>(config.x_strides),
646-
detail::VectorToArray<int, kReduceRank>(config.reduce_dim),
647-
detail::VectorToArray<int, kReduceRank>(config.reduce_strides),
648-
detail::VectorToArray<int, kRank - kReduceRank>(config.left_dim),
649-
detail::VectorToArray<int, kRank - kReduceRank>(config.left_strides));
614+
detail::VectorToArray<int, Rank>(config.x_strides),
615+
detail::VectorToArray<int, ReduceRank>(config.reduce_dim),
616+
detail::VectorToArray<int, ReduceRank>(config.reduce_strides),
617+
detail::VectorToArray<int, Rank - ReduceRank>(config.left_dim),
618+
detail::VectorToArray<int, Rank - ReduceRank>(config.left_strides));
650619
}
651620
}
652621

@@ -659,15 +628,15 @@ static void ReduceKernelImpl(const Tx* x_data, Ty* y_data,
659628

660629
#define CUB_RANK_CASE(i, ...) \
661630
case i: { \
662-
constexpr auto kRank = i; \
631+
constexpr auto Rank = i; \
663632
switch (reduce_rank) { __VA_ARGS__; } \
664633
} break
665634

666-
#define CUB_REDUCE_RANK_CASE(i, ...) \
667-
case i: { \
668-
constexpr auto kReduceRank = i; \
669-
LaunchReduceKernel<Tx, Ty, ReduceOp, kRank, kReduceRank>( \
670-
x_data, y_data, reducer, init, stream, config); \
635+
#define CUB_REDUCE_RANK_CASE(i, ...) \
636+
case i: { \
637+
constexpr auto ReduceRank = i; \
638+
LaunchReduceKernel<Tx, Ty, ReduceOp, Rank, ReduceRank>( \
639+
x_data, y_data, reducer, init, stream, config); \
671640
} break
672641

673642
detail::CheckReduceRank(reduce_rank, rank);

0 commit comments

Comments
 (0)