Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 93 additions & 39 deletions paddle/operators/adam_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,59 +13,113 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/framework/eigen.h"
#include <math.h> // for sqrt in CPU and CUDA
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/for_range.h"

namespace paddle {
namespace operators {

template <typename T>
struct AdamFunctor {
T beta1_;
T beta2_;
T epsilon_;

const T* beta1_pow_;
const T* beta2_pow_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
const T* lr_;
const T* grad_;
const T* param_;
T* param_out_;

AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* lr, const T* grad, const T* param,
T* param_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output variable should be the end of the function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That will fit google C++ style.

: beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
moment2_out_(mom2_out),
lr_(lr),
grad_(grad),
param_(param),
param_out_(param_out) {}

inline HOSTDEVICE void operator()(size_t i) const {
// Merge all memory access together.
T g = grad_[i];
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[i];

// Calculation
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));

// Write back to global memory
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
param_out_[i] = p;
}
};

template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment1_out_tensor = ctx.Output<framework::Tensor>("Moment1Out");
auto moment2_out_tensor = ctx.Output<framework::Tensor>("Moment2Out");

param_out_tensor->mutable_data<T>(ctx.GetPlace());
moment1_out_tensor->mutable_data<T>(ctx.GetPlace());
moment2_out_tensor->mutable_data<T>(ctx.GetPlace());
using paddle::framework::LoDTensor;
using paddle::operators::detail::Ref;

T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto& param = Ref(ctx.Input<LoDTensor>("Param"), "Must set Param");
auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
auto& mom1 = Ref(ctx.Input<LoDTensor>("Moment1"), "Must set Moment1");
auto& mom2 = Ref(ctx.Input<LoDTensor>("Moment2"), "Must set Moment2");
auto& lr =
Ref(ctx.Input<LoDTensor>("LearningRate"), "Must set LearningRate");

auto& beta1_pow =
Ref(ctx.Input<LoDTensor>("Beta1Pow"), "Must set Beta1Pow");
auto& beta2_pow =
Ref(ctx.Input<LoDTensor>("Beta2Pow"), "Must set Beta2Pow");

auto& param_out =
Ref(ctx.Output<LoDTensor>("ParamOut"), "Must set ParamOut");
auto& mom1_out =
Ref(ctx.Output<LoDTensor>("Moment1Out"), "Must set Moment1Out");
auto& mom2_out =
Ref(ctx.Output<LoDTensor>("Moment2Out"), "Must set Moment1Out");

auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
auto grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Grad"));
auto moment1 = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment1"));
auto moment2 = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment2"));
auto lr = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("LearningRate"));
auto beta1_pow = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Beta1Pow"));
auto beta2_pow = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Beta2Pow"));
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment1_out = framework::EigenVector<T>::Flatten(*moment1_out_tensor);
auto moment2_out = framework::EigenVector<T>::Flatten(*moment2_out_tensor);
auto* place = ctx.template device_context<DeviceContext>().eigen_device();

moment1_out.device(*place) = beta1 * moment1 + (1 - beta1) * grad;
moment2_out.device(*place) = beta2 * moment2 + (1 - beta2) * grad.square();

// All of these are tensors of 1 element
auto lr_t = lr * (1 - beta2_pow).sqrt() / (1 - beta1_pow);
// Eigen does not support automatic broadcast
// Get dimensions of moment vector to broadcast lr_t
Eigen::DSizes<int, 1> m_dsize(moment1_out_tensor->numel());
param_out.device(*place) =
param -
lr_t.broadcast(m_dsize) *
(moment1_out / (moment2_out.sqrt() + epsilon));
AdamFunctor<T> functor(beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(),
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad.template data<T>(),
param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()));
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()), param.numel());
for_range(functor);
}
};

Expand Down
85 changes: 85 additions & 0 deletions paddle/platform/for_range.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/platform/device_context.h"

namespace paddle {
namespace platform {

template <typename DeviceContext>
struct ForRange {
ForRange(const DeviceContext& dev_ctx, size_t limit);

template <typename Function>
void operator()(Function func) const;
};

template <>
struct ForRange<CPUDeviceContext> {
ForRange(const CPUDeviceContext& dev_ctx, size_t limit) : limit_(limit) {}

template <typename Function>
void operator()(Function func) const {
for (size_t i = 0; i < limit_; ++i) {
func(i);
}
}

size_t limit_;
};

#ifdef __NVCC__
template <typename Function>
__global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
size_t idx = static_cast<size_t>(threadIdx.x);
func(idx);
}

template <typename Function>
__global__ static void ForRangeElemwiseOp(Function func, int limit) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (idx < limit) {
func(idx);
}
}

template <>
struct ForRange<CUDADeviceContext> {
ForRange(const CUDADeviceContext& dev_ctx, size_t limit)
: dev_ctx_(dev_ctx), limit_(static_cast<int>(limit)) {}

template <typename Function>
inline void operator()(Function func) const {
constexpr size_t num_threads = 1024;
int block_size = limit_ <= num_threads ? limit_ : num_threads;
int grid_size = (limit_ + num_threads - 1) / num_threads;

if (grid_size == 1) {
ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>(
func);
} else {
ForRangeElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
func, limit_);
}
}

const CUDADeviceContext& dev_ctx_;
int limit_;
};

#endif

} // namespace platform
} // namespace paddle