Skip to content

Commit a85dedf

Browse files
authored
Delete duplicate quant nodes in QAT (#48751)
1 parent 2a31c9d commit a85dedf

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

python/paddle/fluid/contrib/slim/quantization/quantization_pass.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2907,13 +2907,31 @@ def apply(self, graph):
29072907
graph, IrGraph
29082908
), 'graph must be the instance of IrGraph.'
29092909
fake_quant_dequant_ops = []
2910+
remove_fake_quant_ops = []
2911+
observer_out_node_names = []
2912+
for op in graph.all_op_nodes():
2913+
# collect observer node
2914+
if op.name() == "moving_average_abs_max_scale":
2915+
observer_out_node_names.append(op.output("Out")[0])
29102916

29112917
for op in graph.all_op_nodes():
29122918
if (
29132919
op.name() in _fake_quant_dequant_op_list
29142920
or op.name() == "moving_average_abs_max_scale"
29152921
):
2916-
fake_quant_dequant_ops.append(op)
2922+
var_name = op.input("X")[0]
2923+
if var_name in observer_out_node_names:
2924+
remove_fake_quant_ops.append(op)
2925+
else:
2926+
fake_quant_dequant_ops.append(op)
2927+
2928+
for _op in remove_fake_quant_ops:
2929+
x_node = graph._find_node_by_name(_op.inputs, _op.input("X")[0])
2930+
out_node = graph._find_node_by_name(
2931+
_op.outputs, _op.output("Out")[0]
2932+
)
2933+
for next_op_node in out_node.outputs:
2934+
graph.update_input_link(out_node, x_node, next_op_node)
29172935

29182936
for _op in fake_quant_dequant_ops:
29192937
self._replace_op(graph, _op)

0 commit comments

Comments
 (0)