1717import numpy as np
1818from .... import core
1919from ....framework import Program , Operator , Variable , program_guard
20+ from ....executor import global_scope
2021from .... import unique_name
2122from ....layer_helper import LayerHelper
2223from ....param_attr import ParamAttr
@@ -34,7 +35,9 @@ def __init__(self,
3435 weight_quantize_type = 'abs_max' ,
3536 activation_quantize_type = 'moving_average_abs_max' ,
3637 quantizable_op_type = [
37- 'conv2d' , 'depthwise_conv2d' , 'mul' , 'matmul'
38+ 'conv2d' ,
39+ 'depthwise_conv2d' ,
40+ 'mul' ,
3841 ],
3942 skip_pattern = ['skip_quant' ]):
4043 """
@@ -64,6 +67,9 @@ def __init__(self,
6467 self ._activation_quantize_type = activation_quantize_type
6568 self ._weight_quantize_type = weight_quantize_type
6669
70+ for op_type in quantizable_op_type :
71+ assert op_type in ['conv2d' , 'depthwise_conv2d' , 'mul' ], \
72+ "Quantize op should be ['conv2d', 'depthwise_conv2d', 'mul']"
6773 self ._quantizable_ops = quantizable_op_type
6874 self ._quantizable_grad_ops = [
6975 '%s_grad' % (op ) for op in self ._quantizable_ops
@@ -110,32 +116,41 @@ def apply(self, program, startup_program, is_test=False):
110116 (not self ._is_skip_quant (op )):
111117 self ._transform_backward (block , op , var_rename_map )
112118
113- def _is_skip_quant (self , op ):
119+ def convert (self , test_program , scope = None ):
114120 """
115- Analyse whether the op should skip quantization or not.
121+ Convert the test program.
122+ Args:
123+ test_program(Program): the test program to be converted.
124+ scope(fluid.Scope, optional): The scope of the program, use it to load
125+ and save variables. If scope=None, get scope by global_scope().
116126 """
117- user_skipped = False
118- if isinstance (self ._skip_pattern , list ):
119- user_skipped = op .has_attr ("op_namescope" ) and \
120- any (pattern in op .attr ("op_namescope" ) \
121- for pattern in self ._skip_pattern )
122- elif isinstance (self ._skip_pattern , str ):
123- user_skipped = op .has_attr ("op_namescope" ) and \
124- op .attr ("op_namescope" ).find (
125- self ._skip_pattern ) != - 1
126- return user_skipped
127+ scope = global_scope () if scope == None else scope
128+
129+ target_ops = []
130+ for block in test_program .blocks :
131+ for op in block .ops :
132+ if op .type == "moving_average_abs_max_scale" :
133+ target_ops .append (op )
134+
135+ for op in target_ops :
136+ out_scale_name = op .output ("OutScale" )
137+ # TODO: save the out threshold in the target ops
138+ #print(out_scale_name)
139+ #print(self._load_variable_data(scope, out_scale_name[0]))
127140
128141 def _transform_forward (self , block , op , var_rename_map , is_test ):
129142 """
130143 Insert fake quant op before the target ops.
131144 """
132145 op ._set_attr ("quantization_type" , "qat_with_weight" )
133- idx = block .ops .index (op )
134- block_id = block .idx
135146
147+ # insert fake quant op before the quantized op
136148 for in_name in op .input_arg_names :
149+ block_id = block .idx
150+ idx = block .ops .index (op )
151+
137152 if in_name in var_rename_map [block_id ]:
138- new_var_name = var_rename_map [block_id ][in_name ]
153+ new_in_name = var_rename_map [block_id ][in_name ]
139154 else :
140155 in_var = block .var (in_name )
141156 if in_var .dtype != core .VarDesc .VarType .FP32 :
@@ -161,13 +176,62 @@ def _transform_forward(self, block, op, var_rename_map, is_test):
161176 quant_type )
162177 continue
163178
164- var_rename_map [block_id ][in_name ] = new_var .name
165- op ._rename_input (in_name , new_var .name )
179+ new_in_name = new_var .name
180+ var_rename_map [block_id ][in_name ] = new_in_name
181+
182+ op ._rename_input (in_name , new_in_name )
183+
184+ # insert out scale op followed the quantized op
185+ for out_name in op .output_arg_names :
186+ next_ops = self ._find_next_ops (block , out_name )
187+
188+ idx = block .ops .index (op )
189+ out_var = block .var (out_name )
190+ new_out_var = self ._insert_ma_abs_max_scale_op (
191+ block , idx + 1 , out_var , is_test , True )
192+
193+ for next_op in next_ops :
194+ if "_grad" not in next_op .type :
195+ next_op ._rename_input (out_name , new_out_var .name )
196+
197+ def _find_next_ops (self , block , var_name ):
198+ """
199+ Find all followed ops for the input variable.
200+ """
201+ res_ops = []
202+ for op in block .ops :
203+ if var_name in op .input_arg_names :
204+ res_ops .append (op )
205+ return res_ops
206+
207+ def _load_variable_data (self , scope , var_name ):
208+ '''
209+ Load variable value from scope
210+ '''
211+ var_node = scope .find_var (var_name )
212+ assert var_node is not None , \
213+ "Cannot find " + var_name + " in scope."
214+ return np .array (var_node .get_tensor ())
215+
216+ def _is_skip_quant (self , op ):
217+ """
218+ Analyse whether the op should skip quantization or not.
219+ """
220+ user_skipped = False
221+ if isinstance (self ._skip_pattern , list ):
222+ user_skipped = op .has_attr ("op_namescope" ) and \
223+ any (pattern in op .attr ("op_namescope" ) \
224+ for pattern in self ._skip_pattern )
225+ elif isinstance (self ._skip_pattern , str ):
226+ user_skipped = op .has_attr ("op_namescope" ) and \
227+ op .attr ("op_namescope" ).find (
228+ self ._skip_pattern ) != - 1
229+ return user_skipped
166230
167231 def _transform_backward (self , block , op , var_rename_map ):
168232 """
169233 Update the backword of the target ops.
170- Note: Skip rename the output of the grad ops .
234+ Note: for the grad ops, only rename the input, skip rename the output .
171235 """
172236 block_id = block .idx
173237 no_dequanted_input_vars = True
@@ -192,7 +256,7 @@ def _insert_abs_max_fq_op(self, block, idx, in_var, quant_bits):
192256 scale_var = self ._helper .create_parameter (
193257 attr = ParamAttr (
194258 name = "{}.quant_dequant.scale" .format (in_var .name ),
195- initializer = Constant (0.001 ),
259+ initializer = Constant (0. ),
196260 trainable = False ),
197261 shape = [1 ],
198262 dtype = in_var .dtype )
@@ -222,7 +286,7 @@ def _insert_ma_abs_max_fq_op(self, block, idx, in_var, quant_bits, is_test):
222286 scale_var = self ._helper .create_parameter (
223287 attr = ParamAttr (
224288 name = "{}.quant_dequant.scale" .format (in_var .name ),
225- initializer = Constant (0.001 ),
289+ initializer = Constant (0. ),
226290 trainable = False ),
227291 shape = [1 ],
228292 dtype = in_var .dtype )
@@ -232,7 +296,7 @@ def _insert_ma_abs_max_fq_op(self, block, idx, in_var, quant_bits, is_test):
232296 state_var = self ._helper .create_parameter (
233297 attr = ParamAttr (
234298 name = "{}.quant_dequant.state" .format (in_var .name ),
235- initializer = Constant (1 ),
299+ initializer = Constant (0 ),
236300 trainable = False ),
237301 shape = [1 ],
238302 dtype = in_var .dtype )
@@ -241,7 +305,7 @@ def _insert_ma_abs_max_fq_op(self, block, idx, in_var, quant_bits, is_test):
241305 accum_var = self ._helper .create_parameter (
242306 attr = ParamAttr (
243307 name = "{}.quant_dequant.accum" .format (in_var .name ),
244- initializer = Constant (1 ),
308+ initializer = Constant (0 ),
245309 trainable = False ),
246310 shape = [1 ],
247311 dtype = in_var .dtype )
@@ -297,3 +361,68 @@ def _insert_pc_abs_max_fq_op(self, block, idx, in_var, quant_bits, ch_axis):
297361 inputs = inputs ,
298362 outputs = outputs )
299363 return quant_dequant_var
364+
365+ def _insert_ma_abs_max_scale_op (self ,
366+ block ,
367+ idx ,
368+ in_var ,
369+ is_test ,
370+ has_out_var = False ):
371+ """
372+ Insert moving average abs max scale op.
373+ """
374+ scale_var = self ._helper .create_parameter (
375+ attr = ParamAttr (
376+ name = "{}.outscale.scale" .format (in_var .name ),
377+ initializer = Constant (0. ),
378+ trainable = False ),
379+ shape = [1 ],
380+ dtype = in_var .dtype )
381+ scale_var .stop_gradient = True
382+
383+ attrs = {'moving_rate' : self ._moving_rate , 'is_test' : is_test }
384+ inputs = {'X' : in_var }
385+ outputs = {'OutScale' : scale_var }
386+
387+ if not is_test :
388+ state_var = self ._helper .create_parameter (
389+ attr = ParamAttr (
390+ name = "{}.outscale.state" .format (in_var .name ),
391+ initializer = Constant (0 ),
392+ trainable = False ),
393+ shape = [1 ],
394+ dtype = in_var .dtype )
395+ state_var .stop_gradient = True
396+
397+ accum_var = self ._helper .create_parameter (
398+ attr = ParamAttr (
399+ name = "{}.outscale.accum" .format (in_var .name ),
400+ initializer = Constant (0 ),
401+ trainable = False ),
402+ shape = [1 ],
403+ dtype = in_var .dtype )
404+ accum_var .stop_gradient = True
405+
406+ inputs ['InState' ] = state_var
407+ inputs ['InAccum' ] = accum_var
408+ outputs ['OutState' ] = state_var
409+ outputs ['OutAccum' ] = accum_var
410+
411+ if has_out_var :
412+ out_var = block .create_var (
413+ type = in_var .type ,
414+ name = "{}.tmp" .format (in_var .name ),
415+ shape = in_var .shape ,
416+ dtype = in_var .dtype )
417+
418+ outputs ['Out' ] = out_var
419+
420+ block ._insert_op (
421+ idx ,
422+ type = 'moving_average_abs_max_scale' ,
423+ attrs = attrs ,
424+ inputs = inputs ,
425+ outputs = outputs )
426+
427+ if has_out_var :
428+ return out_var
0 commit comments