Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/operators/optimizers/adam_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ class AdamWOpMaker : public AdamOpMaker {
public:
void Make() {
AdamOpMaker::Make();
AddAttr<float>("lr_ratio",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add this argument to adam? is that adamw and adam share the same .cc file ?

in this case, adamw should have its own .cc file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AdamWOpMaker inherits AdamOpMaker, and they use the same InferShape function of AdamOp.
In this case, 'lr_ratio' has no effect on Adam.

"(float, default 1.0) "
"layerwise learning rate decay")
.SetDefault(1.0f);
AddAttr<float>("coeff",
"(float, default 0.01) "
"coeff of the weight decay")
Expand Down
68 changes: 34 additions & 34 deletions paddle/fluid/operators/optimizers/adamw_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ namespace operators {

template <typename T, typename MT>
__global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff,
MT beta1_pow_, MT beta2_pow_, const MT* moment1,
MT* moment1_out, const MT* moment2,
MT* moment2_out, const MT* lr_, const T* grad,
const T* param, T* param_out,
const MT* master_param, MT* master_param_out,
int ndim) {
MT lr = *lr_;
MT lr_ratio, MT beta1_pow_, MT beta2_pow_,
const MT* moment1, MT* moment1_out,
const MT* moment2, MT* moment2_out,
const MT* lr_, const T* grad, const T* param,
T* param_out, const MT* master_param,
MT* master_param_out, int ndim) {
MT lr = *lr_ * lr_ratio;
MT lr_orig = lr;
MT beta1_pow = beta1_pow_;
MT beta2_pow = beta2_pow_;

MT wd = static_cast<MT>(1.0) - coeff * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);

Expand All @@ -43,9 +43,9 @@ __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff,
MT mom2 = moment2[id];
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p = wd * p -
lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
p -= lr_orig * coeff * p;
p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));

moment1_out[id] = mom1;
moment2_out[id] = mom2;
Expand All @@ -57,18 +57,16 @@ __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff,
}

template <typename T, typename MT>
__global__ void AdamWKernelMEM(MT beta1, MT beta2, MT epsilon, MT coeff,
const MT* beta1_pow_, const MT* beta2_pow_,
const MT* moment1, MT* moment1_out,
const MT* moment2, MT* moment2_out,
const MT* lr_, const T* grad, const T* param,
T* param_out, const MT* master_param,
MT* master_param_out, int ndim) {
MT lr = *lr_;
__global__ void AdamWKernelMEM(
MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, const MT* beta1_pow_,
const MT* beta2_pow_, const MT* moment1, MT* moment1_out, const MT* moment2,
MT* moment2_out, const MT* lr_, const T* grad, const T* param, T* param_out,
const MT* master_param, MT* master_param_out, int ndim) {
MT lr = *lr_ * lr_ratio;
MT lr_orig = lr;
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;

MT wd = static_cast<MT>(1.0) - coeff * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);

Expand All @@ -81,9 +79,9 @@ __global__ void AdamWKernelMEM(MT beta1, MT beta2, MT epsilon, MT coeff,
MT mom2 = static_cast<MT>(moment2[id]);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p = wd * p -
lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
p -= lr_orig * coeff * p;
p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));

moment1_out[id] = mom1;
moment2_out[id] = mom2;
Expand All @@ -103,16 +101,16 @@ __global__ void UpdateAdamWBetaPow(T beta1, T beta2, const T* beta1_pow_,

template <typename T, typename MT>
__global__ void SparseAdamWCUDAKernelREG(
MT beta1, MT beta2, MT epsilon, MT coeff, const MT beta1_pow,
MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, const MT beta1_pow,
const MT beta2_pow, const MT* mom1_, MT* mom1_out_, const MT* mom2_,
MT* mom2_out_, const MT* lr_, const T* grad_, const T* param_,
T* param_out_, const MT* master_param, MT* master_param_out,
const int64_t* rows_, int64_t row_numel, int64_t row_count, bool lazy_mode,
int ndim) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
MT lr = *lr_;
MT lr = *lr_ * lr_ratio;
MT lr_orig = lr;

MT wd = static_cast<MT>(1.0) - coeff * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);

Expand All @@ -130,9 +128,9 @@ __global__ void SparseAdamWCUDAKernelREG(
: static_cast<MT>(0);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p = wd * p -
lr * (mom1 / (sqrt(mom2) +
epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
p -= lr_orig * coeff * p;
p -= lr * (mom1 / (sqrt(mom2) +
epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));

// Write back to global memory
mom1_out_[id] = mom1;
Expand Down Expand Up @@ -165,7 +163,9 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {
bool lazy_mode = ctx.Attr<bool>("lazy_mode");
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
float coeff = ctx.Attr<float>("coeff");

MPDType coeff = static_cast<MPDType>(ctx.Attr<float>("coeff"));
MPDType lr_ratio = static_cast<MPDType>(ctx.Attr<float>("lr_ratio"));

auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad");
Expand Down Expand Up @@ -301,7 +301,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {
beta2_pow->place() == platform::CPUPlace()) {
// Compute with betapow in REG
AdamWKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, coeff, *beta1_pow->data<MPDType>(),
beta1, beta2, epsilon, coeff, lr_ratio, *beta1_pow->data<MPDType>(),
*beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
Expand All @@ -318,7 +318,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {
}
} else {
AdamWKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, coeff, beta1_pow->data<MPDType>(),
beta1, beta2, epsilon, coeff, lr_ratio, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
Expand Down Expand Up @@ -377,7 +377,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {

SparseAdamWCUDAKernelREG<
T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, coeff, *beta1_pow->data<MPDType>(),
beta1, beta2, epsilon, coeff, lr_ratio, *beta1_pow->data<MPDType>(),
*beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
Expand All @@ -395,7 +395,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {
}
} else {
SparseAdamWFunctor<T, GPUAdamW, MPDType> functor(
beta1, beta2, epsilon, coeff, beta1_pow->data<MPDType>(),
beta1, beta2, epsilon, coeff, lr_ratio, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
Expand Down
25 changes: 15 additions & 10 deletions paddle/fluid/operators/optimizers/adamw_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ template <typename T>
class AdamWFunctor<T, CPUAdamW> {
private:
const T coeff_;
const T lr_ratio_;
const T* lr_;
T* param_;

public:
AdamWFunctor(const T coeff, const T* lr, T* param)
: coeff_(coeff), lr_(lr), param_(param) {}
AdamWFunctor(const T coeff, const T lr_ratio, const T* lr, T* param)
: coeff_(coeff), lr_ratio_(lr_ratio), lr_(lr), param_(param) {}

inline HOSTDEVICE void operator()(size_t numel) const {
Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> param{
Expand All @@ -46,7 +47,7 @@ class AdamWFunctor<T, CPUAdamW> {
T lr = *lr_;

// Calculation
param = param * (1 - lr * coeff_);
param -= lr * lr_ratio_ * coeff_ * param;
}
};

Expand All @@ -60,6 +61,7 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> {
MT beta2_;
MT epsilon_;
MT coeff_;
MT lr_ratio_;

const MT* beta1_pow_;
const MT* beta2_pow_;
Expand All @@ -80,7 +82,7 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> {
bool lazy_mode_;

public:
SparseAdamWFunctor(MT beta1, MT beta2, MT epsilon, MT coeff,
SparseAdamWFunctor(MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio,
const MT* beta1_pow, const MT* beta2_pow, const MT* mom1,
MT* mom1_out, const MT* mom2, MT* mom2_out, const MT* lr,
const T* grad, const T* param, T* param_out,
Expand All @@ -91,6 +93,7 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> {
beta2_(beta2),
epsilon_(epsilon),
coeff_(coeff),
lr_ratio_(lr_ratio),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
moment1_(mom1),
Expand All @@ -112,21 +115,21 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> {
// The following code is the same as dense
MT mom1 = moment1_[i];
MT mom2 = moment2_[i];
MT lr = *lr_;
MT lr = *lr_ * lr_ratio_;
MT lr_orig = lr;
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;
MT p = master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);

// Calculation
MT wd = static_cast<MT>(1.0) - coeff_ * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);

mom1 = beta1_ * mom1 + (static_cast<MT>(1.0) - beta1_) * g;
mom2 = beta2_ * mom2 + (static_cast<MT>(1.0) - beta2_) * g * g;
p = wd * p -
lr * (mom1 /
(sqrt(mom2) + epsilon_ * sqrt(static_cast<MT>(1.0) - beta2_pow)));
p -= lr_orig * coeff_ * p;
p -= lr * (mom1 / (sqrt(mom2) +
epsilon_ * sqrt(static_cast<MT>(1.0) - beta2_pow)));

// Write back to global memory
moment1_out_[i] = mom1;
Expand Down Expand Up @@ -187,6 +190,7 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> {
}

T coeff = static_cast<T>(ctx.Attr<float>("coeff"));
T lr_ratio = static_cast<T>(ctx.Attr<float>("lr_ratio"));
auto* lr = ctx.Input<LoDTensor>("LearningRate");

LoDTensor* param;
Expand All @@ -198,7 +202,8 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> {
param = const_cast<LoDTensor*>(ctx.Input<LoDTensor>("Param"));
}

AdamWFunctor<T, CPUAdamW> functor(coeff, lr->data<T>(), param->data<T>());
AdamWFunctor<T, CPUAdamW> functor(coeff, lr_ratio, lr->data<T>(),
param->data<T>());
functor(param->numel());

AdamOpKernel<DeviceContext, T>::Compute(ctx);
Expand Down
87 changes: 87 additions & 0 deletions python/paddle/fluid/tests/unittests/test_adamw_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import paddle
import numpy as np
import paddle.fluid as fluid
from functools import partial


class TestAdamWOp(unittest.TestCase):
Expand Down Expand Up @@ -148,5 +149,91 @@ def test_adamw_op_dygraph(self):
adam.clear_gradients()


def simple_lr_setting(param, decay_rate, n_layers):
if "fc_0" in param.name or "linear_1" in param.name:
depth = int(param.name.split("_")[2]) + 1
elif "fc_1" in param.name or "linear_2" in param.name:
depth = int(param.name.split("_")[2]) + 2
else:
depth = 0

return decay_rate**(n_layers + 2 - depth)


class TestAdamWOpLayerwiseLR(TestAdamWOp):
def test_adamw_op_dygraph(self):
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear1 = paddle.nn.Linear(13, 8)
linear2 = paddle.nn.Linear(8, 5)

simple_lr_fun = partial(simple_lr_setting, decay_rate=0.8, n_layers=2)

adam = paddle.optimizer.AdamW(
learning_rate=0.01,
parameters=[{
'params': linear1.parameters()
}, {
'params': linear2.parameters(),
}],
apply_decay_param_fun=lambda name: True,
weight_decay=0.01,
lr_ratio=simple_lr_fun)

for _ in range(2):
a1 = linear1(a)
out = linear2(a1)
out.backward()
adam.step()
adam.clear_gradients()

def test_adamw_op(self):
paddle.enable_static()
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
train_prog = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(train_prog, startup):
with fluid.unique_name.guard():
x = fluid.data(name='x', shape=[None, 10], dtype='float32')
y = fluid.data(name='y', shape=[None, 1], dtype='float32')

fc1 = fluid.layers.fc(input=x, size=32, act=None)
prediction = fluid.layers.fc(input=fc1, size=1, act=None)
cost = fluid.layers.square_error_cost(input=prediction, label=y)
avg_cost = fluid.layers.mean(cost)

simple_lr_fun = partial(
simple_lr_setting, decay_rate=0.8, n_layers=2)

beta1 = fluid.layers.create_global_var(
shape=[1], value=0.85, dtype='float32', persistable=True)
beta2 = fluid.layers.create_global_var(
shape=[1], value=0.95, dtype='float32', persistable=True)
betas = [beta1, beta2]
opt = paddle.optimizer.AdamW(
learning_rate=1e-5,
beta1=beta1,
beta2=beta2,
weight_decay=0.01,
epsilon=1e-8,
lr_ratio=simple_lr_fun)
opt.minimize(avg_cost)

exe = fluid.Executor(place)
exe.run(startup)
for _ in range(2):
inputs = np.random.random(size=[8, 10]).astype('float32')
outputs = np.random.random(size=[8, 1]).astype('float32')
rets = exe.run(train_prog,
feed={"x": inputs,
"y": outputs},
fetch_list=[avg_cost])
assert rets[0] is not None

paddle.disable_static()


if __name__ == "__main__":
unittest.main()
Loading