Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,23 @@ PDNode *patterns::ConvConcatReLU::operator()() {
return relu_out;
}

PDNode *patterns::ConvRequant::operator()() {
// Create Operators
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto requant_op =
pattern->NewNode(requant_op_repr())->assert_is_op("requantize");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto requant_out = pattern->NewNode(requant_out_repr())
->AsOutput()
->assert_is_op_output("requantize", "Output");

conv_op->LinksTo({conv_out});
requant_op->LinksFrom({conv_out}).LinksTo({requant_out});

return requant_out;
}

PDNode *patterns::PriorBox::operator()() {
auto prior_box_op =
pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box");
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,23 @@ struct ConvConcatReLU : public PatternBase {
PATTERN_DECL_NODE(relu_out);
};

// Conv + Requant
// named nodes:
// conv_op, conv_out
// requant_op, requant_out
struct ConvRequant : public PatternBase {
ConvRequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_requant") {}

PDNode* operator()();

PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);

PATTERN_DECL_NODE(requant_op);
PATTERN_DECL_NODE(requant_out);
};

// PriorBox operator
// operator: prior_box_op
// inputs: prior_box_input, prior_box_image
Expand Down
50 changes: 43 additions & 7 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ void CPUQuantizeSquashPass::FindNodesToKeep(
AddStatis(found_count);
}

void CPUQuantizeSquashPass::Squash(
void CPUQuantizeSquashPass::DequantQuantSquash(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
GraphPatternDetector gpd;
patterns::DequantQuantAny squash_pattern{gpd.mutable_pattern(), "squash"};
squash_pattern();

int found_squash_count = 0;
int found_dequant_quant_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash requantize-quantize ops pair";
Expand Down Expand Up @@ -96,7 +96,7 @@ void CPUQuantizeSquashPass::Squash(

IR_NODE_LINK_TO(dequant_in, next_op);

found_squash_count++;
found_dequant_quant_count++;
} else {
// squash dequantize-quantize to requantize op
OpDesc desc;
Expand All @@ -116,13 +116,48 @@ void CPUQuantizeSquashPass::Squash(
IR_NODE_LINK_TO(dequant_in, requant_op);
IR_NODE_LINK_TO(requant_op, quant_out);

found_squash_count++;
found_dequant_quant_count++;
}
};
gpd(graph, handler);
AddStatis(found_squash_count);
AddStatis(found_dequant_quant_count);
PrettyLogDetail("--- squashed %d dequantize-quantize pairs",
found_squash_count);
found_dequant_quant_count);
}

void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const {
GraphPatternDetector gpd;
patterns::ConvRequant conv_requant_pattern{gpd.mutable_pattern(),
"conv_requant"};
conv_requant_pattern();

int found_requant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash conv-requantize ops pair";

GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, conv_requant_pattern);
GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, conv_requant_pattern);

// if conv2d has one output squash
if (conv_out->outputs.size() == 1) {
float requant_scale_out =
boost::get<float>(requant_op->Op()->GetAttr("Scale_out"));
conv_op->Op()->SetAttr("Scale_out", requant_scale_out);
conv_op->Op()->SetOutput("Output",
std::vector<std::string>({requant_out->Name()}));
IR_NODE_LINK_TO(conv_op, requant_out);
GraphSafeRemoveNodes(graph, {conv_out, requant_op});

found_requant_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_requant_squash_count);
PrettyLogDetail("--- squashed %d requantize with convs",
found_requant_squash_count);
}

void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
Expand All @@ -131,7 +166,8 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {

std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph, &nodes_keep_counter);
Squash(graph, &nodes_keep_counter);
DequantQuantSquash(graph, &nodes_keep_counter);
ConvRequantSquash(graph);
}

} // namespace ir
Expand Down
10 changes: 8 additions & 2 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,14 @@ class CPUQuantizeSquashPass : public FusePassBase {
/*
* Squash dequantize-quantize ops pairs into requantize or nothing
*/
void Squash(Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
void DequantQuantSquash(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const;

/*
* Squash requantize op into conv with scale_out like requantize scale_out
*/
void ConvRequantSquash(Graph* graph) const;

const std::string name_scope_{"squash"};
};
Expand Down
Loading