Skip to content

Commit 5cd2282

Browse files
authored
fc_gru_fuse_pass modify use_mkldnn [fluid_ops] (#74680)
* Fix * Fix
1 parent 8e9a9c7 commit 5cd2282

File tree

4 files changed

+12
-11
lines changed

4 files changed

+12
-11
lines changed

paddle/fluid/framework/ir/fc_gru_fuse_pass.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ int FCGRUFusePass::BuildFusion(Graph* graph,
181181
Node* bias,
182182
Node* hidden,
183183
Node* fc_bias,
184-
const bool use_mkldnn) {
184+
const bool use_onednn) {
185185
OpDesc op_desc;
186186
op_desc.SetType("fusion_gru");
187187

@@ -200,7 +200,7 @@ int FCGRUFusePass::BuildFusion(Graph* graph,
200200
gru->Op()->GetAttrIfExists<bool>("origin_mode"));
201201
// TODO(TJ): This should be a option for infer
202202
op_desc.SetAttr("use_seq", true);
203-
op_desc.SetAttr("use_mkldnn", use_mkldnn);
203+
op_desc.SetAttr("use_onednn", use_onednn);
204204
op_desc.SetAttr("activation", gru->Op()->GetAttr("activation"));
205205
op_desc.SetAttr("gate_activation", gru->Op()->GetAttr("gate_activation"));
206206

@@ -290,8 +290,9 @@ int FCGRUFusePass::BuildFusion(Graph* graph,
290290
LOG(INFO) << "fc_gru_fuse_pass not supported when origin_mode=True.";
291291
return;
292292
}
293-
const bool use_mkldnn =
294-
(mul->Op()->GetAttrIfExists<bool>("use_mkldnn") &&
293+
const bool use_onednn =
294+
((mul->Op()->GetAttrIfExists<bool>("use_mkldnn") ||
295+
mul->Op()->GetAttrIfExists<bool>("use_onednn")) &&
295296
gru->Op()->GetAttrIfExists<std::string>("activation") == "tanh" &&
296297
gru->Op()->GetAttrIfExists<std::string>("gate_activation") ==
297298
"sigmoid");
@@ -302,7 +303,7 @@ int FCGRUFusePass::BuildFusion(Graph* graph,
302303
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
303304
GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
304305

305-
gru_creator(gru, x_n, w, Weight, Bias, Hidden, fc_bias, use_mkldnn);
306+
gru_creator(gru, x_n, w, Weight, Bias, Hidden, fc_bias, use_onednn);
306307
// Remove unneeded nodes.
307308
std::unordered_set<const Node*> marked_nodes({mul,
308309
gru,
@@ -314,7 +315,7 @@ int FCGRUFusePass::BuildFusion(Graph* graph,
314315
BatchHidden});
315316
GraphSafeRemoveNodes(graph, marked_nodes);
316317
} else {
317-
gru_creator(gru, x_n, w, Weight, Bias, Hidden, nullptr, use_mkldnn);
318+
gru_creator(gru, x_n, w, Weight, Bias, Hidden, nullptr, use_onednn);
318319
// Remove unneeded nodes.
319320
std::unordered_set<const Node*> marked_nodes(
320321
{mul, gru, BatchGate, BatchResetHiddenPrev, BatchHidden});

test/ir/inference/test_onednn_operator_reshape2_fuse_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def generate_input(shape):
4545
},
4646
attrs={
4747
"axis": axis,
48-
"use_mkldnn": True,
48+
"use_onednn": True,
4949
},
5050
)
5151

test/ir/inference/test_onednn_operator_unsqueeze2_fuse_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def generate_input(shape):
4343
},
4444
attrs={
4545
"axis": transpose_axis,
46-
"use_mkldnn": True,
46+
"use_onednn": True,
4747
},
4848
)
4949

@@ -102,7 +102,7 @@ def generate_input(shape):
102102
type='elementwise_mul',
103103
inputs={'X': ['eltwise_X'], 'Y': ['eltwise_Y']},
104104
outputs={'Out': ['eltwise_output']},
105-
attrs={"use_mkldnn": True},
105+
attrs={"use_onednn": True},
106106
)
107107

108108
unsqueeze2_op = OpConfig(

test/ir/inference/test_onednn_squeeze2_transpose2_fuse_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def generate_input(shape):
4242
},
4343
attrs={
4444
"axes": [2],
45-
"use_mkldnn": True,
45+
"use_onednn": True,
4646
},
4747
)
4848

@@ -57,7 +57,7 @@ def generate_input(shape):
5757
},
5858
attrs={
5959
"axis": transpose_axis,
60-
"use_mkldnn": True,
60+
"use_onednn": True,
6161
},
6262
)
6363

0 commit comments

Comments
 (0)