2828 __name__ , logging .INFO , fmt = '%(asctime)s-%(levelname)s: %(message)s' )
2929
3030
31+ def find_next_ops (block , var_name ):
32+ """
33+ Find all followed ops for the input variable.
34+ """
35+ res_ops = []
36+ for op in block .ops :
37+ if var_name in op .input_arg_names :
38+ res_ops .append (op )
39+ return res_ops
40+
41+
42+ def load_variable_data (scope , var_name ):
43+ '''
44+ Load variable value from scope
45+ '''
46+ var_node = scope .find_var (var_name )
47+ assert var_node is not None , \
48+ "Cannot find " + var_name + " in scope."
49+ return np .array (var_node .get_tensor ())
50+
51+
3152class QuantizeTranspilerV2 (object ):
3253 def __init__ (self ,
3354 weight_bits = 8 ,
@@ -118,25 +139,36 @@ def apply(self, program, startup_program, is_test=False):
118139
119140 def convert (self , test_program , scope = None ):
120141 """
121- Convert the test program.
142+ Convert the test program.
143+ Get the out scale from the moving_average_abs_max_scale op and save the
144+ out scale into the quantized op.
122145 Args:
123146 test_program(Program): the test program to be converted.
124147 scope(fluid.Scope, optional): The scope of the program, use it to load
125148 and save variables. If scope=None, get scope by global_scope().
126149 """
127150 scope = global_scope () if scope == None else scope
128151
129- target_ops = []
130152 for block in test_program .blocks :
131153 for op in block .ops :
132- if op .type == "moving_average_abs_max_scale" :
133- target_ops .append (op )
154+ if op .has_attr ("quantization_type" ) \
155+ and op .attr ("quantization_type" ) == "qat_with_weight" :
156+ # quant op -> var1 -> fake op -> var2
157+ assert len (op .output_arg_names ) == 1
158+ var1_name = op .output_arg_names [0 ]
159+
160+ fake_ops = find_next_ops (block , var1_name )
161+ assert len (fake_ops ) == 1
162+ fake_op = fake_ops [0 ]
163+ assert fake_op .type == "moving_average_abs_max_scale"
164+
165+ out_scale_name = fake_op .output ("OutScale" )
166+ out_threshold = load_variable_data (scope , out_scale_name [0 ])
167+ op ._set_attr ("out_threshold" , float (out_threshold ))
134168
135- for op in target_ops :
136- out_scale_name = op .output ("OutScale" )
137- # TODO: save the out threshold in the target ops
138- #print(out_scale_name)
139- #print(self._load_variable_data(scope, out_scale_name[0]))
169+ var2_name = fake_op .output ("Out" )[0 ]
170+ op ._rename_output (var1_name , var2_name )
171+ fake_op ._rename_output (var2_name , var1_name )
140172
141173 def _transform_forward (self , block , op , var_rename_map , is_test ):
142174 """
@@ -183,7 +215,7 @@ def _transform_forward(self, block, op, var_rename_map, is_test):
183215
184216 # insert out scale op followed the quantized op
185217 for out_name in op .output_arg_names :
186- next_ops = self . _find_next_ops (block , out_name )
218+ next_ops = find_next_ops (block , out_name )
187219
188220 idx = block .ops .index (op )
189221 out_var = block .var (out_name )
@@ -194,25 +226,6 @@ def _transform_forward(self, block, op, var_rename_map, is_test):
194226 if "_grad" not in next_op .type :
195227 next_op ._rename_input (out_name , new_out_var .name )
196228
197- def _find_next_ops (self , block , var_name ):
198- """
199- Find all followed ops for the input variable.
200- """
201- res_ops = []
202- for op in block .ops :
203- if var_name in op .input_arg_names :
204- res_ops .append (op )
205- return res_ops
206-
207- def _load_variable_data (self , scope , var_name ):
208- '''
209- Load variable value from scope
210- '''
211- var_node = scope .find_var (var_name )
212- assert var_node is not None , \
213- "Cannot find " + var_name + " in scope."
214- return np .array (var_node .get_tensor ())
215-
216229 def _is_skip_quant (self , op ):
217230 """
218231 Analyse whether the op should skip quantization or not.
0 commit comments