@@ -206,17 +206,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
206206 if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
207207
208208 // create mkldnn memory from input x tensor
209- mkldnn::memory::format input_format =
210- platform::MKLDNNFormatForSize (src_tz.size (), x->format ());
211209
212210 // keys for backward pass
213211 const std::string key = BatchNormMKLDNNHandler::GetHash (
214- src_tz, epsilon, flags, global_stats, input_format ,
212+ src_tz, epsilon, flags, global_stats, x-> format () ,
215213 ctx.op ().Output (" SavedMean" ));
216214 const std::string key_batch_norm_fwd_pd = key + " @bn_fwd_pd" ;
217215
218- auto user_src_md = platform::MKLDNNMemDesc (
219- {src_tz}, platform::MKLDNNGetDataType<T>(), input_format);
216+ auto user_src_md = x->get_mkldnn_prim_desc ().desc ();
220217
221218 // create primitive descriptor for batch norm forward
222219 using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
@@ -230,8 +227,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
230227 BatchNormMKLDNNHandler handler (batch_norm_fwd_pd, dev_ctx, mkldnn_engine,
231228 key);
232229
233- auto src_memory =
234- handler. AcquireSrcMemory (user_src_md, to_void_cast (x_data));
230+ auto src_memory = handler. AcquireSrcMemory (x-> get_mkldnn_prim_desc (),
231+ to_void_cast (x_data));
235232
236233 // crate mkldnn memory for weights(scale/shift)
237234 auto scaleshift_memory =
@@ -265,8 +262,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
265262 variance_memory, false );
266263 }
267264
268- y->set_layout (DataLayout::kMKLDNN );
269- y->set_format (platform::GetMKLDNNFormat (*dst_memory));
265+ y->set_mkldnn_prim_desc (dst_memory->get_primitive_desc ());
270266
271267 std::vector<mkldnn::primitive> pipeline;
272268 pipeline.push_back (*batch_norm_p);
@@ -336,24 +332,21 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
336332
337333 using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
338334
339- mkldnn::memory::format dst_format =
340- platform::MKLDNNFormatForSize (src_tz.size (), diff_y->format ());
341-
342335 mkldnn::memory::format input_format =
343336 platform::MKLDNNFormatForSize (src_tz.size (), x->format ());
344337
345338 unsigned flags = mkldnn::use_scale_shift;
346339
347340 // keys from forward pass
348341 const std::string key = BatchNormMKLDNNHandler::GetHash (
349- src_tz, epsilon, flags, false , input_format ,
342+ src_tz, epsilon, flags, false , x-> format () ,
350343 ctx.op ().Input (" SavedMean" ));
351344 const std::string key_batch_norm_fwd_pd = key + " @bn_fwd_pd" ;
352345
353346 // keys for primitives reuse
354347 const std::string key_with_hash =
355348 key + BatchNormMKLDNNHandler::GetHash (src_tz, epsilon, flags, false ,
356- input_format );
349+ x-> format () );
357350 const std::string key_batch_norm_bwd_p =
358351 key_with_hash + " @batch_norm_bwd_p" ;
359352 const std::string key_batch_norm_src_mem_p =
@@ -373,9 +366,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
373366
374367 primitive reorder_diff_dst;
375368 bool is_diff_dst_reordered = false ;
376- auto user_diff_dst_memory = memory (
377- {{{diff_dst_tz}, memory::data_type::f32 , dst_format}, mkldnn_engine},
378- to_void_cast (diff_y_data));
369+ auto user_diff_dst_memory =
370+ memory (diff_y->get_mkldnn_prim_desc (), to_void_cast (diff_y_data));
379371
380372 // MKLDNN requires a single piece of memory for scale and shift/bias data
381373 const size_t scaleshift_size = 2 * ic;
@@ -459,10 +451,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
459451 dev_ctx.SetBlob (key_batch_norm_diff_dst_mem_p, diff_dst_memory);
460452
461453 // set layout/format of output tensors
462- diff_x->set_layout (DataLayout::kMKLDNN );
463- diff_x->set_format ((memory::format)diff_src_memory->get_primitive_desc ()
464- .desc ()
465- .data .format );
454+ diff_x->set_mkldnn_prim_desc (diff_src_memory->get_primitive_desc ());
466455 } else {
467456 // primitives already exist
468457 UpdateMemoryData (dev_ctx, key_batch_norm_src_mem_p, to_void_cast (x_data));
@@ -487,10 +476,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
487476 }
488477
489478 // set layout/format of output tensors
490- diff_x->set_layout (DataLayout::kMKLDNN );
491- diff_x->set_format ((memory::format)diff_src_memory->get_primitive_desc ()
492- .desc ()
493- .data .format );
479+ diff_x->set_mkldnn_prim_desc (diff_src_memory->get_primitive_desc ());
494480 }
495481
496482 // execute optional reorder and batch_norm backward primitive
0 commit comments