Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ QuantDequantFusePass::QuantDequantFusePass() {
.End()
.AddInput("Iter")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
Expand All @@ -40,6 +41,7 @@ QuantDequantFusePass::QuantDequantFusePass() {
.End()
.AddOutput("OutScales")
.IsTensor()
.IsOptional()
.End()
.AddAttr("window_size")
.IsType<int>()
Expand Down Expand Up @@ -167,6 +169,26 @@ QuantDequantFusePass::QuantDequantFusePass() {
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
Expand Down Expand Up @@ -291,7 +313,7 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
quantized_op_type == "fc" ||
quantized_op_type == "conv2d_transpose") {
op_desc->SetAttr("Input_scale", scale_value);
} else if (quantized_op_type == "mul") {
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
op_desc->SetAttr("X_scale", scale_value);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
Expand Down Expand Up @@ -323,7 +345,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
quantized_op_type == "conv2d_transpose") {
weight_name = "Filter";
input_name = "Input";
} else if (quantized_op_type == "mul") {
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
weight_name = "Y";
input_name = "X";
} else if (quantized_op_type == "fc") {
Expand All @@ -332,7 +354,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, "
"conv2d_transpose, fc, mul for "
"conv2d_transpose, fc, mul, matmul for "
"now."));
}
const std::string pattern_name = "dequant_fuse";
Expand Down Expand Up @@ -410,12 +432,13 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
// If quantized op is fc, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if (quantized_op_type == "mul" || quantized_op_type == "fc") {
if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "fc") {
if (dequant_type == "fake_dequantize_max_abs") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), 1,
platform::errors::InvalidArgument(
"mul op weight dequantized by [fake_dequantize_max_abs] "
"mul/matmul op weight dequantized by [fake_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.",
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
Expand All @@ -426,9 +449,10 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"mul op weight dequantized by "
"mul/matmul op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs] requires weight scale "
"size = 2nd dim of mul's weight, which is %d, but got %d.",
"size = 2nd dim of mul/matmul's weight, which is %d, but got "
"%d.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
quantized_weight_data[j] *= weight_scale[j % w_dims[1]];
Expand Down Expand Up @@ -493,7 +517,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
} else if (quantized_op_type == "fc") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Out", {new_output});
} else if (quantized_op_type == "mul") {
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
new_op_desc.SetInput("X", {new_input});
new_op_desc.SetOutput("Out", {new_output});
}
Expand All @@ -520,7 +544,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = {
"conv2d", "mul", "depthwise_conv2d", "fc", "conv2d_transpose"};
"conv2d", "mul", "matmul", "depthwise_conv2d", "fc", "conv2d_transpose"};
auto* scope = param_scope();

for (auto& quant_type : quant_types) {
Expand Down