Skip to content

Commit c63f6b2

Browse files
committed
- MKL-DNN pooling updated to set_prim_desc
- MKLDNN ops revisited - disabled softmax modifications - disabled elementwise_add - reverted LRN modifications - reverted SUM primitive - Partial reviing of softmax - Enable softmax - Softmax changes - LRN is back - LRN partially disabled - LRN is back - LRN fix - compilation fixes - Sum fixed(hopefully) - Enabling (partially) elementwise_add - Fixes to elemenwise_add - Lint fixes quantize fix - compilation fix test=develop Disabling pooling - Disabled quantize op test=develop
1 parent a4b4ecd commit c63f6b2

9 files changed

Lines changed: 48 additions & 86 deletions

File tree

paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
7777
} else {
7878
functor.RunMidWise(n, pre, post);
7979
}
80-
z->set_layout(DataLayout::kMKLDNN);
81-
z->set_format(x->format());
80+
z->set_mkldnn_prim_desc(x->get_mkldnn_prim_desc());
8281
} else {
8382
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
8483
x->format() != memory::format::format_undef,
@@ -116,7 +115,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
116115
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd);
117116

118117
// create mkldnn memory for dst
119-
memory dst_memory = memory(sum_pd.dst_primitive_desc(), z_data);
118+
auto dst_mem_pd = sum_pd.dst_primitive_desc();
119+
memory dst_memory = memory(dst_mem_pd, z_data);
120120

121121
std::vector<primitive::at> inputs;
122122
inputs.push_back(srcs[0]);
@@ -129,9 +129,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
129129
pipeline.push_back(sum_prim);
130130
stream(stream::kind::eager).submit(pipeline).wait();
131131

132-
z->set_layout(DataLayout::kMKLDNN);
133-
z->set_format(
134-
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
132+
z->set_mkldnn_prim_desc(dst_mem_pd);
135133
}
136134
}
137135
};
@@ -152,24 +150,19 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
152150
auto* out = dout;
153151
auto *x = dout, *y = dout;
154152

155-
auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
156-
in->set_layout(DataLayout::kMKLDNN);
157-
in->set_format(out->format());
158-
};
159-
160153
if (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()) {
161154
if (dx->dims() == dy->dims()) {
162155
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
163156
if (dx) {
164157
blas.VCOPY(dout->numel(), dout->data<T>(),
165158
dx->mutable_data<T>(ctx.GetPlace()));
166-
set_mkldnn_format(dx, dout);
159+
dx->set_mkldnn_prim_desc(dout->get_mkldnn_prim_desc());
167160
}
168161

169162
if (dy) {
170163
blas.VCOPY(dout->numel(), dout->data<T>(),
171164
dy->mutable_data<T>(ctx.GetPlace()));
172-
set_mkldnn_format(dy, dout);
165+
dy->set_mkldnn_prim_desc(dout->get_mkldnn_prim_desc());
173166
}
174167
}
175168
} else {

paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
9696

9797
std::vector<int> src_tz = framework::vectorize2int(x->dims());
9898

99-
auto src_format =
100-
src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
99+
auto src_format = x->format();
101100

102101
const std::string key = gethash(src_tz, algorithm);
103102
const std::string key_src_data =
@@ -127,10 +126,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
127126

128127
if (p_fwd == nullptr) {
129128
// create mkldnn memory for input X
130-
auto src_md = platform::MKLDNNMemDesc(
131-
src_tz, platform::MKLDNNGetDataType<T>(), src_format);
132129
auto src_memory = std::shared_ptr<memory>(
133-
new memory({src_md, mkldnn_engine}, to_void_cast(x_data)));
130+
new memory(x->get_mkldnn_prim_desc(), to_void_cast(x_data)));
134131
// save src_memory to be referred in backward path
135132
dev_ctx.SetBlob(key_src_mem, src_memory);
136133

@@ -177,8 +174,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
177174
pipeline.push_back(*p_fwd);
178175
stream(stream::kind::eager).submit(pipeline).wait();
179176

180-
y->set_layout(DataLayout::kMKLDNN);
181-
y->set_format(GetMKLDNNFormat(*dst_memory));
177+
y->set_mkldnn_prim_desc(dst_memory->get_primitive_desc());
182178
}
183179

184180
template <typename T>
@@ -196,9 +192,6 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
196192

197193
std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims());
198194

199-
auto diff_y_format =
200-
diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format();
201-
202195
const std::string key = gethash(diff_dst_tz, algorithm);
203196
const std::string key_src_data =
204197
key + ctx.op().Input("Out") + "@eltwise_fwd_src_data";
@@ -210,8 +203,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
210203
key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem";
211204
const std::string key_fwd_pd =
212205
key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd";
213-
const std::string key_with_layouts =
214-
key + std::to_string(*p_src_layout) + "-" + std::to_string(diff_y_format);
206+
const std::string key_with_layouts = key + std::to_string(*p_src_layout) +
207+
"-" + std::to_string(diff_y->format());
215208
const std::string key_diff_src_mem =
216209
key_with_layouts + "@eltwise_diff_src_mem";
217210
const std::string key_diff_dst_mem =
@@ -234,10 +227,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
234227

235228
if (p_grad == nullptr) {
236229
// create mkldnn memory for input diff_y
237-
auto diff_dst_md = platform::MKLDNNMemDesc(
238-
diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);
239230
auto diff_dst_memory = std::shared_ptr<memory>(
240-
new memory({diff_dst_md, mkldnn_engine}, to_void_cast(diff_y_data)));
231+
new memory(diff_y->get_mkldnn_prim_desc(), to_void_cast(diff_y_data)));
241232
dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory);
242233

243234
// retrieve eltwise primitive desc from device context
@@ -281,8 +272,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
281272
pipeline.push_back(*p_grad);
282273
stream(stream::kind::eager).submit(pipeline).wait();
283274

284-
diff_x->set_layout(DataLayout::kMKLDNN);
285-
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory));
275+
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
286276
}
287277

288278
template <typename T, mkldnn::algorithm algorithm>

paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -206,17 +206,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
206206
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
207207

208208
// create mkldnn memory from input x tensor
209-
mkldnn::memory::format input_format =
210-
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
211209

212210
// keys for backward pass
213211
const std::string key = BatchNormMKLDNNHandler::GetHash(
214-
src_tz, epsilon, flags, global_stats, input_format,
212+
src_tz, epsilon, flags, global_stats, x->format(),
215213
ctx.op().Output("SavedMean"));
216214
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
217215

218-
auto user_src_md = platform::MKLDNNMemDesc(
219-
{src_tz}, platform::MKLDNNGetDataType<T>(), input_format);
216+
auto user_src_md = x->get_mkldnn_prim_desc().desc();
220217

221218
// create primitive descriptor for batch norm forward
222219
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
@@ -230,8 +227,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
230227
BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine,
231228
key);
232229

233-
auto src_memory =
234-
handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data));
230+
auto src_memory = handler.AcquireSrcMemory(x->get_mkldnn_prim_desc(),
231+
to_void_cast(x_data));
235232

236233
// crate mkldnn memory for weights(scale/shift)
237234
auto scaleshift_memory =
@@ -265,8 +262,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
265262
variance_memory, false);
266263
}
267264

268-
y->set_layout(DataLayout::kMKLDNN);
269-
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
265+
y->set_mkldnn_prim_desc(dst_memory->get_primitive_desc());
270266

271267
std::vector<mkldnn::primitive> pipeline;
272268
pipeline.push_back(*batch_norm_p);
@@ -336,24 +332,21 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
336332

337333
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
338334

339-
mkldnn::memory::format dst_format =
340-
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
341-
342335
mkldnn::memory::format input_format =
343336
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
344337

345338
unsigned flags = mkldnn::use_scale_shift;
346339

347340
// keys from forward pass
348341
const std::string key = BatchNormMKLDNNHandler::GetHash(
349-
src_tz, epsilon, flags, false, input_format,
342+
src_tz, epsilon, flags, false, x->format(),
350343
ctx.op().Input("SavedMean"));
351344
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
352345

353346
// keys for primitives reuse
354347
const std::string key_with_hash =
355348
key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false,
356-
input_format);
349+
x->format());
357350
const std::string key_batch_norm_bwd_p =
358351
key_with_hash + "@batch_norm_bwd_p";
359352
const std::string key_batch_norm_src_mem_p =
@@ -373,9 +366,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
373366

374367
primitive reorder_diff_dst;
375368
bool is_diff_dst_reordered = false;
376-
auto user_diff_dst_memory = memory(
377-
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
378-
to_void_cast(diff_y_data));
369+
auto user_diff_dst_memory =
370+
memory(diff_y->get_mkldnn_prim_desc(), to_void_cast(diff_y_data));
379371

380372
// MKLDNN requires a single piece of memory for scale and shift/bias data
381373
const size_t scaleshift_size = 2 * ic;
@@ -459,10 +451,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
459451
dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory);
460452

461453
// set layout/format of output tensors
462-
diff_x->set_layout(DataLayout::kMKLDNN);
463-
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
464-
.desc()
465-
.data.format);
454+
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
466455
} else {
467456
// primitives already exist
468457
UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data));
@@ -487,10 +476,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
487476
}
488477

489478
// set layout/format of output tensors
490-
diff_x->set_layout(DataLayout::kMKLDNN);
491-
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
492-
.desc()
493-
.data.format);
479+
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
494480
}
495481

496482
// execute optional reorder and batch_norm backward primitive

paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,6 @@ static memory::primitive_desc CreateMemPrimDesc(const Tensor& input,
4747
return mem_prim_desc;
4848
}
4949

50-
static mkldnn::memory::format GetDstMemFormat(
51-
const concat::primitive_desc& concat_pd) {
52-
return (memory::format)concat_pd.dst_primitive_desc().desc().data.format;
53-
}
54-
5550
static platform::CPUPlace GetCpuPlace(
5651
const paddle::framework::ExecutionContext& ctx) {
5752
auto place = ctx.GetPlace();
@@ -139,8 +134,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
139134
auto concat = prim_creator.CreateConcatPrimitive(concat_pd, output, place);
140135
stream(stream::kind::eager).submit({concat}).wait();
141136

142-
output->set_layout(DataLayout::kMKLDNN);
143-
output->set_format(GetDstMemFormat(concat_pd));
137+
output->set_mkldnn_prim_desc(concat_pd.dst_primitive_desc());
144138
}
145139
};
146140
} // namespace operators

paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
282282
pipeline.push_back(*conv_p);
283283
stream(stream::kind::eager).submit(pipeline).wait();
284284

285-
auto dst_mpd = dst_memory_p->get_primitive_desc();
286-
output->set_mkldnn_prim_desc(dst_mpd);
285+
output->set_mkldnn_prim_desc(dst_memory_p->get_primitive_desc());
287286
}
288287
void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const {
289288
const bool is_test = ctx.Attr<bool>("is_test");
@@ -972,8 +971,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
972971

973972
pipeline.push_back(*conv_bwd_data_p);
974973

975-
input_grad->set_layout(DataLayout::kMKLDNN);
976-
input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p));
974+
input_grad->set_mkldnn_prim_desc(diff_src_memory_p->get_primitive_desc());
977975
}
978976
stream(stream::kind::eager).submit(pipeline).wait();
979977
}

paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
221221
pipeline.push_back(*conv_p);
222222
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
223223

224-
output->set_layout(DataLayout::kMKLDNN);
225-
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
224+
output->set_mkldnn_prim_desc(dst_memory_p->get_primitive_desc());
226225
}
227226

228227
private:

paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
8181
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
8282
e_mid = e_mid.constant(k);
8383

84-
auto dims = paddle::framework::vectorize2int(x->dims());
85-
86-
auto src_md = paddle::platform::MKLDNNMemDesc(
87-
dims, mkldnn::memory::data_type::f32, x->format());
84+
auto src_md = x->get_mkldnn_prim_desc().desc();
8885

8986
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
9087
mkldnn::lrn_across_channels,
@@ -94,7 +91,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
9491
beta,
9592
k};
9693

97-
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
94+
auto src_memory_pd = x->get_mkldnn_prim_desc();
9895

9996
if (!is_test) {
10097
const std::string key = ctx.op().Output("Out");
@@ -111,30 +108,28 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
111108
src_memory->set_data_handle(
112109
static_cast<void*>(const_cast<T*>(input_data)));
113110

114-
auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(),
115-
static_cast<void*>(output_data));
111+
auto dst_memory_pd = forward_pd->dst_primitive_desc();
112+
auto dst_memory =
113+
mkldnn::memory(dst_memory_pd, static_cast<void*>(output_data));
116114
auto workspace_memory = insert_to_context<mkldnn::memory>(
117115
key_workspace_memory, dev_ctx,
118116
forward_pd->workspace_primitive_desc());
119117

120118
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
121-
122-
out->set_layout(framework::DataLayout::kMKLDNN);
123-
out->set_format(platform::GetMKLDNNFormat(dst_memory));
119+
out->set_mkldnn_prim_desc(dst_memory_pd);
124120
} else {
125121
auto forward_pd =
126122
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
127123
auto src_memory = mkldnn::memory{
128124
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
129125
auto workspace_memory =
130126
mkldnn::memory{forward_pd.workspace_primitive_desc()};
127+
auto dst_memory_pd = forward_pd.dst_primitive_desc();
131128
auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(),
132129
static_cast<void*>(output_data));
133130

134131
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
135-
136-
out->set_layout(framework::DataLayout::kMKLDNN);
137-
out->set_format(platform::GetMKLDNNFormat(dst_memory));
132+
out->set_mkldnn_prim_desc(dst_memory_pd);
138133
}
139134
}
140135
};

paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
158158
auto softmax_p =
159159
handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p);
160160

161+
// We cannot use softmax_dst_memory_p to get prim desc as
162+
// it contains flattened dims (2D) while output tensor can
163+
// have 2,3,4+ dims
164+
auto output_mem_pd = paddle::platform::create_prim_desc_from_dims(
165+
paddle::framework::vectorize2int(output->dims()),
166+
mkldnn::memory::format::blocked);
167+
output->set_mkldnn_prim_desc(output_mem_pd);
168+
161169
std::vector<primitive> pipeline{
162170
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
163171
stream(stream::kind::eager).submit(pipeline).wait();

paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,12 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
106106
memory::desc(dst_tz, memory::data_type::f32, memory::format::any);
107107

108108
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd);
109-
109+
auto dst_mem_pd = sum_pd.dst_primitive_desc();
110110
std::shared_ptr<memory> dst_mem;
111111
if (in_place) {
112-
dst_mem.reset(new memory(sum_pd.dst_primitive_desc()));
112+
dst_mem.reset(new memory(dst_mem_pd));
113113
} else {
114-
dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data));
114+
dst_mem.reset(new memory(dst_mem_pd, output_data));
115115
}
116116
std::vector<mkldnn::primitive::at> inputs;
117117
for (size_t i = 0; i < srcs_mem.size(); ++i) {
@@ -136,8 +136,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
136136
if (in_place) pipeline.push_back(reorder_prim);
137137
stream(stream::kind::eager).submit(pipeline).wait();
138138

139-
output->set_layout(DataLayout::kMKLDNN);
140-
output->set_format(output_format);
139+
output->set_mkldnn_prim_desc(dst_mem_pd);
141140
} else { // Fallback to naive version
142141
// TODO(@mozga-intel) Add MKLDNN SelectedRows & LoDTensorArray support
143142
SumKernel<CPUDeviceContext, T> reference_kernel;

0 commit comments

Comments
 (0)