Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,9 @@ if (WITH_GPU OR WITH_ROCM)
op_library(fused_bn_add_activation_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_bn_add_activation);\n")
endif()
# fused_dropout
# only support CUDA
if(NOT WITH_ROCM)
nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory)
endif()
endif()
99 changes: 99 additions & 0 deletions paddle/fluid/operators/fused/fused_dropout_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <cooperative_groups.h>
#include <cuda.h>
#include <curand_kernel.h>

#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/gpu_launch_config.h"

namespace paddle {
namespace operators {

#define CACHE_LINE 128
#define MAX_CACHE_BYTES (CACHE_LINE / CHAR_BIT)

/**
* get the threads for fused_residual_dropout_bias:
* 1D blocks: blockDim.x = cols
* 2D grids: gridDim.y = rows
*/
inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids(
const platform::CUDADeviceContext &ctx, const uint32_t rows,
const uint32_t cols, const int VecSize) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量名命名用xxx_xxx方式:VecSize -> vec_size

const uint32_t tmp_cols = cols / VecSize;
int threads = std::max(
static_cast<uint32_t>(32),
std::min(tmp_cols, static_cast<uint32_t>(ctx.GetMaxThreadsPerBlock())));
const auto blocks_x =
std::max(static_cast<uint32_t>(1), (tmp_cols + threads - 1) / threads);
const auto blocks_y = std::max(static_cast<uint32_t>(1), rows);
platform::GpuLaunchConfig config;
config.block_per_grid.x = blocks_x;
config.block_per_grid.y = blocks_y;
config.thread_per_block.x = threads;
return config;
}

__forceinline__ __device__ void Rand1(curandStatePhilox4_32_10_t *state,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得写成模板函数、再特化的方式会好一些。

float *data) {
data[0] = curand_uniform(state);
}

__forceinline__ __device__ void Rand2(curandStatePhilox4_32_10_t *state,
float *data) {
data[0] = curand_uniform(state);
data[1] = curand_uniform(state);
}

__forceinline__ __device__ void Rand4(curandStatePhilox4_32_10_t *state,
float *data) {
float4 rand4 = curand_uniform4(state);
data[0] = rand4.x;
data[1] = rand4.y;
data[2] = rand4.w;
data[3] = rand4.z;
}

__forceinline__ __device__ void Rand8(curandStatePhilox4_32_10_t *state,
float *data) {
Rand4(state, data);
Rand4(state, data + 4);
}

__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state,
float *data, const int VecSize) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上。

if (VecSize == 1) {
Rand1(state, data);
} else if (VecSize == 2) {
Rand2(state, data);
} else if (VecSize == 4) {
Rand4(state, data);
} else if (VecSize == 8) {
Rand8(state, data);
} else {
return;
}
}

} // namespace operators
} // namespace paddle
117 changes: 117 additions & 0 deletions paddle/fluid/operators/fused/fused_dropout_test.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <random>
#include <vector>

#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/string/printf.h"

namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace memory = paddle::memory;

USE_OP(dropout);

/**
* @brief call paddle dropout op
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个test是测试原来的dropout op吗?我看到还有一个fuse_dropout_op的test。所以没太明白这个单测的用意?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件是给几个dropout相关的单测共用的,里面call了下dropout_op,作为对比的base版本。

template <typename T>
void Dropout(const std::vector<T> &x, const framework::DDim &x_dim,
std::vector<T> *out, std::vector<uint8_t> *mask,
const platform::CUDADeviceContext &ctx, uint64_t seed,
float dropout_prob, bool is_upscale_in_train, bool is_test) {
framework::Scope scope;
auto var_x = scope.Var("X");
auto tensor_x = var_x->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(x, ctx, tensor_x);
tensor_x->Resize(x_dim);

auto var_out = scope.Var("Out");
auto tensor_out = var_out->GetMutable<framework::LoDTensor>();

auto var_mask = scope.Var("Mask");
auto tensor_mask = var_mask->GetMutable<framework::LoDTensor>();

framework::AttributeMap attrs;
attrs.insert({"fix_seed", 1});
attrs.insert({"seed", static_cast<int>(seed)});
attrs.insert({"dropout_prob", dropout_prob});
if (is_upscale_in_train) {
attrs.insert({"dropout_implementation", std::string("upscale_in_train")});
}

if (is_test) {
attrs.insert({"is_test", true});
}

auto op = framework::OpRegistry::CreateOp(
"dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs);
op->Run(scope, ctx.GetPlace());

framework::TensorToVector<T>(*tensor_out, ctx, out);
if (!is_test) {
framework::TensorToVector<uint8_t>(*tensor_mask, ctx, mask);
}
ctx.Wait();
}

/**
* @brief call paddle dropout_grad op
*/
template <typename T>
void DropoutGrad(std::vector<T> *dx, const framework::DDim &x_dim,
const std::vector<T> &dout, const std::vector<uint8_t> &mask,
const platform::CUDADeviceContext &ctx, float dropout_prob,
bool is_upscale_in_train) {
framework::Scope scope;
const size_t n = x_dim[0] * x_dim[1];
auto var_out = scope.Var("DOut");
auto tensor_out = var_out->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(dout, ctx, tensor_out);
tensor_out->Resize(x_dim);

auto var_mask = scope.Var("Mask");
auto tensor_mask = var_mask->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(mask, ctx, tensor_mask);
tensor_mask->Resize(x_dim);

auto var_dx = scope.Var("DX");
auto tensor_dx = var_dx->GetMutable<framework::LoDTensor>();

framework::AttributeMap attrs;
attrs.insert({"dropout_prob", dropout_prob});
attrs.insert({"is_test", false});
if (is_upscale_in_train) {
attrs.insert({"dropout_implementation", std::string("upscale_in_train")});
} else {
attrs.insert({"dropout_implementation", std::string("downgrade_in_infer")});
}

auto op = framework::OpRegistry::CreateOp(
"dropout_grad", {{"Out@GRAD", {"DOut"}}, {"Mask", {"Mask"}}},
{{"X@GRAD", {"DX"}}}, attrs);
op->Run(scope, ctx.GetPlace());

framework::TensorToVector(*tensor_dx, ctx, dx);
ctx.Wait();
}
Loading