Skip to content

Commit 673bf71

Browse files
authored
[oneDNN] disable caching for interpolate and batch Norm (#35030)
* - disabled interpolate onednn * - compilation fix * - draft of batch_norm cache disabling * - fixes to UT
1 parent a047c13 commit 673bf71

File tree

2 files changed

+125
-155
lines changed

2 files changed

+125
-155
lines changed

paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc

Lines changed: 110 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -35,167 +35,149 @@ using paddle::platform::MKLDNNDeviceContext;
3535
using platform::to_void_cast;
3636

3737
template <typename T>
38-
class BatchNormMKLDNNHandler
39-
: public platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
40-
mkldnn::batch_normalization_backward> {
38+
class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
39+
T, mkldnn::batch_normalization_forward,
40+
mkldnn::batch_normalization_backward> {
4141
public:
4242
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
43-
const platform::MKLDNNDeviceContext &dev_ctx,
44-
const mkldnn::engine mkldnn_engine,
45-
platform::Place cpu_place, const Tensor *x,
46-
const bool global_stats, const bool test_mode,
47-
const std::string &unique_name)
48-
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
49-
mkldnn::batch_normalization_backward>(
50-
dev_ctx, dev_ctx.GetEngine(), cpu_place,
51-
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
52-
unique_name)) {
53-
if (!this->isCached()) {
54-
const float epsilon = ctx.Attr<float>("epsilon");
55-
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
56-
57-
std::vector<std::string> DataLayout_error_msg = {"kNHWC", "kNCHW",
58-
"kAnyLayout", "kMKLDNN"};
59-
PADDLE_ENFORCE_EQ(
60-
x->layout(), DataLayout::kMKLDNN,
61-
platform::errors::InvalidArgument(
62-
"Wrong layout set for X tensor. Expected layout is `kMKLDNN`, "
63-
"But received %s.",
64-
DataLayout_error_msg[static_cast<int>(DataLayout::kMKLDNN)]));
65-
PADDLE_ENFORCE_NE(
66-
x->format(), MKLDNNMemoryFormat::undef,
67-
platform::errors::InvalidArgument("Wrong format set for X tensor"));
68-
69-
auto src_tz = paddle::framework::vectorize(x->dims());
70-
71-
// Flags are added by bitwise OR operation
72-
auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
73-
if (global_stats)
74-
flags |= mkldnn::normalization_flags::use_global_stats; // 010
75-
if (fuse_with_relu && test_mode)
76-
flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
77-
78-
auto md = mkldnn::memory::desc(
79-
src_tz, platform::MKLDNNGetDataType<T>(),
80-
platform::MKLDNNFormatForSize(src_tz.size(), x->format()));
81-
82-
this->AcquireForwardPrimitiveDescriptor(
83-
global_stats == true ? mkldnn::prop_kind::forward_scoring
84-
: mkldnn::prop_kind::forward_training,
85-
md, epsilon, flags);
86-
}
43+
const mkldnn::engine mkldnn_engine, const Tensor *x,
44+
const bool global_stats, const bool test_mode)
45+
: platform::MKLDNNHandlerNoCachingT<T,
46+
mkldnn::batch_normalization_forward,
47+
mkldnn::batch_normalization_backward>(
48+
mkldnn_engine, ctx.GetPlace()) {
49+
const float epsilon = ctx.Attr<float>("epsilon");
50+
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
51+
52+
std::vector<std::string> DataLayout_error_msg = {"kNHWC", "kNCHW",
53+
"kAnyLayout", "kMKLDNN"};
54+
PADDLE_ENFORCE_EQ(
55+
x->layout(), DataLayout::kMKLDNN,
56+
platform::errors::InvalidArgument(
57+
"Wrong layout set for X tensor. Expected layout is `kMKLDNN`, "
58+
"But received %s.",
59+
DataLayout_error_msg[static_cast<int>(DataLayout::kMKLDNN)]));
60+
PADDLE_ENFORCE_NE(
61+
x->format(), MKLDNNMemoryFormat::undef,
62+
platform::errors::InvalidArgument("Wrong format set for X tensor"));
63+
64+
auto src_tz = paddle::framework::vectorize(x->dims());
65+
66+
// Flags are added by bitwise OR operation
67+
auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
68+
if (global_stats)
69+
flags |= mkldnn::normalization_flags::use_global_stats; // 010
70+
if (fuse_with_relu && test_mode)
71+
flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
72+
73+
auto md = mkldnn::memory::desc(
74+
src_tz, platform::MKLDNNGetDataType<T>(),
75+
platform::MKLDNNFormatForSize(src_tz.size(), x->format()));
76+
77+
this->AcquireForwardPrimitiveDescriptor(
78+
global_stats == true ? mkldnn::prop_kind::forward_scoring
79+
: mkldnn::prop_kind::forward_training,
80+
md, epsilon, flags);
8781
}
8882

8983
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
90-
const platform::MKLDNNDeviceContext &dev_ctx,
91-
platform::Place cpu_place, const Tensor *in_x,
92-
const Tensor *scale, const Tensor *out_grad,
93-
const std::string &unique_name)
94-
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
95-
mkldnn::batch_normalization_backward>(
96-
dev_ctx, dev_ctx.GetEngine(), cpu_place,
97-
platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()),
98-
unique_name)) {
99-
if (!this->isBwdCached()) {
100-
PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
101-
platform::errors::InvalidArgument(
102-
"Wrong layout set for Input out_grad tensor"));
103-
PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
104-
platform::errors::InvalidArgument(
105-
"Wrong format set for Input out_grad tensor"));
106-
107-
auto src_tz = paddle::framework::vectorize<int64_t>(in_x->dims());
108-
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
109-
PADDLE_ENFORCE_EQ(
110-
scale_tz.size(), 1,
111-
platform::errors::InvalidArgument(
112-
"Dims of scale tensor must be 1, but received scale's size is %d",
113-
scale_tz.size()));
114-
115-
MKLDNNMemoryFormat diff_fmt =
116-
platform::MKLDNNFormatForSize(src_tz.size(), out_grad->format());
117-
118-
MKLDNNMemoryFormat src_fmt =
119-
platform::MKLDNNFormatForSize(src_tz.size(), in_x->format());
120-
121-
auto dims = framework::vectorize(in_x->dims());
122-
auto diff_dst_md = mkldnn::memory::desc(
123-
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
124-
auto src_md =
125-
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), src_fmt);
126-
127-
const float epsilon = ctx.Attr<float>("epsilon");
128-
129-
this->AcquireForwardPrimitiveDescriptor(
130-
mkldnn::prop_kind::forward_training, src_md, epsilon,
131-
mkldnn::normalization_flags::use_scale_shift);
132-
this->AcquireBackwardPrimitiveDescriptor(
133-
mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon,
134-
mkldnn::normalization_flags::use_scale_shift);
135-
}
84+
const mkldnn::engine mkldnn_engine, const Tensor *in_x,
85+
const Tensor *scale, const Tensor *out_grad)
86+
: platform::MKLDNNHandlerNoCachingT<T,
87+
mkldnn::batch_normalization_forward,
88+
mkldnn::batch_normalization_backward>(
89+
mkldnn_engine, ctx.GetPlace()) {
90+
PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
91+
platform::errors::InvalidArgument(
92+
"Wrong layout set for Input out_grad tensor"));
93+
PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
94+
platform::errors::InvalidArgument(
95+
"Wrong format set for Input out_grad tensor"));
96+
97+
auto src_tz = paddle::framework::vectorize<int64_t>(in_x->dims());
98+
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
99+
PADDLE_ENFORCE_EQ(
100+
scale_tz.size(), 1,
101+
platform::errors::InvalidArgument(
102+
"Dims of scale tensor must be 1, but received scale's size is %d",
103+
scale_tz.size()));
104+
105+
MKLDNNMemoryFormat diff_fmt =
106+
platform::MKLDNNFormatForSize(src_tz.size(), out_grad->format());
107+
108+
MKLDNNMemoryFormat src_fmt =
109+
platform::MKLDNNFormatForSize(src_tz.size(), in_x->format());
110+
111+
auto dims = framework::vectorize(in_x->dims());
112+
auto diff_dst_md =
113+
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
114+
auto src_md =
115+
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), src_fmt);
116+
117+
const float epsilon = ctx.Attr<float>("epsilon");
118+
119+
this->AcquireForwardPrimitiveDescriptor(
120+
mkldnn::prop_kind::forward_training, src_md, epsilon,
121+
mkldnn::normalization_flags::use_scale_shift);
122+
this->AcquireBackwardPrimitiveDescriptor(
123+
mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon,
124+
mkldnn::normalization_flags::use_scale_shift);
136125
}
137126

138127
std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(const Tensor *scale,
139-
const Tensor *shift,
140-
const bool is_test) {
141-
auto scaleshift_memory = this->AcquireMemory("@scaleshift_mem_p");
142-
if (scaleshift_memory == nullptr || !is_test) {
143-
auto scale_tz = paddle::framework::vectorize(scale->dims());
144-
const unsigned int C = scale_tz[0];
145-
PADDLE_ENFORCE_EQ(
146-
scale_tz.size(), 1,
147-
platform::errors::InvalidArgument(
148-
"Dims of scale tensor must be 1, but received scale's size is %d",
149-
scale_tz.size()));
150-
151-
auto mem_p = this->AcquireMemoryFromPrimitive(
152-
this->fwd_pd_->weights_desc(), "@scaleshift_mem_p");
153-
154-
// MKLDNN requires a single piece of memory for scale and shift/bias data
155-
auto mem_data_handle = reinterpret_cast<T *>(mem_p->get_data_handle());
156-
std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle);
157-
std::copy(shift->data<T>(), shift->data<T>() + C, mem_data_handle + C);
158-
159-
return mem_p;
160-
}
128+
const Tensor *shift) {
129+
auto scale_tz = paddle::framework::vectorize(scale->dims());
130+
const unsigned int C = scale_tz[0];
131+
PADDLE_ENFORCE_EQ(
132+
scale_tz.size(), 1,
133+
platform::errors::InvalidArgument(
134+
"Dims of scale tensor must be 1, but received scale's size is %d",
135+
scale_tz.size()));
136+
137+
auto scaleshift_memory =
138+
this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());
139+
140+
// MKLDNN requires a single piece of memory for scale and shift/bias data
141+
auto mem_data_handle =
142+
reinterpret_cast<T *>(scaleshift_memory->get_data_handle());
143+
std::copy(scale->data<T>(), scale->data<T>() + C, mem_data_handle);
144+
std::copy(shift->data<T>(), shift->data<T>() + C, mem_data_handle + C);
161145
return scaleshift_memory;
162146
}
163147

164148
std::shared_ptr<mkldnn::memory> AcquireDiffScaleShiftMemory(
165149
T *diff_scaleshift_data) {
166150
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(),
167-
diff_scaleshift_data,
168-
"@diff_scaleshift_mem_p");
151+
diff_scaleshift_data);
169152
}
170153

171154
std::shared_ptr<mkldnn::memory> AcquireMeanMemory(
172155
const framework::Tensor *mean) {
173156
const T *mean_data = mean->data<T>();
174-
return this->AcquireMemoryFromPrimitive(
175-
this->fwd_pd_->mean_desc(), to_void_cast<T>(mean_data), "@mean_mem_p");
157+
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
158+
to_void_cast<T>(mean_data));
176159
}
177160

178161
std::shared_ptr<mkldnn::memory> AcquireMeanMemory(framework::Tensor *mean) {
179162
T *mean_data = mean->mutable_data<T>(this->place_,
180163
this->fwd_pd_->mean_desc().get_size());
181164
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
182-
mean_data, "@mean_mem_p");
165+
mean_data);
183166
}
184167

185168
std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
186169
const framework::Tensor *variance) {
187170
const T *variance_data = variance->data<T>();
188171
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
189-
to_void_cast<T>(variance_data),
190-
"@variance_mem_p");
172+
to_void_cast<T>(variance_data));
191173
}
192174

193175
std::shared_ptr<mkldnn::memory> AcquireVarianceMemory(
194176
framework::Tensor *variance) {
195177
T *variance_data = variance->mutable_data<T>(
196178
this->place_, this->fwd_pd_->variance_desc().get_size());
197179
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
198-
variance_data, "@variance_mem_p");
180+
variance_data);
199181
}
200182
};
201183

@@ -220,13 +202,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
220202
auto *batch_mean = ctx.Output<Tensor>("SavedMean");
221203
auto *batch_variance = ctx.Output<Tensor>("SavedVariance");
222204

223-
BatchNormMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine,
224-
ctx.GetPlace(), x, global_stats,
225-
test_mode, ctx.OutputName("SavedMean"));
205+
BatchNormMKLDNNHandler<T> handler(ctx, mkldnn_engine, x, global_stats,
206+
test_mode);
226207

227208
auto src_memory = handler.AcquireSrcMemory(x);
228-
auto scaleshift_memory =
229-
handler.AcquireScaleShiftMemory(scale, shift, is_test);
209+
auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift);
230210
auto dst_memory = handler.AcquireDstMemory(y);
231211

232212
auto batch_norm_p = handler.AcquireForwardPrimitive();
@@ -303,8 +283,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
303283
auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
304284
auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias"));
305285

306-
BatchNormMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), x, scale,
307-
diff_y, ctx.InputName("SavedMean"));
286+
BatchNormMKLDNNHandler<T> handler(ctx, mkldnn_engine, x, scale, diff_y);
308287

309288
// MKLDNN requires a single piece of memory for scale and shift/bias data
310289
const unsigned int C = paddle::framework::vectorize(scale->dims())[0];
@@ -316,8 +295,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
316295
auto mean_memory = handler.AcquireMeanMemory(batch_mean);
317296
auto variance_memory = handler.AcquireVarianceMemory(batch_variance);
318297
auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y);
319-
auto scaleshift_memory =
320-
handler.AcquireScaleShiftMemory(scale, shift, false);
298+
auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift);
321299
auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x);
322300
auto diff_scaleshift_memory =
323301
handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data());

paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,27 +30,21 @@ using platform::to_void_cast;
3030

3131
template <typename T = float>
3232
class InterpolateMKLDNNHandler
33-
: public platform::MKLDNNHandlerT<T, dnnl::resampling_forward> {
33+
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::resampling_forward> {
3434
public:
3535
InterpolateMKLDNNHandler(const dnnl::algorithm algo,
36-
const platform::MKLDNNDeviceContext& dev_ctx,
3736
const dnnl::engine engine, platform::Place cpu_place,
38-
const Tensor* x, Tensor* z,
39-
const std::string& uniq_name)
40-
: platform::MKLDNNHandlerT<T, dnnl::resampling_forward>(
41-
dev_ctx, engine, cpu_place,
42-
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
43-
uniq_name)) {
44-
if (!this->isCached()) {
45-
const auto src_x_tz = framework::vectorize(x->dims());
46-
const auto dst_tz = framework::vectorize(z->dims());
47-
const auto src_md = dnnl::memory::desc(
48-
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
49-
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
50-
MKLDNNMemoryFormat::any);
51-
this->AcquireForwardPrimitiveDescriptor(
52-
dnnl::prop_kind::forward_inference, algo, src_md, dst_md);
53-
}
37+
const Tensor* x, Tensor* z)
38+
: platform::MKLDNNHandlerNoCachingT<T, dnnl::resampling_forward>(
39+
engine, cpu_place) {
40+
const auto src_x_tz = framework::vectorize(x->dims());
41+
const auto dst_tz = framework::vectorize(z->dims());
42+
const auto src_md = dnnl::memory::desc(
43+
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
44+
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
45+
MKLDNNMemoryFormat::any);
46+
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_inference,
47+
algo, src_md, dst_md);
5448
}
5549
};
5650

@@ -145,7 +139,6 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
145139
const auto& mkldnn_engine = dev_ctx.GetEngine();
146140

147141
const auto* x = ctx.Input<Tensor>("X");
148-
std::vector<float> scale_prior;
149142
auto* z = ctx.Output<Tensor>("Out");
150143

151144
auto interp_method = ctx.Attr<std::string>("interp_method");
@@ -155,11 +148,10 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
155148

156149
auto out_dims_vec = ComputeOutputShape(ctx);
157150
framework::DDim dim_out = framework::make_ddim(out_dims_vec);
158-
z->mutable_data<T>(dim_out, ctx.GetPlace());
151+
z->Resize(dim_out);
159152

160-
InterpolateMKLDNNHandler<T> handler(algo, dev_ctx, mkldnn_engine,
161-
ctx.GetPlace(), x, z,
162-
ctx.OutputName("Out"));
153+
InterpolateMKLDNNHandler<T> handler(algo, mkldnn_engine, ctx.GetPlace(), x,
154+
z);
163155

164156
auto src_memory_p = handler.AcquireSrcMemory(x);
165157
auto dst_memory_p = handler.AcquireDstMemory(z);

0 commit comments

Comments
 (0)