Skip to content
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
98c720f
add pure fp16 major function in auto_cast & tracer
zhangbo9674 Sep 7, 2021
228c855
support master weight in dygraph for pure fp16
zhangbo9674 Sep 7, 2021
a7f00a1
check mix dtype of fp16&fp32 for check_finite_and_unscale op
zhangbo9674 Sep 7, 2021
1dce0cc
change pure fp16 funtion name
zhangbo9674 Sep 7, 2021
422ced8
refine some bug in auto_cast
zhangbo9674 Sep 7, 2021
1366dad
refine auto_cast interface logic
zhangbo9674 Sep 7, 2021
5cdc012
add param _casted_by_pure_fp16 for class Layer
zhangbo9674 Sep 8, 2021
9e5399b
support state_dict hook for save model by user appointed dtype in pur…
zhangbo9674 Sep 10, 2021
91af1e9
refine pure_fp16_decorator as decorator
zhangbo9674 Sep 12, 2021
6cb2108
merge paddle develop
zhangbo9674 Sep 13, 2021
00cfbef
add unittest
zhangbo9674 Sep 13, 2021
64e2af6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbo9674 Sep 13, 2021
ac2342b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbo9674 Sep 13, 2021
6df1354
add comment
zhangbo9674 Sep 13, 2021
528da76
add comment
zhangbo9674 Sep 13, 2021
ae6d0a4
support recompute
zhangbo9674 Sep 13, 2021
806018b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbo9674 Sep 13, 2021
d1c277a
add comment for auto_cast and decorator
zhangbo9674 Sep 13, 2021
c2b472b
support to_static_state_dict for paddle.jit.save
zhangbo9674 Sep 14, 2021
f9f75a4
unlimite models num and optimizers num
zhangbo9674 Sep 14, 2021
fa9c9d9
add lookup_table in black_list
zhangbo9674 Sep 14, 2021
cafea36
fix momentum and layer state_dict
zhangbo9674 Sep 14, 2021
cd545e6
merget upstream develop
zhangbo9674 Sep 14, 2021
1502f55
fix bug in layer state_dict
zhangbo9674 Sep 15, 2021
fb9a785
fix bug in layer state_dict_helper
zhangbo9674 Sep 15, 2021
704e7f6
refine unittest
zhangbo9674 Sep 15, 2021
5b56d84
refine test_momentun_op
zhangbo9674 Sep 15, 2021
042a953
refine interface and some code
zhangbo9674 Sep 16, 2021
b6e4a99
refine amp_decorator interface
zhangbo9674 Sep 16, 2021
4524bab
refine pure fp16 interface
zhangbo9674 Sep 17, 2021
e1118cb
refine master weight interface
zhangbo9674 Sep 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion paddle/fluid/imperative/amp_auto_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ static inline std::shared_ptr<imperative::VarBase> CastToType(
imperative::NameVarBaseMap outs = {{"Out", {out}}};

{
AutoCastGuard guard(tracer, false);
AutoCastGuard guard(tracer, 0);
tracer->TraceOp("cast", ins, outs, std::move(attrs));
}

Expand Down Expand Up @@ -225,5 +225,30 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
return new_ins;
}

NameVarBaseMap CastPureFp16Inputs(const std::string& op_type,
const NameVarBaseMap& ins) {
NameVarBaseMap new_ins(ins);
auto dst_type = framework::proto::VarType::FP16;
if (AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(op_type) ||
AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
dst_type = framework::proto::VarType::FP32;
}
for (auto& pair : new_ins) {
if ((op_type == "batch_norm" || op_type == "layer_norm" ||
op_type == "sync_batch_norm") &&
pair.first != "X") {
continue;
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type);
for (auto& var : pair.second) {
var = (dst_type == framework::proto::VarType::FP32 ? CastToFP32(var)
: CastToFP16(var));
}
}
return new_ins;
}

} // namespace imperative
} // namespace paddle
16 changes: 10 additions & 6 deletions paddle/fluid/imperative/amp_auto_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,31 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
// NOTE(zhiqiu): AutoCastGuard is used for RAII.
class AutoCastGuard {
public:
AutoCastGuard(std::shared_ptr<Tracer> tracer, bool guard_mode)
AutoCastGuard(std::shared_ptr<Tracer> tracer, int guard_level)
: tracer_(tracer) {
pre_mode_ = tracer_->IsAutoCastEnabled();
if (pre_mode_ != guard_mode) {
tracer_->SetEnableAutoCast(guard_mode);
pre_amp_level_ = tracer_->AMPLevel();

if (pre_amp_level_ != guard_level) {
tracer_->SetAMPLevel(guard_level);
}
}

~AutoCastGuard() { tracer_->SetEnableAutoCast(pre_mode_); }
~AutoCastGuard() { tracer_->SetAMPLevel(pre_amp_level_); }

// forbid copy and operator=
AutoCastGuard(const AutoCastGuard& guard) = delete;
AutoCastGuard& operator=(const AutoCastGuard& guard) = delete;

private:
std::shared_ptr<Tracer> tracer_;
bool pre_mode_;
int pre_amp_level_;
};

NameVarBaseMap AutoCastInputs(const std::string& op_type,
const NameVarBaseMap& ins);

NameVarBaseMap CastPureFp16Inputs(const std::string& op_type,
const NameVarBaseMap& ins);

} // namespace imperative
} // namespace paddle
5 changes: 4 additions & 1 deletion paddle/fluid/imperative/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,12 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
: attr_checker->GetDefaultAttrMap();

NameVarBaseMap new_ins = ins;
if (enable_autocast_) {
if (amp_level_ == 1) {
VLOG(5) << "Auto mixed precision run operator: " << type;
new_ins = AutoCastInputs(type, ins);
} else if (amp_level_ == 2) {
VLOG(5) << "Pure fp16 run operator: " << type;
new_ins = CastPureFp16Inputs(type, ins);
}

try {
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/imperative/tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class Tracer {

void SetHasGrad(bool has_grad) { has_grad_ = has_grad; }

void SetEnableAutoCast(bool enabled) { enable_autocast_ = enabled; }
void SetAMPLevel(int level) { amp_level_ = level; }

bool IsAutoCastEnabled() const { return enable_autocast_; }
int AMPLevel() const { return amp_level_; }

paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists(
const platform::Place& place);
Expand All @@ -118,9 +118,9 @@ class Tracer {
bool enable_program_desc_tracing_{false};
std::unique_ptr<UniqueNameGenerator> generator_;
platform::Place expected_place_;
bool enable_autocast_{false};
GarbageCollectorMap gcs_;
static thread_local bool has_grad_;
int amp_level_{0};
};

// To access static variable current_tracer
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1815,8 +1815,8 @@ void BindImperative(py::module *m_ptr) {
.def_property("_enable_program_desc_tracing",
&imperative::Tracer::IsProgramDescTracingEnabled,
&imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_enable_autocast", &imperative::Tracer::IsAutoCastEnabled,
&imperative::Tracer::SetEnableAutoCast)
.def_property("_amp_level", &imperative::Tracer::AMPLevel,
&imperative::Tracer::SetAMPLevel)
.def_property("_has_grad", &imperative::Tracer::HasGrad,
&imperative::Tracer::SetHasGrad)
.def_property(
Expand Down
18 changes: 13 additions & 5 deletions paddle/fluid/pybind/op_function_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,15 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"moving_average_abs_max_scale", {"X", "InAccum", "InState"}},
{"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}},
{"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}},
{"momentum", {"Param", "Grad", "Velocity", "LearningRate"}},
{"momentum", {"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}},
{"sparse_momentum", {"Param", "Grad", "Velocity", "Index", "LearningRate"}},
Copy link
Contributor

@GuoxiaWang GuoxiaWang Sep 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下次提交的时候,帮忙把 sparse_momentum 也加上 MasterParam 吧,谢谢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前框架中还没有找到使用sparse_momentum的优化器,以及动态图调用sparse_momentum的地方,所以pure fp16的pr中暂时先不加入了。

{"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}},
{"run_program", {"X", "Params"}},
{"matrix_rank", {"X", "TolTensor"}}};
{"matrix_rank", {"X", "TolTensor"}},
{"adam",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}},
};

// NOTE(zhiqiu): Like op_ins_map.
// Commonly, the outputs in auto-generated OP function are determined by the
Expand Down Expand Up @@ -97,12 +101,15 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"Out", "OutScale", "OutAccum", "OutState"}},
{"multiclass_nms3", {"Out", "NmsRoisNum"}},
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
{"momentum", {"ParamOut", "VelocityOut"}},
{"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"sparse_momentum", {"ParamOut", "VelocityOut"}},
{"rnn", {"DropoutState", "Reserve", "Out", "State"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"run_program", {"DOut"}},
{"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
};

// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
Expand All @@ -119,13 +126,14 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"sgd", {"ParamOut"}},
{"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"average_accumulates",
{"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
"out_old_num_accumulates", "out_num_updates"}},
{"momentum", {"ParamOut", "VelocityOut"}},
{"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
{"sparse_momentum", {"ParamOut", "VelocityOut"}},
{"batch_norm", {"MeanOut", "VarianceOut"}},
{"sync_batch_norm", {"MeanOut", "VarianceOut"}},
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/amp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@

from .auto_cast import auto_cast # noqa: F401
from .grad_scaler import GradScaler # noqa: F401
from .auto_cast import decorator # noqa: F401

__all__ = ['auto_cast', 'GradScaler']
__all__ = ['auto_cast', 'GradScaler', 'decorator']
73 changes: 70 additions & 3 deletions python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,22 @@
# limitations under the License.

from paddle.fluid.dygraph.amp import amp_guard
from paddle.fluid.dygraph.amp import amp_decorator

__all__ = []


def auto_cast(enable=True, custom_white_list=None, custom_black_list=None):
def auto_cast(enable=True,
custom_white_list=None,
custom_black_list=None,
level='O1'):
"""
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
If enabled, the input data type (float32 or float16) of each operator is decided
by autocast algorithm for better performance.

Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in
imperative mode.
imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode.

Args:
enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
Expand All @@ -34,6 +38,8 @@ def auto_cast(enable=True, custom_white_list=None, custom_black_list=None):
custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16
calculation and are considered numerically-dangerous and whose effects may also be
observed in downstream ops. These ops will not be converted to fp16.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list;
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp)

Examples:

Expand Down Expand Up @@ -61,6 +67,67 @@ def auto_cast(enable=True, custom_white_list=None, custom_black_list=None):
with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}):
c = a + b
print(c.dtype) # FP16

with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}, level='O2'):
d = a + b
print(d.dtype) # FP16

"""
return amp_guard(enable, custom_white_list, custom_black_list, level)


def decorator(models,
optimizers=None,
level='O1',
master_weight=None,
save_dtype=None):
"""
Decorator models and optimizers for auto-mixed-precision. When level is O1(amp), the decorator will do nothing.
When level is O2(pure fp16), the decorator will cast all parameters of models to FP16, except BatchNorm and LayerNorm.

Commonly, it is used together with `auto_cast` to achieve Pure fp16 in imperative mode.

Args:
models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None.
optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing;
O2 represent Pure fp16, the decorator will cast all parameters of models to FP16, except BatchNorm and LayerNorm. Default is O1(amp)
master_weight(None|bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, it will keep origin Optimizer multi-precision strategy. Default is None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, float32, float64 or None.
The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.

Examples:

.. code-block:: python

# required: gpu
# Demo1: single model and optimizer:
import paddle

model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
optimzier = paddle.optimizer.SGD(parameters=model.parameters())

model, optimizer = paddle.amp.decorator(models=model, optimizers=optimzier, level='O2')

data = paddle.rand([10, 3, 32, 32])

with paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
output = model(data)
print(output.dtype) # FP16

# required: gpu
# Demo2: multi models and optimizers:
model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters())

models, optimizers = paddle.amp.decorator(models=[model, model2], optimizers=[optimzier, optimizer2], level='O2')

data = paddle.rand([10, 3, 32, 32])

with paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
output = models[0](data)
output2 = models[1](data)
print(output.dtype) # FP16
print(output2.dtype) # FP16
"""
return amp_guard(enable, custom_white_list, custom_black_list)
return amp_decorator(models, optimizers, level, master_weight, save_dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ def forward(ctx, run_function, all_outputs, *args):

# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = tracer._enable_autocast
if tracer._amp_level == 0:
ctx.is_fw_autocast = False
else:
ctx.is_fw_autocast = True
ctx.amp_mode = 'O1'
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

with paddle.no_grad():
Expand Down Expand Up @@ -258,7 +262,8 @@ def backward(ctx, *args):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list):
custom_black_list=ctx.amp_black_list,
level=ctx.amp_mode):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)

Expand Down
12 changes: 9 additions & 3 deletions python/paddle/distributed/fleet/utils/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def forward(ctx, run_function, preserve_rng_state, *args):

# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = tracer._enable_autocast
if tracer._amp_level == 0:
ctx.is_fw_autocast = False
else:
ctx.is_fw_autocast = True
ctx.amp_mode = 'O1'
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

with paddle.no_grad():
Expand Down Expand Up @@ -128,14 +132,16 @@ def backward(ctx, *args):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list):
custom_black_list=ctx.amp_black_list,
level=ctx.amp_mode):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list):
custom_black_list=ctx.amp_black_list,
level=ctx.amp_mode):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)

Expand Down
20 changes: 11 additions & 9 deletions python/paddle/fluid/contrib/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,21 @@ def _append_optimize_op(self, block, param_and_grad):
param_and_grad[0])
lr = self._create_param_lr(param_and_grad)

if framework.in_dygraph_mode():
_, _ = _C_ops.momentum(
param_and_grad[0], param_and_grad[1], velocity_acc, lr,
param_and_grad[0], velocity_acc, 'mu', self._momentum,
'use_nesterov', self._use_nesterov, 'regularization_method',
self._regularization_method, 'regularization_coeff',
self._regularization_coeff)
return None

find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)

if framework.in_dygraph_mode():
_, _, _ = _C_ops.momentum(
param_and_grad[0], param_and_grad[1], velocity_acc, lr,
master_weight, param_and_grad[0], velocity_acc, master_weight,
'mu', self._momentum, 'use_nesterov', self._use_nesterov,
'regularization_method', self._regularization_method,
'regularization_coeff', self._regularization_coeff,
'multi_precision', find_master)
return None

attrs = {
"mu": self._momentum,
"use_nesterov": self._use_nesterov,
Expand Down
Loading