Skip to content

Commit fb4d568

Browse files
authored
Support EMA in Paddle2.x and Fleet (#35673)
* Support EMA in Paddle2.x and Fleet * update * update * update * modify ut of ema * modify docs * modify bugs * update * update * update * modify ut
1 parent 177bf52 commit fb4d568

File tree

5 files changed

+151
-53
lines changed

5 files changed

+151
-53
lines changed

python/paddle/fluid/optimizer.py

Lines changed: 49 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3959,62 +3959,59 @@ class ExponentialMovingAverage(object):
39593959
39603960
39613961
Args:
3962-
decay (float, optional): The exponential decay rate, usually close to 1, such as
3963-
0.999, 0.9999, ... . Default 0.999.
3964-
thres_steps (Variable|None): If not `None`, schedule the decay rate.
3965-
Default None.
3966-
name (str|None): For detailed information, please refer to
3967-
:ref:`api_guide_Name`. Usually name is no need to set and None by
3968-
default.
3962+
decay (float, optional): The exponential decay rate, usually close to 1, such as 0.999, 0.9999, ... . Default 0.999.
3963+
thres_steps (Variable|None, optional): If not `None`, schedule the decay rate. Default None.
3964+
name (str|None, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default.
39693965
39703966
39713967
Examples:
39723968
3973-
.. code-block:: python
3974-
3975-
import numpy
3976-
import paddle
3977-
import paddle.fluid as fluid
3978-
3979-
data = fluid.data(name='x', shape=[-1, 5], dtype='float32')
3980-
hidden = fluid.layers.fc(input=data, size=10)
3981-
cost = fluid.layers.mean(hidden)
3982-
3983-
test_program = fluid.default_main_program().clone(for_test=True)
3984-
3985-
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
3986-
optimizer.minimize(cost)
3987-
3988-
global_steps = fluid.layers.autoincreased_step_counter()
3989-
ema = fluid.optimizer.ExponentialMovingAverage(0.999, thres_steps=global_steps)
3990-
ema.update()
3991-
3992-
place = fluid.CPUPlace()
3993-
exe = fluid.Executor(place)
3994-
exe.run(fluid.default_startup_program())
3995-
3996-
for pass_id in range(3):
3997-
for batch_id in range(6):
3998-
data = numpy.random.random(size=(10, 5)).astype('float32')
3999-
exe.run(program=fluid.default_main_program(),
4000-
feed={'x': data},
4001-
fetch_list=[cost.name])
4002-
4003-
# usage 1
4004-
with ema.apply(exe):
4005-
data = numpy.random.random(size=(10, 5)).astype('float32')
4006-
exe.run(program=test_program,
4007-
feed={'x': data},
4008-
fetch_list=[hidden.name])
4009-
4010-
4011-
# usage 2
4012-
with ema.apply(exe, need_restore=False):
4013-
data = numpy.random.random(size=(10, 5)).astype('float32')
4014-
exe.run(program=test_program,
4015-
feed={'x': data},
4016-
fetch_list=[hidden.name])
4017-
ema.restore(exe)
3969+
.. code-block:: python
3970+
3971+
import numpy
3972+
import paddle
3973+
import paddle.static as static
3974+
from paddle.static import ExponentialMovingAverage
3975+
3976+
paddle.enable_static()
3977+
3978+
data = static.data(name='x', shape=[-1, 5], dtype='float32')
3979+
hidden = static.nn.fc(x=data, size=10)
3980+
cost = paddle.mean(hidden)
3981+
3982+
test_program = static.default_main_program().clone(for_test=True)
3983+
optimizer = paddle.optimizer.Adam(learning_rate=0.001)
3984+
optimizer.minimize(cost)
3985+
3986+
ema = ExponentialMovingAverage(0.999)
3987+
ema.update()
3988+
3989+
place = paddle.CPUPlace()
3990+
exe = static.Executor(place)
3991+
exe.run(static.default_startup_program())
3992+
3993+
for pass_id in range(3):
3994+
for batch_id in range(6):
3995+
data = numpy.random.random(size=(10, 5)).astype('float32')
3996+
exe.run(program=static.default_main_program(),
3997+
feed={'x': data},
3998+
fetch_list=[cost.name])
3999+
4000+
# usage 1
4001+
with ema.apply(exe):
4002+
data = numpy.random.random(size=(10, 5)).astype('float32')
4003+
exe.run(program=test_program,
4004+
feed={'x': data},
4005+
fetch_list=[hidden.name])
4006+
4007+
# usage 2
4008+
with ema.apply(exe, need_restore=False):
4009+
data = numpy.random.random(size=(10, 5)).astype('float32')
4010+
exe.run(program=test_program,
4011+
feed={'x': data},
4012+
fetch_list=[hidden.name])
4013+
ema.restore(exe)
4014+
40184015
"""
40194016

40204017
def __init__(self, decay=0.999, thres_steps=None, name=None):
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
import paddle
20+
import paddle.utils as utils
21+
import paddle.static as static
22+
23+
24+
def gen_data():
25+
return np.random.random(size=(10, 5)).astype('float32')
26+
27+
28+
class TestFleetStaticEMA(unittest.TestCase):
29+
def setUp(self):
30+
self._places = [paddle.CPUPlace()]
31+
if paddle.device.is_compiled_with_cuda():
32+
self._places.append(paddle.CUDAPlace(0))
33+
self._ema_decay = 0.999
34+
self._param_name = "fc.weight"
35+
self._train_program = static.Program()
36+
self._startup_prog = static.Program()
37+
38+
strategy = paddle.distributed.fleet.DistributedStrategy()
39+
strategy.without_graph_optimization = True
40+
paddle.distributed.fleet.init(is_collective=True, strategy=strategy)
41+
42+
with static.program_guard(self._train_program, self._startup_prog):
43+
with utils.unique_name.guard():
44+
data = static.data(name='x', shape=[-1, 5], dtype='float32')
45+
hidden = static.nn.fc(x=data,
46+
size=10,
47+
weight_attr=self._param_name)
48+
cost = paddle.mean(hidden)
49+
50+
self._test_program = static.default_main_program().clone(
51+
for_test=True)
52+
53+
optimizer = paddle.optimizer.Adam(learning_rate=0.001)
54+
optimizer = paddle.distributed.fleet.distributed_optimizer(
55+
optimizer, strategy)
56+
optimizer.minimize(cost)
57+
58+
self._ema = static.ExponentialMovingAverage(self._ema_decay)
59+
self._ema.update()
60+
61+
def train(self, place, restore):
62+
exe = static.Executor(place)
63+
exe.run(self._startup_prog)
64+
65+
params = []
66+
for pass_id in range(2):
67+
for batch_id in range(3):
68+
exe.run(program=self._train_program, feed={'x': gen_data()})
69+
tmp_param = np.array(static.global_scope().find_var(
70+
self._param_name).get_tensor())
71+
params.append(tmp_param)
72+
73+
with self._ema.apply(exe, restore):
74+
final_ema = np.array(static.global_scope().find_var(
75+
self._param_name).get_tensor())
76+
exe.run(program=self._test_program, feed={'x': gen_data()})
77+
if not restore:
78+
self._ema.restore(exe)
79+
80+
return params, final_ema
81+
82+
def test_check_ema(self):
83+
for place in self._places:
84+
for restore in (True, False):
85+
params, final_ema = self.train(place, restore)
86+
manu_ema = np.zeros_like(final_ema)
87+
if len(params) > 0:
88+
for param in params:
89+
manu_ema = self._ema_decay * manu_ema + (
90+
1 - self._ema_decay) * param
91+
manu_ema = manu_ema / (1.0 - self._ema_decay**len(params))
92+
self.assertTrue(np.allclose(manu_ema, final_ema))
93+
94+
95+
if __name__ == '__main__':
96+
paddle.enable_static()
97+
unittest.main()

python/paddle/static/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from ..fluid.layers.nn import py_func # noqa: F401
4949
from ..fluid.parallel_executor import ParallelExecutor # noqa: F401
5050
from ..fluid.param_attr import WeightNormParamAttr # noqa: F401
51+
from ..fluid.optimizer import ExponentialMovingAverage # noqa: F401
5152
from ..fluid.io import save # noqa: F401
5253
from ..fluid.io import load # noqa: F401
5354
from ..fluid.io import load_program_state # noqa: F401
@@ -76,6 +77,7 @@
7677
'ParallelExecutor',
7778
'program_guard',
7879
'WeightNormParamAttr',
80+
'ExponentialMovingAverage',
7981
'default_main_program',
8082
'default_startup_program',
8183
'Program',

tools/parallel_UT_rule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@
528528
'test_trunc_op', 'test_bernoulli_op', 'test_custom_relu_model',
529529
'test_backward', 'test_conv3d_transpose_part2_op', 'test_complex_transpose',
530530
'test_memory_reuse_exclude_feed_var', 'test_polygon_box_transform',
531-
'math_function_gpu_test', 'test_program_prune_backward',
531+
'math_function_gpu_test', 'test_program_prune_backward', 'test_ema_fleet',
532532
'test_fleet_amp_init', 'test_normalize', 'test_correlation',
533533
'test_conv_elementwise_add2_act_fuse_pass',
534534
'test_imperative_container_layerlist', 'test_dequantize_abs_max_op',
@@ -1324,6 +1324,7 @@
13241324
'test_slice_op',
13251325
'test_cond',
13261326
'test_ema',
1327+
'test_ema_fleet',
13271328
'test_nan_inf',
13281329
'test_isinstance',
13291330
'test_box_clip_op',

tools/static_mode_white_list.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@
173173
'test_elementwise_nn_grad',
174174
'test_elementwise_pow_op',
175175
'test_ema',
176+
'test_ema_fleet',
176177
'test_embedding_id_stop_gradient',
177178
'test_empty_like_op',
178179
'test_entry_attr',

0 commit comments

Comments
 (0)