Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self,
self.bf16_list = copy.copy(bf16_list)
self.fp32_list = copy.copy(fp32_list)
self.gray_list = copy.copy(gray_list)
self.bf16_initializer_list = copy.copy(bf16_initializer_list)
self.unsupported_list = copy.copy(unsupported_list)
self.fp32_varnames = copy.copy(custom_fp32_varnames)
self._update_list()
Expand Down Expand Up @@ -79,6 +80,8 @@ def _update_list(self):
self.unsupported_list.add(op_name)


bf16_initializer_list = {'fill_constant', 'uniform_random'}

# always bf16
bf16_list = {'elementwise_add', }

Expand Down
51 changes: 50 additions & 1 deletion python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,52 @@ def bf16_guard():
yield


def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True):
def are_post_ops_bf16(post_ops, keep_fp32_ops):
for post_op in post_ops:
for op in post_op:
if op.type in keep_fp32_ops:
return False
return True


def cast_initializers_to_bf16(startup_prog,
amp_lists,
block,
all_ops,
keep_fp32_ops,
to_bf16_var_names=None):
prepend_ops = startup_prog.global_block().ops
for op in prepend_ops:
if str(op.type) in amp_lists.bf16_initializer_list:
change_op = True
op_post_ops = []
op_out_vars = []
for out_name in op.output_names:
for out_var_name in op.output(out_name):
out_var = block.var(out_var_name)
post_op = find_true_post_op(all_ops, op, out_var_name, True)

if out_var is None or out_var.type not in _valid_types:
change_op = False
break
op_post_ops.append(post_op)
op_out_vars.append(out_var)

if change_op and are_post_ops_bf16(op_post_ops, keep_fp32_ops):
for out_var in op_out_vars:
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
if to_bf16_var_names is not None and out_var.name in to_bf16_var_names:
to_bf16_var_names.remove(out_var.name)
if op.has_attr('dtype') and op.attr(
'dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.BF16)


def cast_model_to_bf16(program,
startup_prog=None,
amp_lists=None,
use_bf16_guard=True):
"""
Traverse all ops in the whole model and set their inputs and outputs
to the bf16 data type. This function will do some special processing for
Expand Down Expand Up @@ -329,6 +374,10 @@ def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True):
if op.has_attr('mkldnn_data_type'):
op._set_attr('mkldnn_data_type', 'bfloat16')

if startup_prog is not None:
cast_initializers_to_bf16(startup_prog, amp_lists, global_block,
ops, keep_fp32_ops, to_bf16_var_names)

# process ops in keep_fp32_ops
op_var_rename_map = [
collections.OrderedDict() for _ in range(len(program.blocks))
Expand Down
11 changes: 7 additions & 4 deletions python/paddle/fluid/contrib/mixed_precision/bf16/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def backward(self,

if self._use_pure_bf16:
self._to_bf16_var_names = cast_model_to_bf16(
self._train_program, self._amp_lists, self._use_bf16_guard)
self._train_program, startup_program, self._amp_lists,
self._use_bf16_guard)
else:
rewrite_program_bf16(self._train_program, self._amp_lists)

Expand Down Expand Up @@ -168,10 +169,12 @@ def run_example_code():
self._to_bf16_var_names)
if test_program is not None:
if self._use_pure_bf16:
cast_model_to_bf16(test_program, self._amp_lists,
self._use_bf16_guard)
cast_model_to_bf16(
test_program,
amp_lists=self._amp_lists,
use_bf16_guard=self._use_bf16_guard)
elif use_bf16_test:
rewrite_program_bf16(test_program, self._amp_lists)
rewrite_program_bf16(test_program, amp_lists=self._amp_lists)

def apply_gradients(self, params_grads):
"""
Expand Down
30 changes: 21 additions & 9 deletions python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
return num_cast_ops

assert target_var.dtype == src_dtype, \
"The real dtype({}) is not equal to the src dtype({})".format(_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype))
"The real dtype({}) is not equal to the src dtype({})".format(
_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype))

cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype)
cast_var = block.vars.get(cast_name)
Expand Down Expand Up @@ -209,19 +210,30 @@ def find_true_prev_op(ops, cur_op, var_name):
return None


def find_true_post_op(ops, cur_op, var_name):
def find_true_post_op(ops, cur_op, var_name, search_all=False):
"""
if there are post ops, return them, if there is no post op,
return None instead.
Args:
ops (list): A list of ops.
cur_op (Operator): Current operator which has var_name variable.
var_name (string): Variable name.
search_all (bool): The type of operator search. Use if \"cur_op\" is not in the \"ops\" set.
"""
post_op = []
for idx, op in enumerate(ops):
if op == cur_op:
break
if search_all:
"""
\"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come
from startup_prog block and \"ops\" list from main_prog block.
By setting idx to -1, we'll start looking for post-ops from the top of the list.
If search_all is False, assume that \"cur_op\" is in \"ops\" list,
so to reduce the time of search we can start iterating from \"cur_op\" idx.
"""
idx = -1
else:
for idx, op in enumerate(ops):
if op == cur_op:
break

for i in range(idx + 1, len(ops)):
op = ops[i]
Expand Down Expand Up @@ -270,7 +282,7 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard):

if use_fp16_guard:
if op.has_attr("op_namescope") and \
(_fp16_guard_pattern in op.attr("op_namescope")):
(_fp16_guard_pattern in op.attr("op_namescope")):
# op in fp16 guard
return False
else:
Expand Down Expand Up @@ -496,8 +508,8 @@ def rewrite_program(main_prog, amp_lists):
black_op_set = set()
for op in ops:

# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
# we don't need to handle reader op and the input of 'create_py_reader' is not
# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
# we don't need to handle reader op and the input of 'create_py_reader' is not
# in block, which may result in errors.
# See GeneratorLoader._init_non_iterable() for details.
if op.type == 'create_py_reader' or op.type == 'read':
Expand Down Expand Up @@ -612,7 +624,7 @@ def update_role_var_grad(main_prog, params_grads):
raise ValueError("The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0]))
#add new op in the python and cpp at the same time
# add new op in the python and cpp at the same time
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op = framework.Operator(
Expand Down
23 changes: 23 additions & 0 deletions python/paddle/fluid/contrib/tests/test_bf16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,29 @@ def test_find_true_post_op(self):
res = amp.bf16.amp_utils.find_true_post_op(block.ops, op1, "Y")
assert (res == [op2])

def test_find_true_post_op_with_search_all(self):
program = fluid.Program()
block = program.current_block()
startup_block = fluid.default_startup_program().global_block()

var1 = block.create_var(name="X", shape=[3], dtype='float32')
var2 = block.create_var(name="Y", shape=[3], dtype='float32')
inititializer_op = startup_block._prepend_op(
type="fill_constant",
outputs={"Out": var1},
attrs={"shape": var1.shape,
"dtype": var1.dtype,
"value": 1.0})

op1 = block.append_op(
type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]})
result = amp.bf16.amp_utils.find_true_post_op(
block.ops, inititializer_op, "X", search_all=False)
assert (len(result) == 0)
result = amp.bf16.amp_utils.find_true_post_op(
block.ops, inititializer_op, "X", search_all=True)
assert (result == [op1])


if __name__ == '__main__':
unittest.main()
28 changes: 19 additions & 9 deletions python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,27 @@ def scope_prog_guard(self):
with fluid.program_guard(prog, startup_prog):
yield

def get_static_graph_result(self, feed, fetch_list, amp_fun,
with_lod=False):
def get_static_graph_result(self,
feed,
fetch_list,
amp_fun,
with_lod=False,
startup_prog=None):
exe = fluid.Executor(core.CPUPlace())
exe.run(fluid.default_startup_program())
exe.run(fluid.default_startup_program()
if startup_prog is None else startup_prog)
prog = fluid.default_main_program()
if amp_fun is not None:
amp_fun(prog)
if startup_prog is not None:
amp_fun(prog, startup_prog)
else:
amp_fun(prog)
return exe.run(prog,
feed=feed,
fetch_list=fetch_list,
return_numpy=(not with_lod))

def _graph_common(self, _amp_fun):
def _graph_common(self, _amp_fun, startup_prog=None):
size = 3
n = np.ones([size, size], dtype='float32') * 3.2
nn = np.ones([size, size], dtype='float32') * -2.7
Expand Down Expand Up @@ -122,7 +130,8 @@ def _graph_common(self, _amp_fun):
self.get_static_graph_result(
feed={'t': n, 'tt': nn},
fetch_list=[ret],
amp_fun=_amp_fun
amp_fun=_amp_fun,
startup_prog=startup_prog
)
self.assertTrue(
static_ret_bf16, np.ones(
Expand All @@ -132,16 +141,17 @@ def test_graph_rewrite(self):
self._graph_common(lambda prog: amp.bf16.rewrite_program_bf16(
prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_varnames={'elementwise_add_0.tmp_0'}),
custom_fp32_varnames={'elementwise_add_0.tmp_0'})
))

def test_graph_cast(self):
self._graph_common(lambda prog: amp.bf16.cast_model_to_bf16(
self._graph_common(lambda prog, startup_prog: amp.bf16.cast_model_to_bf16(
prog,
startup_prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_mul'}),
use_bf16_guard=True
))
), startup_prog=fluid.default_startup_program())


if __name__ == '__main__':
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/fluid/layers/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,13 @@ def cast(x, dtype):
out = core.ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return out

check_variable_and_dtype(
x, 'x',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
'cast')
check_variable_and_dtype(x, 'x', [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8',
'uint16'
], 'cast')
check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'int8', 'int32', 'int64',
'uint8'
'uint8', 'uint16'
], 'cast')

helper = LayerHelper('cast', **locals())
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/tests/book/test_fit_a_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16):
amp_lists=amp.bf16.AutoMixedPrecisionListsBF16(),
use_bf16_guard=False,
use_pure_bf16=pure_bf16)
sgd_optimizer.minimize(avg_cost)
sgd_optimizer.minimize(
avg_cost, startup_program=fluid.default_startup_program())

BATCH_SIZE = 20

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/book/test_word2vec_book.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __network__(words):
use_bf16_guard=False,
use_pure_bf16=pure_bf16)

sgd_optimizer.minimize(avg_cost)
sgd_optimizer.minimize(avg_cost, fluid.default_startup_program())

train_reader = paddle.batch(
paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE)
Expand Down