Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions lite/core/optimizer/mir/fusion/fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,24 @@ namespace lite {
namespace mir {

void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> mul_types{"mul", "matmul", "matmul_v2"};
for (auto op_type : mul_types) {
#if defined(LITE_WITH_X86) || defined(LITE_WITH_CUDA)
#ifdef LITE_WITH_MLU
fusion::FcFuser fuser(false);
fuser(graph.get());
fusion::FcFuser fuser(op_type, false);
fuser(graph.get());
#else
fusion::FcFuser fuser(true);
fuser(graph.get());
fusion::FcFuser fuser(op_type, true);
fuser(graph.get());
#endif
#endif
fusion::FcFuser fuser2(false);
fuser2(graph.get());
fusion::FcFuser fuser2(op_type, false);
fuser2(graph.get());
#ifdef LITE_WITH_FPGA
fusion::FcFuser fpga_fuser(true);
fpga_fuser(graph.get());
fusion::FcFuser fpga_fuser(op_type, true);
fpga_fuser(graph.get());
#endif
}
}

} // namespace mir
Expand Down
26 changes: 20 additions & 6 deletions lite/core/optimizer/mir/fusion/fc_fuser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,25 @@ void FcFuser::BuildPattern() {
size_t b_rank = b_shape.size();
return b_rank == 2 || b_rank == 1;
};
auto input_attr_teller = [](const Node* node) -> bool {
auto op_desc = *const_cast<Node*>(node)->stmt()->op_info();
bool trans_x = op_desc.GetAttr<bool>("trans_x");
bool trans_y = op_desc.GetAttr<bool>("trans_y") return trans_x == false &&
trans_y == false;
};

// create nodes.
auto* x = VarNode("x")->assert_is_op_input("mul", "X");
auto* W = VarNode("W")->assert_is_op_input("mul", "Y");
auto* x = VarNode("x")->assert_is_op_input(op_type_, "X");
auto* W = VarNode("W")->assert_is_op_input(op_type_, "Y");
auto* b = VarNode("b")->assert_is_persistable_var();
auto* mul = OpNode("mul", "mul")->assert_node_satisfied(inputs_teller0);
auto* mul = OpNode("mul", op_type_)->assert_node_satisfied(inputs_teller0);
auto* mul_out = VarNode("mul_out");
auto* add =
OpNode("add", "elementwise_add")->assert_node_satisfied(inputs_teller1);
auto* Out = VarNode("Out");
if (op_type_ == "matmul" || op_type_ == "matmul_v2") {
mul = OpNode("mul", op_type_)->assert_node_satisfied(input_attr_teller);
}

// create topology.
std::vector<PMNode*> mul_inputs{W, x};
Expand Down Expand Up @@ -129,9 +138,14 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc.SetInput("W", {matched.at("W")->arg()->name});
op_desc.SetInput("Bias", {matched.at("b")->arg()->name});
op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});
op_desc.SetAttr(
"in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));
if (op_type_ == "mul") {
op_desc.SetAttr(
"in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));
} else {
op_desc.SetAttr("in_num_col_dims", 1);
}

if (with_relu_) {
op_desc.SetAttr("activation_type", std::string{"relu"});
}
Expand Down
4 changes: 3 additions & 1 deletion lite/core/optimizer/mir/fusion/fc_fuser.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ namespace fusion {

class FcFuser : public FuseBase {
public:
explicit FcFuser(bool with_relu) : with_relu_(with_relu) {}
explicit FcFuser(std::string op_type, bool with_relu)
: op_type_(op_type), with_relu_(with_relu) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;

private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string op_type_;
bool with_relu_;
};

Expand Down
19 changes: 19 additions & 0 deletions lite/tests/unittest_py/pass/test_fc_fuse_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def is_program_valid(self,

def sample_program_configs(self, draw):
has_relu = draw(st.sampled_from([True, False]))
op_type = draw(st.sampled_from(["mul", "matmul", "matmul_v2"]))
mul_x_in_shape = draw(
st.lists(
st.integers(
Expand Down Expand Up @@ -105,6 +106,24 @@ def sample_program_configs(self, draw):
"x_num_col_dims": x_num_col_dims_data,
"y_num_col_dims": 1
})
inputs_data = {
"mul_x_data": TensorConfig(shape=mul_x_in_shape),
"mul_y_data": TensorConfig(shape=[x1, y1])
}
if op_type == "matmul" or op_type == "matmul_v2":
mul_op = OpConfig(
type=op_type,
inputs={"X": ["mul_x_data"],
"Y": ["mul_y_data"]},
outputs={"Out": ["mul_output_data"]},
attrs={"trans_x": False,
"trans_y": False})
inputs_data = {
"mul_x_data": TensorConfig(
shape=[draw(st.integers(
min_value=2, max_value=100)), x1]),
"mul_y_data": TensorConfig(shape=[x1, y1])
}

elementwise_add_op = OpConfig(
type="elementwise_add",
Expand Down