Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
7d58b91
Merge pull request #1 from PaddlePaddle/develop
AnnaTrainingG Mar 25, 2021
1021e08
Merge pull request #2 from PaddlePaddle/develop
AnnaTrainingG Mar 29, 2021
43f53fe
Merge pull request #3 from PaddlePaddle/develop
AnnaTrainingG Apr 19, 2021
d25ab26
Merge pull request #4 from PaddlePaddle/develop
AnnaTrainingG May 7, 2021
8c8717f
Merge pull request #5 from PaddlePaddle/develop
AnnaTrainingG May 25, 2021
9ddf5e8
Merge pull request #6 from PaddlePaddle/develop
AnnaTrainingG May 26, 2021
b0cbcca
Merge pull request #9 from PaddlePaddle/develop
AnnaTrainingG Jun 1, 2021
cdecaf0
Merge pull request #14 from PaddlePaddle/develop
AnnaTrainingG Jun 11, 2021
0da14c9
Merge pull request #16 from PaddlePaddle/develop
AnnaTrainingG Jun 15, 2021
ca95763
Merge pull request #17 from PaddlePaddle/develop
AnnaTrainingG Jun 22, 2021
25ba21c
Merge pull request #18 from PaddlePaddle/develop
AnnaTrainingG Jul 5, 2021
3ce9983
Merge pull request #19 from PaddlePaddle/develop
AnnaTrainingG Jul 6, 2021
61842ed
Merge pull request #20 from PaddlePaddle/develop
AnnaTrainingG Jul 12, 2021
0e2c73b
Merge pull request #21 from PaddlePaddle/develop
AnnaTrainingG Jul 28, 2021
c1e59cf
Merge pull request #22 from PaddlePaddle/develop
AnnaTrainingG Aug 2, 2021
3a54149
Merge pull request #23 from PaddlePaddle/develop
AnnaTrainingG Aug 4, 2021
3dd8c9d
merge
AnnaTrainingG Aug 4, 2021
ba1d2fa
update
AnnaTrainingG Aug 4, 2021
3e10bd2
update
AnnaTrainingG Aug 9, 2021
66ec9bf
update
AnnaTrainingG Aug 10, 2021
59f0df2
update
AnnaTrainingG Aug 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion paddle/fluid/operators/reduce_ops/reduce_all_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"

// reduce_prod
REGISTER_OP_CUDA_KERNEL(
reduce_all,
ops::ReduceCudaKernel<bool, paddle::operators::CustomLogicalAnd>);
1 change: 0 additions & 1 deletion paddle/fluid/operators/reduce_ops/reduce_any_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"

// reduce_prod
REGISTER_OP_CUDA_KERNEL(
reduce_any,
ops::ReduceCudaKernel<bool, paddle::operators::CustomLogicalOr>);
59 changes: 6 additions & 53 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,58 +13,11 @@
// limitations under the License.

#include <vector>
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"

namespace paddle {
namespace operators {

template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}

HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }

private:
T n_inv;
};

template <typename T>
class ReduceMeanKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");

auto dims = context.Attr<std::vector<int>>("dim");
bool keep_dim = context.Attr<bool>("keep_dim");

std::vector<int> reduce_dims;
if (reduce_all) {
reduce_dims.resize(input->dims().size());
for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i;
} else {
for (auto e : dims) {
reduce_dims.push_back(e >= 0 ? e : e + input->dims().size());
}
}

int reduce_num = 1;
for (int i = 0; i < reduce_dims.size(); ++i) {
reduce_num *= input->dims()[reduce_dims[i]];
}

auto stream = context.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, DivideFunctor<T>>(
*input, output, reduce_dims, static_cast<T>(0), cub::Sum(),
DivideFunctor<T>(reduce_num), stream);
}
};

} // namespace operators
} // namespace paddle

REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<bool>,
ops::ReduceMeanKernel<float>,
ops::ReduceMeanKernel<double>);
REGISTER_OP_CUDA_KERNEL(
reduce_mean, ops::ReduceCudaKernel<bool, paddle::operators::CustomMean>,
ops::ReduceCudaKernel<float, paddle::operators::CustomMean>,
ops::ReduceCudaKernel<double, paddle::operators::CustomMean>);
116 changes: 56 additions & 60 deletions paddle/fluid/operators/reduce_ops/reduce_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace cub = hipcub;
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h"

Expand Down Expand Up @@ -145,7 +146,6 @@ using Tensor = framework::Tensor;
constexpr int kMaxRank = framework::DDim::kMaxRank;

enum ReduceType {
kReduceAll = 0x00, // when reduce_rank == x_rank
kReduceLastDim = 0x01, // when reduce_dim[0] == x_dim.size() - 1;
kReduceHigherDim = 0x02, // ReduceFirstDim or reduceSecondDim
kReduceAny = 0x03, // when reduce_dim.size() > 1
Expand Down Expand Up @@ -338,15 +338,11 @@ struct ReduceConfig {
void SetReduceType() {
int rank = x_dim.size();
int reduce_rank = reduce_dim.size();
bool is_large_enough = (reduce_num > REDUCE_SPLIT_BOUNDARY / 2) ||
(left_num > REDUCE_SPLIT_BOUNDARY);

if (rank == reduce_rank) {
reduce_type = static_cast<int>(ReduceType::kReduceAll);
} else if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
bool is_last_dim =
(rank == 2) && (reduce_rank == 1) && (reduce_dim[0] == 1);
if (rank == reduce_rank || is_last_dim) {
reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
} else if (reduce_rank == 1 &&
((rank == 2 && is_large_enough) || rank != 2)) {
} else if (reduce_rank == 1) {
// ReduceFirstDim and reduceSecondDim
reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
} else {
Expand Down Expand Up @@ -576,39 +572,40 @@ static __device__ T BlockYReduce(T val, ReduceOp reducer) {
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename TransformOp>
__device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init,
TransformOp transformer, MPType init,
int reduce_num, int left_num, int block_size) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int idy = blockIdx.y * block_size;

Ty reduce_var = init;
MPType reduce_var = init;

if (idx < left_num) {
int loop = reduce_num - idy;
loop = loop > block_size ? block_size : loop;

for (int iy = 0; iy < loop; iy++) {
int id = (idy + iy) * left_num + idx + blockIdx.z * reduce_num * left_num;
reduce_var = reducer(reduce_var, static_cast<Ty>(transformer(x[id])));
reduce_var = reducer(reduce_var, static_cast<MPType>(transformer(x[id])));
}

y[idx + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num] =
reduce_var;
static_cast<Ty>(reduce_var);
}
}

// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
typename ReduceIndexCal, typename LeftIndexCal>
template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename TransformOp>
__device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init, int reduce_num,
TransformOp transformer, MPType init, int reduce_num,
int left_num, bool reduce_lastdim,
ReduceIndexCal reduce_index_calculator,
LeftIndexCal left_index_calculator) {
const IndexCalculator& reduce_index_calculator,
const IndexCalculator& left_index_calculator) {
int input_idx, left_idx, stride;
// the last dim gets involved in reduction
if (reduce_lastdim) {
Expand All @@ -621,9 +618,9 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
stride = gridDim.y * blockDim.y;
}
// calculate the offset, means the addr where each thread really start.
int input_offset = left_index_calculator(left_idx);
int input_offset = left_index_calculator.Get(left_idx);
const Tx* input = x + input_offset;
Ty reduce_var = init;
MPType reduce_var = init;

// 1. reduce for each thread
if (left_idx < left_num) {
Expand All @@ -634,12 +631,13 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
#pragma unroll
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
int reduce_idx = input_idx + i * stride;
int idx_x = reduce_index_calculator(reduce_idx);
int idx_x = reduce_index_calculator.Get(reduce_idx);
input_reg[i] = input[idx_x];
}
#pragma unroll
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
reduce_var = reducer(reduce_var, transformer(input_reg[i]));
reduce_var =
reducer(reduce_var, static_cast<MPType>(transformer(input_reg[i])));
}
input_idx += REDUCE_VEC_SIZE * stride;
}
Expand All @@ -652,7 +650,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
break;
}
int reduce_idx = input_idx;
int idx_x = reduce_index_calculator(reduce_idx);
int idx_x = reduce_index_calculator.Get(reduce_idx);
input_reg[i] = input[idx_x];
input_idx += stride;
}
Expand All @@ -662,7 +660,8 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
if (input_idx >= reduce_num) {
break;
}
reduce_var = reducer(reduce_var, transformer(input_reg[i]));
reduce_var =
reducer(reduce_var, static_cast<MPType>(transformer(input_reg[i])));
input_idx += stride;
}
}
Expand All @@ -677,71 +676,64 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
// 3. reduce in block x
reduce_var = BlockXReduce(reduce_var, reducer);
if (left_idx < left_num && threadIdx.x == 0) {
y[blockIdx.y * left_num + left_idx] = reduce_var;
y[blockIdx.y * left_num + left_idx] = static_cast<Ty>(reduce_var);
}
} else {
if (left_idx < left_num && threadIdx.y == 0) {
y[blockIdx.y * left_num + left_idx] = reduce_var;
y[blockIdx.y * left_num + left_idx] = static_cast<Ty>(reduce_var);
}
}
}

// module function designed for global function
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename TransformOp>
__device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init, int reduce_num,
int left_num, int blocking_size, int reduce_type,
bool reduce_lastdim,
TransformOp transformer, MPType init,
int reduce_num, int left_num, int blocking_size,
int reduce_type, bool reduce_lastdim,
const IndexCalculator& reduce_index_calculator,
const IndexCalculator& left_index_calculator) {
if (reduce_type == ReduceType::kReduceLastDim) {
ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
if (reduce_type == ReduceType::kReduceLastDim ||
reduce_type == ReduceType::kReduceAny) {
ReduceAny<Tx, Ty, MPType, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
[&](int idx) { return idx; },
[&](int idx) { return idx * reduce_num; });

reduce_index_calculator, left_index_calculator);
// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
} else if (reduce_type == ReduceType::kReduceHigherDim) {
ReduceHigherDim<Tx, Ty, ReduceOp, TransformOp>(
ReduceHigherDim<Tx, Ty, MPType, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);

// reduce_rank >= 2
} else {
ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
[&](int idx) { return reduce_index_calculator.Get(idx); },
[&](int idx) { return left_index_calculator.Get(idx); });
}
}

template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename TransformOp>
__global__ void ReduceKernelFunction(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init,
TransformOp transformer, MPType init,
int reduce_num, int left_num,
int blocking_size, int reduce_type,
bool reduce_lastdim,
IndexCalculator reduce_index_calculator,
IndexCalculator left_index_calculator) {
ReduceModule<Tx, Ty, ReduceOp, TransformOp>(
ReduceModule<Tx, Ty, MPType, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, blocking_size,
reduce_type, reduce_lastdim, reduce_index_calculator,
left_index_calculator);
}

template <typename Tx, typename Ty, typename ReduceOp>
template <typename Tx, typename Ty, typename MPType, typename ReduceOp>
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
const ReduceOp& reducer, Ty init,
const ReduceOp& reducer, MPType init,
gpuStream_t stream, ReduceConfig<Ty> config) {
using TransformOp = typename ReduceOp::Transformer;

int reduce_rank = config.reduce_strides.size();
int left_rank = config.left_strides.size();
auto reduce_index_calculator = IndexCalculator(
reduce_rank, config.reduce_dim, config.reduce_strides, config.x_strides);
auto left_index_calculator = IndexCalculator(
left_rank, config.left_dim, config.left_strides, config.x_strides);

ReduceKernelFunction<Tx, Ty, ReduceOp,
ReduceKernelFunction<Tx, Ty, MPType, ReduceOp,
TransformOp><<<config.grid, config.block, 0, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num), init,
config.reduce_num, config.left_num, config.blocking_size,
Expand All @@ -759,10 +751,11 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
grid = dim3(config.grid.x, 1, config.grid.z);
}

ReduceKernelFunction<Ty, Ty, ReduceOp, detail::IdentityFunctor<
Ty>><<<grid, block, 0, stream>>>(
ReduceKernelFunction<
Ty, Ty, MPType, ReduceOp,
detail::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>(
config.output_data, y_data, reducer,
detail::IdentityFunctor<Ty>(config.grid.y), init, config.grid.y,
detail::IdentityFunctor<Ty, MPType>(config.grid.y), init, config.grid.y,
config.left_num, config.grid.y, ReduceType::kReduceHigherDim,
config.reduce_lastdim, reduce_index_calculator, left_index_calculator);
}
Expand Down Expand Up @@ -793,11 +786,12 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
}

config.SetOutputData(y_data, x.place(), &tmp);

using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
auto reducer = ReduceOp<Tx, Ty>();
// launch CUB::Reduce
if (config.reduce_type == static_cast<int>(ReduceType::kReduceAll)) {
bool use_cub_reduce = (config.left_num == 1) &&
(!std::is_same<Tx, paddle::platform::float16>::value);
if (use_cub_reduce) {
// launch CUB::Reduce
using TransformOp = typename ReduceOp<Tx, Ty>::Transformer;
auto reducer = ReduceOp<Tx, Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
x_data, TransformOp(config.reduce_num));
size_t temp_storage_bytes = 0;
Expand All @@ -815,7 +809,9 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
return;
}

LaunchReduceKernel<Tx, Ty, ReduceOp<Tx, Ty>>(
using MPType = typename details::MPTypeTrait<Ty>::Type;
auto reducer = ReduceOp<Tx, MPType>();
LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<Tx, MPType>>(
x_data, y_data, reducer, reducer.initial(), stream, config);
}

Expand Down
Loading