Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pass_library(runtime_context_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference)
pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference)

if(ANAKIN_FOUND)
pass_library(simplify_anakin_priorbox_detection_out_pass inference)
Expand Down
82 changes: 82 additions & 0 deletions paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>

#include "paddle/fluid/framework/ir/delete_quant_dequant_op_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"

namespace paddle {
namespace framework {
namespace ir {

#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(any_op_out); \
GET_IR_NODE(quant_dequant_op_inscale); \
GET_IR_NODE(quant_dequant_op); \
GET_IR_NODE(quant_dequant_op_outscale); \
GET_IR_NODE(quant_dequant_op_out); \
GET_IR_NODE(any_op2);

void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_quantdequant_op_pattern";
FusePassBase::Init(pattern_name, graph);

GraphPatternDetector gpd;

patterns::DeleteQuantDequantOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
pattern();

auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
IR_NODE_LINK_TO(any_op_out, any_op2);
std::string any_op_out_name = any_op_out->Var()->Name();
std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name();

auto* any_op2_desc = any_op2->Op();
// auto input_args_names = any_op2_desc->InputArgumentNames();
auto var_map = any_op2_desc->Inputs();

for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(),
quant_dequant_op_out_name) != name_m.second.end()) {
std::vector<std::string> new_inputs;
for (auto& i_n : name_m.second) {
if (i_n != quant_dequant_op_out_name) {
new_inputs.push_back(i_n);
}
}
new_inputs.push_back(any_op_out_name);
any_op2_desc->SetInput(name_m.first, new_inputs);
any_op2_desc->Flush();
}
}
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph,
{quant_dequant_op, quant_dequant_op_out,
quant_dequant_op_inscale, quant_dequant_op_outscale});
};

gpd(graph, handler);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(delete_quant_dequant_op_pass,
paddle::framework::ir::DeleteQuantDequantOpPass);
34 changes: 34 additions & 0 deletions paddle/fluid/framework/ir/delete_quant_dequant_op_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"

namespace paddle {
namespace framework {
namespace ir {

class DeleteQuantDequantOpPass : public FusePassBase {
public:
virtual ~DeleteQuantDequantOpPass() {}

protected:
void ApplyImpl(ir::Graph* graph) const override;
};

} // namespace ir
} // namespace framework
} // namespace paddle
5 changes: 5 additions & 0 deletions paddle/fluid/framework/ir/fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8"));
desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale"));
desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale"));
if (base_op_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", base_op_desc->GetAttr("out_scale"));
auto elementwise_desc = elementwise_add->Op();
if (elementwise_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale"));
}

desc.SetType("fc");
Expand Down
85 changes: 73 additions & 12 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1641,25 +1641,36 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const std::string &op_type,
const std::string &weight_name,
int times,
const std::string &quant_type) {
const int kNumFields = 5;
const std::string &quant_type,
const std::string &dequant_type) {
int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
const int kDequantOpWeightScaleOffset = 5;

// the quant op always be one.
auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input(quant_type, "InScale")
->AsInput();
auto quant_op =
pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type);

auto quant_op_out_scale =
pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale")
->AsIntermediate();
PDNode *quant_op_out_scale = nullptr;
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
kNumFields += 1;
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_nth_input(dequant_type, "Scales", 1)
->AsIntermediate();
} else {
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_input(dequant_type, "Scale")
->AsIntermediate();
}

auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output(quant_type, "Out")
Expand All @@ -1680,16 +1691,25 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op_out") + std::to_string(i))
->assert_is_op_output(op_type)
->assert_is_op_input("fake_dequantize_max_abs", "X")
->assert_is_op_input(dequant_type, "X")
->AsIntermediate());

nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op") + std::to_string(i))
->assert_is_op("fake_dequantize_max_abs"));
->assert_is_op(dequant_type));

nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op_out") + std::to_string(i))
->assert_is_op_output("fake_dequantize_max_abs", "Out")
->assert_is_op_output(dequant_type, "Out")
->AsOutput());

if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes.push_back(pattern
->NewNode(GetNodeName("dequant_channel_scale") +
std::to_string(i))
->assert_is_op_nth_input(dequant_type, "Scales", 0)
->AsInput());
}
}

quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
Expand All @@ -1699,8 +1719,14 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOffset]});
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale,
nodes[i * kNumFields + kDequantOpWeightScaleOffset]});
} else {
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
}
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kDequantOpOffset]});
}
Expand Down Expand Up @@ -1737,6 +1763,41 @@ void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
reshape2_out->LinksFrom({reshape2_op});
}

void patterns::DeleteQuantDequantOpPattern::operator()() {
auto any_op_out =
pattern->NewNode(any_op_out_repr())
->assert_is_op_input(
"fake_quantize_dequantize_moving_average_abs_max", "X")
->AsInput();

auto quant_dequant_op_inscale =
pattern->NewNode(quant_dequant_op_inscale_repr())
->assert_is_op_input(
"fake_quantize_dequantize_moving_average_abs_max", "InScale")
->AsInput();
auto quant_dequant_op =
pattern->NewNode(quant_dequant_op_repr())
->assert_is_op("fake_quantize_dequantize_moving_average_abs_max");

auto quant_dequant_out =
pattern->NewNode(quant_dequant_op_out_repr())
->assert_is_op_output(
"fake_quantize_dequantize_moving_average_abs_max", "Out")
->AsIntermediate();

auto quant_dequant_op_outscale =
pattern->NewNode(quant_dequant_op_outscale_repr())
->assert_is_op_output(
"fake_quantize_dequantize_moving_average_abs_max", "OutScale")
->AsOutput();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();

quant_dequant_op->LinksFrom({any_op_out, quant_dequant_op_inscale});
quant_dequant_op_outscale->LinksFrom({quant_dequant_op});
quant_dequant_out->LinksFrom({quant_dequant_op});
any_op2->LinksFrom({quant_dequant_out});
}

} // namespace ir
} // namespace framework
} // namespace paddle
17 changes: 16 additions & 1 deletion paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,8 @@ struct QuantDequantOpFuse : public PatternBase {

void operator()(PDNode* quant_op_input, const std::string& op_name,
const std::string& weight_name, int times,
const std::string& quant_type);
const std::string& quant_type,
const std::string& dequant_type);

std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);
Expand All @@ -907,6 +908,20 @@ struct ShuffleChannelPattern : public PatternBase {
PATTERN_DECL_NODE(reshape2_out);
};

struct DeleteQuantDequantOpPattern : public PatternBase {
DeleteQuantDequantOpPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_quantdequant_op_pattern") {}

void operator()();

PATTERN_DECL_NODE(any_op_out);
PATTERN_DECL_NODE(quant_dequant_op_inscale);
PATTERN_DECL_NODE(quant_dequant_op);
PATTERN_DECL_NODE(quant_dequant_op_outscale);
PATTERN_DECL_NODE(quant_dequant_op_out);
PATTERN_DECL_NODE(any_op2);
};

} // namespace patterns

// Link two ir::Nodes from each other.
Expand Down
Loading