@@ -62,9 +62,8 @@ def __init__(self,
6262 self ._ops_to_quantize = _ops_to_quantize
6363 self ._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set (
6464 [- 1 ])
65- self ._scale_immutable_ops = [
66- 'transpose2' , 'reshape2' , 'pool2d' , 'scale'
67- ]
65+ self ._scale_immutable_ops = ['transpose2' , 'reshape2' , 'pool2d' ]
66+ self ._scale_ops = ['scale' ]
6867 self ._conv_ops = ['conv2d' , 'depthwise_conv2d' ]
6968 self ._pool_ops = ['pool2d' ]
7069 self ._mul_ops = ['mul' ]
@@ -87,8 +86,8 @@ def apply(self, graph):
8786 self ._reset_pass_idx_and_group ('int8' )
8887 graph = self ._label_skip_quantized_op (graph )
8988 graph = self ._gather_weight_thresholds_from_fake (graph )
90- graph = self ._gather_output_scales_from_attr (graph )
9189 graph = self ._gather_input_scales_from_fake (graph )
90+ graph = self ._gather_output_scales_from_attr (graph )
9291 graph = self ._remove_fake_ops (graph )
9392 graph = self ._dequantize_weights (graph )
9493 graph = self ._optimize_fp32_graph (graph )
@@ -160,12 +159,16 @@ def _label_skip_quantized_op(self, graph):
160159 op_node .op ()._set_attr ("skip_quant" , True )
161160 return graph
162161
163- def _gather_input_scales_from_fake (self , graph ):
164- def _add_scale_for_vars (var_names , use_unsigned_int , lod_tensor ):
165- scales = self ._var_quant_scales
166- for var_name in var_names :
162+ def _add_scale_for_vars (self , var_names , use_unsigned_int , lod_tensor ):
163+ """
164+ Save quantization scales for variables. Do not overwrite.
165+ """
166+ scales = self ._var_quant_scales
167+ for var_name in var_names :
168+ if var_name not in scales :
167169 scales [var_name ] = (use_unsigned_int , lod_tensor )
168170
171+ def _gather_input_scales_from_fake (self , graph ):
169172 # fake_quantize_dequantize_abs_max doesn't have scale value
170173 fake_ops = ['fake_quantize_dequantize_moving_average_abs_max' ]
171174 fake_ops .extend (self ._fake_quantize_types )
@@ -185,8 +188,8 @@ def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
185188 scale [scale == np .Inf ] = 0.0
186189 lod_tensor = self ._convert_scale2tensor (scale )
187190 use_unsigned_int = False
188- _add_scale_for_vars ([input_name , output_name ], use_unsigned_int ,
189- lod_tensor )
191+ self . _add_scale_for_vars ([input_name , output_name ],
192+ use_unsigned_int , lod_tensor )
190193
191194 return graph
192195
@@ -219,8 +222,8 @@ def _gather_output_scales_from_attr(self, graph):
219222 use_unsigned_int = False
220223 for output_name in op .op ().outputs ():
221224 for out_var_name in op .op ().output (output_name ):
222- self ._var_quant_scales [ out_var_name ] = (
223- use_unsigned_int , scale_lod_tensor )
225+ self ._add_scale_for_vars (
226+ [ out_var_name ], use_unsigned_int , scale_lod_tensor )
224227
225228 return graph
226229
@@ -239,24 +242,21 @@ def _update_scales(graph):
239242 output_name = op .output ("Out" )[0 ]
240243 tensor_names = [input_name , output_name ]
241244
242- # Scale is not quantized, so if it doesn't have any scales
243- # to propagate, its tensors won't be added to the waiting list.
244- if all (name not in self ._var_quant_scales for name in tensor_names ) \
245- and op .name () != 'scale' :
245+ if all (name not in self ._var_quant_scales
246+ for name in tensor_names ):
246247 waiting_for_scale .update (tensor_names )
247248 continue
248-
249- if input_name in self ._var_quant_scales :
249+ elif input_name in self ._var_quant_scales :
250250 self ._var_quant_scales [
251251 output_name ] = self ._var_quant_scales [input_name ]
252252 elif output_name in self ._var_quant_scales :
253- if op . name () == 'scale' :
254- _update_scale_op_in_scale ( op , input_name ,
255- output_name )
256- else :
257- self . _var_quant_scales [
258- input_name ] = self ._var_quant_scales [
259- output_name ]
253+ self . _var_quant_scales [
254+ input_name ] = self . _var_quant_scales [ output_name ]
255+ elif op . name () in self . _scale_ops :
256+ input_name = op . input ( "X" )[ 0 ]
257+ output_name = op . output ( "Out" )[ 0 ]
258+ if output_name in self ._var_quant_scales :
259+ _update_scale_op_in_scale ( op , input_name , output_name )
260260 return waiting_for_scale
261261
262262 waiting_for_scale = _update_scales (graph )
0 commit comments