Skip to content

Commit cf8bf03

Browse files
authored
add a fusion op: fused_residual_dropout_bias (#34963)
1 parent eb1fbf1 commit cf8bf03

File tree

5 files changed

+871
-0
lines changed

5 files changed

+871
-0
lines changed

paddle/fluid/operators/fused/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,9 @@ if (WITH_GPU OR WITH_ROCM)
7171
op_library(fused_bn_add_activation_op)
7272
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_bn_add_activation);\n")
7373
endif()
74+
# fused_dropout
75+
# only support CUDA
76+
if(NOT WITH_ROCM)
77+
nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory)
78+
endif()
7479
endif()
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <cooperative_groups.h>
18+
#include <cuda.h>
19+
#include <curand_kernel.h>
20+
21+
#include "paddle/fluid/memory/memory.h"
22+
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
23+
#include "paddle/fluid/platform/aligned_vector.h"
24+
#include "paddle/fluid/platform/cuda_device_function.h"
25+
#include "paddle/fluid/platform/device_context.h"
26+
#include "paddle/fluid/platform/float16.h"
27+
#include "paddle/fluid/platform/gpu_launch_config.h"
28+
29+
namespace paddle {
30+
namespace operators {
31+
32+
#define CACHE_LINE 128
33+
#define MAX_CACHE_BYTES (CACHE_LINE / CHAR_BIT)
34+
35+
/**
36+
* get the threads for fused_residual_dropout_bias:
37+
* 1D blocks: blockDim.x = cols
38+
* 2D grids: gridDim.y = rows
39+
*/
40+
inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids(
41+
const platform::CUDADeviceContext &ctx, const uint32_t rows,
42+
const uint32_t cols, const int VecSize) {
43+
const uint32_t tmp_cols = cols / VecSize;
44+
int threads = std::max(
45+
static_cast<uint32_t>(32),
46+
std::min(tmp_cols, static_cast<uint32_t>(ctx.GetMaxThreadsPerBlock())));
47+
const auto blocks_x =
48+
std::max(static_cast<uint32_t>(1), (tmp_cols + threads - 1) / threads);
49+
const auto blocks_y = std::max(static_cast<uint32_t>(1), rows);
50+
platform::GpuLaunchConfig config;
51+
config.block_per_grid.x = blocks_x;
52+
config.block_per_grid.y = blocks_y;
53+
config.thread_per_block.x = threads;
54+
return config;
55+
}
56+
57+
__forceinline__ __device__ void Rand1(curandStatePhilox4_32_10_t *state,
58+
float *data) {
59+
data[0] = curand_uniform(state);
60+
}
61+
62+
__forceinline__ __device__ void Rand2(curandStatePhilox4_32_10_t *state,
63+
float *data) {
64+
data[0] = curand_uniform(state);
65+
data[1] = curand_uniform(state);
66+
}
67+
68+
__forceinline__ __device__ void Rand4(curandStatePhilox4_32_10_t *state,
69+
float *data) {
70+
float4 rand4 = curand_uniform4(state);
71+
data[0] = rand4.x;
72+
data[1] = rand4.y;
73+
data[2] = rand4.w;
74+
data[3] = rand4.z;
75+
}
76+
77+
__forceinline__ __device__ void Rand8(curandStatePhilox4_32_10_t *state,
78+
float *data) {
79+
Rand4(state, data);
80+
Rand4(state, data + 4);
81+
}
82+
83+
__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state,
84+
float *data, const int VecSize) {
85+
if (VecSize == 1) {
86+
Rand1(state, data);
87+
} else if (VecSize == 2) {
88+
Rand2(state, data);
89+
} else if (VecSize == 4) {
90+
Rand4(state, data);
91+
} else if (VecSize == 8) {
92+
Rand8(state, data);
93+
} else {
94+
return;
95+
}
96+
}
97+
98+
} // namespace operators
99+
} // namespace paddle
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <random>
18+
#include <vector>
19+
20+
#include "gtest/gtest.h"
21+
#include "paddle/fluid/framework/op_registry.h"
22+
#include "paddle/fluid/framework/operator.h"
23+
#include "paddle/fluid/framework/program_desc.h"
24+
#include "paddle/fluid/framework/tensor_util.h"
25+
#include "paddle/fluid/memory/memory.h"
26+
#include "paddle/fluid/operators/math/math_function.h"
27+
#include "paddle/fluid/string/printf.h"
28+
29+
namespace framework = paddle::framework;
30+
namespace platform = paddle::platform;
31+
namespace memory = paddle::memory;
32+
33+
USE_OP(dropout);
34+
35+
/**
36+
* @brief call paddle dropout op
37+
*/
38+
template <typename T>
39+
void Dropout(const std::vector<T> &x, const framework::DDim &x_dim,
40+
std::vector<T> *out, std::vector<uint8_t> *mask,
41+
const platform::CUDADeviceContext &ctx, uint64_t seed,
42+
float dropout_prob, bool is_upscale_in_train, bool is_test) {
43+
framework::Scope scope;
44+
auto var_x = scope.Var("X");
45+
auto tensor_x = var_x->GetMutable<framework::LoDTensor>();
46+
framework::TensorFromVector(x, ctx, tensor_x);
47+
tensor_x->Resize(x_dim);
48+
49+
auto var_out = scope.Var("Out");
50+
auto tensor_out = var_out->GetMutable<framework::LoDTensor>();
51+
52+
auto var_mask = scope.Var("Mask");
53+
auto tensor_mask = var_mask->GetMutable<framework::LoDTensor>();
54+
55+
framework::AttributeMap attrs;
56+
attrs.insert({"fix_seed", 1});
57+
attrs.insert({"seed", static_cast<int>(seed)});
58+
attrs.insert({"dropout_prob", dropout_prob});
59+
if (is_upscale_in_train) {
60+
attrs.insert({"dropout_implementation", std::string("upscale_in_train")});
61+
}
62+
63+
if (is_test) {
64+
attrs.insert({"is_test", true});
65+
}
66+
67+
auto op = framework::OpRegistry::CreateOp(
68+
"dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs);
69+
op->Run(scope, ctx.GetPlace());
70+
71+
framework::TensorToVector<T>(*tensor_out, ctx, out);
72+
if (!is_test) {
73+
framework::TensorToVector<uint8_t>(*tensor_mask, ctx, mask);
74+
}
75+
ctx.Wait();
76+
}
77+
78+
/**
79+
* @brief call paddle dropout_grad op
80+
*/
81+
template <typename T>
82+
void DropoutGrad(std::vector<T> *dx, const framework::DDim &x_dim,
83+
const std::vector<T> &dout, const std::vector<uint8_t> &mask,
84+
const platform::CUDADeviceContext &ctx, float dropout_prob,
85+
bool is_upscale_in_train) {
86+
framework::Scope scope;
87+
const size_t n = x_dim[0] * x_dim[1];
88+
auto var_out = scope.Var("DOut");
89+
auto tensor_out = var_out->GetMutable<framework::LoDTensor>();
90+
framework::TensorFromVector(dout, ctx, tensor_out);
91+
tensor_out->Resize(x_dim);
92+
93+
auto var_mask = scope.Var("Mask");
94+
auto tensor_mask = var_mask->GetMutable<framework::LoDTensor>();
95+
framework::TensorFromVector(mask, ctx, tensor_mask);
96+
tensor_mask->Resize(x_dim);
97+
98+
auto var_dx = scope.Var("DX");
99+
auto tensor_dx = var_dx->GetMutable<framework::LoDTensor>();
100+
101+
framework::AttributeMap attrs;
102+
attrs.insert({"dropout_prob", dropout_prob});
103+
attrs.insert({"is_test", false});
104+
if (is_upscale_in_train) {
105+
attrs.insert({"dropout_implementation", std::string("upscale_in_train")});
106+
} else {
107+
attrs.insert({"dropout_implementation", std::string("downgrade_in_infer")});
108+
}
109+
110+
auto op = framework::OpRegistry::CreateOp(
111+
"dropout_grad", {{"Out@GRAD", {"DOut"}}, {"Mask", {"Mask"}}},
112+
{{"X@GRAD", {"DX"}}}, attrs);
113+
op->Run(scope, ctx.GetPlace());
114+
115+
framework::TensorToVector(*tensor_dx, ctx, dx);
116+
ctx.Wait();
117+
}

0 commit comments

Comments
 (0)