Skip to content

Commit e335e98

Browse files
zhangbo9674AnnaTrainingG
authored andcommitted
[AMP] Support pure fp16 training mode for dygraph (PaddlePaddle#35521)
* add pure fp16 major function in auto_cast & tracer * support master weight in dygraph for pure fp16 * check mix dtype of fp16&fp32 for check_finite_and_unscale op * change pure fp16 funtion name * refine some bug in auto_cast * refine auto_cast interface logic * add param _casted_by_pure_fp16 for class Layer * support state_dict hook for save model by user appointed dtype in pure_fp16_decorator * refine pure_fp16_decorator as decorator * add unittest * add comment * add comment * support recompute * add comment for auto_cast and decorator * support to_static_state_dict for paddle.jit.save * unlimite models num and optimizers num * add lookup_table in black_list * fix momentum and layer state_dict * fix bug in layer state_dict * fix bug in layer state_dict_helper * refine unittest * refine test_momentun_op * refine interface and some code * refine amp_decorator interface * refine pure fp16 interface * refine master weight interface
1 parent 8c58c76 commit e335e98

File tree

21 files changed

+1069
-192
lines changed

21 files changed

+1069
-192
lines changed

paddle/fluid/imperative/amp_auto_cast.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ static inline std::shared_ptr<imperative::VarBase> CastToType(
117117
imperative::NameVarBaseMap outs = {{"Out", {out}}};
118118

119119
{
120-
AutoCastGuard guard(tracer, false);
120+
AutoCastGuard guard(tracer, 0);
121121
tracer->TraceOp("cast", ins, outs, std::move(attrs));
122122
}
123123

@@ -225,5 +225,30 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
225225
return new_ins;
226226
}
227227

228+
NameVarBaseMap CastPureFp16Inputs(const std::string& op_type,
229+
const NameVarBaseMap& ins) {
230+
NameVarBaseMap new_ins(ins);
231+
auto dst_type = framework::proto::VarType::FP16;
232+
if (AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(op_type) ||
233+
AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
234+
dst_type = framework::proto::VarType::FP32;
235+
}
236+
for (auto& pair : new_ins) {
237+
if ((op_type == "batch_norm" || op_type == "layer_norm" ||
238+
op_type == "sync_batch_norm") &&
239+
pair.first != "X") {
240+
continue;
241+
}
242+
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
243+
<< GetDtypeStr(*pair.second.cbegin()) << " to "
244+
<< framework::DataTypeToString(dst_type);
245+
for (auto& var : pair.second) {
246+
var = (dst_type == framework::proto::VarType::FP32 ? CastToFP32(var)
247+
: CastToFP16(var));
248+
}
249+
}
250+
return new_ins;
251+
}
252+
228253
} // namespace imperative
229254
} // namespace paddle

paddle/fluid/imperative/amp_auto_cast.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,27 +63,31 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
6363
// NOTE(zhiqiu): AutoCastGuard is used for RAII.
6464
class AutoCastGuard {
6565
public:
66-
AutoCastGuard(std::shared_ptr<Tracer> tracer, bool guard_mode)
66+
AutoCastGuard(std::shared_ptr<Tracer> tracer, int guard_level)
6767
: tracer_(tracer) {
68-
pre_mode_ = tracer_->IsAutoCastEnabled();
69-
if (pre_mode_ != guard_mode) {
70-
tracer_->SetEnableAutoCast(guard_mode);
68+
pre_amp_level_ = tracer_->AMPLevel();
69+
70+
if (pre_amp_level_ != guard_level) {
71+
tracer_->SetAMPLevel(guard_level);
7172
}
7273
}
7374

74-
~AutoCastGuard() { tracer_->SetEnableAutoCast(pre_mode_); }
75+
~AutoCastGuard() { tracer_->SetAMPLevel(pre_amp_level_); }
7576

7677
// forbid copy and operator=
7778
AutoCastGuard(const AutoCastGuard& guard) = delete;
7879
AutoCastGuard& operator=(const AutoCastGuard& guard) = delete;
7980

8081
private:
8182
std::shared_ptr<Tracer> tracer_;
82-
bool pre_mode_;
83+
int pre_amp_level_;
8384
};
8485

8586
NameVarBaseMap AutoCastInputs(const std::string& op_type,
8687
const NameVarBaseMap& ins);
8788

89+
NameVarBaseMap CastPureFp16Inputs(const std::string& op_type,
90+
const NameVarBaseMap& ins);
91+
8892
} // namespace imperative
8993
} // namespace paddle

paddle/fluid/imperative/tracer.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,12 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
176176
: attr_checker->GetDefaultAttrMap();
177177

178178
NameVarBaseMap new_ins = ins;
179-
if (enable_autocast_) {
179+
if (amp_level_ == 1) {
180180
VLOG(5) << "Auto mixed precision run operator: " << type;
181181
new_ins = AutoCastInputs(type, ins);
182+
} else if (amp_level_ == 2) {
183+
VLOG(5) << "Pure fp16 run operator: " << type;
184+
new_ins = CastPureFp16Inputs(type, ins);
182185
}
183186

184187
try {

paddle/fluid/imperative/tracer.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ class Tracer {
105105

106106
void SetHasGrad(bool has_grad) { has_grad_ = has_grad; }
107107

108-
void SetEnableAutoCast(bool enabled) { enable_autocast_ = enabled; }
108+
void SetAMPLevel(int level) { amp_level_ = level; }
109109

110-
bool IsAutoCastEnabled() const { return enable_autocast_; }
110+
int AMPLevel() const { return amp_level_; }
111111

112112
paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists(
113113
const platform::Place& place);
@@ -118,9 +118,9 @@ class Tracer {
118118
bool enable_program_desc_tracing_{false};
119119
std::unique_ptr<UniqueNameGenerator> generator_;
120120
platform::Place expected_place_;
121-
bool enable_autocast_{false};
122121
GarbageCollectorMap gcs_;
123122
static thread_local bool has_grad_;
123+
int amp_level_{0};
124124
};
125125

126126
// To access static variable current_tracer

paddle/fluid/pybind/imperative.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,8 +1947,8 @@ void BindImperative(py::module *m_ptr) {
19471947
.def_property("_enable_program_desc_tracing",
19481948
&imperative::Tracer::IsProgramDescTracingEnabled,
19491949
&imperative::Tracer::SetEnableProgramDescTracing)
1950-
.def_property("_enable_autocast", &imperative::Tracer::IsAutoCastEnabled,
1951-
&imperative::Tracer::SetEnableAutoCast)
1950+
.def_property("_amp_level", &imperative::Tracer::AMPLevel,
1951+
&imperative::Tracer::SetAMPLevel)
19521952
.def_property("_has_grad", &imperative::Tracer::HasGrad,
19531953
&imperative::Tracer::SetHasGrad)
19541954
.def_property(

paddle/fluid/pybind/op_function_generator.cc

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,15 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
6363
{"moving_average_abs_max_scale", {"X", "InAccum", "InState"}},
6464
{"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}},
6565
{"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}},
66-
{"momentum", {"Param", "Grad", "Velocity", "LearningRate"}},
66+
{"momentum", {"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}},
6767
{"sparse_momentum", {"Param", "Grad", "Velocity", "Index", "LearningRate"}},
6868
{"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}},
6969
{"run_program", {"X", "Params"}},
70-
{"matrix_rank", {"X", "TolTensor"}}};
70+
{"matrix_rank", {"X", "TolTensor"}},
71+
{"adam",
72+
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
73+
"Beta2Pow", "MasterParam"}},
74+
};
7175

7276
// NOTE(zhiqiu): Like op_ins_map.
7377
// Commonly, the outputs in auto-generated OP function are determined by the
@@ -97,12 +101,15 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
97101
{"Out", "OutScale", "OutAccum", "OutState"}},
98102
{"multiclass_nms3", {"Out", "NmsRoisNum"}},
99103
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
100-
{"momentum", {"ParamOut", "VelocityOut"}},
104+
{"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
101105
{"sparse_momentum", {"ParamOut", "VelocityOut"}},
102106
{"rnn", {"DropoutState", "Reserve", "Out", "State"}},
103107
{"lamb",
104108
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
105109
{"run_program", {"DOut"}},
110+
{"adam",
111+
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
112+
"MasterParamOut"}},
106113
};
107114

108115
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
@@ -119,13 +126,14 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
119126
std::map<std::string, std::set<std::string>> op_passing_outs_map = {
120127
{"sgd", {"ParamOut"}},
121128
{"adam",
122-
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
129+
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
130+
"MasterParamOut"}},
123131
{"adamw",
124132
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
125133
{"average_accumulates",
126134
{"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
127135
"out_old_num_accumulates", "out_num_updates"}},
128-
{"momentum", {"ParamOut", "VelocityOut"}},
136+
{"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}},
129137
{"sparse_momentum", {"ParamOut", "VelocityOut"}},
130138
{"batch_norm", {"MeanOut", "VarianceOut"}},
131139
{"sync_batch_norm", {"MeanOut", "VarianceOut"}},

python/paddle/amp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414

1515
from .auto_cast import auto_cast # noqa: F401
1616
from .grad_scaler import GradScaler # noqa: F401
17+
from .auto_cast import decorate # noqa: F401
1718

18-
__all__ = ['auto_cast', 'GradScaler']
19+
__all__ = ['auto_cast', 'GradScaler', 'decorate']

python/paddle/amp/auto_cast.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,22 @@
1313
# limitations under the License.
1414

1515
from paddle.fluid.dygraph.amp import amp_guard
16+
from paddle.fluid.dygraph.amp import amp_decorate
1617

1718
__all__ = []
1819

1920

20-
def auto_cast(enable=True, custom_white_list=None, custom_black_list=None):
21+
def auto_cast(enable=True,
22+
custom_white_list=None,
23+
custom_black_list=None,
24+
level='O1'):
2125
"""
2226
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
2327
If enabled, the input data type (float32 or float16) of each operator is decided
2428
by autocast algorithm for better performance.
2529
2630
Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in
27-
imperative mode.
31+
imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode.
2832
2933
Args:
3034
enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
@@ -34,6 +38,8 @@ def auto_cast(enable=True, custom_white_list=None, custom_black_list=None):
3438
custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16
3539
calculation and are considered numerically-dangerous and whose effects may also be
3640
observed in downstream ops. These ops will not be converted to fp16.
41+
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list;
42+
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp)
3743
3844
Examples:
3945
@@ -61,6 +67,67 @@ def auto_cast(enable=True, custom_white_list=None, custom_black_list=None):
6167
with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}):
6268
c = a + b
6369
print(c.dtype) # FP16
70+
71+
with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}, level='O2'):
72+
d = a + b
73+
print(d.dtype) # FP16
74+
75+
"""
76+
return amp_guard(enable, custom_white_list, custom_black_list, level)
77+
78+
79+
def decorate(models,
80+
optimizers=None,
81+
level='O1',
82+
master_weight=None,
83+
save_dtype=None):
84+
"""
85+
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
86+
When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm and LayerNorm.
87+
88+
Commonly, it is used together with `auto_cast` to achieve Pure fp16 in imperative mode.
89+
90+
Args:
91+
models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None.
92+
optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.
93+
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing;
94+
O2 represent Pure fp16, the decorator will cast all parameters of models to FP16, except BatchNorm and LayerNorm. Default is O1(amp)
95+
master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
96+
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, float32, float64 or None.
97+
The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.
98+
99+
Examples:
100+
101+
.. code-block:: python
102+
103+
# required: gpu
104+
# Demo1: single model and optimizer:
105+
import paddle
106+
107+
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
108+
optimzier = paddle.optimizer.SGD(parameters=model.parameters())
109+
110+
model, optimizer = paddle.amp.decorate(models=model, optimizers=optimzier, level='O2')
111+
112+
data = paddle.rand([10, 3, 32, 32])
113+
114+
with paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
115+
output = model(data)
116+
print(output.dtype) # FP16
117+
118+
# required: gpu
119+
# Demo2: multi models and optimizers:
120+
model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
121+
optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters())
122+
123+
models, optimizers = paddle.amp.decorate(models=[model, model2], optimizers=[optimzier, optimizer2], level='O2')
124+
125+
data = paddle.rand([10, 3, 32, 32])
64126
127+
with paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
128+
output = models[0](data)
129+
output2 = models[1](data)
130+
print(output.dtype) # FP16
131+
print(output2.dtype) # FP16
65132
"""
66-
return amp_guard(enable, custom_white_list, custom_black_list)
133+
return amp_decorate(models, optimizers, level, master_weight, save_dtype)

python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,11 @@ def forward(ctx, run_function, all_outputs, *args):
198198

199199
# TODO support AMP
200200
tracer = framework._dygraph_tracer()
201-
ctx.is_fw_autocast = tracer._enable_autocast
201+
if tracer._amp_level == 0:
202+
ctx.is_fw_autocast = False
203+
else:
204+
ctx.is_fw_autocast = True
205+
ctx.amp_mode = 'O1'
202206
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
203207

204208
with paddle.no_grad():
@@ -258,7 +262,8 @@ def backward(ctx, *args):
258262
with paddle.amp.auto_cast(
259263
enable=ctx.is_fw_autocast,
260264
custom_white_list=ctx.amp_white_list,
261-
custom_black_list=ctx.amp_black_list):
265+
custom_black_list=ctx.amp_black_list,
266+
level=ctx.amp_mode):
262267
detached_inputs = detach_variable(tuple(inputs))
263268
outputs = ctx.run_function(*detached_inputs)
264269

python/paddle/distributed/fleet/utils/recompute.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ def forward(ctx, run_function, preserve_rng_state, *args):
9898

9999
# TODO support AMP
100100
tracer = framework._dygraph_tracer()
101-
ctx.is_fw_autocast = tracer._enable_autocast
101+
if tracer._amp_level == 0:
102+
ctx.is_fw_autocast = False
103+
else:
104+
ctx.is_fw_autocast = True
105+
ctx.amp_mode = 'O1'
102106
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
103107

104108
with paddle.no_grad():
@@ -128,14 +132,16 @@ def backward(ctx, *args):
128132
with paddle.amp.auto_cast(
129133
enable=ctx.is_fw_autocast,
130134
custom_white_list=ctx.amp_white_list,
131-
custom_black_list=ctx.amp_black_list):
135+
custom_black_list=ctx.amp_black_list,
136+
level=ctx.amp_mode):
132137
detached_inputs = detach_variable(tuple(inputs))
133138
outputs = ctx.run_function(*detached_inputs)
134139
else:
135140
with paddle.amp.auto_cast(
136141
enable=ctx.is_fw_autocast,
137142
custom_white_list=ctx.amp_white_list,
138-
custom_black_list=ctx.amp_black_list):
143+
custom_black_list=ctx.amp_black_list,
144+
level=ctx.amp_mode):
139145
detached_inputs = detach_variable(tuple(inputs))
140146
outputs = ctx.run_function(*detached_inputs)
141147

0 commit comments

Comments
 (0)