@@ -21,94 +21,86 @@ using paddle::framework::Tensor;
2121using paddle::platform::MKLDNNDeviceContext;
2222
2323template <typename T>
24- class LRNMKLDNNHandler : public platform ::MKLDNNHandlerT<T, mkldnn::lrn_forward,
25- mkldnn::lrn_backward> {
24+ class LRNMKLDNNHandler
25+ : public platform::MKLDNNHandlerNoCachingT<T, mkldnn::lrn_forward,
26+ mkldnn::lrn_backward> {
2627 public:
2728 LRNMKLDNNHandler (const framework::ExecutionContext& ctx,
28- const MKLDNNDeviceContext& dev_ctx,
2929 const mkldnn::engine mkldnn_engine,
30- platform::Place cpu_place, const Tensor* input,
31- const std::string& unique_name)
32-
33- : platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
34- dev_ctx, mkldnn_engine, cpu_place,
35- platform::CreateKey (dev_ctx, framework::vectorize(input->dims ()),
36- unique_name)) {
37- if (!this ->isCached ()) {
38- const int n = ctx.Attr <int >(" n" );
39- // MKL-DNN implements LRN in a caffe way:
40- // http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
41- // Where sum of squares is divided by size of normalization window
42- // this is not the case for PaddlePaddle LRN.
43- // Hence we need to compensate for this diffrence by
44- // multipliing alpha by size of window(n)
45- const float alpha = ctx.Attr <float >(" alpha" ) * static_cast <float >(n);
46- const float beta = ctx.Attr <float >(" beta" );
47- const float k = ctx.Attr <float >(" k" );
48- bool is_test = ctx.Attr <bool >(" is_test" );
49-
50- auto dims = framework::vectorize (input->dims ());
51-
52- auto src_md = mkldnn::memory::desc (dims, platform::MKLDNNGetDataType<T>(),
53- input->format ());
54-
55- this ->AcquireForwardPrimitiveDescriptor (
56- is_test ? mkldnn::prop_kind::forward_inference
57- : mkldnn::prop_kind::forward_training,
58- mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
59- }
30+ platform::Place cpu_place, const Tensor* input)
31+
32+ : platform::MKLDNNHandlerNoCachingT<T, mkldnn::lrn_forward,
33+ mkldnn::lrn_backward>(mkldnn_engine,
34+ cpu_place) {
35+ const int n = ctx.Attr <int >(" n" );
36+ // MKL-DNN implements LRN in a caffe way:
37+ // http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
38+ // Where sum of squares is divided by size of normalization window
39+ // this is not the case for PaddlePaddle LRN.
40+ // Hence we need to compensate for this diffrence by
41+ // multipliing alpha by size of window(n)
42+ const float alpha = ctx.Attr <float >(" alpha" ) * static_cast <float >(n);
43+ const float beta = ctx.Attr <float >(" beta" );
44+ const float k = ctx.Attr <float >(" k" );
45+ bool is_test = ctx.Attr <bool >(" is_test" );
46+
47+ auto dims = framework::vectorize (input->dims ());
48+
49+ auto src_md = mkldnn::memory::desc (dims, platform::MKLDNNGetDataType<T>(),
50+ input->format ());
51+
52+ this ->AcquireForwardPrimitiveDescriptor (
53+ is_test ? mkldnn::prop_kind::forward_inference
54+ : mkldnn::prop_kind::forward_training,
55+ mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
6056 }
6157
6258 LRNMKLDNNHandler (const framework::ExecutionContext& ctx,
63- const MKLDNNDeviceContext& dev_ctx ,
59+ const mkldnn::engine mkldnn_engine ,
6460 platform::Place cpu_place, const Tensor* in_x,
65- const Tensor* out_grad, Tensor* in_x_grad,
66- const std::string& unique_name)
67- : platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
68- dev_ctx, dev_ctx.GetEngine(), cpu_place,
69- platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims ()),
70- unique_name)) {
71- if (!this ->isBwdCached ()) {
72- PADDLE_ENFORCE_EQ (
73- ctx.Attr <bool >(" is_test" ), false ,
74- platform::errors::PreconditionNotMet (
75- " is_test attribute should be set to False in training phase." ));
76-
77- const int n = ctx.Attr <int >(" n" );
78- const float alpha = ctx.Attr <float >(" alpha" ) * static_cast <float >(n);
79- const float beta = ctx.Attr <float >(" beta" );
80- const float k = ctx.Attr <float >(" k" );
81-
82- auto dims = framework::vectorize<int64_t >(in_x->dims ());
83-
84- auto src_md = mkldnn::memory::desc (dims, platform::MKLDNNGetDataType<T>(),
85- in_x->format ());
86- auto diff_md = mkldnn::memory::desc (
87- dims, platform::MKLDNNGetDataType<T>(), out_grad->format ());
88-
89- this ->AcquireForwardPrimitiveDescriptor (
90- mkldnn::prop_kind::forward_training,
91- mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
92-
93- this ->AcquireBackwardPrimitiveDescriptor (
94- mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha,
95- beta, k);
96- }
61+ const Tensor* out_grad, Tensor* in_x_grad)
62+ : platform::MKLDNNHandlerNoCachingT<T, mkldnn::lrn_forward,
63+ mkldnn::lrn_backward>(mkldnn_engine,
64+ cpu_place) {
65+ PADDLE_ENFORCE_EQ (
66+ ctx.Attr <bool >(" is_test" ), false ,
67+ platform::errors::PreconditionNotMet (
68+ " is_test attribute should be set to False in training phase." ));
69+
70+ const int n = ctx.Attr <int >(" n" );
71+ const float alpha = ctx.Attr <float >(" alpha" ) * static_cast <float >(n);
72+ const float beta = ctx.Attr <float >(" beta" );
73+ const float k = ctx.Attr <float >(" k" );
74+
75+ auto dims = framework::vectorize<int64_t >(in_x->dims ());
76+
77+ auto src_md = mkldnn::memory::desc (dims, platform::MKLDNNGetDataType<T>(),
78+ in_x->format ());
79+ auto diff_md = mkldnn::memory::desc (dims, platform::MKLDNNGetDataType<T>(),
80+ out_grad->format ());
81+
82+ this ->AcquireForwardPrimitiveDescriptor (
83+ mkldnn::prop_kind::forward_training,
84+ mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
85+
86+ this ->AcquireBackwardPrimitiveDescriptor (
87+ mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha, beta,
88+ k);
9789 }
9890
9991 std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory (Tensor* workspace) {
10092 T* ptr = workspace->mutable_data <T>(
10193 this ->place_ , this ->fwd_pd_ ->workspace_desc ().get_size ());
10294 return this ->AcquireMemoryFromPrimitive (this ->fwd_pd_ ->workspace_desc (),
103- ptr, " @wrk_mem_p " );
95+ ptr);
10496 }
10597
10698 std::shared_ptr<mkldnn::memory> AcquireBackwardWorkspaceMemory (
10799 const Tensor* workspace) {
108100 const T* workspace_data = workspace->data <T>();
109101 return this ->AcquireMemoryFromPrimitive (
110102 this ->fwd_pd_ ->workspace_desc (),
111- platform::to_void_cast<T>(workspace_data), " @bwd-wrk_mem_p " );
103+ platform::to_void_cast<T>(workspace_data));
112104 }
113105};
114106
@@ -131,8 +123,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
131123 auto out = ctx.Output <Tensor>(" Out" );
132124 auto mid = ctx.Output <Tensor>(" MidOut" );
133125
134- LRNMKLDNNHandler<T> handler (ctx, dev_ctx, mkldnn_engine, ctx.GetPlace (), x,
135- ctx.OutputName (" Out" ));
126+ LRNMKLDNNHandler<T> handler (ctx, mkldnn_engine, ctx.GetPlace (), x);
136127
137128 auto src_memory = handler.AcquireSrcMemory (x);
138129 auto dst_memory = handler.AcquireDstMemory (out);
@@ -178,9 +169,10 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
178169 auto in_x_grad = ctx.Output <Tensor>(framework::GradVarName (" X" ));
179170
180171 auto & dev_ctx = ctx.template device_context <MKLDNNDeviceContext>();
172+ const auto & mkldnn_engine = dev_ctx.GetEngine ();
181173
182- LRNMKLDNNHandler<T> handler (ctx, dev_ctx , ctx.GetPlace (), in_x, out_grad ,
183- in_x_grad, ctx. InputName ( " Out " ) );
174+ LRNMKLDNNHandler<T> handler (ctx, mkldnn_engine , ctx.GetPlace (), in_x,
175+ out_grad, in_x_grad );
184176
185177 auto src_memory = handler.AcquireSrcMemory (in_x);
186178 auto workspace = handler.AcquireBackwardWorkspaceMemory (mid);
0 commit comments