@@ -35,167 +35,149 @@ using paddle::platform::MKLDNNDeviceContext;
3535using platform::to_void_cast;
3636
3737template <typename T>
38- class BatchNormMKLDNNHandler
39- : public platform::MKLDNNHandlerT< T, mkldnn::batch_normalization_forward,
40- mkldnn::batch_normalization_backward> {
38+ class BatchNormMKLDNNHandler : public platform ::MKLDNNHandlerNoCachingT<
39+ T, mkldnn::batch_normalization_forward,
40+ mkldnn::batch_normalization_backward> {
4141 public:
4242 BatchNormMKLDNNHandler (const paddle::framework::ExecutionContext &ctx,
43- const platform::MKLDNNDeviceContext &dev_ctx,
44- const mkldnn::engine mkldnn_engine,
45- platform::Place cpu_place, const Tensor *x,
46- const bool global_stats, const bool test_mode,
47- const std::string &unique_name)
48- : platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
49- mkldnn::batch_normalization_backward>(
50- dev_ctx, dev_ctx.GetEngine(), cpu_place,
51- platform::CreateKey (dev_ctx, framework::vectorize(x->dims ()),
52- unique_name)) {
53- if (!this ->isCached ()) {
54- const float epsilon = ctx.Attr <float >(" epsilon" );
55- const bool fuse_with_relu = ctx.Attr <bool >(" fuse_with_relu" );
56-
57- std::vector<std::string> DataLayout_error_msg = {" kNHWC" , " kNCHW" ,
58- " kAnyLayout" , " kMKLDNN" };
59- PADDLE_ENFORCE_EQ (
60- x->layout (), DataLayout::kMKLDNN ,
61- platform::errors::InvalidArgument (
62- " Wrong layout set for X tensor. Expected layout is `kMKLDNN`, "
63- " But received %s." ,
64- DataLayout_error_msg[static_cast <int >(DataLayout::kMKLDNN )]));
65- PADDLE_ENFORCE_NE (
66- x->format (), MKLDNNMemoryFormat::undef,
67- platform::errors::InvalidArgument (" Wrong format set for X tensor" ));
68-
69- auto src_tz = paddle::framework::vectorize (x->dims ());
70-
71- // Flags are added by bitwise OR operation
72- auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
73- if (global_stats)
74- flags |= mkldnn::normalization_flags::use_global_stats; // 010
75- if (fuse_with_relu && test_mode)
76- flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
77-
78- auto md = mkldnn::memory::desc (
79- src_tz, platform::MKLDNNGetDataType<T>(),
80- platform::MKLDNNFormatForSize (src_tz.size (), x->format ()));
81-
82- this ->AcquireForwardPrimitiveDescriptor (
83- global_stats == true ? mkldnn::prop_kind::forward_scoring
84- : mkldnn::prop_kind::forward_training,
85- md, epsilon, flags);
86- }
43+ const mkldnn::engine mkldnn_engine, const Tensor *x,
44+ const bool global_stats, const bool test_mode)
45+ : platform::MKLDNNHandlerNoCachingT<T,
46+ mkldnn::batch_normalization_forward,
47+ mkldnn::batch_normalization_backward>(
48+ mkldnn_engine, ctx.GetPlace()) {
49+ const float epsilon = ctx.Attr <float >(" epsilon" );
50+ const bool fuse_with_relu = ctx.Attr <bool >(" fuse_with_relu" );
51+
52+ std::vector<std::string> DataLayout_error_msg = {" kNHWC" , " kNCHW" ,
53+ " kAnyLayout" , " kMKLDNN" };
54+ PADDLE_ENFORCE_EQ (
55+ x->layout (), DataLayout::kMKLDNN ,
56+ platform::errors::InvalidArgument (
57+ " Wrong layout set for X tensor. Expected layout is `kMKLDNN`, "
58+ " But received %s." ,
59+ DataLayout_error_msg[static_cast <int >(DataLayout::kMKLDNN )]));
60+ PADDLE_ENFORCE_NE (
61+ x->format (), MKLDNNMemoryFormat::undef,
62+ platform::errors::InvalidArgument (" Wrong format set for X tensor" ));
63+
64+ auto src_tz = paddle::framework::vectorize (x->dims ());
65+
66+ // Flags are added by bitwise OR operation
67+ auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
68+ if (global_stats)
69+ flags |= mkldnn::normalization_flags::use_global_stats; // 010
70+ if (fuse_with_relu && test_mode)
71+ flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
72+
73+ auto md = mkldnn::memory::desc (
74+ src_tz, platform::MKLDNNGetDataType<T>(),
75+ platform::MKLDNNFormatForSize (src_tz.size (), x->format ()));
76+
77+ this ->AcquireForwardPrimitiveDescriptor (
78+ global_stats == true ? mkldnn::prop_kind::forward_scoring
79+ : mkldnn::prop_kind::forward_training,
80+ md, epsilon, flags);
8781 }
8882
8983 BatchNormMKLDNNHandler (const paddle::framework::ExecutionContext &ctx,
90- const platform::MKLDNNDeviceContext &dev_ctx,
91- platform::Place cpu_place, const Tensor *in_x,
92- const Tensor *scale, const Tensor *out_grad,
93- const std::string &unique_name)
94- : platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
95- mkldnn::batch_normalization_backward>(
96- dev_ctx, dev_ctx.GetEngine(), cpu_place,
97- platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims ()),
98- unique_name)) {
99- if (!this ->isBwdCached ()) {
100- PADDLE_ENFORCE_EQ (out_grad->layout (), DataLayout::kMKLDNN ,
101- platform::errors::InvalidArgument (
102- " Wrong layout set for Input out_grad tensor" ));
103- PADDLE_ENFORCE_NE (out_grad->format (), MKLDNNMemoryFormat::undef,
104- platform::errors::InvalidArgument (
105- " Wrong format set for Input out_grad tensor" ));
106-
107- auto src_tz = paddle::framework::vectorize<int64_t >(in_x->dims ());
108- auto scale_tz = paddle::framework::vectorize<int64_t >(scale->dims ());
109- PADDLE_ENFORCE_EQ (
110- scale_tz.size (), 1 ,
111- platform::errors::InvalidArgument (
112- " Dims of scale tensor must be 1, but received scale's size is %d" ,
113- scale_tz.size ()));
114-
115- MKLDNNMemoryFormat diff_fmt =
116- platform::MKLDNNFormatForSize (src_tz.size (), out_grad->format ());
117-
118- MKLDNNMemoryFormat src_fmt =
119- platform::MKLDNNFormatForSize (src_tz.size (), in_x->format ());
120-
121- auto dims = framework::vectorize (in_x->dims ());
122- auto diff_dst_md = mkldnn::memory::desc (
123- dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
124- auto src_md =
125- mkldnn::memory::desc (dims, platform::MKLDNNGetDataType<T>(), src_fmt);
126-
127- const float epsilon = ctx.Attr <float >(" epsilon" );
128-
129- this ->AcquireForwardPrimitiveDescriptor (
130- mkldnn::prop_kind::forward_training, src_md, epsilon,
131- mkldnn::normalization_flags::use_scale_shift);
132- this ->AcquireBackwardPrimitiveDescriptor (
133- mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon,
134- mkldnn::normalization_flags::use_scale_shift);
135- }
84+ const mkldnn::engine mkldnn_engine, const Tensor *in_x,
85+ const Tensor *scale, const Tensor *out_grad)
86+ : platform::MKLDNNHandlerNoCachingT<T,
87+ mkldnn::batch_normalization_forward,
88+ mkldnn::batch_normalization_backward>(
89+ mkldnn_engine, ctx.GetPlace()) {
90+ PADDLE_ENFORCE_EQ (out_grad->layout (), DataLayout::kMKLDNN ,
91+ platform::errors::InvalidArgument (
92+ " Wrong layout set for Input out_grad tensor" ));
93+ PADDLE_ENFORCE_NE (out_grad->format (), MKLDNNMemoryFormat::undef,
94+ platform::errors::InvalidArgument (
95+ " Wrong format set for Input out_grad tensor" ));
96+
97+ auto src_tz = paddle::framework::vectorize<int64_t >(in_x->dims ());
98+ auto scale_tz = paddle::framework::vectorize<int64_t >(scale->dims ());
99+ PADDLE_ENFORCE_EQ (
100+ scale_tz.size (), 1 ,
101+ platform::errors::InvalidArgument (
102+ " Dims of scale tensor must be 1, but received scale's size is %d" ,
103+ scale_tz.size ()));
104+
105+ MKLDNNMemoryFormat diff_fmt =
106+ platform::MKLDNNFormatForSize (src_tz.size (), out_grad->format ());
107+
108+ MKLDNNMemoryFormat src_fmt =
109+ platform::MKLDNNFormatForSize (src_tz.size (), in_x->format ());
110+
111+ auto dims = framework::vectorize (in_x->dims ());
112+ auto diff_dst_md =
113+ mkldnn::memory::desc (dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
114+ auto src_md =
115+ mkldnn::memory::desc (dims, platform::MKLDNNGetDataType<T>(), src_fmt);
116+
117+ const float epsilon = ctx.Attr <float >(" epsilon" );
118+
119+ this ->AcquireForwardPrimitiveDescriptor (
120+ mkldnn::prop_kind::forward_training, src_md, epsilon,
121+ mkldnn::normalization_flags::use_scale_shift);
122+ this ->AcquireBackwardPrimitiveDescriptor (
123+ mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon,
124+ mkldnn::normalization_flags::use_scale_shift);
136125 }
137126
138127 std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory (const Tensor *scale,
139- const Tensor *shift,
140- const bool is_test) {
141- auto scaleshift_memory = this ->AcquireMemory (" @scaleshift_mem_p" );
142- if (scaleshift_memory == nullptr || !is_test) {
143- auto scale_tz = paddle::framework::vectorize (scale->dims ());
144- const unsigned int C = scale_tz[0 ];
145- PADDLE_ENFORCE_EQ (
146- scale_tz.size (), 1 ,
147- platform::errors::InvalidArgument (
148- " Dims of scale tensor must be 1, but received scale's size is %d" ,
149- scale_tz.size ()));
150-
151- auto mem_p = this ->AcquireMemoryFromPrimitive (
152- this ->fwd_pd_ ->weights_desc (), " @scaleshift_mem_p" );
153-
154- // MKLDNN requires a single piece of memory for scale and shift/bias data
155- auto mem_data_handle = reinterpret_cast <T *>(mem_p->get_data_handle ());
156- std::copy (scale->data <T>(), scale->data <T>() + C, mem_data_handle);
157- std::copy (shift->data <T>(), shift->data <T>() + C, mem_data_handle + C);
158-
159- return mem_p;
160- }
128+ const Tensor *shift) {
129+ auto scale_tz = paddle::framework::vectorize (scale->dims ());
130+ const unsigned int C = scale_tz[0 ];
131+ PADDLE_ENFORCE_EQ (
132+ scale_tz.size (), 1 ,
133+ platform::errors::InvalidArgument (
134+ " Dims of scale tensor must be 1, but received scale's size is %d" ,
135+ scale_tz.size ()));
136+
137+ auto scaleshift_memory =
138+ this ->AcquireMemoryFromPrimitive (this ->fwd_pd_ ->weights_desc ());
139+
140+ // MKLDNN requires a single piece of memory for scale and shift/bias data
141+ auto mem_data_handle =
142+ reinterpret_cast <T *>(scaleshift_memory->get_data_handle ());
143+ std::copy (scale->data <T>(), scale->data <T>() + C, mem_data_handle);
144+ std::copy (shift->data <T>(), shift->data <T>() + C, mem_data_handle + C);
161145 return scaleshift_memory;
162146 }
163147
164148 std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory (
165149 T *diff_scaleshift_data) {
166150 return this ->AcquireMemoryFromPrimitive (this ->bwd_pd_ ->diff_weights_desc (),
167- diff_scaleshift_data,
168- " @diff_scaleshift_mem_p" );
151+ diff_scaleshift_data);
169152 }
170153
171154 std::shared_ptr<mkldnn::memory> AcquireMeanMemory (
172155 const framework::Tensor *mean) {
173156 const T *mean_data = mean->data <T>();
174- return this ->AcquireMemoryFromPrimitive (
175- this -> fwd_pd_ -> mean_desc (), to_void_cast<T>(mean_data), " @mean_mem_p " );
157+ return this ->AcquireMemoryFromPrimitive (this -> fwd_pd_ -> mean_desc (),
158+ to_void_cast<T>(mean_data));
176159 }
177160
178161 std::shared_ptr<mkldnn::memory> AcquireMeanMemory (framework::Tensor *mean) {
179162 T *mean_data = mean->mutable_data <T>(this ->place_ ,
180163 this ->fwd_pd_ ->mean_desc ().get_size ());
181164 return this ->AcquireMemoryFromPrimitive (this ->fwd_pd_ ->mean_desc (),
182- mean_data, " @mean_mem_p " );
165+ mean_data);
183166 }
184167
185168 std::shared_ptr<mkldnn::memory> AcquireVarianceMemory (
186169 const framework::Tensor *variance) {
187170 const T *variance_data = variance->data <T>();
188171 return this ->AcquireMemoryFromPrimitive (this ->fwd_pd_ ->variance_desc (),
189- to_void_cast<T>(variance_data),
190- " @variance_mem_p" );
172+ to_void_cast<T>(variance_data));
191173 }
192174
193175 std::shared_ptr<mkldnn::memory> AcquireVarianceMemory (
194176 framework::Tensor *variance) {
195177 T *variance_data = variance->mutable_data <T>(
196178 this ->place_ , this ->fwd_pd_ ->variance_desc ().get_size ());
197179 return this ->AcquireMemoryFromPrimitive (this ->fwd_pd_ ->variance_desc (),
198- variance_data, " @variance_mem_p " );
180+ variance_data);
199181 }
200182};
201183
@@ -220,13 +202,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
220202 auto *batch_mean = ctx.Output <Tensor>(" SavedMean" );
221203 auto *batch_variance = ctx.Output <Tensor>(" SavedVariance" );
222204
223- BatchNormMKLDNNHandler<T> handler (ctx, dev_ctx, mkldnn_engine,
224- ctx.GetPlace (), x, global_stats,
225- test_mode, ctx.OutputName (" SavedMean" ));
205+ BatchNormMKLDNNHandler<T> handler (ctx, mkldnn_engine, x, global_stats,
206+ test_mode);
226207
227208 auto src_memory = handler.AcquireSrcMemory (x);
228- auto scaleshift_memory =
229- handler.AcquireScaleShiftMemory (scale, shift, is_test);
209+ auto scaleshift_memory = handler.AcquireScaleShiftMemory (scale, shift);
230210 auto dst_memory = handler.AcquireDstMemory (y);
231211
232212 auto batch_norm_p = handler.AcquireForwardPrimitive ();
@@ -303,8 +283,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
303283 auto *diff_scale = ctx.Output <Tensor>(framework::GradVarName (" Scale" ));
304284 auto *diff_shift = ctx.Output <Tensor>(framework::GradVarName (" Bias" ));
305285
306- BatchNormMKLDNNHandler<T> handler (ctx, dev_ctx, ctx.GetPlace (), x, scale,
307- diff_y, ctx.InputName (" SavedMean" ));
286+ BatchNormMKLDNNHandler<T> handler (ctx, mkldnn_engine, x, scale, diff_y);
308287
309288 // MKLDNN requires a single piece of memory for scale and shift/bias data
310289 const unsigned int C = paddle::framework::vectorize (scale->dims ())[0 ];
@@ -316,8 +295,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
316295 auto mean_memory = handler.AcquireMeanMemory (batch_mean);
317296 auto variance_memory = handler.AcquireVarianceMemory (batch_variance);
318297 auto diff_dst_memory = handler.AcquireDiffDstMemory (diff_y);
319- auto scaleshift_memory =
320- handler.AcquireScaleShiftMemory (scale, shift, false );
298+ auto scaleshift_memory = handler.AcquireScaleShiftMemory (scale, shift);
321299 auto diff_src_memory = handler.AcquireDiffSrcMemory (diff_x);
322300 auto diff_scaleshift_memory =
323301 handler.AcquireDiffScaleShiftMemory (diff_scaleshift_data.data ());
0 commit comments