Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,9 @@ if (WITH_GPU OR WITH_ROCM)
op_library(fused_bn_add_activation_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_bn_add_activation);\n")
endif()
# fused_dropout
# only support CUDA
if(NOT WITH_ROCM)
nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator)
endif()
endif()
70 changes: 70 additions & 0 deletions paddle/fluid/operators/fused/fused_dropout.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fused_dropout.h文件名不合适。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成fused_dropout_common.h了

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <cooperative_groups.h>
#include <cuda.h>
#include <curand_kernel.h>

#include <iostream>
#include <memory>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这2个头文件没有用到吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

/**
* get 1D threads and blocks
*/
template <int VecSize = 4>
inline std::pair<uint32_t, uint32_t> Get1DThreadsAndBlocks(
const platform::CUDADeviceContext &ctx, const uint64_t n) {
const uint64_t tmp_n = n / VecSize;
int threads = std::max(
(uint64_t)32, std::min(tmp_n, (uint64_t)ctx.GetMaxThreadsPerBlock()));
int blocks = std::max((uint64_t)1, (tmp_n + threads - 1) / threads);
return std::pair<uint32_t, uint32_t>{threads, blocks};
}

/**
* get the threads for fused_residual_dropout_bias:
* 1D blocks: blockDim.x = cols
* 2D grids: gridDim.y = rows
*/
template <int VecSize = 4>
inline std::pair<dim3, dim3> Get1DBlocksAnd2DGrids(
const platform::CUDADeviceContext &ctx, const uint32_t rows,
const uint32_t cols) {
const uint32_t tmp_cols = cols / VecSize;
int threads = std::max(
(uint32_t)32, std::min(tmp_cols, (uint32_t)ctx.GetMaxThreadsPerBlock()));
int blocks_x = std::max((uint32_t)1, (tmp_cols + threads - 1) / threads);
int blocks_y = std::max((uint32_t)1, rows);
dim3 block_dim(threads, 1, 1);
dim3 grid_dim(blocks_x, blocks_y, 1);
return std::pair<dim3, dim3>{block_dim, grid_dim};
}

// aligned vector generates vectorized load/store on CUDA
template <typename T, int VecSize>
struct alignas(sizeof(T) * VecSize) AlignedVector {
T val[VecSize];
};

} // namespace operators
} // namespace paddle
121 changes: 121 additions & 0 deletions paddle/fluid/operators/fused/fused_dropout_test.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <random>
#include <vector>

#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/string/printf.h"

namespace framework = paddle::framework;
namespace platform = paddle::platform;

USE_OP(dropout);

/**
* @brief call paddle dropout op
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个test是测试原来的dropout op吗?我看到还有一个fuse_dropout_op的test。所以没太明白这个单测的用意?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件是给几个dropout相关的单测共用的,里面call了下dropout_op,作为对比的base版本。

template <typename T>
void Dropout(const T *x, const framework::DDim &x_dim, T *out,
std::vector<uint8_t> *mask, const platform::CUDADeviceContext &ctx,
uint64_t seed, float dropout_prob, bool is_upscale_in_train,
bool is_test) {
framework::Scope scope;
auto var_x = scope.Var("X");
auto tensor_x = var_x->GetMutable<framework::LoDTensor>();
tensor_x->Resize(x_dim);
tensor_x->mutable_data<T>(ctx.GetPlace());
cudaMemcpy(tensor_x->data<T>(), x, x_dim[0] * x_dim[1] * sizeof(T),
cudaMemcpyHostToDevice);

auto var_out = scope.Var("Out");
auto tensor_out = var_out->GetMutable<framework::LoDTensor>();

auto var_mask = scope.Var("Mask");
auto tensor_mask = var_mask->GetMutable<framework::LoDTensor>();

framework::AttributeMap attrs;
attrs.insert({"fix_seed", 1});
attrs.insert({"seed", static_cast<int>(seed)});
attrs.insert({"dropout_prob", dropout_prob});
if (is_upscale_in_train) {
attrs.insert({"dropout_implementation", std::string("upscale_in_train")});
}
if (is_test) {
attrs.insert({"is_test", 1});
}

auto op = framework::OpRegistry::CreateOp(
"dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs);
op->Run(scope, ctx.GetPlace());
cudaMemcpy(out, tensor_out->data<T>(), x_dim[0] * x_dim[1] * sizeof(T),
cudaMemcpyDeviceToHost);
if (!is_test) {
cudaMemcpy((*mask).data(), tensor_mask->data<uint8_t>(),
x_dim[0] * x_dim[1] * sizeof(uint8_t), cudaMemcpyDeviceToHost);
}
ctx.Wait();
}

/**
* @brief call paddle dropout_grad op
*/
template <typename T>
void DropoutGrad(T *dx, const framework::DDim &x_dim, const T *dout,
const uint8_t *mask, const platform::CUDADeviceContext &ctx,
float dropout_prob, bool is_upscale_in_train) {
framework::Scope scope;
const size_t n = x_dim[0] * x_dim[1];
auto var_out = scope.Var("DOut");
auto tensor_out = var_out->GetMutable<framework::LoDTensor>();
tensor_out->Resize(x_dim);
tensor_out->mutable_data<T>(ctx.GetPlace());
cudaMemcpy(tensor_out->data<T>(), dout, n * sizeof(T),
cudaMemcpyHostToDevice);

auto var_mask = scope.Var("Mask");
auto tensor_mask = var_mask->GetMutable<framework::LoDTensor>();
tensor_mask->Resize(x_dim);
tensor_mask->mutable_data<uint8_t>(ctx.GetPlace());
cudaMemcpy(tensor_mask->data<uint8_t>(), mask, n * sizeof(uint8_t),
cudaMemcpyHostToDevice);

auto var_dx = scope.Var("DX");
auto tensor_dx = var_dx->GetMutable<framework::LoDTensor>();

framework::AttributeMap attrs;
attrs.insert({"dropout_prob", dropout_prob});
attrs.insert({"is_test", 0});
if (is_upscale_in_train) {
attrs.insert({"dropout_implementation", std::string("upscale_in_train")});
} else {
attrs.insert({"dropout_implementation", std::string("downgrade_in_infer")});
}

auto op = framework::OpRegistry::CreateOp(
"dropout_grad", {{"Out@GRAD", {"DOut"}}, {"Mask", {"Mask"}}},
{{"X@GRAD", {"DX"}}}, attrs);
op->Run(scope, ctx.GetPlace());

cudaMemcpy(dx, tensor_dx->data<T>(), x_dim[0] * x_dim[1] * sizeof(T),
cudaMemcpyDeviceToHost);
ctx.Wait();
}
Loading