Skip to content

Commit 4d805e6

Browse files
authored
multi pricison for lars op and lars optimizer (#33280)
1 parent fc5b3a9 commit 4d805e6

File tree

6 files changed

+271
-55
lines changed

6 files changed

+271
-55
lines changed

paddle/fluid/operators/optimizers/lars_momentum_op.cc

100755100644
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,18 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
3434
AddInput("LearningRate",
3535
"(LoDTensor, default LoDTensor<float>) "
3636
"Input learning rate");
37+
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
3738

3839
AddOutput("ParamOut",
3940
"(LoDTensor) This output is updated parameter. "
4041
"It shared memory with Input(Param).");
4142
AddOutput("VelocityOut",
4243
"(LoDTensor) This output is updated velocity. "
4344
"It shared memory with Input(Velocity).");
45+
AddOutput("MasterParamOut",
46+
"The updated FP32 master weight for AMP. "
47+
"It shared memory with Input(MasterParam).")
48+
.AsDispensable();
4449

4550
AddAttr<float>("mu", "(float) Momentum coefficient");
4651
AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.")
@@ -51,6 +56,15 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
5156
AddAttr<float>("epsilon",
5257
"(float, default 0.0) epsilon to avoid Division by Zero.")
5358
.SetDefault(0.0);
59+
AddAttr<bool>("multi_precision",
60+
"(bool, default false) "
61+
"Whether to use multi-precision during weight updating.")
62+
.SetDefault(false);
63+
AddAttr<float>(
64+
"rescale_grad",
65+
"(float, default 1.0) Multiply the gradient with `rescale_grad`"
66+
"before updating. Often choose to be `1.0/batch_size`.")
67+
.SetDefault(1.0f);
5468

5569
AddComment(R"DOC(
5670
Lars Momentum Optimizer.

paddle/fluid/operators/optimizers/lars_momentum_op.cu

Lines changed: 89 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,55 +13,105 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
1617
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
1718

1819
namespace paddle {
1920
namespace operators {
2021

2122
template <typename T>
22-
__global__ void MomentumLarsKernel(const T* p, const T* g, const T* v,
23-
const T* learning_rate, const T mu,
24-
const int64_t num, const T lars_coeff,
25-
const T lars_weight_decay, const T* p_norm,
26-
const T* g_norm, T* p_out, T* v_out,
27-
const T epsilon) {
28-
T lr = learning_rate[0];
29-
T local_lr = learning_rate[0];
23+
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;
24+
25+
template <typename T, typename MT>
26+
__global__ void MomentumLarsKernel(
27+
const T* p, const T* g, const MT* v,
28+
const MultiPrecisionType<T>* learning_rate, const MT mu, const int64_t num,
29+
const MT lars_coeff, const MT lars_weight_decay,
30+
const MultiPrecisionType<T>* p_norm, const MultiPrecisionType<T>* g_norm,
31+
T* p_out, MT* v_out, const MT epsilon, const MT* master_p, MT* master_p_out,
32+
const MultiPrecisionType<T> rescale_grad) {
33+
const MT lr = static_cast<MT>(learning_rate[0]);
34+
MT local_lr = lr;
35+
const MT p_n = static_cast<MT>(p_norm[0]);
36+
const MT g_n = static_cast<MT>(g_norm[0]);
37+
38+
if (lars_weight_decay > static_cast<MT>(0) && p_n > static_cast<MT>(0) &&
39+
g_n > static_cast<MT>(0)) {
40+
local_lr =
41+
lr * lars_coeff * p_n / (g_n + lars_weight_decay * p_n + epsilon);
42+
}
3043
CUDA_KERNEL_LOOP(i, num) {
31-
if (lars_weight_decay > 0 && p_norm[0] > 0 && g_norm[0] > 0) {
32-
local_lr = lr * lars_coeff * p_norm[0] /
33-
(g_norm[0] + lars_weight_decay * p_norm[0] + epsilon);
34-
}
44+
MT grad = static_cast<MT>(g[i]) * static_cast<MT>(rescale_grad);
45+
MT param = master_p ? master_p[i] : static_cast<MT>(p[i]);
46+
47+
MT v_new = v[i] * mu + local_lr * (grad + lars_weight_decay * param);
48+
MT p_new = param - v_new;
3549

36-
T v_new = v[i] * mu + local_lr * (g[i] + lars_weight_decay * p[i]);
3750
v_out[i] = v_new;
38-
p_out[i] = p[i] - v_new;
51+
p_out[i] = static_cast<T>(p_new);
52+
if (master_p_out) master_p_out[i] = p_new;
3953
}
4054
}
4155

4256
template <typename DeviceContext, typename T>
4357
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
58+
using MPDType = MultiPrecisionType<T>;
59+
4460
public:
4561
void Compute(const framework::ExecutionContext& ctx) const override {
62+
const bool multi_precision = ctx.Attr<bool>("multi_precision");
63+
if (multi_precision) {
64+
InnerCompute<MPDType>(ctx, multi_precision);
65+
} else {
66+
InnerCompute<T>(ctx, multi_precision);
67+
}
68+
}
69+
70+
private:
71+
template <typename MT>
72+
void InnerCompute(const framework::ExecutionContext& ctx,
73+
const bool multi_precision) const {
4674
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
4775
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
4876
auto param = ctx.Input<framework::LoDTensor>("Param");
4977
auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
5078
auto grad = ctx.Input<framework::LoDTensor>("Grad");
5179
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");
5280

81+
const framework::Tensor* master_param = nullptr;
82+
framework::Tensor* master_param_out = nullptr;
83+
if (multi_precision) {
84+
bool has_master =
85+
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
86+
PADDLE_ENFORCE_EQ(has_master, true,
87+
platform::errors::InvalidArgument(
88+
"The Input(MasterParam) and Output(MasterParamOut) "
89+
"should not be null when "
90+
"the attr `multi_precision` is true"));
91+
master_param = ctx.Input<framework::Tensor>("MasterParam");
92+
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
93+
}
94+
95+
const MT* master_p = multi_precision ? master_param->data<MT>() : nullptr;
96+
MT* master_p_out = multi_precision
97+
? master_param_out->mutable_data<MT>(ctx.GetPlace())
98+
: nullptr;
99+
53100
T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
54-
T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
101+
MT* v_out = velocity_out->mutable_data<MT>(ctx.GetPlace());
55102

56-
T mu = static_cast<T>(ctx.Attr<float>("mu"));
57-
T lars_coeff = ctx.Attr<float>("lars_coeff");
58-
T lars_weight_decay = ctx.Attr<float>("lars_weight_decay");
59-
T epsilon = ctx.Attr<float>("epsilon");
103+
MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
104+
MT lars_coeff = static_cast<MT>(ctx.Attr<float>("lars_coeff"));
105+
MT lars_weight_decay =
106+
static_cast<MT>(ctx.Attr<float>("lars_weight_decay"));
107+
MT epsilon = static_cast<MT>(ctx.Attr<float>("epsilon"));
108+
MPDType rescale_grad =
109+
static_cast<MPDType>(ctx.Attr<float>("rescale_grad"));
60110

61111
auto* p = param->data<T>();
62-
auto* v = velocity->data<T>();
63112
auto* g = grad->data<T>();
64-
auto* lr = learning_rate->data<T>();
113+
auto* v = velocity->data<MT>();
114+
auto* lr = learning_rate->data<MPDType>();
65115

66116
int block = 512;
67117
int grid = (param->numel() + block - 1) / block;
@@ -72,17 +122,24 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
72122
framework::Tensor p_norm_t, g_norm_t;
73123
p_norm_t.Resize({1});
74124
g_norm_t.Resize({1});
75-
auto* p_norm_data = p_norm_t.mutable_data<T>(ctx.GetPlace());
76-
auto* g_norm_data = g_norm_t.mutable_data<T>(ctx.GetPlace());
77-
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
78-
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
125+
auto* p_norm_data = p_norm_t.mutable_data<MPDType>(ctx.GetPlace());
126+
auto* g_norm_data = g_norm_t.mutable_data<MPDType>(ctx.GetPlace());
127+
auto ep_norm = framework::EigenScalar<MPDType>::From(p_norm_t);
128+
auto eg_norm = framework::EigenScalar<MPDType>::From(g_norm_t);
79129

80130
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
81-
ep_norm.device(*place) = eigen_p.square().sum().sqrt();
82-
eg_norm.device(*place) = eigen_g.square().sum().sqrt();
83-
MomentumLarsKernel<<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
131+
132+
// eigen unsupport fp16 l2-norm
133+
ep_norm.device(*place) =
134+
eigen_p.template cast<MPDType>().square().sum().sqrt();
135+
eg_norm.device(*place) =
136+
(eigen_g.template cast<MPDType>() * rescale_grad).square().sum().sqrt();
137+
138+
MomentumLarsKernel<
139+
T, MT><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
84140
p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay,
85-
p_norm_data, g_norm_data, p_out, v_out, epsilon);
141+
p_norm_data, g_norm_data, p_out, v_out, epsilon, master_p, master_p_out,
142+
rescale_grad);
86143
}
87144
};
88145

@@ -93,4 +150,6 @@ namespace ops = paddle::operators;
93150
REGISTER_OP_CUDA_KERNEL(
94151
lars_momentum,
95152
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
96-
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double>);
153+
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double>,
154+
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext,
155+
paddle::platform::float16>);

paddle/fluid/operators/optimizers/momentum_op.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ class MomentumOp : public framework::OperatorWithKernel {
135135

136136
ctx->SetOutputDim("ParamOut", param_dim);
137137
ctx->SetOutputDim("VelocityOut", param_dim);
138+
if (ctx->HasOutput("MasterParamOut")) {
139+
ctx->SetOutputDim("MasterParamOut", param_dim);
140+
}
138141
}
139142

140143
framework::OpKernelType GetExpectedKernelType(

python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def layer_warp(block_func, input, ch_in, ch_out, count, stride):
7373
return pool
7474

7575

76-
def train(use_pure_fp16=True, use_nesterov=False, use_adam=False):
76+
def train(use_pure_fp16=True, use_nesterov=False, optimizer=""):
7777
classdim = 10
7878
data_shape = [3, 32, 32]
7979
BATCH_SIZE = 32
@@ -96,12 +96,17 @@ def train(use_pure_fp16=True, use_nesterov=False, use_adam=False):
9696
# Test program
9797
test_program = train_program.clone(for_test=True)
9898

99-
if use_adam:
99+
if optimizer == "Adam":
100100
optimizer = paddle.optimizer.AdamW(
101101
learning_rate=0.001,
102102
epsilon=1e-8,
103103
weight_decay=0.0,
104104
multi_precision=True)
105+
elif optimizer == "Lars":
106+
optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer(
107+
learning_rate=0.001,
108+
momentum=0.9,
109+
multi_precision=use_pure_fp16)
105110
else:
106111
optimizer = paddle.optimizer.Momentum(
107112
learning_rate=0.001,
@@ -169,9 +174,11 @@ def test_resnet_pure_fp16(self):
169174
if not fluid.core.is_compiled_with_cuda():
170175
return
171176

172-
def do_test(use_nesterov=False, use_adam=False):
173-
if use_adam:
177+
def do_test(use_nesterov=False, optimizer=""):
178+
if optimizer == "Adam":
174179
suffix = "use Adam"
180+
elif optimizer == "Lars":
181+
suffix = "use Lars"
175182
else:
176183
suffix = "with Nesterov" if use_nesterov else "without Nesterov"
177184
with self.scope_prog_guard():
@@ -180,14 +187,14 @@ def do_test(use_nesterov=False, use_adam=False):
180187
train_loss_fp16, test_loss_fp16 = train(
181188
use_pure_fp16=True,
182189
use_nesterov=use_nesterov,
183-
use_adam=use_adam)
190+
optimizer=optimizer)
184191
with self.scope_prog_guard():
185192
print("-----------------FP32 Train {}-----------------".format(
186193
suffix))
187194
train_loss_fp32, test_loss_fp32 = train(
188195
use_pure_fp16=False,
189196
use_nesterov=use_nesterov,
190-
use_adam=use_adam)
197+
optimizer=optimizer)
191198

192199
self.assertTrue(
193200
np.allclose(
@@ -208,7 +215,8 @@ def do_test(use_nesterov=False, use_adam=False):
208215

209216
do_test(use_nesterov=False)
210217
do_test(use_nesterov=True)
211-
do_test(use_adam=True)
218+
do_test(optimizer="Adam")
219+
do_test(optimizer="Lars")
212220

213221
@contextlib.contextmanager
214222
def scope_prog_guard(self):

0 commit comments

Comments
 (0)