@@ -92,7 +92,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
9292 int range = ((1 << (bit_length - 1 )) - 1 );
9393 std::vector<float > weight_scale;
9494 std::string quant_dequant_op_out_name = quant_dequant_op_out->Var ()->Name ();
95-
9695 auto * any_op2_desc = any_op2->Op ();
9796 auto var_map = any_op2_desc->Inputs ();
9897 std::string arg_name = " " ;
@@ -106,43 +105,52 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
106105 PADDLE_ENFORCE_GT (arg_name.size (), 0 , platform::errors::InvalidArgument (
107106 " can not find the input %s." ,
108107 quant_dequant_op_out_name));
109- any_op2_desc->SetAttr (" enable_int8" , true );
108+ // any_op2_desc->SetAttr("enable_int8", true);
110109 any_op2_desc->SetAttr (" bit_length" , bit_length);
110+
111111 // modify the any_op2's inputs
112- any_op2_desc->Flush ();
113112 auto dequant_type = quant_dequant_op->Op ()->Type ();
114- auto quantized_op_type = any_op2_desc-> Type ();
113+
115114 // get weight tensor
116115 auto * weight_tensor =
117116 scope->GetVar (quant_dequant_op_x->Name ())->GetMutable <LoDTensor>();
118117 auto w_dims = weight_tensor->dims ();
118+
119119 float * quantized_weight_data =
120120 weight_tensor->mutable_data <float >(platform::CPUPlace ());
121121
122122 // Get weight scale
123123 if (dequant_type == " fake_channel_wise_quantize_dequantize_abs_max" ) {
124- auto scales_name = quant_dequant_op->Op ()->Output (" OutScale" );
124+ int quant_axis =
125+ BOOST_GET_CONST (int , quant_dequant_op->Op ()->GetAttr (" quant_axis" ));
126+ PADDLE_ENFORCE_EQ (quant_axis == 0 || quant_axis == 1 , true ,
127+ platform::errors::InvalidArgument (
128+ " 'quant_axis' should be 0 or 1, but "
129+ " the received is %d" ,
130+ quant_axis));
131+
132+ // To Do @Wangzheee: use "OutScale" to quantdequant
133+ /* auto scales_name = quant_dequant_op->Op()->Output("OutScale");
125134 PADDLE_ENFORCE_EQ(scales_name.size(), 1,
126135 platform::errors::InvalidArgument(
127136 "Scales size in channel-wise quant dequantize op "
128137 "should be 1, got %d.",
129138 scales_name.size()));
130139 const LoDTensor& channel_scale_tensor =
131- scope->GetVar (scales_name[0 ])->Get <LoDTensor>();
140+ scope->FindVar (scales_name[0])->Get<LoDTensor>();
132141 PADDLE_ENFORCE(
133142 paddle::platform::is_cpu_place(channel_scale_tensor.place()),
134143 platform::errors::InvalidArgument(
135144 "Channel scale tensor's place should be CPU."));
136145 // compute the channel wise abs max of the weight tensor
137- int quant_axis =
138- BOOST_GET_CONST (int , quant_dequant_op->Op ()->GetAttr (" quant_axis" ));
139146
140- PADDLE_ENFORCE_EQ (quant_axis == 0 || quant_axis == 1 , true ,
141- platform::errors::InvalidArgument (
142- " 'quant_axis' should be 0 or 1, but "
143- " the received is %d" ,
144- quant_axis));
147+ const float* channel_scale_data = channel_scale_tensor.data<float>();
148+ for (int i = 0; i < channel_scale_tensor.numel(); i++) {
149+ weight_scale.push_back(channel_scale_data[i] );
150+ }*/
145151
152+ // Implement channel_wise_quantize_dequantize_abs_max quantization
153+ // algorithm
146154 const int64_t channel = w_dims[quant_axis];
147155 weight_scale.resize (channel, 0 );
148156 if (quant_axis == 0 ) {
@@ -171,11 +179,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
171179 PADDLE_ENFORCE_NE (weight_scale[i], 0 ,
172180 platform::errors::InvalidArgument (
173181 " Weight scale should be nonzero, but get zero." ));
174- weight_scale[i] = range / weight_scale[i];
182+ weight_scale[i] = weight_scale[i] / range ;
175183 }
176184 } else {
177- auto scale_name = quant_dequant_op_outscale->Name ();
178- // compute the abs max of the weight tensor
185+ // Implement quantize_dequantize_abs_max quantization algorithm
179186 float abs_max_weight = 0 .;
180187 for (int j = 0 ; j < weight_tensor->numel (); j++) {
181188 abs_max_weight =
@@ -184,113 +191,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
184191 PADDLE_ENFORCE_NE (abs_max_weight, 0 ,
185192 platform::errors::InvalidArgument (
186193 " Weight scale should be nonzero, but get zero" ));
187- weight_scale.push_back ((range * range) / abs_max_weight / range);
194+ weight_scale.push_back (abs_max_weight / range);
188195 }
189196
190197 nodes2rm.insert (quant_dequant_op_outscale);
191-
192- // perform quantize dequantize operations
193- // If quantized op is not channel wise, weight scale size = 1;
194- // If quantized op is conv2d, weight scale size = weight dims[0]
195- // If quantized op is conv2d_transpose, weight scale size = weight dims[1]
196- if (dequant_type == " fake_quantize_dequantize_abs_max" ) {
197- PADDLE_ENFORCE_EQ (
198- weight_scale.size (), 1 ,
199- platform::errors::InvalidArgument (
200- " %s op weight dequantized by [fake_quantize_dequantize_max_abs] "
201- " requires weight scale size = 1, but got %d." ,
202- quantized_op_type, weight_scale.size ()));
203- for (int j = 0 ; j < weight_tensor->numel (); j++) {
204- // quantized
205- quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0 ];
206- quantized_weight_data[j] = std::round (quantized_weight_data[j]);
207- // dequantized
208- quantized_weight_data[j] /= weight_scale[0 ];
209- }
210- } else if (quantized_op_type == " mul" || quantized_op_type == " matmul" ||
211- quantized_op_type == " fc" ) {
212- if (dequant_type == " fake_channel_wise_quantize_dequantize_abs_max" ) {
213- PADDLE_ENFORCE_EQ (
214- weight_scale.size (), static_cast <size_t >(w_dims[1 ]),
215- platform::errors::InvalidArgument (
216- " mul op weight dequantized by "
217- " [fake_channel_wise_quantize_dequantize_abs_max] requires "
218- " weight scale "
219- " size = 2nd dim of mul's weight, which is %zu, but got %zu." ,
220- static_cast <size_t >(w_dims[1 ]), weight_scale.size ()));
221- for (int j = 0 ; j < weight_tensor->numel (); j++) {
222- // quantized
223- PADDLE_ENFORCE_NE (
224- weight_scale[j % w_dims[1 ]], 0 ,
225- platform::errors::InvalidArgument (
226- " fc op weight scale should be nonzero, but get zero" ));
227- quantized_weight_data[j] =
228- quantized_weight_data[j] * weight_scale[j % w_dims[1 ]];
229- quantized_weight_data[j] = std::round (quantized_weight_data[j]);
230- // dequantized
231- quantized_weight_data[j] /= weight_scale[j % w_dims[1 ]];
232- }
233- } else {
234- PADDLE_THROW (platform::errors::InvalidArgument (
235- " Unsupported quantized op type: %s" , quantized_op_type));
236- }
237- } else if (quantized_op_type == " conv2d" ||
238- quantized_op_type == " depthwise_conv2d" ) {
239- if (dequant_type == " fake_channel_wise_quantize_dequantize_abs_max" ) {
240- PADDLE_ENFORCE_EQ (
241- weight_scale.size (), static_cast <size_t >(w_dims[0 ]),
242- platform::errors::InvalidArgument (
243- " conv2d op requires weight scale size = channel size of the "
244- " weight, which is %zu, but got %zu." ,
245- static_cast <size_t >(w_dims[0 ]), weight_scale.size ()));
246- int inner_size = w_dims[1 ] * w_dims[2 ] * w_dims[3 ];
247- for (int j = 0 ; j < weight_tensor->numel (); j++) {
248- // quantized
249- PADDLE_ENFORCE_NE (
250- weight_scale[j / inner_size], 0 ,
251- platform::errors::InvalidArgument (
252- " conv2d op weight scale should be nonzero, but get zero" ));
253- quantized_weight_data[j] =
254- quantized_weight_data[j] * weight_scale[j / inner_size];
255- quantized_weight_data[j] = std::round (quantized_weight_data[j]);
256- // dequantized
257- quantized_weight_data[j] /= weight_scale[j / inner_size];
258- }
259- } else {
260- PADDLE_THROW (platform::errors::InvalidArgument (
261- " Unsupported quantized op type: %s" , quantized_op_type));
262- }
263- } else if (quantized_op_type == " conv2d_transpose" ) {
264- if (dequant_type == " fake_channel_wise_quantize_dequantize_abs_max" ) {
265- PADDLE_ENFORCE_EQ (
266- weight_scale.size (), static_cast <size_t >(w_dims[0 ]),
267- platform::errors::InvalidArgument (
268- " conv2d_transpose op requires weight scale size = channel size "
269- " of the "
270- " weight, which is %zu, but got %zu." ,
271- static_cast <size_t >(w_dims[1 ]), weight_scale.size ()));
272- int inner_size = w_dims[2 ] * w_dims[3 ];
273- for (int j = 0 ; j < weight_tensor->numel (); j++) {
274- // quantized
275- PADDLE_ENFORCE_NE (weight_scale[(j / inner_size) % w_dims[1 ]], 0 ,
276- platform::errors::InvalidArgument (
277- " conv2d_transpose op weight scale should be "
278- " nonzero, but get zero" ));
279- quantized_weight_data[j] = quantized_weight_data[j] *
280- weight_scale[(j / inner_size) % w_dims[1 ]];
281- quantized_weight_data[j] = std::round (quantized_weight_data[j]);
282- // dequantized
283- quantized_weight_data[j] /=
284- weight_scale[(j / inner_size) % w_dims[1 ]];
285- }
286- } else {
287- PADDLE_THROW (platform::errors::InvalidArgument (
288- " Unsupported quantized op type: %s" , quantized_op_type));
289- }
290- } else {
291- PADDLE_THROW (platform::errors::InvalidArgument (
292- " Unsupported quantized op type: %s" , quantized_op_type));
293- }
294198 nodes2rm.insert (quant_dequant_op_out);
295199
296200 // link weight in quant_dequant_op_x to any_op2
0 commit comments