Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def __init__(self,
self._ops_to_quantize = _ops_to_quantize
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
[-1])
self._scale_immutable_ops = [
'transpose2', 'reshape2', 'pool2d', 'scale'
]
self._scale_immutable_ops = ['transpose2', 'reshape2', 'pool2d']
self._scale_ops = ['scale']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._pool_ops = ['pool2d']
self._mul_ops = ['mul']
Expand All @@ -87,8 +86,8 @@ def apply(self, graph):
self._reset_pass_idx_and_group('int8')
graph = self._label_skip_quantized_op(graph)
graph = self._gather_weight_thresholds_from_fake(graph)
graph = self._gather_output_scales_from_attr(graph)
graph = self._gather_input_scales_from_fake(graph)
graph = self._gather_output_scales_from_attr(graph)
graph = self._remove_fake_ops(graph)
graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph)
Expand Down Expand Up @@ -160,12 +159,16 @@ def _label_skip_quantized_op(self, graph):
op_node.op()._set_attr("skip_quant", True)
return graph

def _gather_input_scales_from_fake(self, graph):
def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
scales = self._var_quant_scales
for var_name in var_names:
def _add_scale_for_vars(self, var_names, use_unsigned_int, lod_tensor):
"""
Save quantization scales for variables. Do not overwrite.
"""
scales = self._var_quant_scales
for var_name in var_names:
if var_name not in scales:
scales[var_name] = (use_unsigned_int, lod_tensor)

def _gather_input_scales_from_fake(self, graph):
# fake_quantize_dequantize_abs_max doesn't have scale value
fake_ops = ['fake_quantize_dequantize_moving_average_abs_max']
fake_ops.extend(self._fake_quantize_types)
Expand All @@ -185,8 +188,8 @@ def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
scale[scale == np.Inf] = 0.0
lod_tensor = self._convert_scale2tensor(scale)
use_unsigned_int = False
_add_scale_for_vars([input_name, output_name], use_unsigned_int,
lod_tensor)
self._add_scale_for_vars([input_name, output_name],
use_unsigned_int, lod_tensor)

return graph

Expand Down Expand Up @@ -219,8 +222,8 @@ def _gather_output_scales_from_attr(self, graph):
use_unsigned_int = False
for output_name in op.op().outputs():
for out_var_name in op.op().output(output_name):
self._var_quant_scales[out_var_name] = (
use_unsigned_int, scale_lod_tensor)
self._add_scale_for_vars(
[out_var_name], use_unsigned_int, scale_lod_tensor)

return graph

Expand All @@ -239,24 +242,21 @@ def _update_scales(graph):
output_name = op.output("Out")[0]
tensor_names = [input_name, output_name]

# Scale is not quantized, so if it doesn't have any scales
# to propagate, its tensors won't be added to the waiting list.
if all(name not in self._var_quant_scales for name in tensor_names) \
and op.name() != 'scale':
if all(name not in self._var_quant_scales
for name in tensor_names):
waiting_for_scale.update(tensor_names)
continue

if input_name in self._var_quant_scales:
elif input_name in self._var_quant_scales:
self._var_quant_scales[
output_name] = self._var_quant_scales[input_name]
elif output_name in self._var_quant_scales:
if op.name() == 'scale':
_update_scale_op_in_scale(op, input_name,
output_name)
else:
self._var_quant_scales[
input_name] = self._var_quant_scales[
output_name]
self._var_quant_scales[
input_name] = self._var_quant_scales[output_name]
elif op.name() in self._scale_ops:
input_name = op.input("X")[0]
output_name = op.output("Out")[0]
if output_name in self._var_quant_scales:
_update_scale_op_in_scale(op, input_name, output_name)
return waiting_for_scale

waiting_for_scale = _update_scales(graph)
Expand Down