@@ -157,7 +157,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
157157 return num_cast_ops
158158
159159 assert target_var .dtype == src_dtype , \
160- "The real dtype({}) is not equal to the src dtype({})" .format (_dtype_to_str (target_var .dtype ), _dtype_to_str (src_dtype ))
160+ "The real dtype({}) is not equal to the src dtype({})" .format (
161+ _dtype_to_str (target_var .dtype ), _dtype_to_str (src_dtype ))
161162
162163 cast_name = target_var .name + '.cast_' + _dtype_to_str (dest_dtype )
163164 cast_var = block .vars .get (cast_name )
@@ -221,6 +222,13 @@ def find_true_post_op(ops, cur_op, var_name, search_all=False):
221222 """
222223 post_op = []
223224 if search_all :
225+ """
226+ \" cur_op\" do not have to be in list of \" ops\" . E.g. \" cur_op\" can come
227+ from startup_prog block and \" ops\" list from main_prog block.
228+ By setting idx to -1, we'll start looking for post-ops from the top of the list.
229+ If search_all is False, assume that \" cur_op\" is in \" ops\" list,
230+ so to reduce the time of search we can start iterating from \" cur_op\" idx.
231+ """
224232 idx = - 1
225233 else :
226234 for idx , op in enumerate (ops ):
@@ -274,7 +282,7 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard):
274282
275283 if use_fp16_guard :
276284 if op .has_attr ("op_namescope" ) and \
277- (_fp16_guard_pattern in op .attr ("op_namescope" )):
285+ (_fp16_guard_pattern in op .attr ("op_namescope" )):
278286 # op in fp16 guard
279287 return False
280288 else :
@@ -500,8 +508,8 @@ def rewrite_program(main_prog, amp_lists):
500508 black_op_set = set ()
501509 for op in ops :
502510
503- # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
504- # we don't need to handle reader op and the input of 'create_py_reader' is not
511+ # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
512+ # we don't need to handle reader op and the input of 'create_py_reader' is not
505513 # in block, which may result in errors.
506514 # See GeneratorLoader._init_non_iterable() for details.
507515 if op .type == 'create_py_reader' or op .type == 'read' :
@@ -616,7 +624,7 @@ def update_role_var_grad(main_prog, params_grads):
616624 raise ValueError ("The cast op {0}'s output should not be"
617625 "used by a non-optimize op, however, it"
618626 "is used by {1}" .format (op , post_ops [0 ]))
619- #add new op in the python and cpp at the same time
627+ # add new op in the python and cpp at the same time
620628 new_op_desc = block .desc .append_op ()
621629 new_op_desc .copy_from (op .desc )
622630 new_op = framework .Operator (
0 commit comments