@@ -2907,13 +2907,31 @@ def apply(self, graph):
29072907 graph , IrGraph
29082908 ), 'graph must be the instance of IrGraph.'
29092909 fake_quant_dequant_ops = []
2910+ remove_fake_quant_ops = []
2911+ observer_out_node_names = []
2912+ for op in graph .all_op_nodes ():
2913+ # collect observer node
2914+ if op .name () == "moving_average_abs_max_scale" :
2915+ observer_out_node_names .append (op .output ("Out" )[0 ])
29102916
29112917 for op in graph .all_op_nodes ():
29122918 if (
29132919 op .name () in _fake_quant_dequant_op_list
29142920 or op .name () == "moving_average_abs_max_scale"
29152921 ):
2916- fake_quant_dequant_ops .append (op )
2922+ var_name = op .input ("X" )[0 ]
2923+ if var_name in observer_out_node_names :
2924+ remove_fake_quant_ops .append (op )
2925+ else :
2926+ fake_quant_dequant_ops .append (op )
2927+
2928+ for _op in remove_fake_quant_ops :
2929+ x_node = graph ._find_node_by_name (_op .inputs , _op .input ("X" )[0 ])
2930+ out_node = graph ._find_node_by_name (
2931+ _op .outputs , _op .output ("Out" )[0 ]
2932+ )
2933+ for next_op_node in out_node .outputs :
2934+ graph .update_input_link (out_node , x_node , next_op_node )
29172935
29182936 for _op in fake_quant_dequant_ops :
29192937 self ._replace_op (graph , _op )
0 commit comments