Skip to content

Init mixed precision training interface#16856

Merged
kuke merged 9 commits intoPaddlePaddle:developfrom
kuke:mixed_precision_init
Apr 25, 2019
Merged

Init mixed precision training interface#16856
kuke merged 9 commits intoPaddlePaddle:developfrom
kuke:mixed_precision_init

Conversation

@kuke
Copy link
Contributor

@kuke kuke commented Apr 15, 2019

Simple Usage:

  1. Cast specific inputs to float16 (Q: Should we do it automatically?)
  imgs = fluid.layers.cast(images, "float16")
  1. Decorate optimizer and minimize scaled loss
 optimizer = fluid.optimizer.Adam(learning_rate=0.001)
	
 mp_optimizer = fluid.contrib.mixed_precision.decorate(
	            optimizer=optimizer, init_loss_scaling=8.0)
	
 scaled_loss, _, _ = mp_optimizer.minimize(avg_cost)

See test_image_classification_fp16.py for details.

@CLAassistant
Copy link

CLAassistant commented Apr 15, 2019

CLA assistant check
All committers have signed the CLA.

@kuke kuke requested review from chengduoZH and typhoonzero April 15, 2019 04:20
@kuke kuke force-pushed the mixed_precision_init branch from 409a90b to 4636605 Compare April 15, 2019 04:23
@kuke kuke requested a review from qingqing01 April 15, 2019 04:24
test=develop


class OptimizerWithMixedPrecison(object):
def __init__(self, optimizer, init_loss_scaling, use_dynamic_loss_scaling):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need comments for input arguments. Same as below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(
['truncated_gaussian_random', 'tmp'])),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use var.name as part of prefix for unique_name.generate a new name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

inputs={"X": out_var},
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wether the other initializer needs to do this FP32 weight creating and casting?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. They all need. And all initializers have been made to support float16.

})


def copy_to_master_param(p, block):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need comments for all tehse APIs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@kuke kuke requested review from chenwhql and gongweibao and removed request for chenwhql April 15, 2019 06:15
master_params_grads = []
tmp_role = main_prog._current_role
OpRole = core.op_proto_and_checker_maker.OpRole
main_prog._current_role = OpRole.Backward
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use _backward_role_guard here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

new_p = framework.Parameter(
block=block,
shape=v.shape,
dtype=core.VarDesc.VarType.FP32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the dtype must be FP32 here? Is it possible that the dtype is fp64?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, the master parameter can be float64. But there are some hard-coded implementations, and the fp64 support seems not to be that straightforward. So we are going to only support float32 temporarily because it is more common used. Maybe we can go back to consider fp64 some day in the future.

startup_master_param = startup_prog.global_block()._clone_variable(
master_param)
startup_p = startup_prog.global_block().var(p.name)
cast_fp16_to_fp32(startup_p, startup_master_param, startup_prog)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check that p.type is not fp32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unified the two cast functions.

startup_master_param = startup_prog.global_block()._clone_variable(
master_param)
startup_p = startup_prog.global_block().var(p.name)
cast_fp16_to_fp32(startup_p, startup_master_param, startup_prog)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check that p.type is not fp32 here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as above

@kuke kuke force-pushed the mixed_precision_init branch from 15b00d6 to 9456cef Compare April 21, 2019 09:14
@kuke kuke force-pushed the mixed_precision_init branch from 5f2faba to c2fa295 Compare April 21, 2019 17:26
test=develop
@kuke kuke force-pushed the mixed_precision_init branch from 60bdfb7 to b2d80ea Compare April 21, 2019 18:02
# fp16 -> fp32
append_cast_op(startup_p, startup_master_param, startup_prog)
# cast fp16 gradients to fp32 before apply gradients
if g.name.find("batch_norm") > -1:
Copy link
Contributor

@gongweibao gongweibao Apr 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add attribute to these operator desc instead of hard code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next step, wo can optimize such hard code

paddle.fluid.contrib.multi_download (ArgSpec(args=['client', 'hdfs_path', 'local_path', 'trainer_id', 'trainers', 'multi_processes'], varargs=None, keywords=None, defaults=(5,)), ('document', '100927be598ed8f9eaa1f3ef1b23568a'))
paddle.fluid.contrib.multi_upload (ArgSpec(args=['client', 'hdfs_path', 'local_path', 'multi_processes', 'overwrite', 'sync'], varargs=None, keywords=None, defaults=(5, False, True)), ('document', '183f34c83d30dbe16e09e8716c41958a'))
paddle.fluid.contrib.extend_with_decoupled_weight_decay (ArgSpec(args=['base_optimizer'], varargs=None, keywords=None, defaults=None), ('document', 'a1095dfd4ec725747f662d69cd7659d4'))
paddle.fluid.contrib.decorate (ArgSpec(args=['optimizer', 'init_loss_scaling', 'use_dynamic_loss_scaling'], varargs=None, keywords=None, defaults=(1.0, False)), ('document', '089f0c8d7c03bd3d0edc3ac83dbe41fd'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name(decorate) is not explicit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


class OptimizerWithMixedPrecison(object):
"""
Optimizer class with mixed-precision training.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should introduce the implementation of OptimizerWithMixedPrecison detailly and give an example here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more details

@@ -0,0 +1,301 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is duplicated?
Can add arguments to the exists unittest instead of a new one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'Cause this interface is in contrib, we'd better use a seperate unit test file

out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(
['constant_init', var.name, 'tmp'])),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are these vars used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the startup program

@kuke kuke force-pushed the mixed_precision_init branch from 1598ea7 to e3b4499 Compare April 25, 2019 03:36
and maintenance of master parameters, scaling of loss, etc.

Args:
optimizer (Optimizer): A common Optimizer object.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> A Optimizer object.


# Ensure the data type of learning rate vars is float32 (same as the
# master parameter dtype)
if isinstance(optimizer._learning_rate, float):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen when the optimizer._learning_rate is not float?

"""

def __init__(self, optimizer, init_loss_scaling, use_dynamic_loss_scaling):
self._optimizer = optimizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check the type of optimizer.

# fp16 -> fp32
append_cast_op(startup_p, startup_master_param, startup_prog)
# cast fp16 gradients to fp32 before apply gradients
if g.name.find("batch_norm") > -1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with @gongweibao, maybe you should get the op that generates g.

return scaled_loss, optimize_ops, master_params_grads


def decorate(optimizer, init_loss_scaling=1.0, use_dynamic_loss_scaling=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about extend_optimizer_with_mixed_precison?

Args:
optimizer(Optimizer): A common Optimizer.
init_loss_scaling(float): The initial loss scaling factor.
use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please give the default value.


Returns:
A list of (param, grad), which is a tuple of a parameter and its
gradient respectively, and the scaled loss.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain the master_params?

Copy link
Contributor

@chengduoZH chengduoZH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code should be polished.

Copy link
Contributor

@chengduoZH chengduoZH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code should be polished.

Copy link
Collaborator

@shanyi15 shanyi15 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approve for under contrib menu

@kuke kuke merged commit beda782 into PaddlePaddle:develop Apr 25, 2019
sneaxiy pushed a commit to sneaxiy/Paddle that referenced this pull request Apr 28, 2019
# The first commit's message is:
remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (PaddlePaddle#17066)

# This is the 2nd commit message:

Fleet unify distributed training (PaddlePaddle#16791)

* implement distributed transpiler with fleet
# This is the 3rd commit message:

ParallelDyGraph with GPU collective mode (PaddlePaddle#16827)

implement dygraph.parallel.DataParallel to hook reduce op.

# This is the 4th commit message:

Init mixed precision training interface (PaddlePaddle#16856)

* Init mixed precision training interface

* Add fp16 test script

test=develop

* All initializers support float16

test=develop

* Code cleanup & add more code annotations

test=develop

* Update API spec

test=develop

* Add usage example in doc

test=develop

# This is the 5th commit message:

fix reference_count_pass,test=develop (PaddlePaddle#17060)

test=develop
# This is the 6th commit message:

Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (PaddlePaddle#17090)

* Cache the information of linear interpolation in forward and use it in backward.
test=develop

* Fix cuda kernel.
test=develop

# This is the 7th commit message:

remove unnecessary prepare_data (PaddlePaddle#17080)

test=develop
# This is the 8th commit message:

fix interpolate cu. test=develop (PaddlePaddle#17101)

# This is the 9th commit message:

test=develop, double backward leaky_relu (PaddlePaddle#17067)

backward of backward: leaky_relu
# This is the 10th commit message:

fix fuse optimizer ops (PaddlePaddle#17102)

test=develop
# This is the 11th commit message:

truncated_gaussian_random supported in distributed training, test=develop (PaddlePaddle#17091)

# This is the 12th commit message:

 Detailed coordinate description for yolov3 loss (PaddlePaddle#17007)

* Detailed coordinate description for yolov3 loss

test=develop

* modified api.spec

test=develop

* modified loss name

* fix api.spec

test=develop

* polish description

test=develop

* modified api.spec

test=develop

# This is the 13th commit message:

fix test_weight_decay (PaddlePaddle#17109)

test=develop
# This is the 14th commit message:

Path flag (PaddlePaddle#17105)

* fix python/paddle/fluid/__init__.py detecting problems
sneaxiy added a commit that referenced this pull request Apr 28, 2019
* refine_dropout_mem,test=develop

* # This is a combination of 14 commits.
# The first commit's message is:
remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (#17066)

# This is the 2nd commit message:

Fleet unify distributed training (#16791)

* implement distributed transpiler with fleet
# This is the 3rd commit message:

ParallelDyGraph with GPU collective mode (#16827)

implement dygraph.parallel.DataParallel to hook reduce op.

# This is the 4th commit message:

Init mixed precision training interface (#16856)

* Init mixed precision training interface

* Add fp16 test script

test=develop

* All initializers support float16

test=develop

* Code cleanup & add more code annotations

test=develop

* Update API spec

test=develop

* Add usage example in doc

test=develop

# This is the 5th commit message:

fix reference_count_pass,test=develop (#17060)

test=develop
# This is the 6th commit message:

Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (#17090)

* Cache the information of linear interpolation in forward and use it in backward.
test=develop

* Fix cuda kernel.
test=develop

# This is the 7th commit message:

remove unnecessary prepare_data (#17080)

test=develop
# This is the 8th commit message:

fix interpolate cu. test=develop (#17101)

# This is the 9th commit message:

test=develop, double backward leaky_relu (#17067)

backward of backward: leaky_relu
# This is the 10th commit message:

fix fuse optimizer ops (#17102)

test=develop
# This is the 11th commit message:

truncated_gaussian_random supported in distributed training, test=develop (#17091)

# This is the 12th commit message:

 Detailed coordinate description for yolov3 loss (#17007)

* Detailed coordinate description for yolov3 loss

test=develop

* modified api.spec

test=develop

* modified loss name

* fix api.spec

test=develop

* polish description

test=develop

* modified api.spec

test=develop

# This is the 13th commit message:

fix test_weight_decay (#17109)

test=develop
# This is the 14th commit message:

Path flag (#17105)

* fix python/paddle/fluid/__init__.py detecting problems
@gongweibao gongweibao added the AMP label Feb 10, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants