Skip to content

Commit 3bcac6e

Browse files
author
wozna
committed
Correct test and add comment
1 parent 70a25f4 commit 3bcac6e

2 files changed

Lines changed: 30 additions & 14 deletions

File tree

python/paddle/fluid/contrib/mixed_precision/fp16_utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
157157
return num_cast_ops
158158

159159
assert target_var.dtype == src_dtype, \
160-
"The real dtype({}) is not equal to the src dtype({})".format(_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype))
160+
"The real dtype({}) is not equal to the src dtype({})".format(
161+
_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype))
161162

162163
cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype)
163164
cast_var = block.vars.get(cast_name)
@@ -221,6 +222,13 @@ def find_true_post_op(ops, cur_op, var_name, search_all=False):
221222
"""
222223
post_op = []
223224
if search_all:
225+
"""
226+
\"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come
227+
from startup_prog block and \"ops\" list from main_prog block.
228+
By setting idx to -1, we'll start looking for post-ops from the top of the list.
229+
If search_all is False, assume that \"cur_op\" is in \"ops\" list,
230+
so to reduce the time of search we can start iterating from \"cur_op\" idx.
231+
"""
224232
idx = -1
225233
else:
226234
for idx, op in enumerate(ops):
@@ -274,7 +282,7 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard):
274282

275283
if use_fp16_guard:
276284
if op.has_attr("op_namescope") and \
277-
(_fp16_guard_pattern in op.attr("op_namescope")):
285+
(_fp16_guard_pattern in op.attr("op_namescope")):
278286
# op in fp16 guard
279287
return False
280288
else:
@@ -500,8 +508,8 @@ def rewrite_program(main_prog, amp_lists):
500508
black_op_set = set()
501509
for op in ops:
502510

503-
# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
504-
# we don't need to handle reader op and the input of 'create_py_reader' is not
511+
# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
512+
# we don't need to handle reader op and the input of 'create_py_reader' is not
505513
# in block, which may result in errors.
506514
# See GeneratorLoader._init_non_iterable() for details.
507515
if op.type == 'create_py_reader' or op.type == 'read':
@@ -616,7 +624,7 @@ def update_role_var_grad(main_prog, params_grads):
616624
raise ValueError("The cast op {0}'s output should not be"
617625
"used by a non-optimize op, however, it"
618626
"is used by {1}".format(op, post_ops[0]))
619-
#add new op in the python and cpp at the same time
627+
# add new op in the python and cpp at the same time
620628
new_op_desc = block.desc.append_op()
621629
new_op_desc.copy_from(op.desc)
622630
new_op = framework.Operator(

python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,27 @@ def scope_prog_guard(self):
5353
with fluid.program_guard(prog, startup_prog):
5454
yield
5555

56-
def get_static_graph_result(self, feed, fetch_list, amp_fun,
57-
with_lod=False):
56+
def get_static_graph_result(self,
57+
feed,
58+
fetch_list,
59+
amp_fun,
60+
with_lod=False,
61+
startup_prog=None):
5862
exe = fluid.Executor(core.CPUPlace())
59-
exe.run(fluid.default_startup_program())
63+
exe.run(fluid.default_startup_program()
64+
if startup_prog is None else startup_prog)
6065
prog = fluid.default_main_program()
61-
startup_prog = fluid.default_startup_program()
6266
if amp_fun is not None:
63-
amp_fun(prog, startup_prog)
67+
if startup_prog is not None:
68+
amp_fun(prog, startup_prog)
69+
else:
70+
amp_fun(prog)
6471
return exe.run(prog,
6572
feed=feed,
6673
fetch_list=fetch_list,
6774
return_numpy=(not with_lod))
6875

69-
def _graph_common(self, _amp_fun):
76+
def _graph_common(self, _amp_fun, startup_prog=None):
7077
size = 3
7178
n = np.ones([size, size], dtype='float32') * 3.2
7279
nn = np.ones([size, size], dtype='float32') * -2.7
@@ -123,7 +130,8 @@ def _graph_common(self, _amp_fun):
123130
self.get_static_graph_result(
124131
feed={'t': n, 'tt': nn},
125132
fetch_list=[ret],
126-
amp_fun=_amp_fun
133+
amp_fun=_amp_fun,
134+
startup_prog=startup_prog
127135
)
128136
self.assertTrue(
129137
static_ret_bf16, np.ones(
@@ -133,7 +141,7 @@ def test_graph_rewrite(self):
133141
self._graph_common(lambda prog: amp.bf16.rewrite_program_bf16(
134142
prog,
135143
amp.bf16.AutoMixedPrecisionListsBF16(
136-
custom_fp32_varnames={'elementwise_add_0.tmp_0'}),
144+
custom_fp32_varnames={'elementwise_add_0.tmp_0'})
137145
))
138146

139147
def test_graph_cast(self):
@@ -143,7 +151,7 @@ def test_graph_cast(self):
143151
amp.bf16.AutoMixedPrecisionListsBF16(
144152
custom_fp32_list={'elementwise_mul'}),
145153
use_bf16_guard=True
146-
))
154+
), startup_prog=fluid.default_startup_program())
147155

148156

149157
if __name__ == '__main__':

0 commit comments

Comments
 (0)