Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,24 @@ PDNode *patterns::ConvRequant::operator()() {
return requant_out;
}

PDNode *patterns::ConvDequant::operator()() {
// Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");

auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");

conv_op->LinksTo({conv_out});
dequant_op->LinksFrom({conv_out}).LinksTo({dequant_out});

return dequant_out;
}

PDNode *patterns::PriorBox::operator()() {
auto prior_box_op =
pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box");
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,23 @@ struct ConvRequant : public PatternBase {
PATTERN_DECL_NODE(requant_out);
};

// Conv + Dequant
// named nodes:
// conv_op, conv_out
// dequant_op, dequant_out
struct ConvDequant : public PatternBase {
ConvDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_dequant") {}

PDNode* operator()();

PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);

PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
};

// PriorBox operator
// operator: prior_box_op
// inputs: prior_box_input, prior_box_image
Expand Down
33 changes: 33 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,38 @@ void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const {
found_requant_squash_count);
}

void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(),
"conv_dequant"};
conv_dequant_pattern();

int found_conv_dequant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash conv-dequant ops pair";

GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, conv_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, conv_dequant_pattern);

// if conv2d has one output
if (conv_out->outputs.size() == 1) {
conv_op->Op()->SetAttr("force_fp32_output", true);
conv_op->Op()->SetOutput("Output",
std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(conv_op, dequant_out);
GraphSafeRemoveNodes(graph, {conv_out, dequant_op});
found_conv_dequant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_conv_dequant_squash_count);
PrettyLogDetail("--- squashed %d dequant with convs",
found_conv_dequant_squash_count);
}

void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("cpu_quantize_squash_pass", graph);
Expand All @@ -168,6 +200,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
FindNodesToKeep(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter);
ConvRequantSquash(graph);
ConvDequantSquash(graph);
}

} // namespace ir
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/
void ConvRequantSquash(Graph* graph) const;

/*
* Squash conv2d with dequant when dequant is the only op after con2d
Comment thread
wojtuss marked this conversation as resolved.
Outdated
*/
void ConvDequantSquash(Graph* graph) const;

const std::string name_scope_{"squash"};
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
// d->Dequant(scale1)->e
// e->Quant(scale2)->f
// (f,w2,b2)->Conv2->i

Comment thread
wojtuss marked this conversation as resolved.
Outdated
ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2) {
ProgramDesc prog;
Expand Down Expand Up @@ -85,6 +86,7 @@ static const std::initializer_list<std::string> variable_names{
// c->Quant1(scale2)->d and d->Conv2->e
// c->Conv3->f
// c->Quant2(scale3)->g and g->Conv4->h

Comment thread
wojtuss marked this conversation as resolved.
Outdated
ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out,
float scale1, float scale2,
float scale3) {
Expand Down Expand Up @@ -161,6 +163,36 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
return prog;
}

// a->Conv1->b
// b->Dequant1(Scale1)->c
// c->Concat
ProgramDesc BuildConvDequantConcatProgramDesc(bool use_mkldnn, float scale_out,
float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn);
return prog;
}

// a->Conv1->b
// b->Dequant1(Scale1)->c
// b->Conv2->d
ProgramDesc BuildConvDequantConvProgramDesc(bool use_mkldnn, float scale_out,
Comment thread
wojtuss marked this conversation as resolved.
float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
SetOp(&prog, "conv2d", "Conv2", {"b"}, {"d"}, use_mkldnn);
return prog;
}

void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
const char* var_name) {
auto x = scope->Var(var_name);
Expand Down Expand Up @@ -192,6 +224,7 @@ void CountNodeTest(const ProgramDesc& prog, int removed_nodes_num) {

int original_nodes_num = graph->Nodes().size();
RegisterPass(&graph);

Comment thread
wojtuss marked this conversation as resolved.
Outdated
int current_nodes_num = graph->Nodes().size();

EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num);
Expand All @@ -217,6 +250,7 @@ void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name,
void CheckRequantScalesTest(const ProgramDesc& prog, float scale_in,
float scale_out) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));

PrepareGraph(&graph, prog);
RegisterPass(&graph);

Expand All @@ -238,6 +272,7 @@ TEST(CpuQuantizeSquashPass, equal_scales) {
auto use_mkldnn = true;
// Remove 4 nodes: Dequant, Quant, e, f
auto remove_nodes = 4;

CountNodeTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale, scale),
remove_nodes);
Expand All @@ -253,6 +288,7 @@ TEST(CpuQuantizeSquashPass, unequal_scales) {
auto use_mkldnn = true;
// Remove 4 nodes: Dequant, Quant, e, d
auto remove_nodes = 4;

CountNodeTest(
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
remove_nodes);
Expand Down Expand Up @@ -280,6 +316,7 @@ TEST(CpuQuantizeSquashPass, branch_to_equal_unequal_and_fp32) {
// Remove 3 nodes: Quant1, c, Quant2,
// Insert 1 node: Requant
auto remove_nodes = 2;

CountNodeTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out, scale,
scale, scale2),
remove_nodes);
Expand Down Expand Up @@ -322,6 +359,7 @@ TEST(CpuQuantizeSquashPass,
// Remove 3 nodes: Dequant1, c, Quant
// Insert 1 node: Requant
auto remove_nodes = 2;

CountNodeTest(
BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2),
remove_nodes);
Expand All @@ -345,6 +383,27 @@ TEST(CpuQuantizeSquashPass, more_than_one_conv_out_outputs) {
remove_nodes);
}

// a->Conv1->c->Concat
TEST(CpuQuantizeSquashPass, conv_dequant_only_one_output) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto use_mkldnn = true;
// remove 2 nodes: Dequant1, c
auto remove_nodes = 2;
CountNodeTest(BuildConvDequantConcatProgramDesc(use_mkldnn, scale_out, scale),
remove_nodes);
}

TEST(CpuQuantizeSquashPass, conv_dequant_more_than_one_op_after_conv) {
auto scale_out = 1.0f;
auto scale = 1.2345f;
auto use_mkldnn = true;
// nothing change
auto remove_nodes = 0;
CountNodeTest(BuildConvDequantConvProgramDesc(use_mkldnn, scale_out, scale),
remove_nodes);
}

} // namespace ir
} // namespace framework
} // namespace paddle
Expand Down