diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index d76c093c79c258..5b208b62b491a8 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -181,7 +181,7 @@ int FCGRUFusePass::BuildFusion(Graph* graph, Node* bias, Node* hidden, Node* fc_bias, - const bool use_mkldnn) { + const bool use_onednn) { OpDesc op_desc; op_desc.SetType("fusion_gru"); @@ -200,7 +200,7 @@ int FCGRUFusePass::BuildFusion(Graph* graph, gru->Op()->GetAttrIfExists("origin_mode")); // TODO(TJ): This should be a option for infer op_desc.SetAttr("use_seq", true); - op_desc.SetAttr("use_mkldnn", use_mkldnn); + op_desc.SetAttr("use_onednn", use_onednn); op_desc.SetAttr("activation", gru->Op()->GetAttr("activation")); op_desc.SetAttr("gate_activation", gru->Op()->GetAttr("gate_activation")); @@ -290,8 +290,9 @@ int FCGRUFusePass::BuildFusion(Graph* graph, LOG(INFO) << "fc_gru_fuse_pass not supported when origin_mode=True."; return; } - const bool use_mkldnn = - (mul->Op()->GetAttrIfExists("use_mkldnn") && + const bool use_onednn = + ((mul->Op()->GetAttrIfExists("use_mkldnn") || + mul->Op()->GetAttrIfExists("use_onednn")) && gru->Op()->GetAttrIfExists("activation") == "tanh" && gru->Op()->GetAttrIfExists("gate_activation") == "sigmoid"); @@ -302,7 +303,7 @@ int FCGRUFusePass::BuildFusion(Graph* graph, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern); - gru_creator(gru, x_n, w, Weight, Bias, Hidden, fc_bias, use_mkldnn); + gru_creator(gru, x_n, w, Weight, Bias, Hidden, fc_bias, use_onednn); // Remove unneeded nodes. std::unordered_set marked_nodes({mul, gru, @@ -314,7 +315,7 @@ int FCGRUFusePass::BuildFusion(Graph* graph, BatchHidden}); GraphSafeRemoveNodes(graph, marked_nodes); } else { - gru_creator(gru, x_n, w, Weight, Bias, Hidden, nullptr, use_mkldnn); + gru_creator(gru, x_n, w, Weight, Bias, Hidden, nullptr, use_onednn); // Remove unneeded nodes. std::unordered_set marked_nodes( {mul, gru, BatchGate, BatchResetHiddenPrev, BatchHidden}); diff --git a/test/ir/inference/test_onednn_operator_reshape2_fuse_pass.py b/test/ir/inference/test_onednn_operator_reshape2_fuse_pass.py index abd8f90f099632..251ac7a506fe15 100644 --- a/test/ir/inference/test_onednn_operator_reshape2_fuse_pass.py +++ b/test/ir/inference/test_onednn_operator_reshape2_fuse_pass.py @@ -45,7 +45,7 @@ def generate_input(shape): }, attrs={ "axis": axis, - "use_mkldnn": True, + "use_onednn": True, }, ) diff --git a/test/ir/inference/test_onednn_operator_unsqueeze2_fuse_pass.py b/test/ir/inference/test_onednn_operator_unsqueeze2_fuse_pass.py index f35c355eb0314f..eadd8379d783cd 100644 --- a/test/ir/inference/test_onednn_operator_unsqueeze2_fuse_pass.py +++ b/test/ir/inference/test_onednn_operator_unsqueeze2_fuse_pass.py @@ -43,7 +43,7 @@ def generate_input(shape): }, attrs={ "axis": transpose_axis, - "use_mkldnn": True, + "use_onednn": True, }, ) @@ -102,7 +102,7 @@ def generate_input(shape): type='elementwise_mul', inputs={'X': ['eltwise_X'], 'Y': ['eltwise_Y']}, outputs={'Out': ['eltwise_output']}, - attrs={"use_mkldnn": True}, + attrs={"use_onednn": True}, ) unsqueeze2_op = OpConfig( diff --git a/test/ir/inference/test_onednn_squeeze2_transpose2_fuse_pass.py b/test/ir/inference/test_onednn_squeeze2_transpose2_fuse_pass.py index 3b6f86d7d027dc..23fe42c69a0a60 100644 --- a/test/ir/inference/test_onednn_squeeze2_transpose2_fuse_pass.py +++ b/test/ir/inference/test_onednn_squeeze2_transpose2_fuse_pass.py @@ -42,7 +42,7 @@ def generate_input(shape): }, attrs={ "axes": [2], - "use_mkldnn": True, + "use_onednn": True, }, ) @@ -57,7 +57,7 @@ def generate_input(shape): }, attrs={ "axis": transpose_axis, - "use_mkldnn": True, + "use_onednn": True, }, )