@@ -31,6 +31,7 @@ QuantDequantFusePass::QuantDequantFusePass() {
3131 .End ()
3232 .AddInput (" Iter" )
3333 .IsTensor ()
34+ .IsOptional ()
3435 .End ()
3536 .AddOutput (" Out" )
3637 .IsTensor ()
@@ -40,6 +41,7 @@ QuantDequantFusePass::QuantDequantFusePass() {
4041 .End ()
4142 .AddOutput (" OutScales" )
4243 .IsTensor ()
44+ .IsOptional ()
4345 .End ()
4446 .AddAttr (" window_size" )
4547 .IsType <int >()
@@ -167,6 +169,26 @@ QuantDequantFusePass::QuantDequantFusePass() {
167169 .AddAttr (" y_num_col_dims" )
168170 .IsNumEQ (1 )
169171 .End ();
172+ AddOpCompat (OpCompat (" matmul" ))
173+ .AddInput (" X" )
174+ .IsTensor ()
175+ .End ()
176+ .AddInput (" Y" )
177+ .IsTensor ()
178+ .End ()
179+ .AddOutput (" Out" )
180+ .IsTensor ()
181+ .End ()
182+ .AddAttr (" alpha" )
183+ .IsNumGE (0 .99f )
184+ .IsNumLE (1 .01f )
185+ .End ()
186+ .AddAttr (" transpose_X" )
187+ .IsBoolEQ (false )
188+ .End ()
189+ .AddAttr (" transpose_Y" )
190+ .IsBoolEQ (false )
191+ .End ();
170192 AddOpCompat (OpCompat (" fc" ))
171193 .AddInput (" Input" )
172194 .IsTensor ()
@@ -291,7 +313,7 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
291313 quantized_op_type == " fc" ||
292314 quantized_op_type == " conv2d_transpose" ) {
293315 op_desc->SetAttr (" Input_scale" , scale_value);
294- } else if (quantized_op_type == " mul" ) {
316+ } else if (quantized_op_type == " mul" || quantized_op_type == " matmul " ) {
295317 op_desc->SetAttr (" X_scale" , scale_value);
296318 } else {
297319 PADDLE_THROW (platform::errors::Unimplemented (
@@ -323,7 +345,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
323345 quantized_op_type == " conv2d_transpose" ) {
324346 weight_name = " Filter" ;
325347 input_name = " Input" ;
326- } else if (quantized_op_type == " mul" ) {
348+ } else if (quantized_op_type == " mul" || quantized_op_type == " matmul " ) {
327349 weight_name = " Y" ;
328350 input_name = " X" ;
329351 } else if (quantized_op_type == " fc" ) {
@@ -332,7 +354,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
332354 } else {
333355 PADDLE_THROW (platform::errors::Unimplemented (
334356 " QuantDequantFuse: We only support conv2d, conv2d_fusion, "
335- " conv2d_transpose, fc, mul for "
357+ " conv2d_transpose, fc, mul, matmul for "
336358 " now." ));
337359 }
338360 const std::string pattern_name = " dequant_fuse" ;
@@ -410,12 +432,13 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
410432 // If quantized op is fc, weight scale size = 1;
411433 // If quantized op is conv2d, weight scale size = weight dims[0]
412434 // If quantized op is conv2d_transpose, weight scale size = weight dims[1]
413- if (quantized_op_type == " mul" || quantized_op_type == " fc" ) {
435+ if (quantized_op_type == " mul" || quantized_op_type == " matmul" ||
436+ quantized_op_type == " fc" ) {
414437 if (dequant_type == " fake_dequantize_max_abs" ) {
415438 PADDLE_ENFORCE_EQ (
416439 weight_scale.size (), 1 ,
417440 platform::errors::InvalidArgument (
418- " mul op weight dequantized by [fake_dequantize_max_abs] "
441+ " mul/matmul op weight dequantized by [fake_dequantize_max_abs] "
419442 " requires weight scale size = 1, but got %d." ,
420443 weight_scale.size ()));
421444 for (int j = 0 ; j < weight_tensor->numel (); j++) {
@@ -426,9 +449,10 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
426449 PADDLE_ENFORCE_EQ (
427450 weight_scale.size (), static_cast <size_t >(w_dims[1 ]),
428451 platform::errors::InvalidArgument (
429- " mul op weight dequantized by "
452+ " mul/matmul op weight dequantized by "
430453 " [fake_channel_wise_dequantize_max_abs] requires weight scale "
431- " size = 2nd dim of mul's weight, which is %d, but got %d." ,
454+ " size = 2nd dim of mul/matmul's weight, which is %d, but got "
455+ " %d." ,
432456 static_cast <size_t >(w_dims[1 ]), weight_scale.size ()));
433457 for (int j = 0 ; j < weight_tensor->numel (); j++) {
434458 quantized_weight_data[j] *= weight_scale[j % w_dims[1 ]];
@@ -493,7 +517,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
493517 } else if (quantized_op_type == " fc" ) {
494518 new_op_desc.SetInput (" Input" , {new_input});
495519 new_op_desc.SetOutput (" Out" , {new_output});
496- } else if (quantized_op_type == " mul" ) {
520+ } else if (quantized_op_type == " mul" || quantized_op_type == " matmul " ) {
497521 new_op_desc.SetInput (" X" , {new_input});
498522 new_op_desc.SetOutput (" Out" , {new_output});
499523 }
@@ -520,7 +544,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
520544 std::unordered_set<std::string> quant_types = {
521545 " fake_quantize_range_abs_max" , " fake_quantize_moving_average_abs_max" };
522546 std::unordered_set<std::string> quantized_op_types = {
523- " conv2d" , " mul" , " depthwise_conv2d" , " fc" , " conv2d_transpose" };
547+ " conv2d" , " mul" , " matmul " , " depthwise_conv2d" , " fc" , " conv2d_transpose" };
524548 auto * scope = param_scope ();
525549
526550 for (auto & quant_type : quant_types) {
0 commit comments