Skip to content

Commit 2149fd6

Browse files
piotrekobizhangbo9674
authored andcommitted
Added elementwise_sub_mkldnn operator (PaddlePaddle#35662)
* Add elementwise_sub_mkldnn_op without grad * Add test to static_mode_white_list * Refactor code, change license years * Remove invalid grad implementation * Fix element_wise_sub_op test * Fix CI Approval error * Remove unnecessary EltwiseSubMKLDNNGradKernel class * Fix CI Approval 2 * Fix CI Approval 3 * Fix CI Approval Attempt #4 * Fix CI Approve Attempt #5 * Fix CI Approval Attempt #6 * Fix CI Approval Attemt #7 * Change test names containing add to sub * Fix old tests testing add instead of sub * Copy grad implementation from elementwise_add_mkldnn * CI test fix attempt * Revert "CI test fix attempt" This reverts commit c647cacf41e6a87c715385a185de5cbf65fc8900. * Fix CI attempt 2 * Fix elementwise_sub tests, temporary mkldnn broadcast test disable * Add working implementation of elementwise_sub grad * Fix build errors caused by pull * Fix format error * Fix format error 2 * Disable elementwise_sub_mkldnn test on GPU * Apply fix for paddle.fluid import * Revert changes of test_elementwise_sub and Fix mkldnn test * Revert "Apply fix for paddle.fluid import" This reverts commit fc3b122. * fix bug of module 'paddle' has no attribute 'fluid' for python3.6 (PaddlePaddle#35862) * Add changes suggested by reviewers * Change @unittest.skipIf... to @OpTestTool.skip_if_not_cpu_bf16() to satisfy Approval CI * Remove check_dygraph=False to satisify CI Approval Co-authored-by: zhangbo9674 <[email protected]>
1 parent 23b23cd commit 2149fd6

File tree

4 files changed

+380
-5
lines changed

4 files changed

+380
-5
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
2+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
17+
namespace paddle {
18+
namespace framework {
19+
class ExecutionContext;
20+
} // namespace framework
21+
namespace platform {
22+
class CPUDeviceContext;
23+
struct CPUPlace;
24+
} // namespace platform
25+
} // namespace paddle
26+
27+
namespace paddle {
28+
namespace operators {
29+
template <typename T>
30+
class EltwiseSubMKLDNNGradKernel : public ElemwiseGradKernel<T> {
31+
public:
32+
void Compute(const framework::ExecutionContext& ctx) const override {
33+
ElemwiseGradKernel<T>::Compute(ctx);
34+
using Tensor = framework::Tensor;
35+
36+
auto& dev_ctx =
37+
ctx.template device_context<platform::MKLDNNDeviceContext>();
38+
const auto& onednn_engine = dev_ctx.GetEngine();
39+
40+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
41+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
42+
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
43+
44+
auto tz = framework::vectorize<int64_t>(dout->dims());
45+
memory::data_type dout_type = framework::ToMKLDNNDataType(dout->type());
46+
platform::ReorderMKLDNNHandler handler(tz, dout->type(), dout_type,
47+
onednn_engine);
48+
49+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
50+
auto reorder_src_memory_p = handler.AcquireSrcMemory(
51+
dout->format(), platform::to_void_cast(dout->data<T>()));
52+
53+
if (dx) {
54+
auto reorder_dst_memory_p =
55+
handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace());
56+
auto reorder_p =
57+
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
58+
platform::RecordEvent record_reorder("int_reorder",
59+
platform::EventRole::kUniqueOp);
60+
61+
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
62+
astream.wait();
63+
64+
dx->set_layout(DataLayout::kMKLDNN);
65+
dx->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
66+
}
67+
68+
if (dy) {
69+
// Direct copy
70+
if (dout->dims() == dy->dims()) {
71+
auto reorder_dst_memory_p =
72+
handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace());
73+
74+
dnnl::primitive_attr reorder_attr;
75+
std::vector<float> scales = {-1};
76+
reorder_attr.set_output_scales(0, scales);
77+
auto reorder_p = std::make_shared<dnnl::reorder>(
78+
*(reorder_src_memory_p), *(reorder_dst_memory_p), reorder_attr);
79+
platform::RecordEvent record_reorder("int_reorder",
80+
platform::EventRole::kUniqueOp);
81+
reorder_p->execute(astream, *reorder_src_memory_p,
82+
*reorder_dst_memory_p);
83+
astream.wait();
84+
85+
dy->set_layout(DataLayout::kMKLDNN);
86+
dy->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
87+
} else {
88+
// Broadcasting
89+
90+
dnnl::post_ops po;
91+
po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0);
92+
dnnl::primitive_attr attr;
93+
attr.set_post_ops(po);
94+
95+
platform::ReductionMKLDNNHandler<T> handler_sum(
96+
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
97+
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy), attr);
98+
99+
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
100+
auto reduction_p = handler_sum.AcquireForwardPrimitive();
101+
102+
reduction_p->execute(astream, {
103+
{DNNL_ARG_SRC, *reorder_src_memory_p},
104+
{DNNL_ARG_DST, *dy_memory_p},
105+
});
106+
astream.wait();
107+
108+
dy->set_layout(DataLayout::kMKLDNN);
109+
dy->set_format(
110+
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
111+
paddle::framework::vectorize<int64_t>(dy->dims()))));
112+
}
113+
}
114+
}
115+
};
116+
117+
} // namespace operators
118+
} // namespace paddle
119+
120+
namespace ops = paddle::operators;
121+
122+
REGISTER_OP_KERNEL(
123+
elementwise_sub, MKLDNN, paddle::platform::CPUPlace,
124+
ops::EltwiseMKLDNNKernel<float, dnnl::algorithm::binary_sub>,
125+
ops::EltwiseMKLDNNKernel<paddle::platform::bfloat16,
126+
dnnl::algorithm::binary_sub>,
127+
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_sub>,
128+
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_sub>)
129+
130+
REGISTER_OP_KERNEL(elementwise_sub_grad, MKLDNN, ::paddle::platform::CPUPlace,
131+
ops::EltwiseSubMKLDNNGradKernel<paddle::platform::bfloat16>,
132+
ops::EltwiseSubMKLDNNGradKernel<float>)

paddle/fluid/platform/mkldnn_reuse.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include <string>
2020
#include <utility>
2121
#include <vector>
22+
2223
#include "boost/optional.hpp"
2324
#include "paddle/fluid/framework/data_layout_transform.h"
2425
#include "paddle/fluid/framework/operator.h"
@@ -927,7 +928,6 @@ class BroadcastDataMKLDNNHandler
927928
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) {
928929
T_out* ptr = output->mutable_data<T_out>(
929930
this->place_, this->fwd_pd_->dst_desc().get_size());
930-
;
931931
memset(ptr, 0, this->fwd_pd_->dst_desc().get_size());
932932
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
933933
}
@@ -940,7 +940,8 @@ class ReductionMKLDNNHandler
940940
ReductionMKLDNNHandler(const dnnl::algorithm algo, const float p,
941941
const float eps, const mkldnn::engine engine,
942942
platform::Place cpu_place, const Tensor* x,
943-
const Tensor* y, std::vector<int64_t> y_tz)
943+
const Tensor* y, std::vector<int64_t> y_tz,
944+
const dnnl::primitive_attr& attr = NULL)
944945
: platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction>(engine,
945946
cpu_place) {
946947
PADDLE_ENFORCE_EQ(
@@ -957,7 +958,10 @@ class ReductionMKLDNNHandler
957958
const auto y_md =
958959
memory::desc(y_tz, platform::MKLDNNGetDataType<T>(), x->format());
959960

960-
this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps);
961+
if (attr)
962+
this->AcquireForwardPrimitiveDescriptor(attr, algo, x_md, y_md, p, eps);
963+
else
964+
this->AcquireForwardPrimitiveDescriptor(algo, x_md, y_md, p, eps);
961965
}
962966
};
963967

@@ -979,8 +983,9 @@ class ActivationMKLDNNHandler
979983
if (ctx.Type() == "scale") {
980984
bool bias_after_scale = ctx.Attr<bool>("bias_after_scale");
981985
auto* scale_tensor = ctx.Input<Tensor>("ScaleTensor");
982-
alpha = (scale_tensor == nullptr) ? ctx.Attr<float>("scale")
983-
: (float)*(scale_tensor->data<T>());
986+
alpha = (scale_tensor == nullptr)
987+
? ctx.Attr<float>("scale")
988+
: static_cast<float>(*(scale_tensor->data<T>()));
984989
beta = ctx.Attr<float>("bias");
985990
// if bias_after_scale == true
986991
// out = scale*X + bias
@@ -1504,6 +1509,7 @@ static void SetDstMemoryQuantized(
15041509
T* output_data = output->mutable_data<T>(ctx.GetPlace());
15051510
const size_t dst_dims = dst_tz.size();
15061511
MKLDNNMemoryFormat dst_fmt;
1512+
15071513
PADDLE_ENFORCE_LE(dst_dims, 5, platform::errors::InvalidArgument(
15081514
"Dst memory for quantization can not have "
15091515
"dims > 5. But received dst_dims is %d.",

0 commit comments

Comments
 (0)