@@ -32,56 +32,69 @@ using platform::to_void_cast;
3232
3333template <typename T>
3434class SoftmaxMKLDNNHandler
35- : public platform::MKLDNNHandlerNoCachingT <T, mkldnn::softmax_forward,
36- mkldnn::softmax_backward> {
35+ : public platform::MKLDNNHandlerT <T, mkldnn::softmax_forward,
36+ mkldnn::softmax_backward> {
3737 public:
38- SoftmaxMKLDNNHandler (const mkldnn::engine mkldnn_engine,
38+ SoftmaxMKLDNNHandler (const MKLDNNDeviceContext& dev_ctx,
39+ const mkldnn::engine mkldnn_engine,
3940 platform::Place cpu_place, const Tensor* input,
40- Tensor* output, const int axis)
41- : platform::MKLDNNHandlerNoCachingT<T, mkldnn::softmax_forward,
42- mkldnn::softmax_backward>(
43- mkldnn_engine, cpu_place) {
44- PADDLE_ENFORCE_EQ (
45- input->dims (), output->dims (),
46- platform::errors::InvalidArgument (
47- " The shape of input and output tensor must be identical." ));
48-
49- auto softmax_tz = framework::vectorize (input->dims ());
50- auto md = memory::desc (softmax_tz, platform::MKLDNNGetDataType<T>(),
51- input->format ());
52-
53- this ->AcquireForwardPrimitiveDescriptor (prop_kind::forward_scoring, md,
54- axis);
41+ Tensor* output, const int axis,
42+ const std::string uniq_name, bool is_inplaced)
43+ : platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
44+ mkldnn::softmax_backward>(
45+ dev_ctx, mkldnn_engine, cpu_place,
46+ // Softmax may be inplace then uniq_name is no longer unique
47+ is_inplaced ? platform::CreateKey(
48+ dev_ctx, framework::vectorize(input->dims ()),
49+ axis, uniq_name)
50+ : platform::CreateKey(
51+ dev_ctx, framework::vectorize(input->dims ()),
52+ uniq_name)) {
53+ if (!this ->isCached ()) {
54+ PADDLE_ENFORCE_EQ (
55+ input->dims (), output->dims (),
56+ platform::errors::InvalidArgument (
57+ " The shape of input and output tensor must be identical." ));
58+
59+ auto softmax_tz = framework::vectorize (input->dims ());
60+ auto md = memory::desc (softmax_tz, platform::MKLDNNGetDataType<T>(),
61+ input->format ());
62+
63+ this ->AcquireForwardPrimitiveDescriptor (prop_kind::forward_scoring, md,
64+ axis);
65+ }
5566 }
5667
5768 SoftmaxMKLDNNHandler (const framework::ExecutionContext& ctx,
58- const mkldnn::engine mkldnn_engine ,
69+ const MKLDNNDeviceContext& dev_ctx ,
5970 platform::Place cpu_place, const Tensor* out,
6071 const Tensor* out_grad, Tensor* in_x_grad,
6172 const std::string& unique_name)
62- : platform::MKLDNNHandlerNoCachingT<T, mkldnn::softmax_forward,
63- mkldnn::softmax_backward>(
64- mkldnn_engine, cpu_place) {
65- PADDLE_ENFORCE_EQ (out_grad->dims (), in_x_grad->dims (),
66- platform::errors::InvalidArgument (
67- " The shape of softmax_grad's input "
68- " and output must be identical, but shapes differ, "
69- " out_grad: %s in_grad: %s" ,
70- out_grad->dims (), in_x_grad->dims ()));
71-
72- auto dims = out_grad->dims (); // input and output share the same shape
73- const int axis = CanonicalAxis (ctx.Attr <int >(" axis" ), dims.size ());
74- auto softmax_tz = framework::vectorize<int64_t >(dims);
75-
76- auto data_softmax_md = MKLDNNMemDesc (
77- softmax_tz, platform::MKLDNNGetDataType<T>(), out->format ());
78- auto diff_softmax_md = MKLDNNMemDesc (
79- softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format ());
80-
81- this ->AcquireForwardPrimitiveDescriptor (prop_kind::forward_scoring,
82- data_softmax_md, axis);
83- this ->AcquireBackwardPrimitiveDescriptor (diff_softmax_md, data_softmax_md,
84- axis);
73+ : platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
74+ mkldnn::softmax_backward>(
75+ dev_ctx, dev_ctx.GetEngine(), cpu_place,
76+ platform::CreateKey(dev_ctx, framework::vectorize(out->dims ()),
77+ unique_name)) {
78+ if (!this ->isBwdCached ()) {
79+ PADDLE_ENFORCE_EQ (
80+ out_grad->dims (), in_x_grad->dims (),
81+ platform::errors::InvalidArgument (" The shape of softmax_grad's input "
82+ " and output must be identical." ));
83+
84+ auto dims = out_grad->dims (); // input and output share the same shape
85+ const int axis = CanonicalAxis (ctx.Attr <int >(" axis" ), dims.size ());
86+ auto softmax_tz = framework::vectorize<int64_t >(dims);
87+
88+ auto data_softmax_md = MKLDNNMemDesc (
89+ softmax_tz, platform::MKLDNNGetDataType<T>(), out->format ());
90+ auto diff_softmax_md = MKLDNNMemDesc (
91+ softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format ());
92+
93+ this ->AcquireForwardPrimitiveDescriptor (prop_kind::forward_scoring,
94+ data_softmax_md, axis);
95+ this ->AcquireBackwardPrimitiveDescriptor (diff_softmax_md, data_softmax_md,
96+ axis);
97+ }
8598 }
8699};
87100
@@ -98,8 +111,9 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
98111
99112 const int axis = CanonicalAxis (ctx.Attr <int >(" axis" ), input->dims ().size ());
100113
101- SoftmaxMKLDNNHandler<T> handler (mkldnn_engine, ctx.GetPlace (), input,
102- output, axis);
114+ SoftmaxMKLDNNHandler<T> handler (dev_ctx, mkldnn_engine, ctx.GetPlace (),
115+ input, output, axis, ctx.OutputName (" Out" ),
116+ is_inplaced);
103117
104118 auto softmax_src_memory_p = handler.AcquireSrcMemory (input);
105119 // For Inplace src and and dst are the same memory object
@@ -135,12 +149,11 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
135149 paddle::platform::errors::PreconditionNotMet (
136150 " Operator DNNL SoftmaxGrad must use CPUPlace" ));
137151 auto & dev_ctx = ctx.template device_context <MKLDNNDeviceContext>();
138- const auto & mkldnn_engine = dev_ctx.GetEngine ();
139152 const Tensor* output = ctx.Input <Tensor>(" Out" );
140153 auto * out_grad = ctx.template Input <Tensor>(framework::GradVarName (" Out" ));
141154 auto * in_x_grad = ctx.template Output <Tensor>(framework::GradVarName (" X" ));
142155
143- SoftmaxMKLDNNHandler<T> handler (ctx, mkldnn_engine , ctx.GetPlace (), output,
156+ SoftmaxMKLDNNHandler<T> handler (ctx, dev_ctx , ctx.GetPlace (), output,
144157 out_grad, in_x_grad, ctx.InputName (" Out" ));
145158
146159 auto dst_memory_p = handler.AcquireDstMemory (output);
0 commit comments