Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0) / out;
dx.device(d) = static_cast<T>(0) * out;
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; }
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/operators/conv_cudnn_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
dev_ctx);
void* cudnn_workspace_ptr =
static_cast<void*>(cudnn_workspace.data<int8_t>());
VLOG(2) << "Cudnn workspace size fwd: "
<< static_cast<double>(workspace_size_in_bytes) / (1 << 20)
<< " MB";
// ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
for (int i = 0; i < groups; i++) {
Expand Down Expand Up @@ -473,6 +476,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
{static_cast<int64_t>(workspace_size_in_bytes)}),
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
VLOG(2) << "Cudnn workspace size bwd: "
<< static_cast<double>(workspace_size_in_bytes) / (1 << 20)
<< " MB";
}

// ------------------- cudnn conv backward data ---------------------
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/operators/dropout_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
ctx->ShareLoD(framework::GradVarName("Out"),
/*->*/ framework::GradVarName("X"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};

class DropoutGradOpDescMaker : public framework::SingleGradOpDescMaker {
Expand Down
39 changes: 25 additions & 14 deletions paddle/fluid/operators/dropout_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ limitations under the License. */
namespace paddle {
namespace operators {

template <typename T>
template <typename T, typename MaskType>
__global__ void RandomGenerator(const size_t n, const int seed,
const float dropout_prob, const T* src,
T* mask_data, T* dst,
MaskType* mask_data, T* dst,
bool is_upscale_in_train) {
thrust::minstd_rand rng;
rng.seed(seed);
Expand All @@ -34,7 +34,7 @@ __global__ void RandomGenerator(const size_t n, const int seed,
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;

T mask;
MaskType mask;
T dest;
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
Expand All @@ -45,15 +45,16 @@ __global__ void RandomGenerator(const size_t n, const int seed,
rng.discard(step_size);
}
if (dist(rng) < dropout_prob) {
mask = static_cast<T>(0);
mask = 0;
dest = 0;
} else {
mask = 1;
if (is_upscale_in_train) {
mask = static_cast<T>(1.0f / (1.0f - dropout_prob));
dest = s / static_cast<T>(1.0f - dropout_prob);
} else {
mask = static_cast<T>(1);
dest = s;
}
}
dest = s * mask;
mask_data[idx] = mask;
dst[idx] = dest;
}
Expand All @@ -71,30 +72,40 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");

auto dropout_implementation =
auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation");
bool upscale_in_train = (dropout_implementation == "upscale_in_train");

auto& place = *context.template device_context<Place>().eigen_device();
if (!context.Attr<bool>("is_test")) {
int64_t x_numel = x->numel();
auto stream = context.cuda_device_context().stream();

auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
auto* mask_data = mask->mutable_data<uint8_t>(context.GetPlace());
size_t size = framework::product(mask->dims());
auto* x_data = x->data<T>();
auto* y_data = y->mutable_data<T>(context.GetPlace());
if (dropout_prob == 1.0f) {
PADDLE_ENFORCE(cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
PADDLE_ENFORCE(cudaMemsetAsync(mask_data, 0,
x_numel * sizeof(*mask_data), stream));
return;
}

std::random_device rnd;
int seed =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();

int threads = 512;
int grid = (x->numel() + threads - 1) / threads;
RandomGenerator<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
int grid = (x_numel + threads - 1) / threads;
RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed, dropout_prob, x_data, mask_data, y_data,
(dropout_implementation == "upscale_in_train"));
upscale_in_train);
} else {
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
if (dropout_implementation == "upscale_in_train") {
if (upscale_in_train) {
Y.device(place) = X;
} else {
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
Expand Down
39 changes: 30 additions & 9 deletions paddle/fluid/operators/dropout_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

#include <cstring>
#include <random>
#include <string>

Expand All @@ -37,11 +38,20 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y_data = y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");

auto dropout_implementation =
auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation");
bool upscale_in_train = (dropout_implementation == "upscale_in_train");
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
auto* mask_data = mask->mutable_data<uint8_t>(context.GetPlace());
size_t size = framework::product(mask->dims());

// Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) {
std::memset(y_data, 0, size * sizeof(*y_data)); // NOLINT
std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT
return;
}

// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
Expand All @@ -53,17 +63,15 @@ class CPUDropoutKernel : public framework::OpKernel<T> {

std::uniform_real_distribution<float> dist(0, 1);

size_t size = framework::product(mask->dims());
for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) {
mask_data[i] = 0;
y_data[i] = 0;
} else {
if (dropout_implementation == "upscale_in_train") {
mask_data[i] = 1.0f / static_cast<T>(1.0f - dropout_prob);
mask_data[i] = 1;
if (upscale_in_train) {
y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
} else {
mask_data[i] = 1;
y_data[i] = x_data[i];
}
}
Expand All @@ -73,7 +81,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
if (dropout_implementation == "upscale_in_train") {
if (upscale_in_train) {
Y.device(place) = X;
} else {
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
Expand All @@ -94,13 +102,26 @@ class DropoutGradKernel : public framework::OpKernel<T> {
auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace());

auto M = EigenMatrix<T>::Reshape(*mask, 1);
auto M = EigenMatrix<uint8_t>::Reshape(*mask, 1);
auto dX = EigenMatrix<T>::Reshape(*grad_x, 1);
auto dY = EigenMatrix<T>::Reshape(*grad_y, 1);

auto& place =
*context.template device_context<DeviceContext>().eigen_device();
dX.device(place) = dY * M;

auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation");
if (dropout_implementation == "upscale_in_train") {
float dropout_prob = context.Attr<float>("dropout_prob");
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
} else {
dX.device(place) =
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
} else {
dX.device(place) = dY * M.cast<T>();
}
}
};

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,7 +1390,7 @@ def dropout(x,
helper = LayerHelper('dropout', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
mask = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)

if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
seed = helper.main_program.random_seed
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/fluid/tests/unittests/test_dropout_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setUp(self):
self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64)).astype('float32')
'Mask': np.ones((32, 64)).astype('uint8')
}

def test_check_output(self):
Expand All @@ -44,7 +44,7 @@ def setUp(self):
self.attrs = {'dropout_prob': 1.0, 'fix_seed': True, 'is_test': False}
self.outputs = {
'Out': np.zeros((32, 64)).astype('float32'),
'Mask': np.zeros((32, 64)).astype('float32')
'Mask': np.zeros((32, 64)).astype('uint8')
}


Expand All @@ -55,7 +55,7 @@ def setUp(self):
self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64, 2)).astype('float32')
'Mask': np.ones((32, 64, 2)).astype('uint8')
}


Expand Down Expand Up @@ -97,7 +97,7 @@ def setUp(self):
}
self.outputs = {
'Out': np.zeros((32, 64)).astype('float32'),
'Mask': np.zeros((32, 64)).astype('float32')
'Mask': np.zeros((32, 64)).astype('uint8')
}


Expand All @@ -113,7 +113,7 @@ def setUp(self):
}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64, 2)).astype('float32')
'Mask': np.ones((32, 64, 2)).astype('uint8')
}


Expand Down