-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add new inplace api Tensor.uniform_, test=develop #33934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,181 @@ | ||
| /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/fluid/framework/generator.h" | ||
| #include "paddle/fluid/framework/op_registry.h" | ||
| #include "paddle/fluid/framework/operator.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| class UniformRandomInplaceOpMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| void Make() override { | ||
| AddComment(R"DOC( | ||
| This operator fills self tensor with random values sampled from a | ||
| uniform distribution. The random result is in a range of [min, max). | ||
| )DOC"); | ||
| AddInput("X", "The input tensor."); | ||
| AddOutput("Out", "The output tensor of uniform random op"); | ||
| AddAttr<float>("min", "Minimum value of uniform random. [default -1.0].") | ||
| .SetDefault(-1.0f); | ||
| AddAttr<float>("max", "Maximun value of uniform random. [default 1.0].") | ||
| .SetDefault(1.0f); | ||
| AddAttr<int>("seed", | ||
| "Random seed used for generating samples. " | ||
| "If seed is 0, it will use the seed of the global default " | ||
| "generator (which can be set by paddle.seed). " | ||
| "Note that if seed is not 0, this operator will always " | ||
| "generate the same random numbers every time. [default 0].") | ||
| .SetDefault(0); | ||
| AddAttr<int>("diag_num", | ||
| "The number of diag elements. Note that if " | ||
| "diag_num is 0, it means without diag init.[default 0].") | ||
| .SetDefault(0); | ||
| AddAttr<int>("diag_step", "The step between two diag element.[default 0].") | ||
| .SetDefault(0); | ||
| AddAttr<float>("diag_val", "The value of diag element. [default 1.0].") | ||
| .SetDefault(1.0f); | ||
| } | ||
| }; | ||
|
|
||
| class UniformRandomInplaceOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| void InferShape(framework::InferShapeContext *ctx) const override { | ||
| OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "UniformRandomInplaceOp"); | ||
| OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", | ||
| "UniformRandomInplaceOp"); | ||
| PADDLE_ENFORCE_LT( | ||
| ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max"), | ||
| platform::errors::InvalidArgument( | ||
| "The uniform_random's min must less then max. But received min = " | ||
| "%f great than or equal max = %f.", | ||
| ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max"))); | ||
| PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_num"), 0, | ||
| platform::errors::InvalidArgument( | ||
| "The uniform_random's diag_num must greater than or " | ||
| "equal 0. But recevied diag_num (%d) < 0.", | ||
| ctx->Attrs().Get<int>("diag_num"))); | ||
| PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_step"), 0, | ||
| platform::errors::InvalidArgument( | ||
| "The uniform_random's diag_step must greater than or " | ||
| "equal 0. But recevied diag_step (%d) < 0.", | ||
| ctx->Attrs().Get<int>("diag_step"))); | ||
| auto xdim = ctx->GetInputDim("X"); | ||
| ctx->SetOutputDim("Out", xdim); | ||
| } | ||
|
|
||
| protected: | ||
| framework::OpKernelType GetExpectedKernelType( | ||
| const framework::ExecutionContext &ctx) const override { | ||
| return framework::OpKernelType( | ||
| OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| class CPUUniformRandomInplaceKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext &ctx) const override { | ||
| auto out_var = ctx.OutputVar("Out"); | ||
| auto *tensor = out_var->GetMutable<framework::LoDTensor>(); | ||
| T *data = tensor->mutable_data<T>(ctx.GetPlace()); | ||
| int64_t size = tensor->numel(); | ||
| std::uniform_real_distribution<T> dist( | ||
| static_cast<T>(ctx.Attr<float>("min")), | ||
| static_cast<T>(ctx.Attr<float>("max"))); | ||
| auto engine = paddle::framework::GetCPURandomEngine( | ||
| static_cast<unsigned int>(ctx.Attr<int>("seed"))); | ||
| for (int64_t i = 0; i < size; ++i) { | ||
| data[i] = dist(*engine); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| class UniformRandomInplaceOpVarTypeInference | ||
| : public framework::VarTypeInference { | ||
| public: | ||
| void operator()(framework::InferVarTypeContext *ctx) const override {} | ||
| }; | ||
|
|
||
| class UniformRandomInplaceGradOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| void InferShape(framework::InferShapeContext *ctx) const override { | ||
| OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", | ||
| "Out_Grad", "UniformRandomInplaceGradOp"); | ||
| auto x_dims = ctx->GetInputDim(framework::GradVarName("Out")); | ||
| auto x_grad_name = framework::GradVarName("X"); | ||
| if (ctx->HasOutput(x_grad_name)) { | ||
| ctx->SetOutputDim(x_grad_name, x_dims); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| class UniformRandomInplaceGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
| public: | ||
| using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
|
||
| protected: | ||
| void Apply(GradOpPtr<T> retv) const override { | ||
| retv->SetType(this->ForwardOpType() + "_grad"); | ||
| retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
| retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); | ||
| retv->SetAttrMap(this->Attrs()); | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| class CPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const paddle::framework::ExecutionContext &ctx) const override { | ||
| auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X")); | ||
| if (dx) { | ||
| auto *data = dx->mutable_data<T>(ctx.GetPlace()); | ||
| std::fill(data, data + dx->numel(), T(0)); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
| DECLARE_INPLACE_OP_INFERER(UniformRandomInplaceInferer, {"X", "Out"}); | ||
| DECLARE_INPLACE_OP_INFERER(UniformRandomInplaceGradInplaceInferer, | ||
| {paddle::framework::GradVarName("Out"), | ||
| paddle::framework::GradVarName("X")}); | ||
|
|
||
| REGISTER_OPERATOR(uniform_random_inplace, | ||
| paddle::operators::UniformRandomInplaceOp, | ||
| paddle::operators::UniformRandomInplaceOpMaker, | ||
| paddle::operators::UniformRandomInplaceGradOpMaker< | ||
| paddle::framework::OpDesc>, | ||
| paddle::operators::UniformRandomInplaceGradOpMaker< | ||
| paddle::imperative::OpBase>, | ||
| paddle::operators::UniformRandomInplaceOpVarTypeInference, | ||
| UniformRandomInplaceInferer); | ||
| REGISTER_OPERATOR(uniform_random_inplace_grad, | ||
| paddle::operators::UniformRandomInplaceGradOp, | ||
| UniformRandomInplaceGradInplaceInferer); | ||
| REGISTER_OP_CPU_KERNEL( | ||
| uniform_random_inplace, | ||
| paddle::operators::CPUUniformRandomInplaceKernel<float>, | ||
| paddle::operators::CPUUniformRandomInplaceKernel<double>); | ||
| REGISTER_OP_CPU_KERNEL( | ||
| uniform_random_inplace_grad, | ||
| paddle::operators::CPUUniformRandomInplaceGradKernel<float>, | ||
| paddle::operators::CPUUniformRandomInplaceGradKernel<double>); | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,171 @@ | ||
| /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include <thrust/device_vector.h> | ||
| #include <thrust/host_vector.h> | ||
| #include <thrust/random.h> | ||
| #include <thrust/transform.h> | ||
| #include "paddle/fluid/framework/generator.h" | ||
| #include "paddle/fluid/framework/op_registry.h" | ||
| #include "paddle/fluid/framework/operator.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| template <typename T> | ||
| struct UniformGenerator { | ||
| T min_, max_; | ||
| unsigned int seed_; | ||
| T diag_val_; | ||
| unsigned int diag_num_; | ||
| unsigned int diag_step_; | ||
| __host__ __device__ UniformGenerator(T min, T max, int seed, int diag_num, | ||
| int diag_step, T diag_val) | ||
| : min_(min), | ||
| max_(max), | ||
| seed_(seed), | ||
| diag_num_(diag_num), | ||
| diag_step_(diag_step), | ||
| diag_val_(diag_val) {} | ||
|
|
||
| __host__ __device__ T operator()(const unsigned int n) const { | ||
| thrust::minstd_rand rng; | ||
| rng.seed(seed_); | ||
| thrust::uniform_real_distribution<T> dist(min_, max_); | ||
| rng.discard(n); | ||
| T out = dist(rng); | ||
| unsigned int remainder = n % (diag_step_ + 1); | ||
| if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { | ||
| out = diag_val_; | ||
| } | ||
| return out; | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| struct UniformGeneratorOffset { | ||
| T min_, max_; | ||
| unsigned int seed_; | ||
| T diag_val_; | ||
| unsigned int diag_num_; | ||
| unsigned int diag_step_; | ||
| int offset_; | ||
| __host__ __device__ UniformGeneratorOffset(T min, T max, int seed, | ||
| int diag_num, int diag_step, | ||
| T diag_val, int offset) | ||
| : min_(min), | ||
| max_(max), | ||
| seed_(seed), | ||
| diag_num_(diag_num), | ||
| diag_step_(diag_step), | ||
| diag_val_(diag_val), | ||
| offset_(offset) {} | ||
|
|
||
| __host__ __device__ T operator()(const unsigned int n) const { | ||
| thrust::minstd_rand rng; | ||
| rng.seed(seed_); | ||
| thrust::uniform_real_distribution<T> dist(min_, max_); | ||
| rng.discard(n + offset_); | ||
| T out = dist(rng); | ||
| unsigned int remainder = n % (diag_step_ + 1); | ||
| if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) { | ||
| out = diag_val_; | ||
| } | ||
| return out; | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| __global__ void fill_value(int64_t size, T* data, float value) { | ||
| for (int idx = threadIdx.x; idx < size; idx += blockDim.x) { | ||
| data[idx] = static_cast<T>(value); | ||
| } | ||
| } | ||
|
|
||
| // It seems that Eigen::Tensor::random in GPU will SEGFAULT. | ||
| // Use std::random and thrust::random(thrust is a std library in CUDA) to | ||
| // implement uniform random as uniform_random_op.cu. | ||
| template <typename T> | ||
| class GPUUniformRandomInplaceKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& ctx) const override { | ||
| auto out_var = ctx.OutputVar("Out"); | ||
| auto* tensor = out_var->GetMutable<framework::LoDTensor>(); | ||
| T* data = tensor->mutable_data<T>(ctx.GetPlace()); | ||
| unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed")); | ||
| bool seed_flag = false; | ||
| if (seed == 0) { | ||
| std::random_device rd; | ||
| seed = rd(); | ||
| seed_flag = true; | ||
| } | ||
|
|
||
| T min = static_cast<T>(ctx.Attr<float>("min")); | ||
| T max = static_cast<T>(ctx.Attr<float>("max")); | ||
| unsigned int diag_num = | ||
| static_cast<unsigned int>(ctx.Attr<int>("diag_num")); | ||
| unsigned int diag_step = | ||
| static_cast<unsigned int>(ctx.Attr<int>("diag_step")); | ||
| T diag_val = static_cast<T>(ctx.Attr<float>("diag_val")); | ||
| thrust::counting_iterator<unsigned int> index_sequence_begin(0); | ||
| int64_t size = tensor->numel(); | ||
| int device_id = | ||
| BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId(); | ||
| auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); | ||
| if (gen_cuda->GetIsInitPy() && seed_flag) { | ||
| auto seed_offset = gen_cuda->IncrementOffset(1); | ||
| int gen_offset = size * seed_offset.second; | ||
| thrust::transform( | ||
| index_sequence_begin, index_sequence_begin + size, | ||
| thrust::device_ptr<T>(data), | ||
| UniformGeneratorOffset<T>(min, max, seed_offset.first, diag_num, | ||
| diag_step, diag_val, gen_offset)); | ||
| } else { | ||
| thrust::transform( | ||
| index_sequence_begin, index_sequence_begin + size, | ||
| thrust::device_ptr<T>(data), | ||
| UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val)); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| template <typename T> | ||
| class GPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& ctx) const override { | ||
| #ifdef __HIPCC__ | ||
| const int64_t kMaxBlockDim = 256; | ||
| #else | ||
| const int64_t kMaxBlockDim = 512; | ||
| #endif | ||
| auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X")); | ||
| auto* data = dx->mutable_data<T>(ctx.GetPlace()); | ||
|
|
||
| auto size = dx->numel(); | ||
| int64_t kBlockDim = std::min(size, kMaxBlockDim); | ||
| fill_value<T><<<1, kBlockDim, 0>>>(size, data, static_cast<float>(0)); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| REGISTER_OP_CUDA_KERNEL( | ||
| uniform_random_inplace, | ||
| paddle::operators::GPUUniformRandomInplaceKernel<float>, | ||
| paddle::operators::GPUUniformRandomInplaceKernel<double>); | ||
| REGISTER_OP_CUDA_KERNEL( | ||
| uniform_random_inplace_grad, | ||
| paddle::operators::GPUUniformRandomInplaceGradKernel<float>, | ||
| paddle::operators::GPUUniformRandomInplaceGradKernel<double>); |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please refer to other inplace operator's implementation, under normal situation, it can be shared with files of non-inplace, so no need to add uniform_random_inplace_op.cc, uniform_random_inplace_op.cu and uniform_random_inplace_op_xpu.cc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the parameters between paddle.uniform and paddle.Tensor.uniform_ are different, non-inplace version Uniform Kernel can not be directly reused.