Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/API.spec
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ paddle.fluid.contrib.HDFSClient.upload (ArgSpec(args=['self', 'hdfs_path', 'loca
paddle.fluid.contrib.multi_download (ArgSpec(args=['client', 'hdfs_path', 'local_path', 'trainer_id', 'trainers', 'multi_processes'], varargs=None, keywords=None, defaults=(5,)), ('document', '100927be598ed8f9eaa1f3ef1b23568a'))
paddle.fluid.contrib.multi_upload (ArgSpec(args=['client', 'hdfs_path', 'local_path', 'multi_processes', 'overwrite', 'sync'], varargs=None, keywords=None, defaults=(5, False, True)), ('document', '183f34c83d30dbe16e09e8716c41958a'))
paddle.fluid.contrib.extend_with_decoupled_weight_decay (ArgSpec(args=['base_optimizer'], varargs=None, keywords=None, defaults=None), ('document', 'a1095dfd4ec725747f662d69cd7659d4'))
paddle.fluid.contrib.mixed_precision.decorate (ArgSpec(args=['optimizer', 'init_loss_scaling', 'use_dynamic_loss_scaling'], varargs=None, keywords=None, defaults=(1.0, False)), ('document', '67e9bf14f345b38da169beb1ebb276eb'))
paddle.fluid.transpiler.DistributeTranspiler.__init__ (ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program (ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None), ('document', '292ab72977afbe58e6a3bde175452680'))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs (ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None), ('document', '78f4949aedf317666a89ca74b3748ba8'))
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from .extend_optimizer import *
from . import model_stat
from .model_stat import *
from . import mixed_precision
from .mixed_precision import *

__all__ = []
__all__ += decoder.__all__
Expand All @@ -45,3 +47,4 @@
__all__ += slim.__all__
__all__ += utils.__all__
__all__ += extend_optimizer.__all__
__all__ += ['mixed_precision']
19 changes: 19 additions & 0 deletions python/paddle/fluid/contrib/mixed_precision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
from . import decorator
from .decorator import *

__all__ = decorator.__all__
157 changes: 157 additions & 0 deletions python/paddle/fluid/contrib/mixed_precision/decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ... import default_main_program
from ... import default_startup_program
from ... import layers
from ... import unique_name
from . import fp16_utils
from .fp16_utils import create_master_params_grads, master_param_to_train_param

__all__ = ["decorate"]


class OptimizerWithMixedPrecison(object):
"""
Optimizer with mixed-precision (MP) training. This is a wrapper of a common
optimizer, plus the support of mixed-precision pretraining. The object
of this class almost has the same behavior as the common optimizer, with the
methods `minimize()`, `backward()`, `apply_gradients()` implemented.
Additionally, it enables the MP training automatically, i.e, the creation
and maintenance of master parameters, scaling of loss, etc.

Args:
optimizer (Optimizer): A common Optimizer object.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> A Optimizer object.

init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
"""

def __init__(self, optimizer, init_loss_scaling, use_dynamic_loss_scaling):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need comments for input arguments. Same as below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self._optimizer = optimizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check the type of optimizer.

self._param_grads = None
self._train_program = default_main_program()
self._startup_prog = default_startup_program()
self._loss_scaling = init_loss_scaling
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling

# Ensure the data type of learning rate vars is float32 (same as the
# master parameter dtype)
if isinstance(optimizer._learning_rate, float):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen when the optimizer._learning_rate is not float?

optimizer._learning_rate_map[default_main_program()] = \
layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(optimizer._learning_rate),
dtype='float32',
persistable=True)

def get_loss_scaling(self):
"""Return the real-time loss scaling factor.
"""
return self._loss_scaling

def backward(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None):
"""
Backward propogation or auto differentiation for gradients' computation.

Args:
loss (Variable): The loss Variable to minimize.
startup_program (Program|None): The startup Program for initializing
parameters in `parameter_list`.
parameter_list (list|None): A list of Variables to update.
no_grad_set (set|None): A set of Variables should be ignored.
callbacks (list|None): A list of callables to run when appending
backward operator for one parameter.

Returns:
A list of (param, grad), which is a tuple of a parameter and its
gradient respectively, and the scaled loss.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain the master_params?

"""
scaled_loss = loss * self._loss_scaling
self._param_grads = self._optimizer.backward(
scaled_loss, startup_program, parameter_list, no_grad_set,
callbacks)
master_params_grads = create_master_params_grads(
self._param_grads, self._train_program, self._startup_prog,
self._loss_scaling)

return master_params_grads, scaled_loss

def apply_gradients(self, master_params_grads):
"""
Update master parameters by their gradients, and cast to parameters
in float16.

Args:
master_params_grads (list): A list of master params and grads.

Returns:
A list of optimize operators.
"""
optimize_ops = self._optimizer.apply_gradients(master_params_grads)
master_param_to_train_param(master_params_grads, self._param_grads,
self._train_program)
return optimize_ops

def minimize(self, loss):
"""
Perform optimization by minimizing the given loss.

Args:
loss (Variable): The loss Variable.

Returns:
The scaled loss by scaling factor, the list of optimize ops, and a
list of master parameters and gradients.
"""
master_params_grads, scaled_loss = self.backward(loss)
optimize_ops = self.apply_gradients(master_params_grads)

return scaled_loss, optimize_ops, master_params_grads


def decorate(optimizer, init_loss_scaling=1.0, use_dynamic_loss_scaling=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about extend_optimizer_with_mixed_precison?

"""
Decorate the given optimizer to adapt to the mixed-precision training.

Args:
optimizer(Optimizer): A common Optimizer.
init_loss_scaling(float): The initial loss scaling factor.
use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please give the default value.


Returns:
An optimizer acting like a normal one but with mixed-precision training
enabled.

Examples:
.. code-block:: python

loss = network()
optimizer = fluid.optimizer.Adam(learning_rate=0.001)

mp_optimizer = fluid.contrib.mixed_precision.decorate(
optimizer=optimizer, init_loss_scaling=8.0)

scaled_loss, _, _ = mp_optimizer.minimize(loss)
"""

mp_optimizer = OptimizerWithMixedPrecison(optimizer, init_loss_scaling,
use_dynamic_loss_scaling)

return mp_optimizer
125 changes: 125 additions & 0 deletions python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from ... import core
from ... import layers
from ... import framework


def append_cast_op(i, o, prog):
"""
Append a cast op in a given Program to cast input `i` to data type `o.dtype`.

Args:
i (Variable): The input Variable.
o (Variable): The output Variable.
prog (Program): The Program to append cast op.
"""
prog.global_block().append_op(
type="cast",
inputs={"X": i},
outputs={"Out": o},
attrs={"in_dtype": i.dtype,
"out_dtype": o.dtype})


def copy_to_master_param(p, block):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need comments for all tehse APIs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"""
New a master parameter for the input parameter, and they two share the same
attributes except the data type.

Args:
p(Parameter): The input parameter in float16.
block(Program): The block in which the parameter is.
"""
v = block.vars.get(p.name, None)
if v is None:
raise ValueError("no param name %s found!" % p.name)
new_p = framework.Parameter(
block=block,
shape=v.shape,
dtype=core.VarDesc.VarType.FP32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the dtype must be FP32 here? Is it possible that the dtype is fp64?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, the master parameter can be float64. But there are some hard-coded implementations, and the fp64 support seems not to be that straightforward. So we are going to only support float32 temporarily because it is more common used. Maybe we can go back to consider fp64 some day in the future.

type=v.type,
lod_level=v.lod_level,
stop_gradient=p.stop_gradient,
trainable=p.trainable,
optimize_attr=p.optimize_attr,
regularizer=p.regularizer,
gradient_clip_attr=p.gradient_clip_attr,
error_clip=p.error_clip,
name=v.name + ".master")
return new_p


def create_master_params_grads(params_grads, main_prog, startup_prog,
loss_scaling):
"""
Create master parameters and gradients in float32 from params and grads
in float16.

Args:
params_grads (list): A list of tuple (parameter, gradient) in float32.
main_prog (Program): The main program for training.
startup_prog (Program): The startup program to initialize all parameters.
loss_scaling (float): The factor to scale loss and gradients.

Returns:
A list of master parameters and gradients.
"""
master_params_grads = []
with main_prog._backward_role_guard():
for p, g in params_grads:
# create master parameters
master_param = copy_to_master_param(p, main_prog.global_block())
startup_master_param = startup_prog.global_block()._clone_variable(
master_param)
startup_p = startup_prog.global_block().var(p.name)
# fp16 -> fp32
append_cast_op(startup_p, startup_master_param, startup_prog)
# cast fp16 gradients to fp32 before apply gradients
if g.name.find("batch_norm") > -1:
Copy link
Contributor

@gongweibao gongweibao Apr 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add attribute to these operator desc instead of hard code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next step, wo can optimize such hard code

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with @gongweibao, maybe you should get the op that generates g.

if loss_scaling > 1:
scaled_g = g / float(loss_scaling)
else:
scaled_g = g
master_params_grads.append([p, scaled_g])
continue
master_grad = layers.cast(x=g, dtype="float32")
if loss_scaling > 1:
master_grad = master_grad / float(loss_scaling)
master_params_grads.append([master_param, master_grad])

return master_params_grads


def master_param_to_train_param(master_params_grads, params_grads, main_prog):
"""
Convert master master parameters and gradients in float32 to parameters and
gradients in float16 for forward computation.

Args:
master_params_grads (list): A list of master parameters and gradients in
float32.
params_grads (list): A list of parameters and gradients in float16.
main_prog (list): The main program for execution.
"""
for idx, m_p_g in enumerate(master_params_grads):
train_p, _ = params_grads[idx]
if train_p.name.find("batch_norm") > -1:
continue
with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]):
# fp32 -> fp16
append_cast_op(m_p_g[0], train_p, main_prog)
Loading