@@ -181,7 +181,7 @@ int FCGRUFusePass::BuildFusion(Graph* graph,
181181 Node* bias,
182182 Node* hidden,
183183 Node* fc_bias,
184- const bool use_mkldnn ) {
184+ const bool use_onednn ) {
185185 OpDesc op_desc;
186186 op_desc.SetType (" fusion_gru" );
187187
@@ -200,7 +200,7 @@ int FCGRUFusePass::BuildFusion(Graph* graph,
200200 gru->Op ()->GetAttrIfExists <bool >(" origin_mode" ));
201201 // TODO(TJ): This should be a option for infer
202202 op_desc.SetAttr (" use_seq" , true );
203- op_desc.SetAttr (" use_mkldnn " , use_mkldnn );
203+ op_desc.SetAttr (" use_onednn " , use_onednn );
204204 op_desc.SetAttr (" activation" , gru->Op ()->GetAttr (" activation" ));
205205 op_desc.SetAttr (" gate_activation" , gru->Op ()->GetAttr (" gate_activation" ));
206206
@@ -290,8 +290,9 @@ int FCGRUFusePass::BuildFusion(Graph* graph,
290290 LOG (INFO) << " fc_gru_fuse_pass not supported when origin_mode=True." ;
291291 return ;
292292 }
293- const bool use_mkldnn =
294- (mul->Op ()->GetAttrIfExists <bool >(" use_mkldnn" ) &&
293+ const bool use_onednn =
294+ ((mul->Op ()->GetAttrIfExists <bool >(" use_mkldnn" ) ||
295+ mul->Op ()->GetAttrIfExists <bool >(" use_onednn" )) &&
295296 gru->Op ()->GetAttrIfExists <std::string>(" activation" ) == " tanh" &&
296297 gru->Op ()->GetAttrIfExists <std::string>(" gate_activation" ) ==
297298 " sigmoid" );
@@ -302,7 +303,7 @@ int FCGRUFusePass::BuildFusion(Graph* graph,
302303 GET_IR_NODE_FROM_SUBGRAPH (elementwise_add, elementwise_add, fc_pattern);
303304 GET_IR_NODE_FROM_SUBGRAPH (fc_out, elementwise_add_out, fc_pattern);
304305
305- gru_creator (gru, x_n, w, Weight, Bias, Hidden, fc_bias, use_mkldnn );
306+ gru_creator (gru, x_n, w, Weight, Bias, Hidden, fc_bias, use_onednn );
306307 // Remove unneeded nodes.
307308 std::unordered_set<const Node*> marked_nodes ({mul,
308309 gru,
@@ -314,7 +315,7 @@ int FCGRUFusePass::BuildFusion(Graph* graph,
314315 BatchHidden});
315316 GraphSafeRemoveNodes (graph, marked_nodes);
316317 } else {
317- gru_creator (gru, x_n, w, Weight, Bias, Hidden, nullptr , use_mkldnn );
318+ gru_creator (gru, x_n, w, Weight, Bias, Hidden, nullptr , use_onednn );
318319 // Remove unneeded nodes.
319320 std::unordered_set<const Node*> marked_nodes (
320321 {mul, gru, BatchGate, BatchResetHiddenPrev, BatchHidden});
0 commit comments