Skip to content

Commit 93e0638

Browse files
committed
fused_dropout: optimize code structure to facilitate reuse
1 parent 462caa1 commit 93e0638

File tree

5 files changed

+136
-114
lines changed

5 files changed

+136
-114
lines changed

paddle/fluid/operators/fused/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,6 @@ if (WITH_GPU OR WITH_ROCM)
7474
# fused_dropout
7575
# only support CUDA
7676
if(NOT WITH_ROCM)
77-
nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry elementwise_add_op dropout_op device_context generator)
77+
nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator)
7878
endif()
7979
endif()

paddle/fluid/operators/fused/fused_dropout.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,17 @@ struct alignas(sizeof(T) * VecSize) AlignedVector {
6666
T val[VecSize];
6767
};
6868

69+
// reduce sum by a warp
70+
template <typename U>
71+
static __forceinline__ __device__ U WarpReduceSum(U val) {
72+
unsigned mask = 0u;
73+
CREATE_SHFL_MASK(mask, true);
74+
const int warpSize = 32;
75+
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
76+
val += paddle::platform::CudaShuffleDownSync(mask, val, offset);
77+
}
78+
return val;
79+
}
80+
6981
} // namespace operators
7082
} // namespace paddle
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <random>
18+
#include <vector>
19+
20+
#include "gtest/gtest.h"
21+
#include "paddle/fluid/framework/op_registry.h"
22+
#include "paddle/fluid/framework/operator.h"
23+
#include "paddle/fluid/framework/program_desc.h"
24+
#include "paddle/fluid/framework/tensor_util.h"
25+
#include "paddle/fluid/operators/math/math_function.h"
26+
#include "paddle/fluid/string/printf.h"
27+
28+
namespace framework = paddle::framework;
29+
namespace platform = paddle::platform;
30+
31+
USE_OP(dropout);
32+
33+
/**
34+
* @brief call paddle dropout op
35+
*/
36+
template <typename T>
37+
void Dropout(const T *x, const framework::DDim &x_dim, T *out,
38+
std::vector<uint8_t> *mask, const platform::CUDADeviceContext &ctx,
39+
uint64_t seed, float dropout_prob, bool is_upscale_in_train,
40+
bool is_test) {
41+
framework::Scope scope;
42+
auto var_x = scope.Var("X");
43+
auto tensor_x = var_x->GetMutable<framework::LoDTensor>();
44+
tensor_x->Resize(x_dim);
45+
tensor_x->mutable_data<T>(ctx.GetPlace());
46+
cudaMemcpy(tensor_x->data<T>(), x, x_dim[0] * x_dim[1] * sizeof(T),
47+
cudaMemcpyHostToDevice);
48+
49+
auto var_out = scope.Var("Out");
50+
auto tensor_out = var_out->GetMutable<framework::LoDTensor>();
51+
52+
auto var_mask = scope.Var("Mask");
53+
auto tensor_mask = var_mask->GetMutable<framework::LoDTensor>();
54+
55+
framework::AttributeMap attrs;
56+
attrs.insert({"fix_seed", 1});
57+
attrs.insert({"seed", static_cast<int>(seed)});
58+
attrs.insert({"dropout_prob", dropout_prob});
59+
if (is_upscale_in_train) {
60+
attrs.insert({"dropout_implementation", std::string("upscale_in_train")});
61+
}
62+
if (is_test) {
63+
attrs.insert({"is_test", 1});
64+
}
65+
66+
auto op = framework::OpRegistry::CreateOp(
67+
"dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs);
68+
op->Run(scope, ctx.GetPlace());
69+
cudaMemcpy(out, tensor_out->data<T>(), x_dim[0] * x_dim[1] * sizeof(T),
70+
cudaMemcpyDeviceToHost);
71+
if (!is_test) {
72+
cudaMemcpy((*mask).data(), tensor_mask->data<uint8_t>(),
73+
x_dim[0] * x_dim[1] * sizeof(uint8_t), cudaMemcpyDeviceToHost);
74+
}
75+
ctx.Wait();
76+
}
77+
78+
/**
79+
* @brief call paddle dropout_grad op
80+
*/
81+
template <typename T>
82+
void DropoutGrad(T *dx, const framework::DDim &x_dim, const T *dout,
83+
const uint8_t *mask, const platform::CUDADeviceContext &ctx,
84+
float dropout_prob, bool is_upscale_in_train) {
85+
framework::Scope scope;
86+
const size_t n = x_dim[0] * x_dim[1];
87+
auto var_out = scope.Var("DOut");
88+
auto tensor_out = var_out->GetMutable<framework::LoDTensor>();
89+
tensor_out->Resize(x_dim);
90+
tensor_out->mutable_data<T>(ctx.GetPlace());
91+
cudaMemcpy(tensor_out->data<T>(), dout, n * sizeof(T),
92+
cudaMemcpyHostToDevice);
93+
94+
auto var_mask = scope.Var("Mask");
95+
auto tensor_mask = var_mask->GetMutable<framework::LoDTensor>();
96+
tensor_mask->Resize(x_dim);
97+
tensor_mask->mutable_data<uint8_t>(ctx.GetPlace());
98+
cudaMemcpy(tensor_mask->data<uint8_t>(), mask, n * sizeof(uint8_t),
99+
cudaMemcpyHostToDevice);
100+
101+
auto var_dx = scope.Var("DX");
102+
auto tensor_dx = var_dx->GetMutable<framework::LoDTensor>();
103+
104+
framework::AttributeMap attrs;
105+
attrs.insert({"dropout_prob", dropout_prob});
106+
attrs.insert({"is_test", 0});
107+
if (is_upscale_in_train) {
108+
attrs.insert({"dropout_implementation", std::string("upscale_in_train")});
109+
} else {
110+
attrs.insert({"dropout_implementation", std::string("downgrade_in_infer")});
111+
}
112+
113+
auto op = framework::OpRegistry::CreateOp(
114+
"dropout_grad", {{"Out@GRAD", {"DOut"}}, {"Mask", {"Mask"}}},
115+
{{"X@GRAD", {"DX"}}}, attrs);
116+
op->Run(scope, ctx.GetPlace());
117+
118+
cudaMemcpy(dx, tensor_dx->data<T>(), x_dim[0] * x_dim[1] * sizeof(T),
119+
cudaMemcpyDeviceToHost);
120+
ctx.Wait();
121+
}

paddle/fluid/operators/fused/fused_residual_dropout_bias.h

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,8 @@ __global__ void FusedResidualDropoutBiasIsTest(const size_t rows,
118118

119119
using LoadT = AlignedVector<T, VecSize>;
120120

121-
const int tmp_cols = cols / VecSize * VecSize;
122121
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
123-
for (int i = col_id * VecSize; i < tmp_cols;
122+
for (int i = col_id * VecSize; i < cols;
124123
i += blockDim.x * gridDim.x * VecSize) {
125124
T src_vec[VecSize];
126125
T residual_vec[VecSize];
@@ -249,17 +248,6 @@ __global__ void FusedResidualDropoutGradVec(const T *dout, const MaskType *mask,
249248
}
250249
}
251250

252-
template <typename U>
253-
static __forceinline__ __device__ U WarpReduceSum(U val) {
254-
unsigned mask = 0u;
255-
CREATE_SHFL_MASK(mask, true);
256-
const int warpSize = 32;
257-
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
258-
val += paddle::platform::CudaShuffleDownSync(mask, val, offset);
259-
}
260-
return val;
261-
}
262-
263251
/**
264252
* blocks(128 * 8)
265253
* 1. calculate the dx and reduce total rows to 128 rows
@@ -285,7 +273,6 @@ __global__ void FusedResidualDropoutBiasGradVec(
285273
T dx_vec[VecSize];
286274
LoadT *out_value = reinterpret_cast<LoadT *>(&out_vec);
287275
MaskLoadT *mask_value = reinterpret_cast<MaskLoadT *>(&mask_vec);
288-
LoadT *dx_value = reinterpret_cast<LoadT *>(&dx_vec);
289276
*out_value = *reinterpret_cast<const LoadT *>(&dout[index]);
290277
*mask_value = *reinterpret_cast<const MaskLoadT *>(&mask[index]);
291278

paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu

Lines changed: 1 addition & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,12 @@ limitations under the License. */
1717
#include <random>
1818
#include <vector>
1919

20-
#include "gtest/gtest.h"
21-
#include "paddle/fluid/framework/op_registry.h"
22-
#include "paddle/fluid/framework/operator.h"
23-
#include "paddle/fluid/framework/program_desc.h"
24-
#include "paddle/fluid/framework/tensor_util.h"
20+
#include "paddle/fluid/operators/fused/fused_dropout_test.h"
2521
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
26-
#include "paddle/fluid/operators/math/math_function.h"
27-
#include "paddle/fluid/string/printf.h"
2822

2923
namespace framework = paddle::framework;
3024
namespace platform = paddle::platform;
3125

32-
USE_OP(dropout);
33-
3426
/**
3527
* @brief the unittest of fused_residual_dropout_bias
3628
* 1. random input data
@@ -39,96 +31,6 @@ USE_OP(dropout);
3931
* 4. compare ther base result and fused result
4032
*/
4133

42-
/**
43-
* @brief call paddle dropout op
44-
*/
45-
template <typename T>
46-
void Dropout(const T *x, const framework::DDim &x_dim, T *out,
47-
std::vector<uint8_t> *mask, const platform::CUDADeviceContext &ctx,
48-
uint64_t seed, float dropout_prob, bool is_upscale_in_train,
49-
bool is_test) {
50-
framework::Scope scope;
51-
auto var_x = scope.Var("X");
52-
auto tensor_x = var_x->GetMutable<framework::LoDTensor>();
53-
tensor_x->Resize(x_dim);
54-
tensor_x->mutable_data<T>(ctx.GetPlace());
55-
cudaMemcpy(tensor_x->data<T>(), x, x_dim[0] * x_dim[1] * sizeof(T),
56-
cudaMemcpyHostToDevice);
57-
58-
auto var_out = scope.Var("Out");
59-
auto tensor_out = var_out->GetMutable<framework::LoDTensor>();
60-
61-
auto var_mask = scope.Var("Mask");
62-
auto tensor_mask = var_mask->GetMutable<framework::LoDTensor>();
63-
64-
framework::AttributeMap attrs;
65-
attrs.insert({"fix_seed", 1});
66-
attrs.insert({"seed", static_cast<int>(seed)});
67-
attrs.insert({"dropout_prob", dropout_prob});
68-
if (is_upscale_in_train) {
69-
attrs.insert({"dropout_implementation", std::string("upscale_in_train")});
70-
}
71-
if (is_test) {
72-
attrs.insert({"is_test", 1});
73-
}
74-
75-
auto op = framework::OpRegistry::CreateOp(
76-
"dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs);
77-
op->Run(scope, ctx.GetPlace());
78-
cudaMemcpy(out, tensor_out->data<T>(), x_dim[0] * x_dim[1] * sizeof(T),
79-
cudaMemcpyDeviceToHost);
80-
if (!is_test) {
81-
cudaMemcpy((*mask).data(), tensor_mask->data<uint8_t>(),
82-
x_dim[0] * x_dim[1] * sizeof(uint8_t), cudaMemcpyDeviceToHost);
83-
}
84-
ctx.Wait();
85-
}
86-
87-
/**
88-
* @brief call paddle dropout_grad op
89-
*/
90-
template <typename T>
91-
void DropoutGrad(T *dx, const framework::DDim &x_dim, const T *dout,
92-
const uint8_t *mask, const platform::CUDADeviceContext &ctx,
93-
float dropout_prob, bool is_upscale_in_train) {
94-
framework::Scope scope;
95-
const size_t n = x_dim[0] * x_dim[1];
96-
auto var_out = scope.Var("DOut");
97-
auto tensor_out = var_out->GetMutable<framework::LoDTensor>();
98-
tensor_out->Resize(x_dim);
99-
tensor_out->mutable_data<T>(ctx.GetPlace());
100-
cudaMemcpy(tensor_out->data<T>(), dout, n * sizeof(T),
101-
cudaMemcpyHostToDevice);
102-
103-
auto var_mask = scope.Var("Mask");
104-
auto tensor_mask = var_mask->GetMutable<framework::LoDTensor>();
105-
tensor_mask->Resize(x_dim);
106-
tensor_mask->mutable_data<uint8_t>(ctx.GetPlace());
107-
cudaMemcpy(tensor_mask->data<uint8_t>(), mask, n * sizeof(uint8_t),
108-
cudaMemcpyHostToDevice);
109-
110-
auto var_dx = scope.Var("DX");
111-
auto tensor_dx = var_dx->GetMutable<framework::LoDTensor>();
112-
113-
framework::AttributeMap attrs;
114-
attrs.insert({"dropout_prob", dropout_prob});
115-
attrs.insert({"is_test", 0});
116-
if (is_upscale_in_train) {
117-
attrs.insert({"dropout_implementation", std::string("upscale_in_train")});
118-
} else {
119-
attrs.insert({"dropout_implementation", std::string("downgrade_in_infer")});
120-
}
121-
122-
auto op = framework::OpRegistry::CreateOp(
123-
"dropout_grad", {{"Out@GRAD", {"DOut"}}, {"Mask", {"Mask"}}},
124-
{{"X@GRAD", {"DX"}}}, attrs);
125-
op->Run(scope, ctx.GetPlace());
126-
127-
cudaMemcpy(dx, tensor_dx->data<T>(), x_dim[0] * x_dim[1] * sizeof(T),
128-
cudaMemcpyDeviceToHost);
129-
ctx.Wait();
130-
}
131-
13234
template <typename T>
13335
struct TestFusedResidualDropoutBias {
13436
uint32_t _rows;

0 commit comments

Comments
 (0)