@@ -157,7 +157,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
157157 return num_cast_ops
158158
159159 assert target_var .dtype == src_dtype , \
160- "The real dtype({}) is not equal to the src dtype({})" .format (_dtype_to_str (target_var .dtype ), _dtype_to_str (src_dtype ))
160+ "The real dtype({}) is not equal to the src dtype({})" .format (
161+ _dtype_to_str (target_var .dtype ), _dtype_to_str (src_dtype ))
161162
162163 cast_name = target_var .name + '.cast_' + _dtype_to_str (dest_dtype )
163164 cast_var = block .vars .get (cast_name )
@@ -209,19 +210,30 @@ def find_true_prev_op(ops, cur_op, var_name):
209210 return None
210211
211212
212- def find_true_post_op (ops , cur_op , var_name ):
213+ def find_true_post_op (ops , cur_op , var_name , search_all = False ):
213214 """
214215 if there are post ops, return them, if there is no post op,
215216 return None instead.
216217 Args:
217218 ops (list): A list of ops.
218219 cur_op (Operator): Current operator which has var_name variable.
219220 var_name (string): Variable name.
221+ search_all (bool): The type of operator search. Use if \" cur_op\" is not in the \" ops\" set.
220222 """
221223 post_op = []
222- for idx , op in enumerate (ops ):
223- if op == cur_op :
224- break
224+ if search_all :
225+ """
226+ \" cur_op\" do not have to be in list of \" ops\" . E.g. \" cur_op\" can come
227+ from startup_prog block and \" ops\" list from main_prog block.
228+ By setting idx to -1, we'll start looking for post-ops from the top of the list.
229+ If search_all is False, assume that \" cur_op\" is in \" ops\" list,
230+ so to reduce the time of search we can start iterating from \" cur_op\" idx.
231+ """
232+ idx = - 1
233+ else :
234+ for idx , op in enumerate (ops ):
235+ if op == cur_op :
236+ break
225237
226238 for i in range (idx + 1 , len (ops )):
227239 op = ops [i ]
@@ -270,7 +282,7 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard):
270282
271283 if use_fp16_guard :
272284 if op .has_attr ("op_namescope" ) and \
273- (_fp16_guard_pattern in op .attr ("op_namescope" )):
285+ (_fp16_guard_pattern in op .attr ("op_namescope" )):
274286 # op in fp16 guard
275287 return False
276288 else :
@@ -496,8 +508,8 @@ def rewrite_program(main_prog, amp_lists):
496508 black_op_set = set ()
497509 for op in ops :
498510
499- # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
500- # we don't need to handle reader op and the input of 'create_py_reader' is not
511+ # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
512+ # we don't need to handle reader op and the input of 'create_py_reader' is not
501513 # in block, which may result in errors.
502514 # See GeneratorLoader._init_non_iterable() for details.
503515 if op .type == 'create_py_reader' or op .type == 'read' :
@@ -612,7 +624,7 @@ def update_role_var_grad(main_prog, params_grads):
612624 raise ValueError ("The cast op {0}'s output should not be"
613625 "used by a non-optimize op, however, it"
614626 "is used by {1}" .format (op , post_ops [0 ]))
615- #add new op in the python and cpp at the same time
627+ # add new op in the python and cpp at the same time
616628 new_op_desc = block .desc .append_op ()
617629 new_op_desc .copy_from (op .desc )
618630 new_op = framework .Operator (
0 commit comments