Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,24 @@ PDNode *patterns::ConvRequant::operator()() {
return requant_out;
}

PDNode *patterns::ConvDequant::operator()() {
// Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");

auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");

conv_op->LinksTo({conv_out});
dequant_op->LinksFrom({conv_out}).LinksTo({dequant_out});

return dequant_out;
}

PDNode *patterns::PriorBox::operator()() {
auto prior_box_op =
pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box");
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,23 @@ struct ConvRequant : public PatternBase {
PATTERN_DECL_NODE(requant_out);
};

// Conv + Dequant
// named nodes:
// conv_op, conv_out
// dequant_op, dequant_out
struct ConvDequant : public PatternBase {
ConvDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_dequant") {}

PDNode* operator()();

PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);

PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
};

// PriorBox operator
// operator: prior_box_op
// inputs: prior_box_input, prior_box_image
Expand Down
33 changes: 33 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,38 @@ void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const {
found_requant_squash_count);
}

void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(),
"conv_dequant"};
conv_dequant_pattern();

int found_conv_dequant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash conv-dequant ops pair";

GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, conv_dequant_pattern);

// if conv2d has one output
if (conv_out->outputs.size() == 1) {
conv_op->Op()->SetAttr("force_fp32_output", true);
conv_op->Op()->SetOutput("Output",
std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(conv_op, dequant_out);
GraphSafeRemoveNodes(graph, {conv_out, dequant_op});
found_conv_dequant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_conv_dequant_squash_count);
PrettyLogDetail("--- squashed %d dequant with convs",
found_conv_dequant_squash_count);
}

void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("cpu_quantize_squash_pass", graph);
Expand All @@ -168,6 +200,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
FindNodesToKeep(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter);
ConvRequantSquash(graph);
ConvDequantSquash(graph);
}

} // namespace ir
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/
void ConvRequantSquash(Graph* graph) const;

/*
* Squash conv2d with dequant when dequant is the only op after conv2d
*/
void ConvDequantSquash(Graph* graph) const;

const std::string name_scope_{"squash"};
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,36 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
return prog;
}

// a->Conv1->b
// b->Dequant1(Scale1)->c
// c->Concat
ProgramDesc BuildConvDequantConcatProgramDesc(bool use_mkldnn, float scale_out,
float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn);
return prog;
}

// a->Conv1->b
// b->Dequant1(Scale1)->c
// b->Conv2->d
ProgramDesc BuildConvDequantConvProgramDesc(bool use_mkldnn, float scale_out,
float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
SetOp(&prog, "conv2d", "Conv2", {"b"}, {"d"}, use_mkldnn);
return prog;
}

void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
const char* var_name) {
auto x = scope->Var(var_name);
Expand Down Expand Up @@ -217,6 +247,7 @@ void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name,
void CheckRequantScalesTest(const ProgramDesc& prog, float scale_in,
float scale_out) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));

PrepareGraph(&graph, prog);
RegisterPass(&graph);

Expand All @@ -238,6 +269,7 @@ TEST(CpuQuantizeSquashPass, equal_scales) {
auto use_mkldnn = true;
// Remove 4 nodes: Dequant, Quant, e, f
auto remove_nodes = 4;

CountNodeTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale, scale),
remove_nodes);
Expand All @@ -253,6 +285,7 @@ TEST(CpuQuantizeSquashPass, unequal_scales) {
auto use_mkldnn = true;
// Remove 4 nodes: Dequant, Quant, e, d
auto remove_nodes = 4;

CountNodeTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
remove_nodes);
Expand Down Expand Up @@ -280,6 +313,7 @@ TEST(CpuQuantizeSquashPass, branch_to_equal_unequal_and_fp32) {
// Remove 3 nodes: Quant1, c, Quant2,
// Insert 1 node: Requant
auto remove_nodes = 2;

CountNodeTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out, scale,
scale, scale2),
remove_nodes);
Expand Down Expand Up @@ -322,6 +356,7 @@ TEST(CpuQuantizeSquashPass,
// Remove 3 nodes: Dequant1, c, Quant
// Insert 1 node: Requant
auto remove_nodes = 2;

CountNodeTest(
BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2),
remove_nodes);
Expand All @@ -345,6 +380,27 @@ TEST(CpuQuantizeSquashPass, more_than_one_conv_out_outputs) {
remove_nodes);
}

// a->Conv1->c->Concat
TEST(CpuQuantizeSquashPass, conv_dequant_only_one_output) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto use_mkldnn = true;
// remove 2 nodes: Dequant1, c
auto remove_nodes = 2;
CountNodeTest(BuildConvDequantConcatProgramDesc(use_mkldnn, scale_out, scale),
remove_nodes);
}

TEST(CpuQuantizeSquashPass, conv_dequant_more_than_one_op_after_conv) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto use_mkldnn = true;
// nothing change
auto remove_nodes = 0;
CountNodeTest(BuildConvDequantConvProgramDesc(use_mkldnn, scale_out, scale),
remove_nodes);
}

} // namespace ir
} // namespace framework
} // namespace paddle
Expand Down