Skip to content

Commit c5776cf

Browse files
committed
Add moving_average_abs_max_scale op
1 parent cc213f7 commit c5776cf

File tree

1 file changed

+152
-23
lines changed

1 file changed

+152
-23
lines changed

python/paddle/fluid/contrib/slim/quantization/quantize_transpiler_v2.py

Lines changed: 152 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
from .... import core
1919
from ....framework import Program, Operator, Variable, program_guard
20+
from ....executor import global_scope
2021
from .... import unique_name
2122
from ....layer_helper import LayerHelper
2223
from ....param_attr import ParamAttr
@@ -34,7 +35,9 @@ def __init__(self,
3435
weight_quantize_type='abs_max',
3536
activation_quantize_type='moving_average_abs_max',
3637
quantizable_op_type=[
37-
'conv2d', 'depthwise_conv2d', 'mul', 'matmul'
38+
'conv2d',
39+
'depthwise_conv2d',
40+
'mul',
3841
],
3942
skip_pattern=['skip_quant']):
4043
"""
@@ -64,6 +67,9 @@ def __init__(self,
6467
self._activation_quantize_type = activation_quantize_type
6568
self._weight_quantize_type = weight_quantize_type
6669

70+
for op_type in quantizable_op_type:
71+
assert op_type in ['conv2d', 'depthwise_conv2d', 'mul'], \
72+
"Quantize op should be ['conv2d', 'depthwise_conv2d', 'mul']"
6773
self._quantizable_ops = quantizable_op_type
6874
self._quantizable_grad_ops = [
6975
'%s_grad' % (op) for op in self._quantizable_ops
@@ -110,32 +116,41 @@ def apply(self, program, startup_program, is_test=False):
110116
(not self._is_skip_quant(op)):
111117
self._transform_backward(block, op, var_rename_map)
112118

113-
def _is_skip_quant(self, op):
119+
def convert(self, test_program, scope=None):
114120
"""
115-
Analyse whether the op should skip quantization or not.
121+
Convert the test program.
122+
Args:
123+
test_program(Program): the test program to be converted.
124+
scope(fluid.Scope, optional): The scope of the program, use it to load
125+
and save variables. If scope=None, get scope by global_scope().
116126
"""
117-
user_skipped = False
118-
if isinstance(self._skip_pattern, list):
119-
user_skipped = op.has_attr("op_namescope") and \
120-
any(pattern in op.attr("op_namescope") \
121-
for pattern in self._skip_pattern)
122-
elif isinstance(self._skip_pattern, str):
123-
user_skipped = op.has_attr("op_namescope") and \
124-
op.attr("op_namescope").find(
125-
self._skip_pattern) != -1
126-
return user_skipped
127+
scope = global_scope() if scope == None else scope
128+
129+
target_ops = []
130+
for block in test_program.blocks:
131+
for op in block.ops:
132+
if op.type == "moving_average_abs_max_scale":
133+
target_ops.append(op)
134+
135+
for op in target_ops:
136+
out_scale_name = op.output("OutScale")
137+
# TODO: save the out threshold in the target ops
138+
#print(out_scale_name)
139+
#print(self._load_variable_data(scope, out_scale_name[0]))
127140

128141
def _transform_forward(self, block, op, var_rename_map, is_test):
129142
"""
130143
Insert fake quant op before the target ops.
131144
"""
132145
op._set_attr("quantization_type", "qat_with_weight")
133-
idx = block.ops.index(op)
134-
block_id = block.idx
135146

147+
# insert fake quant op before the quantized op
136148
for in_name in op.input_arg_names:
149+
block_id = block.idx
150+
idx = block.ops.index(op)
151+
137152
if in_name in var_rename_map[block_id]:
138-
new_var_name = var_rename_map[block_id][in_name]
153+
new_in_name = var_rename_map[block_id][in_name]
139154
else:
140155
in_var = block.var(in_name)
141156
if in_var.dtype != core.VarDesc.VarType.FP32:
@@ -161,13 +176,62 @@ def _transform_forward(self, block, op, var_rename_map, is_test):
161176
quant_type)
162177
continue
163178

164-
var_rename_map[block_id][in_name] = new_var.name
165-
op._rename_input(in_name, new_var.name)
179+
new_in_name = new_var.name
180+
var_rename_map[block_id][in_name] = new_in_name
181+
182+
op._rename_input(in_name, new_in_name)
183+
184+
# insert out scale op followed the quantized op
185+
for out_name in op.output_arg_names:
186+
next_ops = self._find_next_ops(block, out_name)
187+
188+
idx = block.ops.index(op)
189+
out_var = block.var(out_name)
190+
new_out_var = self._insert_ma_abs_max_scale_op(
191+
block, idx + 1, out_var, is_test, True)
192+
193+
for next_op in next_ops:
194+
if "_grad" not in next_op.type:
195+
next_op._rename_input(out_name, new_out_var.name)
196+
197+
def _find_next_ops(self, block, var_name):
198+
"""
199+
Find all followed ops for the input variable.
200+
"""
201+
res_ops = []
202+
for op in block.ops:
203+
if var_name in op.input_arg_names:
204+
res_ops.append(op)
205+
return res_ops
206+
207+
def _load_variable_data(self, scope, var_name):
208+
'''
209+
Load variable value from scope
210+
'''
211+
var_node = scope.find_var(var_name)
212+
assert var_node is not None, \
213+
"Cannot find " + var_name + " in scope."
214+
return np.array(var_node.get_tensor())
215+
216+
def _is_skip_quant(self, op):
217+
"""
218+
Analyse whether the op should skip quantization or not.
219+
"""
220+
user_skipped = False
221+
if isinstance(self._skip_pattern, list):
222+
user_skipped = op.has_attr("op_namescope") and \
223+
any(pattern in op.attr("op_namescope") \
224+
for pattern in self._skip_pattern)
225+
elif isinstance(self._skip_pattern, str):
226+
user_skipped = op.has_attr("op_namescope") and \
227+
op.attr("op_namescope").find(
228+
self._skip_pattern) != -1
229+
return user_skipped
166230

167231
def _transform_backward(self, block, op, var_rename_map):
168232
"""
169233
Update the backword of the target ops.
170-
Note: Skip rename the output of the grad ops.
234+
Note: for the grad ops, only rename the input, skip rename the output.
171235
"""
172236
block_id = block.idx
173237
no_dequanted_input_vars = True
@@ -192,7 +256,7 @@ def _insert_abs_max_fq_op(self, block, idx, in_var, quant_bits):
192256
scale_var = self._helper.create_parameter(
193257
attr=ParamAttr(
194258
name="{}.quant_dequant.scale".format(in_var.name),
195-
initializer=Constant(0.001),
259+
initializer=Constant(0.),
196260
trainable=False),
197261
shape=[1],
198262
dtype=in_var.dtype)
@@ -222,7 +286,7 @@ def _insert_ma_abs_max_fq_op(self, block, idx, in_var, quant_bits, is_test):
222286
scale_var = self._helper.create_parameter(
223287
attr=ParamAttr(
224288
name="{}.quant_dequant.scale".format(in_var.name),
225-
initializer=Constant(0.001),
289+
initializer=Constant(0.),
226290
trainable=False),
227291
shape=[1],
228292
dtype=in_var.dtype)
@@ -232,7 +296,7 @@ def _insert_ma_abs_max_fq_op(self, block, idx, in_var, quant_bits, is_test):
232296
state_var = self._helper.create_parameter(
233297
attr=ParamAttr(
234298
name="{}.quant_dequant.state".format(in_var.name),
235-
initializer=Constant(1),
299+
initializer=Constant(0),
236300
trainable=False),
237301
shape=[1],
238302
dtype=in_var.dtype)
@@ -241,7 +305,7 @@ def _insert_ma_abs_max_fq_op(self, block, idx, in_var, quant_bits, is_test):
241305
accum_var = self._helper.create_parameter(
242306
attr=ParamAttr(
243307
name="{}.quant_dequant.accum".format(in_var.name),
244-
initializer=Constant(1),
308+
initializer=Constant(0),
245309
trainable=False),
246310
shape=[1],
247311
dtype=in_var.dtype)
@@ -297,3 +361,68 @@ def _insert_pc_abs_max_fq_op(self, block, idx, in_var, quant_bits, ch_axis):
297361
inputs=inputs,
298362
outputs=outputs)
299363
return quant_dequant_var
364+
365+
def _insert_ma_abs_max_scale_op(self,
366+
block,
367+
idx,
368+
in_var,
369+
is_test,
370+
has_out_var=False):
371+
"""
372+
Insert moving average abs max scale op.
373+
"""
374+
scale_var = self._helper.create_parameter(
375+
attr=ParamAttr(
376+
name="{}.outscale.scale".format(in_var.name),
377+
initializer=Constant(0.),
378+
trainable=False),
379+
shape=[1],
380+
dtype=in_var.dtype)
381+
scale_var.stop_gradient = True
382+
383+
attrs = {'moving_rate': self._moving_rate, 'is_test': is_test}
384+
inputs = {'X': in_var}
385+
outputs = {'OutScale': scale_var}
386+
387+
if not is_test:
388+
state_var = self._helper.create_parameter(
389+
attr=ParamAttr(
390+
name="{}.outscale.state".format(in_var.name),
391+
initializer=Constant(0),
392+
trainable=False),
393+
shape=[1],
394+
dtype=in_var.dtype)
395+
state_var.stop_gradient = True
396+
397+
accum_var = self._helper.create_parameter(
398+
attr=ParamAttr(
399+
name="{}.outscale.accum".format(in_var.name),
400+
initializer=Constant(0),
401+
trainable=False),
402+
shape=[1],
403+
dtype=in_var.dtype)
404+
accum_var.stop_gradient = True
405+
406+
inputs['InState'] = state_var
407+
inputs['InAccum'] = accum_var
408+
outputs['OutState'] = state_var
409+
outputs['OutAccum'] = accum_var
410+
411+
if has_out_var:
412+
out_var = block.create_var(
413+
type=in_var.type,
414+
name="{}.tmp".format(in_var.name),
415+
shape=in_var.shape,
416+
dtype=in_var.dtype)
417+
418+
outputs['Out'] = out_var
419+
420+
block._insert_op(
421+
idx,
422+
type='moving_average_abs_max_scale',
423+
attrs=attrs,
424+
inputs=inputs,
425+
outputs=outputs)
426+
427+
if has_out_var:
428+
return out_var

0 commit comments

Comments
 (0)