Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,6 @@ endif()
cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass)
cc_test(test_multi_gru_fuse_pass SRCS mkldnn/multi_gru_fuse_pass_tester.cc DEPS multi_gru_fuse_pass)
cc_test(test_multi_gru_seq_fuse_pass SRCS mkldnn/multi_gru_seq_fuse_pass_tester.cc DEPS multi_gru_seq_fuse_pass)
set(TEST_FC_RNN_PASS_DEPS fc_gru_fuse_pass fc_lstm_fuse_pass mkldnn_placement_pass)
cc_test(test_fc_rnn_mkldnn_fuse_pass SRCS mkldnn/mkldnn_fc_rnn_fuse_pass_tester.cc DEPS ${TEST_FC_RNN_PASS_DEPS})
endif ()
14 changes: 10 additions & 4 deletions paddle/fluid/framework/ir/fc_gru_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
gru_pattern(fc_out);

// Create New OpDesc
auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h,
Node* bias, Node* hidden, Node* fc_bias) {
auto gru_creator = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h,
Node* bias, Node* hidden, Node* fc_bias,
const bool use_mkldnn) {
OpDesc op_desc;
op_desc.SetType("fusion_gru");

Expand All @@ -67,6 +68,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
gru->Op()->GetAttrIfExists<bool>("origin_mode"));
// TODO(TJ): This should be a option for infer
op_desc.SetAttr("use_seq", true);
op_desc.SetAttr("use_mkldnn", use_mkldnn);
op_desc.SetAttr("activation", gru->Op()->GetAttr("activation"));
op_desc.SetAttr("gate_activation", gru->Op()->GetAttr("gate_activation"));

Expand Down Expand Up @@ -149,21 +151,25 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
LOG(INFO) << "fc_gru_fuse_pass not supported when origin_mode=True.";
return;
}
const bool use_mkldnn =
mul->Op()->GetAttrIfExists<bool>("use_mkldnn") &&
gru->Op()->GetAttrIfExists<std::string>("activation") == "tahn" &&
gru->Op()->GetAttrIfExists<std::string>("gate_activation") == "sigmoid";

if (with_fc_bias) {
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);

gru_creater(gru, x_n, w, Weight, Bias, Hidden, fc_bias);
gru_creator(gru, x_n, w, Weight, Bias, Hidden, fc_bias, use_mkldnn);
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes(
{mul, gru, elementwise_add, fc_out, mul_out, BatchGate,
BatchResetHiddenPrev, BatchHidden});
GraphSafeRemoveNodes(graph, marked_nodes);
} else {
gru_creater(gru, x_n, w, Weight, Bias, Hidden, nullptr);
gru_creator(gru, x_n, w, Weight, Bias, Hidden, nullptr, use_mkldnn);
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes(
{mul, gru, BatchGate, BatchResetHiddenPrev, BatchHidden});
Expand Down
71 changes: 5 additions & 66 deletions paddle/fluid/framework/ir/fc_gru_fuse_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,77 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h"

#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/fc_gru_fuse_pass_tester.h"

namespace paddle {
namespace framework {
namespace ir {

void AddVarToScope(Scope* param_scope, const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}

Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "gru_fc_w", {});
AddVarToScope(param_scope, "gru_fc_b", {});
AddVarToScope(param_scope, "gru_w", {});
AddVarToScope(param_scope, "gru_b", {});
AddVarToScope(param_scope, "gru_batch_gate_0", {});
AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_0", {});
AddVarToScope(param_scope, "gru_batch_hidden_0", {});
AddVarToScope(param_scope, "gru_hidden_0", {});
AddVarToScope(param_scope, "gru_batch_gate_1", {});
AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_1", {});
AddVarToScope(param_scope, "gru_batch_hidden_1", {});
AddVarToScope(param_scope, "gru_hidden_1", {});
return param_scope;
}

TEST(FCFusePass, basic) {
// inputs operator output
// --------------------------------------------------------
// (a, gru_fc_w) mul -> fc_0_tmp_0
// (fc_0_tmp_0, gru_fc_b) elementwise_add -> fc_0_tmp_1
// (fc_0_tmp_1,gru_w,gru_b gru -> gru_out_0

// (b, gru_fc_w) mul -> fc_1_tmp_0
// (fc_1_tmp_0, gru_fc_b) elementwise_add -> fc_1_tmp_1
// (fc_1_tmp_1,gru_w,gru_b) gru -> gru_out_1
Layers layers;
auto* a = layers.data("a");
auto* b = layers.data("b");
auto* fc_w = layers.data("gru_fc_w", {}, true);
auto* fc_b = layers.data("gru_fc_b", {}, true);
auto* gru_w = layers.data("gru_w", {}, true);
auto* gru_b = layers.data("gru_b", {}, true);
auto* fc_0_tmp0 = layers.mul(a, fc_w);
auto* fc_0_tmp1 = layers.elementwise_add(fc_0_tmp0, fc_b);
auto* gru_batch_gate_0 = layers.data("gru_batch_gate_0", {}, false);
auto* gru_batch_reset_hidden_prev_0 =
layers.data("gru_batch_reset_hidden_prev_0", {}, false);
auto* gru_batch_hidden_0 = layers.data("gru_batch_hidden_0", {}, false);
auto* gru_hidden_0 = layers.data("gru_hidden_0", {}, false);
layers.gru(fc_0_tmp1, gru_w, gru_b, gru_batch_gate_0,
gru_batch_reset_hidden_prev_0, gru_batch_hidden_0, gru_hidden_0);

auto* fc_1_tmp0 = layers.mul(b, fc_w);
auto* fc_1_tmp1 = layers.elementwise_add(fc_1_tmp0, fc_b);
auto* gru_batch_gate_1 = layers.data("gru_batch_gate_1", {}, false);
auto* gru_batch_reset_hidden_prev_1 =
layers.data("gru_batch_reset_hidden_prev_1", {}, false);
auto* gru_batch_hidden_1 = layers.data("gru_batch_hidden_1", {}, false);
auto* gru_hidden_1 = layers.data("gru_hidden_1", {}, false);
layers.gru(fc_1_tmp1, gru_w, gru_b, gru_batch_gate_1,
gru_batch_reset_hidden_prev_1, gru_batch_hidden_1, gru_hidden_1);

std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
namespace fc_gru_test {
TEST(FcGruFusePass, basic) {
std::unique_ptr<ir::Graph> graph = PrepareGraph();
auto pass = PassRegistry::Instance().Get("fc_gru_fuse_pass");
pass->Set("use_gpu", new bool(true));
graph->Set("__param_scope__", CreateParamScope());
Expand All @@ -109,6 +47,7 @@ TEST(FCFusePass, basic) {
"expectations after fuse"));
}

} // namespace fc_gru_test
} // namespace ir
} // namespace framework
} // namespace paddle
Expand Down
93 changes: 93 additions & 0 deletions paddle/fluid/framework/ir/fc_gru_fuse_pass_tester.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h"

#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
namespace framework {
namespace ir {

namespace fc_gru_test {
void AddVarToScope(Scope* param_scope, const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}

Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "gru_fc_w", {});
AddVarToScope(param_scope, "gru_fc_b", {});
AddVarToScope(param_scope, "gru_w", {});
AddVarToScope(param_scope, "gru_b", {});
AddVarToScope(param_scope, "gru_batch_gate_0", {});
AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_0", {});
AddVarToScope(param_scope, "gru_batch_hidden_0", {});
AddVarToScope(param_scope, "gru_hidden_0", {});
AddVarToScope(param_scope, "gru_batch_gate_1", {});
AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_1", {});
AddVarToScope(param_scope, "gru_batch_hidden_1", {});
AddVarToScope(param_scope, "gru_hidden_1", {});
return param_scope;
}

std::unique_ptr<ir::Graph> PrepareGraph() {
// inputs operator output
// --------------------------------------------------------
// (a, gru_fc_w) mul -> fc_0_tmp_0
// (fc_0_tmp_0, gru_fc_b) elementwise_add -> fc_0_tmp_1
// (fc_0_tmp_1,gru_w,gru_b gru -> gru_out_0

// (b, gru_fc_w) mul -> fc_1_tmp_0
// (fc_1_tmp_0, gru_fc_b) elementwise_add -> fc_1_tmp_1
// (fc_1_tmp_1,gru_w,gru_b) gru -> gru_out_1
Layers layers;
auto* a = layers.data("a");
auto* b = layers.data("b");
auto* fc_w = layers.data("gru_fc_w", {}, true);
auto* fc_b = layers.data("gru_fc_b", {}, true);
auto* gru_w = layers.data("gru_w", {}, true);
auto* gru_b = layers.data("gru_b", {}, true);
auto* fc_0_tmp0 = layers.mul(a, fc_w);
auto* fc_0_tmp1 = layers.elementwise_add(fc_0_tmp0, fc_b);
auto* gru_batch_gate_0 = layers.data("gru_batch_gate_0", {}, false);
auto* gru_batch_reset_hidden_prev_0 =
layers.data("gru_batch_reset_hidden_prev_0", {}, false);
auto* gru_batch_hidden_0 = layers.data("gru_batch_hidden_0", {}, false);
auto* gru_hidden_0 = layers.data("gru_hidden_0", {}, false);
layers.gru(fc_0_tmp1, gru_w, gru_b, gru_batch_gate_0,
gru_batch_reset_hidden_prev_0, gru_batch_hidden_0, gru_hidden_0);

auto* fc_1_tmp0 = layers.mul(b, fc_w);
auto* fc_1_tmp1 = layers.elementwise_add(fc_1_tmp0, fc_b);
auto* gru_batch_gate_1 = layers.data("gru_batch_gate_1", {}, false);
auto* gru_batch_reset_hidden_prev_1 =
layers.data("gru_batch_reset_hidden_prev_1", {}, false);
auto* gru_batch_hidden_1 = layers.data("gru_batch_hidden_1", {}, false);
auto* gru_hidden_1 = layers.data("gru_hidden_1", {}, false);
layers.gru(fc_1_tmp1, gru_w, gru_b, gru_batch_gate_1,
gru_batch_reset_hidden_prev_1, gru_batch_hidden_1, gru_hidden_1);

std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
return std::move(graph);
}
} // namespace fc_gru_test
} // namespace ir
} // namespace framework
} // namespace paddle
15 changes: 12 additions & 3 deletions paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
// Create New OpDesc
auto lstm_creator = [&](Node* lstm, Node* input, Node* weight_x,
Node* weight_h, Node* bias, Node* hidden, Node* cell,
Node* xx, Node* fc_bias) {
Node* xx, Node* fc_bias, const bool use_mkldnn) {
OpDesc op_desc;
op_desc.SetType("fusion_lstm");
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
Expand Down Expand Up @@ -88,6 +88,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetOutput("XX", {xx->Name()});
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
op_desc.SetAttr("use_mkldnn", use_mkldnn);
// TODO(TJ): get from attr
op_desc.SetAttr("use_seq", true);

Expand Down Expand Up @@ -148,21 +149,29 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
const bool use_mkldnn =
mul->Op()->GetAttrIfExists<bool>("use_mkldnn") &&
lstm->Op()->GetAttrIfExists<std::string>("gate_activation") ==
"sigmoid" &&
lstm->Op()->GetAttrIfExists<std::string>("cell_activation") == "tahn" &&
lstm->Op()->GetAttrIfExists<std::string>("candidate_activation") ==
"tahn";

if (with_fc_bias) {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
fc_bias);
fc_bias, use_mkldnn);
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes(
{mul, lstm, elementwise_add, mul_out, BatchGate, BatchCellPreAct});
GraphSafeRemoveNodes(graph, marked_nodes);
} else {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
nullptr);
nullptr, use_mkldnn);
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes(
{mul, lstm, BatchGate, BatchCellPreAct});
Expand Down
71 changes: 5 additions & 66 deletions paddle/fluid/framework/ir/fc_lstm_fuse_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,77 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"

#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass_tester.h"

namespace paddle {
namespace framework {
namespace ir {

void AddVarToScope(Scope* param_scope, const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}

Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "lstm_fc_w", {});
AddVarToScope(param_scope, "lstm_fc_b", {});
AddVarToScope(param_scope, "lstm_w", {});
AddVarToScope(param_scope, "lstm_b", {});
AddVarToScope(param_scope, "lstm_cell_0", {});
AddVarToScope(param_scope, "lstm_batch_gate_0", {});
AddVarToScope(param_scope, "lstm_batch_cell_pre_gate_0", {});
AddVarToScope(param_scope, "lstm_hidden_0", {});
AddVarToScope(param_scope, "lstm_cell_1", {});
AddVarToScope(param_scope, "lstm_batch_gate_1", {});
AddVarToScope(param_scope, "lstm_batch_cell_pre_gate_1", {});
AddVarToScope(param_scope, "lstm_hidden_1", {});
return param_scope;
}

TEST(FCLSTMFusePass, basic) {
// inputs operator output
// --------------------------------------------------------
// (a, lstm_fc_w) mul -> fc_0_tmp_0
// (fc_0_tmp_0, lstm_fc_b) elementwise_add -> fc_0_tmp_1
// fc_0_tmp_1,lstm_w,lstm_b lstm -> lstm_out_0

// (b, lstm_fc_w) mul -> fc_1_tmp_0
// (fc_1_tmp_0, lstm_fc_b) elementwise_add -> fc_1_tmp_1
// (fc_1_tmp_1,lstm_w,lstm_b) lstm -> lstm_out_1
Layers layers;
auto* a = layers.data("a");
auto* b = layers.data("b");
auto* fc_w = layers.data("lstm_fc_w", {}, true);
auto* fc_b = layers.data("lstm_fc_b", {}, true);
auto* lstm_w = layers.data("lstm_w", {}, true);
auto* lstm_b = layers.data("lstm_b", {}, true);
auto* fc_0_tmp0 = layers.mul(a, fc_w);
auto* fc_0_tmp1 = layers.elementwise_add(fc_0_tmp0, fc_b);
auto* lstm_cell_0 = layers.data("lstm_cell_0", {}, false);
auto* lstm_batch_gate_0 = layers.data("lstm_batch_gate_0", {}, false);
auto* lstm_batch_cell_pre_gate_0 =
layers.data("lstm_batch_cell_pre_gate_0", {}, false);
auto* lstm_hidden_0 = layers.data("lstm_hidden_0", {}, false);
layers.lstm(fc_0_tmp1, lstm_w, lstm_b, lstm_cell_0, lstm_batch_gate_0,
lstm_hidden_0, lstm_batch_cell_pre_gate_0);
namespace fc_lstm_test {

auto* fc_1_tmp0 = layers.mul(b, fc_w);
auto* fc_1_tmp1 = layers.elementwise_add(fc_1_tmp0, fc_b);
auto* lstm_cell_1 = layers.data("lstm_cell_1", {}, false);
auto* lstm_batch_gate_1 = layers.data("lstm_batch_gate_1", {}, false);
auto* lstm_batch_cell_pre_gate_1 =
layers.data("lstm_batch_cell_pre_gate_1", {}, false);
auto* lstm_hidden_1 = layers.data("lstm_hidden_1", {}, false);
layers.lstm(fc_1_tmp1, lstm_w, lstm_b, lstm_cell_1, lstm_batch_gate_1,
lstm_hidden_1, lstm_batch_cell_pre_gate_1);

std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
TEST(FcLstmFusePass, basic) {
std::unique_ptr<ir::Graph> graph = PrepareGraph();
auto pass = PassRegistry::Instance().Get("fc_lstm_fuse_pass");
pass->Set("use_gpu", new bool(false));
graph->Set("__param_scope__", CreateParamScope());
Expand All @@ -108,7 +47,7 @@ TEST(FCLSTMFusePass, basic) {
"The number of fusion_gru nodes does "
"not meet expectations after fuse"));
}

} // namespace fc_lstm_test
} // namespace ir
} // namespace framework
} // namespace paddle
Expand Down
Loading