Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions paddle/fluid/operators/uniform_random_inplace_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace operators {

class UniformRandomInplaceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddComment(R"DOC(
This operator fills self tensor with random values sampled from a
uniform distribution. The random result is in a range of [min, max).
)DOC");
AddInput("X", "The input tensor.");
AddOutput("Out", "The output tensor of uniform random op");
AddAttr<float>("min", "Minimum value of uniform random. [default -1.0].")
.SetDefault(-1.0f);
AddAttr<float>("max", "Maximun value of uniform random. [default 1.0].")
.SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed used for generating samples. "
"If seed is 0, it will use the seed of the global default "
"generator (which can be set by paddle.seed). "
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time. [default 0].")
.SetDefault(0);
AddAttr<int>("diag_num",
"The number of diag elements. Note that if "
"diag_num is 0, it means without diag init.[default 0].")
.SetDefault(0);
AddAttr<int>("diag_step", "The step between two diag element.[default 0].")
.SetDefault(0);
AddAttr<float>("diag_val", "The value of diag element. [default 1.0].")
.SetDefault(1.0f);
}
};

class UniformRandomInplaceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "UniformRandomInplaceOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"UniformRandomInplaceOp");
PADDLE_ENFORCE_LT(
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max"),
platform::errors::InvalidArgument(
"The uniform_random's min must less then max. But received min = "
"%f great than or equal max = %f.",
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max")));
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_num"), 0,
platform::errors::InvalidArgument(
"The uniform_random's diag_num must greater than or "
"equal 0. But recevied diag_num (%d) < 0.",
ctx->Attrs().Get<int>("diag_num")));
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_step"), 0,
platform::errors::InvalidArgument(
"The uniform_random's diag_step must greater than or "
"equal 0. But recevied diag_step (%d) < 0.",
ctx->Attrs().Get<int>("diag_step")));
auto xdim = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", xdim);
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};

template <typename T>
class CPUUniformRandomInplaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto out_var = ctx.OutputVar("Out");
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
T *data = tensor->mutable_data<T>(ctx.GetPlace());
int64_t size = tensor->numel();
std::uniform_real_distribution<T> dist(
static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max")));
auto engine = paddle::framework::GetCPURandomEngine(
static_cast<unsigned int>(ctx.Attr<int>("seed")));
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
}
};

class UniformRandomInplaceOpVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {}
};

class UniformRandomInplaceGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out_Grad", "UniformRandomInplaceGradOp");
auto x_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};

template <typename T>
class UniformRandomInplaceGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType(this->ForwardOpType() + "_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};

template <typename T>
class CPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
if (dx) {
auto *data = dx->mutable_data<T>(ctx.GetPlace());
std::fill(data, data + dx->numel(), T(0));
}
}
};

} // namespace operators
} // namespace paddle
DECLARE_INPLACE_OP_INFERER(UniformRandomInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(UniformRandomInplaceGradInplaceInferer,
{paddle::framework::GradVarName("Out"),
paddle::framework::GradVarName("X")});

REGISTER_OPERATOR(uniform_random_inplace,
paddle::operators::UniformRandomInplaceOp,
paddle::operators::UniformRandomInplaceOpMaker,
paddle::operators::UniformRandomInplaceGradOpMaker<
paddle::framework::OpDesc>,
paddle::operators::UniformRandomInplaceGradOpMaker<
paddle::imperative::OpBase>,
paddle::operators::UniformRandomInplaceOpVarTypeInference,
UniformRandomInplaceInferer);
REGISTER_OPERATOR(uniform_random_inplace_grad,
paddle::operators::UniformRandomInplaceGradOp,
UniformRandomInplaceGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
uniform_random_inplace,
paddle::operators::CPUUniformRandomInplaceKernel<float>,
paddle::operators::CPUUniformRandomInplaceKernel<double>);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to other inplace operator's implementation, under normal situation, it can be shared with files of non-inplace, so no need to add uniform_random_inplace_op.cc, uniform_random_inplace_op.cu and uniform_random_inplace_op_xpu.cc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the parameters between paddle.uniform and paddle.Tensor.uniform_ are different, non-inplace version Uniform Kernel can not be directly reused.

REGISTER_OP_CPU_KERNEL(
uniform_random_inplace_grad,
paddle::operators::CPUUniformRandomInplaceGradKernel<float>,
paddle::operators::CPUUniformRandomInplaceGradKernel<double>);
171 changes: 171 additions & 0 deletions paddle/fluid/operators/uniform_random_inplace_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace operators {

template <typename T>
struct UniformGenerator {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
__host__ __device__ UniformGenerator(T min, T max, int seed, int diag_num,
int diag_step, T diag_val)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val) {}

__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};

template <typename T>
struct UniformGeneratorOffset {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
int offset_;
__host__ __device__ UniformGeneratorOffset(T min, T max, int seed,
int diag_num, int diag_step,
T diag_val, int offset)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val),
offset_(offset) {}

__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n + offset_);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};

template <typename T>
__global__ void fill_value(int64_t size, T* data, float value) {
for (int idx = threadIdx.x; idx < size; idx += blockDim.x) {
data[idx] = static_cast<T>(value);
}
}

// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random as uniform_random_op.cu.
template <typename T>
class GPUUniformRandomInplaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto out_var = ctx.OutputVar("Out");
auto* tensor = out_var->GetMutable<framework::LoDTensor>();
T* data = tensor->mutable_data<T>(ctx.GetPlace());
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}

T min = static_cast<T>(ctx.Attr<float>("min"));
T max = static_cast<T>(ctx.Attr<float>("max"));
unsigned int diag_num =
static_cast<unsigned int>(ctx.Attr<int>("diag_num"));
unsigned int diag_step =
static_cast<unsigned int>(ctx.Attr<int>("diag_step"));
T diag_val = static_cast<T>(ctx.Attr<float>("diag_val"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = tensor->numel();
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int gen_offset = size * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGeneratorOffset<T>(min, max, seed_offset.first, diag_num,
diag_step, diag_val, gen_offset));
} else {
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val));
}
}
};

template <typename T>
class GPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#ifdef __HIPCC__
const int64_t kMaxBlockDim = 256;
#else
const int64_t kMaxBlockDim = 512;
#endif
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* data = dx->mutable_data<T>(ctx.GetPlace());

auto size = dx->numel();
int64_t kBlockDim = std::min(size, kMaxBlockDim);
fill_value<T><<<1, kBlockDim, 0>>>(size, data, static_cast<float>(0));
}
};

} // namespace operators
} // namespace paddle

REGISTER_OP_CUDA_KERNEL(
uniform_random_inplace,
paddle::operators::GPUUniformRandomInplaceKernel<float>,
paddle::operators::GPUUniformRandomInplaceKernel<double>);
REGISTER_OP_CUDA_KERNEL(
uniform_random_inplace_grad,
paddle::operators::GPUUniformRandomInplaceGradKernel<float>,
paddle::operators::GPUUniformRandomInplaceGradKernel<double>);
Loading