Skip to content

Commit 403aa3d

Browse files
committed
Convert the test program
1 parent c5776cf commit 403aa3d

File tree

2 files changed

+50
-33
lines changed

2 files changed

+50
-33
lines changed

python/paddle/fluid/contrib/slim/quantization/quantize_transpiler_v2.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,27 @@
2828
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
2929

3030

31+
def find_next_ops(block, var_name):
32+
"""
33+
Find all followed ops for the input variable.
34+
"""
35+
res_ops = []
36+
for op in block.ops:
37+
if var_name in op.input_arg_names:
38+
res_ops.append(op)
39+
return res_ops
40+
41+
42+
def load_variable_data(scope, var_name):
43+
'''
44+
Load variable value from scope
45+
'''
46+
var_node = scope.find_var(var_name)
47+
assert var_node is not None, \
48+
"Cannot find " + var_name + " in scope."
49+
return np.array(var_node.get_tensor())
50+
51+
3152
class QuantizeTranspilerV2(object):
3253
def __init__(self,
3354
weight_bits=8,
@@ -118,25 +139,36 @@ def apply(self, program, startup_program, is_test=False):
118139

119140
def convert(self, test_program, scope=None):
120141
"""
121-
Convert the test program.
142+
Convert the test program.
143+
Get the out scale from the moving_average_abs_max_scale op and save the
144+
out scale into the quantized op.
122145
Args:
123146
test_program(Program): the test program to be converted.
124147
scope(fluid.Scope, optional): The scope of the program, use it to load
125148
and save variables. If scope=None, get scope by global_scope().
126149
"""
127150
scope = global_scope() if scope == None else scope
128151

129-
target_ops = []
130152
for block in test_program.blocks:
131153
for op in block.ops:
132-
if op.type == "moving_average_abs_max_scale":
133-
target_ops.append(op)
154+
if op.has_attr("quantization_type") \
155+
and op.attr("quantization_type") == "qat_with_weight":
156+
# quant op -> var1 -> fake op -> var2
157+
assert len(op.output_arg_names) == 1
158+
var1_name = op.output_arg_names[0]
159+
160+
fake_ops = find_next_ops(block, var1_name)
161+
assert len(fake_ops) == 1
162+
fake_op = fake_ops[0]
163+
assert fake_op.type == "moving_average_abs_max_scale"
164+
165+
out_scale_name = fake_op.output("OutScale")
166+
out_threshold = load_variable_data(scope, out_scale_name[0])
167+
op._set_attr("out_threshold", float(out_threshold))
134168

135-
for op in target_ops:
136-
out_scale_name = op.output("OutScale")
137-
# TODO: save the out threshold in the target ops
138-
#print(out_scale_name)
139-
#print(self._load_variable_data(scope, out_scale_name[0]))
169+
var2_name = fake_op.output("Out")[0]
170+
op._rename_output(var1_name, var2_name)
171+
fake_op._rename_output(var2_name, var1_name)
140172

141173
def _transform_forward(self, block, op, var_rename_map, is_test):
142174
"""
@@ -183,7 +215,7 @@ def _transform_forward(self, block, op, var_rename_map, is_test):
183215

184216
# insert out scale op followed the quantized op
185217
for out_name in op.output_arg_names:
186-
next_ops = self._find_next_ops(block, out_name)
218+
next_ops = find_next_ops(block, out_name)
187219

188220
idx = block.ops.index(op)
189221
out_var = block.var(out_name)
@@ -194,25 +226,6 @@ def _transform_forward(self, block, op, var_rename_map, is_test):
194226
if "_grad" not in next_op.type:
195227
next_op._rename_input(out_name, new_out_var.name)
196228

197-
def _find_next_ops(self, block, var_name):
198-
"""
199-
Find all followed ops for the input variable.
200-
"""
201-
res_ops = []
202-
for op in block.ops:
203-
if var_name in op.input_arg_names:
204-
res_ops.append(op)
205-
return res_ops
206-
207-
def _load_variable_data(self, scope, var_name):
208-
'''
209-
Load variable value from scope
210-
'''
211-
var_node = scope.find_var(var_name)
212-
assert var_node is not None, \
213-
"Cannot find " + var_name + " in scope."
214-
return np.array(var_node.get_tensor())
215-
216229
def _is_skip_quant(self, op):
217230
"""
218231
Analyse whether the op should skip quantization or not.

python/paddle/fluid/contrib/slim/tests/test_quantize_transpiler_v2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def build_program(main, startup, is_test):
7979
random.seed(0)
8080
np.random.seed(0)
8181

82+
# 1 Define program
8283
train_program = fluid.Program()
8384
startup_program = fluid.Program()
8485
test_program = fluid.Program()
@@ -93,13 +94,14 @@ def build_program(main, startup, is_test):
9394
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
9495
test_graph.draw('.', 'test_program_1')
9596

97+
# 2 Apply quantization
9698
qt = QuantizeTranspilerV2(
9799
activation_quantize_type=activation_quant_type,
98-
weight_quantize_type=weight_quant_type,
99-
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'])
100-
qt.apply(train_program, startup_program, False)
101-
qt.apply(test_program, startup_program, True)
100+
weight_quantize_type=weight_quant_type)
101+
qt.apply(train_program, startup_program, is_test=False)
102+
qt.apply(test_program, startup_program, is_test=True)
102103

104+
# 3 Train
103105
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
104106
exe = fluid.Executor(place)
105107
scope = fluid.Scope()
@@ -135,6 +137,8 @@ def build_program(main, startup, is_test):
135137

136138
print('{}: {}'.format('loss', np.mean(loss_v)))
137139

140+
# 4 Convert
141+
qt.convert(test_program, scope)
138142
if not for_ci:
139143
with fluid.scope_guard(scope):
140144
fluid.io.save_inference_model('./infer_model',

0 commit comments

Comments
 (0)