-
Notifications
You must be signed in to change notification settings - Fork 6k
Fix quantization hang bugs. #16456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix quantization hang bugs. #16456
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,15 +14,10 @@ | |
|
|
||
| import collections | ||
| import numpy as np | ||
| import six | ||
| from ..... import compat as cpt | ||
| from .... import core | ||
| from .... import Executor | ||
| from ....framework import IrGraph | ||
| from ....framework import IrNode | ||
| from ....framework import Program | ||
| from ....initializer import Constant | ||
| from ....initializer import NumpyArrayInitializer | ||
| from .... import unique_name | ||
|
|
||
| __all__ = [ | ||
|
|
@@ -107,7 +102,6 @@ def __init__(self, | |
| self._window_size = window_size | ||
| self._moving_rate = moving_rate | ||
|
|
||
| self._need_initialized = collections.OrderedDict() | ||
| self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] | ||
| self._conv_ops = ['conv2d', 'depthwise_conv2d'] | ||
| self._quantizable_grad_ops = [ | ||
|
|
@@ -127,14 +121,17 @@ def apply(self, graph): | |
| """ | ||
| assert isinstance(graph, | ||
| IrGraph), 'graph must be the instance of IrGraph.' | ||
| self._need_initialized.clear() | ||
| #sequential_execution = core.get_pass('sequential_execution_pass') | ||
| #sequential_execution.apply(graph.graph) | ||
| self._is_test = graph.is_test() | ||
| # marked the variable which has been dequantized. | ||
| dequantized_vars = collections.OrderedDict() | ||
| persistable_vars = [p.name() for p in graph.all_persistable_nodes()] | ||
|
|
||
| def _transform_forward(graph, op): | ||
| for var_node in op.inputs: | ||
| if var_node.name() not in op.input_arg_names(): | ||
| continue | ||
| if var_node.name() in dequantized_vars: | ||
| dequant_var_node = dequantized_vars[var_node.name()] | ||
| else: | ||
|
|
@@ -168,6 +165,8 @@ def _transform_forward(graph, op): | |
| def _transform_backward(graph, op): | ||
| no_dequanted_input_vars = True | ||
| for var_node in op.inputs: | ||
| if var_node.name() not in op.input_arg_names(): | ||
| continue | ||
| if var_node.name() in dequantized_vars: | ||
| dequant_var_node = dequantized_vars[var_node.name()] | ||
| graph.update_input_link(var_node, dequant_var_node, op) | ||
|
|
@@ -188,25 +187,7 @@ def _transform_backward(graph, op): | |
| for op in ops: | ||
| if op.name() in self._quantizable_grad_ops: | ||
| _transform_backward(graph, op) | ||
|
|
||
| if len(self._need_initialized) > 0: | ||
| assert self._scope is not None, \ | ||
| 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' | ||
| assert self._place is not None, \ | ||
| 'The place cannot be set None when activation_quantize_type equals to range_abs_max.' | ||
| init_program = Program() | ||
| for var_desc, initializer in six.iteritems(self._need_initialized): | ||
| var = init_program.global_block().create_var( | ||
| name=var_desc.name(), | ||
| shape=var_desc.shape(), | ||
| dtype=var_desc.dtype(), | ||
| type=var_desc.type(), | ||
| lod_level=var_desc.lod_level(), | ||
| persistable=var_desc.persistable()) | ||
| initializer(var, init_program.global_block()) | ||
| exe = Executor(self._place) | ||
| exe.run(program=init_program, scope=self._scope) | ||
|
|
||
| graph.resolve_hazard() | ||
| return graph | ||
|
|
||
| def _create_global_step(self, graph): | ||
|
|
@@ -222,8 +203,9 @@ def _create_global_step(self, graph): | |
| var_type=core.VarDesc.VarType.LOD_TENSOR, | ||
| shape=[1], | ||
| var_dtype=core.VarDesc.VarType.INT64) | ||
| self._need_initialized[global_step_in.var()] = \ | ||
| Constant(value=0, force_cpu=True) | ||
| self._init_var_node( | ||
| global_step_in, np.zeros( | ||
| [1], dtype='int64')) | ||
| global_step_out = graph.create_var_node_from_desc( | ||
| global_step_in.var()) | ||
| # The attribute of `op_role` is needed by ParallelExecutor. | ||
|
|
@@ -300,7 +282,9 @@ def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits): | |
| var_type=core.VarDesc.VarType.LOD_TENSOR, | ||
| shape=[1], | ||
| var_dtype=var_node.dtype()) | ||
| self._need_initialized[scale_in_node.var()] = Constant(value=0.001) | ||
| data_type = 'float64' if var_node.dtype( | ||
| ) == core.VarDesc.VarType.FP64 else 'float32' | ||
| self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type)) | ||
|
|
||
| scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) | ||
| inputs = {'X': var_node, 'InScale': scale_in_node} | ||
|
|
@@ -313,7 +297,11 @@ def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits): | |
| var_type=core.VarDesc.VarType.LOD_TENSOR, | ||
| shape=[self._window_size], | ||
| var_dtype=var_node.dtype()) | ||
| self._need_initialized[scales_node.var()] = Constant(value=0) | ||
| data_type = 'float64' if var_node.dtype( | ||
| ) == core.VarDesc.VarType.FP64 else 'float32' | ||
| self._init_var_node( | ||
| scales_node, np.zeros( | ||
| [self._window_size], dtype=data_type)) | ||
| inputs['Iter'] = self._global_step | ||
| outputs['OutScales'] = scales_node | ||
| attrs = { | ||
|
|
@@ -353,7 +341,9 @@ def _insert_quant_moving_average_abs_max_op(self, graph, var_node, | |
| var_type=core.VarDesc.VarType.LOD_TENSOR, | ||
| shape=[1], | ||
| var_dtype=var_node.dtype()) | ||
| self._need_initialized[scale_in_node.var()] = Constant(value=0.001) | ||
| data_type = 'float64' if var_node.dtype( | ||
| ) == core.VarDesc.VarType.FP64 else 'float32' | ||
| self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type)) | ||
|
|
||
| scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) | ||
| ins = {'X': var_node, 'InScale': scale_in_node} | ||
|
|
@@ -364,13 +354,15 @@ def _insert_quant_moving_average_abs_max_op(self, graph, var_node, | |
| var_type=core.VarDesc.VarType.LOD_TENSOR, | ||
| var_dtype=var_node.dtype(), | ||
| shape=[1]) | ||
| self._need_initialized[state_in_node.var()] = Constant(value=1) | ||
| data_type = 'float64' if var_node.dtype( | ||
| ) == core.VarDesc.VarType.FP64 else 'float32' | ||
| self._init_var_node(scale_in_node, np.ones([1], dtype=data_type)) | ||
| accum_in_node = graph.create_persistable_node( | ||
| name=unique_name.generate('accum'), | ||
| var_type=core.VarDesc.VarType.LOD_TENSOR, | ||
| var_dtype=var_node.dtype(), | ||
| shape=[1]) | ||
| self._need_initialized[accum_in_node.var()] = Constant(value=1) | ||
| self._init_var_node(accum_in_node, np.ones([1], dtype=data_type)) | ||
| state_out_node = graph.create_var_node_from_desc(state_in_node.var( | ||
| )) | ||
| accum_out_node = graph.create_var_node_from_desc(accum_in_node.var( | ||
|
|
@@ -490,6 +482,16 @@ def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes, | |
| graph.link_to(dequant_op_node, dequant_var_node) | ||
| return dequant_var_node | ||
|
|
||
| def _init_var_node(self, var_node, value): | ||
| assert isinstance( | ||
| value, np.ndarray), 'The type of value should be numpy array.' | ||
| assert self._scope is not None, \ | ||
| 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' | ||
| assert self._place is not None, \ | ||
| 'The place cannot be set None when activation_quantize_type equals to range_abs_max.' | ||
| tensor = self._scope.var(var_node.name()).get_tensor() | ||
| tensor.set(value, self._place) | ||
|
|
||
| def _quantized_var_name(self, var_name): | ||
| """ | ||
| Return quantized variable name for the input `var_name`. | ||
|
|
@@ -592,7 +594,8 @@ def apply(self, graph): | |
| self._weight_bits) | ||
| self._restore_var(input_arg_name, quantized_param_v) | ||
| else: | ||
| scale_v = graph.var_node(op_node.output('OutScale')[0]) | ||
| scale_v = self._to_node(op_node.outputs, | ||
| op_node.output('OutScale')[0]) | ||
| self._var_scale_map[input_arg_name] = scale_v | ||
|
|
||
| ops = graph.all_op_nodes() | ||
|
|
@@ -613,32 +616,35 @@ def apply(self, graph): | |
| for op_node in ops: | ||
| # insert dequant_op after fc/conv, need to rename inputs of the followed ops | ||
| for var_node in op_node.inputs: | ||
| name = var_node.name() | ||
| if name in self._op_output_rename_map: | ||
| old_in = graph.var_node(name) | ||
| new_in = self._op_output_rename_map[name] | ||
| if var_node.node in self._op_output_rename_map: | ||
| old_in = var_node | ||
| new_in = self._op_output_rename_map[var_node.node] | ||
| graph.update_input_link(old_in, new_in, op_node) | ||
|
|
||
| # remove the unused var node in the graph | ||
| self._remove_unused_var_nodes(graph) | ||
| graph.resolve_hazard() | ||
| return graph | ||
|
|
||
| def _remove_fake_quant_and_dequant_op(self, graph, op_node): | ||
| k = op_node.output('Out')[0] | ||
| v = op_node.input('X')[0] | ||
| if v not in self._op_input_rename_map: | ||
| self._op_input_rename_map[k] = v | ||
| k = self._to_node(op_node.outputs, op_node.output('Out')[0]) | ||
| v = self._to_node(op_node.inputs, op_node.input('X')[0]) | ||
| if v.node not in self._op_input_rename_map: | ||
| self._op_input_rename_map[k.node] = v | ||
| else: | ||
| self._op_input_rename_map[k] = self._op_input_rename_map[v] | ||
| self._op_input_rename_map[k.node] = self._op_input_rename_map[ | ||
| v.node] | ||
| graph.safe_remove_nodes(op_node) | ||
|
|
||
| def _insert_post_channel_dequant_op(self, graph, op_node): | ||
| persistable_vars = [p.name() for p in graph.all_persistable_nodes()] | ||
| for var_node in op_node.inputs: | ||
| name = var_node.name() | ||
| if name in self._op_input_rename_map: | ||
| old_in = graph.var_node(name) | ||
| new_in = graph.var_node(self._op_input_rename_map[name]) | ||
| if name not in op_node.input_arg_names(): | ||
| continue | ||
| if var_node.node in self._op_input_rename_map: | ||
| old_in = var_node | ||
| new_in = self._op_input_rename_map[var_node.node] | ||
| new_in.clear_outputs() | ||
| graph.update_input_link(old_in, new_in, op_node) | ||
| original_var_name = self._original_var_name(name) | ||
|
|
@@ -653,28 +659,20 @@ def _insert_post_channel_dequant_op(self, graph, op_node): | |
| assert isinstance(scale_v, IrNode) | ||
| scale_var_node = self._var_scale_map[original_var_name] | ||
|
|
||
| if len(op_node.outputs) != 1: | ||
| if len(op_node.output_arg_names()) != 1: | ||
| raise ValueError("Only support one output, but op %s has" | ||
| " more than one output." % (op_node.name())) | ||
|
|
||
| output_var_node = op_node.outputs[0] | ||
| output_var_node = self._to_node(op_node.outputs, | ||
| op_node.output_arg_names()[0]) | ||
| weight_scale_node = graph.create_persistable_node( | ||
| name=unique_name.generate('channel_scale'), | ||
| var_type=core.VarDesc.VarType.LOD_TENSOR, | ||
| shape=[channel_scale.shape[0]], | ||
| var_dtype=output_var_node.dtype()) | ||
| init_program = Program() | ||
| weight_scale_var = init_program.global_block().create_var( | ||
| name=weight_scale_node.name(), | ||
| shape=weight_scale_node.shape(), | ||
| dtype=weight_scale_node.dtype(), | ||
| type=weight_scale_node.type(), | ||
| lod_level=weight_scale_node.var().lod_level(), | ||
| persistable=weight_scale_node.persistable()) | ||
| initializer = NumpyArrayInitializer(value=channel_scale) | ||
| initializer(weight_scale_var, init_program.global_block()) | ||
| exe = Executor(self._place) | ||
| exe.run(program=init_program, scope=self._scope) | ||
| data_type = 'float64' if output_var_node.dtype( | ||
| ) == core.VarDesc.VarType.FP64 else 'float32' | ||
| self._init_var_node(weight_scale_node, channel_scale.astype(data_type)) | ||
| dequant_var_node = graph.create_var_node( | ||
| name=self._dequantized_var_name(output_var_node.name()), | ||
| var_type=output_var_node.type(), | ||
|
|
@@ -695,16 +693,18 @@ def _insert_post_channel_dequant_op(self, graph, op_node): | |
| graph.link_to(scale_var_node, dequant_op_node) | ||
| graph.link_to(weight_scale_node, dequant_op_node) | ||
| graph.link_to(dequant_op_node, dequant_var_node) | ||
| self._op_output_rename_map[output_var_node.name()] = dequant_var_node | ||
| self._op_output_rename_map[output_var_node.node] = dequant_var_node | ||
| return dequant_var_node | ||
|
|
||
| def _insert_post_dequant_op(self, graph, op_node): | ||
| persistable_vars = [p.name() for p in graph.all_persistable_nodes()] | ||
| for var_node in op_node.inputs: | ||
| name = var_node.name() | ||
| if name in self._op_input_rename_map: | ||
| old_in = graph.var_node(name) | ||
| new_in = graph.var_node(self._op_input_rename_map[name]) | ||
| if name not in op_node.input_arg_names(): | ||
| continue | ||
| if var_node.node in self._op_input_rename_map: | ||
| old_in = var_node | ||
| new_in = self._op_input_rename_map[var_node.node] | ||
| new_in.clear_outputs() | ||
| graph.update_input_link(old_in, new_in, op_node) | ||
| original_var_name = self._original_var_name(name) | ||
|
|
@@ -720,11 +720,12 @@ def _insert_post_dequant_op(self, graph, op_node): | |
| assert isinstance(scale_v, IrNode) | ||
| scale_var_node = self._var_scale_map[original_var_name] | ||
|
|
||
| if len(op_node.outputs) != 1: | ||
| if len(op_node.output_arg_names()) != 1: | ||
| raise ValueError("Only support one output, but op %s has" | ||
| " more than one output." % (op_node.name())) | ||
|
|
||
| output_var_node = op_node.outputs[0] | ||
| output_var_node = self._to_node(op_node.outputs, | ||
| op_node.output_arg_names()[0]) | ||
| dequant_var_node = graph.create_var_node( | ||
| name=self._dequantized_var_name(output_var_node.name()), | ||
| var_type=output_var_node.type(), | ||
|
|
@@ -742,9 +743,27 @@ def _insert_post_dequant_op(self, graph, op_node): | |
| graph.link_to(output_var_node, dequant_op_node) | ||
| graph.link_to(scale_var_node, dequant_op_node) | ||
| graph.link_to(dequant_op_node, dequant_var_node) | ||
| self._op_output_rename_map[output_var_node.name()] = dequant_var_node | ||
| self._op_output_rename_map[output_var_node.node] = dequant_var_node | ||
| return dequant_var_node | ||
|
|
||
| def _init_var_node(self, var_node, value): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 和line485的方法有什么区别么?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不在一个类里面,但是都需要这个功能。 |
||
| assert isinstance( | ||
| value, np.ndarray), 'The type of value should be numpy array.' | ||
| assert self._scope is not None, \ | ||
| 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' | ||
| assert self._place is not None, \ | ||
| 'The place cannot be set None when activation_quantize_type equals to range_abs_max.' | ||
| tensor = self._scope.var(var_node.name()).get_tensor() | ||
| tensor.set(value, self._place) | ||
|
|
||
| def _to_node(self, nodes, node_name): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个方法命名有点不太直观。。。你是想name_to_node吧?或者find_node_by_name?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个是个辅助函数,内部使用的,感觉这样简洁一点? |
||
| target_node = None | ||
| for n in nodes: | ||
| if n.name() == node_name: | ||
| target_node = n | ||
| assert target_node is not None, "Cannot find the target node in the giving set." | ||
| return target_node | ||
|
|
||
| def _load_var(self, name): | ||
| return np.array(self._scope.find_var(name).get_tensor()) | ||
|
|
||
|
|
@@ -848,6 +867,7 @@ def apply(self, graph): | |
|
|
||
| # remove the unused var node in the graph | ||
| self._remove_unused_var_nodes(graph) | ||
| graph.resolve_hazard() | ||
| return graph | ||
|
|
||
| def _convert_to_int8(self, graph, var_node): | ||
|
|
@@ -930,5 +950,5 @@ def apply(self, graph): | |
| for output_node in op_node.outputs: | ||
| graph.link_to(dequant_node, output_node) | ||
| graph.safe_remove_nodes(op_node) | ||
|
|
||
| graph.resolve_hazard() | ||
| return graph | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove unused comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可能需要,正在测试。如果不需要后面会去掉。