Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 88 additions & 68 deletions python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,10 @@

import collections
import numpy as np
import six
from ..... import compat as cpt
from .... import core
from .... import Executor
from ....framework import IrGraph
from ....framework import IrNode
from ....framework import Program
from ....initializer import Constant
from ....initializer import NumpyArrayInitializer
from .... import unique_name

__all__ = [
Expand Down Expand Up @@ -107,7 +102,6 @@ def __init__(self,
self._window_size = window_size
self._moving_rate = moving_rate

self._need_initialized = collections.OrderedDict()
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._quantizable_grad_ops = [
Expand All @@ -127,14 +121,17 @@ def apply(self, graph):
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
self._need_initialized.clear()
#sequential_execution = core.get_pass('sequential_execution_pass')
#sequential_execution.apply(graph.graph)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove unused comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能需要,正在测试。如果不需要后面会去掉。

self._is_test = graph.is_test()
# marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict()
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]

def _transform_forward(graph, op):
for var_node in op.inputs:
if var_node.name() not in op.input_arg_names():
continue
if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()]
else:
Expand Down Expand Up @@ -168,6 +165,8 @@ def _transform_forward(graph, op):
def _transform_backward(graph, op):
no_dequanted_input_vars = True
for var_node in op.inputs:
if var_node.name() not in op.input_arg_names():
continue
if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()]
graph.update_input_link(var_node, dequant_var_node, op)
Expand All @@ -188,25 +187,7 @@ def _transform_backward(graph, op):
for op in ops:
if op.name() in self._quantizable_grad_ops:
_transform_backward(graph, op)

if len(self._need_initialized) > 0:
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program = Program()
for var_desc, initializer in six.iteritems(self._need_initialized):
var = init_program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
initializer(var, init_program.global_block())
exe = Executor(self._place)
exe.run(program=init_program, scope=self._scope)

graph.resolve_hazard()
return graph

def _create_global_step(self, graph):
Expand All @@ -222,8 +203,9 @@ def _create_global_step(self, graph):
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=core.VarDesc.VarType.INT64)
self._need_initialized[global_step_in.var()] = \
Constant(value=0, force_cpu=True)
self._init_var_node(
global_step_in, np.zeros(
[1], dtype='int64'))
global_step_out = graph.create_var_node_from_desc(
global_step_in.var())
# The attribute of `op_role` is needed by ParallelExecutor.
Expand Down Expand Up @@ -300,7 +282,9 @@ def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.dtype())
self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type))

scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
inputs = {'X': var_node, 'InScale': scale_in_node}
Expand All @@ -313,7 +297,11 @@ def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self._window_size],
var_dtype=var_node.dtype())
self._need_initialized[scales_node.var()] = Constant(value=0)
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(
scales_node, np.zeros(
[self._window_size], dtype=data_type))
inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node
attrs = {
Expand Down Expand Up @@ -353,7 +341,9 @@ def _insert_quant_moving_average_abs_max_op(self, graph, var_node,
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.dtype())
self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type))

scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
ins = {'X': var_node, 'InScale': scale_in_node}
Expand All @@ -364,13 +354,15 @@ def _insert_quant_moving_average_abs_max_op(self, graph, var_node,
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
self._need_initialized[state_in_node.var()] = Constant(value=1)
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.ones([1], dtype=data_type))
accum_in_node = graph.create_persistable_node(
name=unique_name.generate('accum'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
self._need_initialized[accum_in_node.var()] = Constant(value=1)
self._init_var_node(accum_in_node, np.ones([1], dtype=data_type))
state_out_node = graph.create_var_node_from_desc(state_in_node.var(
))
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
Expand Down Expand Up @@ -490,6 +482,16 @@ def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes,
graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node

def _init_var_node(self, var_node, value):
assert isinstance(
value, np.ndarray), 'The type of value should be numpy array.'
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor = self._scope.var(var_node.name()).get_tensor()
tensor.set(value, self._place)

def _quantized_var_name(self, var_name):
"""
Return quantized variable name for the input `var_name`.
Expand Down Expand Up @@ -592,7 +594,8 @@ def apply(self, graph):
self._weight_bits)
self._restore_var(input_arg_name, quantized_param_v)
else:
scale_v = graph.var_node(op_node.output('OutScale')[0])
scale_v = self._to_node(op_node.outputs,
op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v

ops = graph.all_op_nodes()
Expand All @@ -613,32 +616,35 @@ def apply(self, graph):
for op_node in ops:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
for var_node in op_node.inputs:
name = var_node.name()
if name in self._op_output_rename_map:
old_in = graph.var_node(name)
new_in = self._op_output_rename_map[name]
if var_node.node in self._op_output_rename_map:
old_in = var_node
new_in = self._op_output_rename_map[var_node.node]
graph.update_input_link(old_in, new_in, op_node)

# remove the unused var node in the graph
self._remove_unused_var_nodes(graph)
graph.resolve_hazard()
return graph

def _remove_fake_quant_and_dequant_op(self, graph, op_node):
k = op_node.output('Out')[0]
v = op_node.input('X')[0]
if v not in self._op_input_rename_map:
self._op_input_rename_map[k] = v
k = self._to_node(op_node.outputs, op_node.output('Out')[0])
v = self._to_node(op_node.inputs, op_node.input('X')[0])
if v.node not in self._op_input_rename_map:
self._op_input_rename_map[k.node] = v
else:
self._op_input_rename_map[k] = self._op_input_rename_map[v]
self._op_input_rename_map[k.node] = self._op_input_rename_map[
v.node]
graph.safe_remove_nodes(op_node)

def _insert_post_channel_dequant_op(self, graph, op_node):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs:
name = var_node.name()
if name in self._op_input_rename_map:
old_in = graph.var_node(name)
new_in = graph.var_node(self._op_input_rename_map[name])
if name not in op_node.input_arg_names():
continue
if var_node.node in self._op_input_rename_map:
old_in = var_node
new_in = self._op_input_rename_map[var_node.node]
new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name)
Expand All @@ -653,28 +659,20 @@ def _insert_post_channel_dequant_op(self, graph, op_node):
assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name]

if len(op_node.outputs) != 1:
if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name()))

output_var_node = op_node.outputs[0]
output_var_node = self._to_node(op_node.outputs,
op_node.output_arg_names()[0])
weight_scale_node = graph.create_persistable_node(
name=unique_name.generate('channel_scale'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[channel_scale.shape[0]],
var_dtype=output_var_node.dtype())
init_program = Program()
weight_scale_var = init_program.global_block().create_var(
name=weight_scale_node.name(),
shape=weight_scale_node.shape(),
dtype=weight_scale_node.dtype(),
type=weight_scale_node.type(),
lod_level=weight_scale_node.var().lod_level(),
persistable=weight_scale_node.persistable())
initializer = NumpyArrayInitializer(value=channel_scale)
initializer(weight_scale_var, init_program.global_block())
exe = Executor(self._place)
exe.run(program=init_program, scope=self._scope)
data_type = 'float64' if output_var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(weight_scale_node, channel_scale.astype(data_type))
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(),
Expand All @@ -695,16 +693,18 @@ def _insert_post_channel_dequant_op(self, graph, op_node):
graph.link_to(scale_var_node, dequant_op_node)
graph.link_to(weight_scale_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
self._op_output_rename_map[output_var_node.name()] = dequant_var_node
self._op_output_rename_map[output_var_node.node] = dequant_var_node
return dequant_var_node

def _insert_post_dequant_op(self, graph, op_node):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs:
name = var_node.name()
if name in self._op_input_rename_map:
old_in = graph.var_node(name)
new_in = graph.var_node(self._op_input_rename_map[name])
if name not in op_node.input_arg_names():
continue
if var_node.node in self._op_input_rename_map:
old_in = var_node
new_in = self._op_input_rename_map[var_node.node]
new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name)
Expand All @@ -720,11 +720,12 @@ def _insert_post_dequant_op(self, graph, op_node):
assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name]

if len(op_node.outputs) != 1:
if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name()))

output_var_node = op_node.outputs[0]
output_var_node = self._to_node(op_node.outputs,
op_node.output_arg_names()[0])
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(),
Expand All @@ -742,9 +743,27 @@ def _insert_post_dequant_op(self, graph, op_node):
graph.link_to(output_var_node, dequant_op_node)
graph.link_to(scale_var_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
self._op_output_rename_map[output_var_node.name()] = dequant_var_node
self._op_output_rename_map[output_var_node.node] = dequant_var_node
return dequant_var_node

def _init_var_node(self, var_node, value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line485的方法有什么区别么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不在一个类里面,但是都需要这个功能。

assert isinstance(
value, np.ndarray), 'The type of value should be numpy array.'
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor = self._scope.var(var_node.name()).get_tensor()
tensor.set(value, self._place)

def _to_node(self, nodes, node_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个方法命名有点不太直观。。。你是想name_to_node吧?或者find_node_by_name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是个辅助函数,内部使用的,感觉这样简洁一点?

target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node

def _load_var(self, name):
return np.array(self._scope.find_var(name).get_tensor())

Expand Down Expand Up @@ -848,6 +867,7 @@ def apply(self, graph):

# remove the unused var node in the graph
self._remove_unused_var_nodes(graph)
graph.resolve_hazard()
return graph

def _convert_to_int8(self, graph, var_node):
Expand Down Expand Up @@ -930,5 +950,5 @@ def apply(self, graph):
for output_node in op_node.outputs:
graph.link_to(dequant_node, output_node)
graph.safe_remove_nodes(op_node)

graph.resolve_hazard()
return graph
Loading