Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
bf318b8
add a fusion op: fused_residual_dropout_bias
zkh2016 Aug 23, 2021
507117a
simplify the code, andd opt reduce sum
zkh2016 Aug 23, 2021
462caa1
resolve review comments and add comments to the code
zkh2016 Aug 24, 2021
93e0638
fused_dropout: optimize code structure to facilitate reuse
zkh2016 Aug 24, 2021
e2808ff
Merge branch 'PaddlePaddle:develop' into develop
zkh2016 Aug 25, 2021
036b430
optimize code structure to facilitate reuse
zkh2016 Aug 25, 2021
4d33b98
modify the code according to the review comments
zkh2016 Aug 30, 2021
bd44d04
replace cudaMemcpy with TensorFromVector and TensorToVector in Dropou…
zkh2016 Aug 30, 2021
d2beab7
set dropout attr 'is_test':false
zkh2016 Aug 31, 2021
5d2bbc8
optimize the code according to the review comments
zkh2016 Sep 2, 2021
934fcac
use static_cast
zkh2016 Sep 2, 2021
44610ea
fix the blocks for large shape
zkh2016 Sep 8, 2021
3133d33
Merge remote-tracking branch 'upstream/develop' into develop
zkh2016 Sep 8, 2021
1a83adb
merge upstream, and used new AlignedVector
zkh2016 Sep 8, 2021
4dba815
add a fusion op: fused_dropout_act_bias
zkh2016 Sep 8, 2021
f848739
remove unused code
zkh2016 Sep 9, 2021
6d30340
Merge branch 'develop' into fused_dropout_act_bias
zkh2016 Sep 9, 2021
b8a9861
redefine activation functor
zkh2016 Sep 9, 2021
fd01daa
implement the same gelu as the baseline for FFN
zkh2016 Sep 9, 2021
cabb9d2
add #define _USE_MATH_DEFINES for windows
zkh2016 Sep 10, 2021
3cfdff8
modify the code according to the review comment
zkh2016 Sep 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,6 @@ if (WITH_GPU OR WITH_ROCM)
# only support CUDA
if(NOT WITH_ROCM)
nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory)
nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory)
endif()
endif()
317 changes: 317 additions & 0 deletions paddle/fluid/operators/fused/fused_dropout_act_bias.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加些注释说明函数的功能吧

#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif

#include "paddle/fluid/operators/fused/fused_dropout_common.h"
#include "paddle/fluid/operators/math/functors.h"

namespace paddle {
namespace operators {

/**
*@brief the gelu functor
*/
template <typename T>
struct GeluFunctor {
inline __host__ __device__ T operator()(const T x) const {
using U = LayerNormParamType<T>;
const U casted_x = static_cast<U>(x);
const U temp = erf(casted_x * static_cast<U>(M_SQRT1_2));
const U out = (casted_x * static_cast<U>(0.5) * (static_cast<U>(1) + temp));
return static_cast<T>(out);
}
};

/**
*@brief the gelu grad functor
*/
template <typename T>
struct GeluGradFunctor {
inline __host__ __device__ T UseOut(const T x) const {
using U = LayerNormParamType<T>;
auto casted_x = static_cast<U>(x);

auto first =
static_cast<U>(0.5) *
(static_cast<U>(1) + erf(casted_x * static_cast<U>(M_SQRT1_2)));

auto second = static_cast<U>(0.5 * M_2_SQRTPI * M_SQRT1_2) * casted_x *
exp(-static_cast<U>(0.5) * casted_x * casted_x);
return static_cast<T>((first + second));
}
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面的激活函数,Relu和Gelu在math下面都有,可以直接复用吗,因为math下面实现的接口已经很统一了,复用的话这里应该就不需要再封装一遍?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, gelu的实现参考gelu_op的,和math下的稍有不同。可以直接传math下的functor。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不同点在哪?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/gelu_op.h#L96
这个主要是参考gelu_op的实现,采用了两种计算方式,
一种近似计算和math的方式一样: gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
另一种:gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))

/**
* @brief dst = dropout(activation(src + bias));
* the src, mask and dst shape is (rows, cols)
* the bias shape is (1, cols)
*/
template <typename T, typename MaskType, int VecSize, typename Functor>
__global__ void FusedDropoutActBias(
Functor act, const uint64_t seed, const uint64_t rows, const uint64_t cols,
const int increment, const float dropout_prob,
const bool is_upscale_in_train, const bool is_test,
const T *__restrict__ src, const T *__restrict__ bias, T *dst,
MaskType *mask) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;

curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);

T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
if (!is_upscale_in_train) {
factor = static_cast<T>(1.0);
}
if (is_test) {
factor = static_cast<T>(1.0f - dropout_prob);
if (is_upscale_in_train) {
factor = static_cast<T>(1.0f);
}
}

using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
using MaskStoreT = platform::AlignedVector<MaskType, VecSize>;

for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols;
i += blockDim.x * gridDim.x * VecSize) {
LoadT src_vec;
LoadT bias_vec;
// vectorize load data from global
platform::Load<T, VecSize>(&src[r * cols + i], &src_vec);

if (bias) {
platform::Load<T, VecSize>(&bias[i], &bias_vec);
} else {
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
bias_vec[ii] = static_cast<T>(0);
}
}

MaskStoreT mask_vec;
if (!is_test) {
float rand[VecSize];
RandVec<VecSize>(&state, rand);
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
mask_vec[ii] = static_cast<MaskType>(rand[ii] >= dropout_prob);
}
} else {
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
mask_vec[ii] = static_cast<MaskType>(1);
}
}

StoreT dest_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
const T tmp = src_vec[ii] + bias_vec[ii];
const T act_out = act(tmp);
dest_vec[ii] = act_out * static_cast<T>(mask_vec[ii]) * factor;
}
// store result to global
platform::Store<T, VecSize>(dest_vec, &dst[r * cols + i]);
if (!is_test) {
platform::Store<MaskType, VecSize>(mask_vec, &mask[r * cols + i]);
}
}
}
}

/**
* @brief dst = dropout(activation(src + bias));
*/
template <typename T, typename MaskType, typename Functor>
void LaunchDropoutActBias(Functor act_functor, const uint64_t seed,
const uint32_t rows, const uint32_t cols,
const int increment, const float dropout_prob,
const bool is_upscale_in_train, const bool is_test,
const T *src, const T *bias, T *dst,
MaskType *mask_data,
const platform::CUDADeviceContext &ctx) {
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
SetZero<T>(ctx, dst, rows * cols);
SetZero<MaskType>(ctx, mask_data, rows * cols);
return;
}

const int VecSize = MAX_CACHE_BYTES / sizeof(T);
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
if (cols % VecSize == 0) {
FusedDropoutActBias<T, MaskType, VecSize, Functor><<<
config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor, seed, rows, cols, increment, dropout_prob,
is_upscale_in_train, is_test, src, bias, dst, mask_data);
} else {
FusedDropoutActBias<T, MaskType, 1, Functor><<<
config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor, seed, rows, cols, increment, dropout_prob,
is_upscale_in_train, is_test, src, bias, dst, mask_data);
}
}

/*
* @brief calculate the grad of no bias
*/
template <typename T, typename MaskType, int VecSize, typename Functor>
__global__ void FusedDropoutActGrad(Functor act_grad, const T *dout,
const MaskType *mask, const T *src,
const T factor, const int64_t size, T *dx) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;

using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
LoadT dout_vec;
LoadT src_vec;
MaskLoadT mask_vec;

platform::Load<T, VecSize>(&dout[i], &dout_vec);
platform::Load<MaskType, VecSize>(&mask[i], &mask_vec);
platform::Load<T, VecSize>(&src[i], &src_vec);

StoreT dx_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是可以调用Kernel Primitives API函数?

T args[2];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前这种写法没有必要定义T args[2];

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已在下一个PR中修改

args[0] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
args[1] = src_vec[ii];
dx_vec[ii] = args[0] * act_grad.UseOut(args[1]);
}
platform::Store<T, VecSize>(dx_vec, &dx[i]);
}
}

/**
* blocks(128 * 8)
* 1. calculate the dx and reduce total rows to 128 rows
* 2. save 128*8 temporary sum in 8*128 shared memory
* 3. reduce the sum of 128 cols data by 8*VecSize warps
*/
template <typename T, typename MaskType, int BlockSizeX, int BlockSizeY,
int VecSize, typename Functor>
__global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout,
const MaskType *mask, const T *src,
const T *bias, const T factor,
const int64_t rows, const int64_t cols,
T *dx, T *dbias) {
int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;

using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
T tmp_sum[VecSize] = {static_cast<T>(0)};
// calculate the dx and temporary sum
if (col_id * VecSize < cols) {
for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) {
int index = row_id * cols + col_id * VecSize;
LoadT dout_vec;
LoadT src_vec;
LoadT bias_vec;
MaskLoadT mask_vec;

platform::Load<T, VecSize>(&dout[index], &dout_vec);
platform::Load<T, VecSize>(&src[index], &src_vec);
platform::Load<MaskType, VecSize>(&mask[index], &mask_vec);
platform::Load<T, VecSize>(&bias[col_id * VecSize], &bias_vec);

StoreT dx_vec;
#pragma unroll
for (int i = 0; i < VecSize; i++) {
T val;
T args[2];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前这种写法没有必要定义T args[2];

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已在下一个PR中修改

args[0] = dout_vec[i] * static_cast<T>(mask_vec[i]) * factor;
args[1] = src_vec[i] + bias_vec[i];
val = args[0] * act_grad.UseOut(args[1]);
dx_vec[i] = val;
tmp_sum[i] += val;
}
platform::Store<T, VecSize>(dx_vec, &dx[index]);
}
}

CalculateDBias<T, VecSize, BlockSizeX, BlockSizeY>(tmp_sum, dbias, cols);
}

/**
* @brief to launch kernel FusedResidualDropoutBiasGradVec
*/
template <typename T, typename MaskType, typename Functor>
void LaunchDropoutActBiasGrad(Functor act_functor, const T *dout,
const MaskType *mask, const T *src, const T *bias,
const float dropout_prob,
const bool is_upscale_in_train,
const uint32_t rows, const uint32_t cols, T *dx,
T *dbias,
const platform::CUDADeviceContext &ctx) {
const T zero = static_cast<T>(0.0);
auto factor = dropout_prob == static_cast<float>(1.0f)
? zero
: static_cast<T>(1.0 / (1.0 - dropout_prob));
if (!is_upscale_in_train) {
factor = static_cast<T>(1.0f);
}

const int VecSize = MAX_CACHE_BYTES / sizeof(T);
int real_vec_size = cols % VecSize == 0 ? VecSize : 1;

if (dbias != nullptr) {
const auto threads = 8;
const auto blocks =
std::max(static_cast<uint32_t>(1),
(cols / real_vec_size + threads - 1) / threads);
dim3 block_dim(threads, 128, 1);
dim3 grid_dim(blocks, 1, 1);
if (cols % VecSize == 0) {
FusedDropoutActBiasGrad<
T, MaskType, 8, 128, VecSize,
Functor><<<grid_dim, block_dim, 0, ctx.stream()>>>(
act_functor, dout, mask, src, bias, factor, rows, cols, dx, dbias);
} else {
FusedDropoutActBiasGrad<
T, MaskType, 8, 128, 1,
Functor><<<grid_dim, block_dim, 0, ctx.stream()>>>(
act_functor, dout, mask, src, bias, factor, rows, cols, dx, dbias);
}
} else {
const uint64_t n = rows * cols;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size);
if (n % VecSize == 0) {
FusedDropoutActGrad<T, MaskType, VecSize, Functor><<<
config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor, dout, mask, src, factor, n, dx);
} else {
FusedDropoutActGrad<T, MaskType, 1, Functor><<<
config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor, dout, mask, src, factor, n, dx);
}
}
}

} // namespace operators
} // namespace paddle
Loading