Skip to content

Commit a2a45d8

Browse files
authored
Added cast op oneDNN kernel for bf16/fp32 datatypes casting(FWD/BWD) (#33056)
* added op cast functionality for fp32/bf16 * added newline * added entries in static mode white list and unity build * fixed failing tests * changes after review * added formatting * upgraded tests file as reviewer suggested * changes after review * minor change
1 parent 009ff61 commit a2a45d8

File tree

7 files changed

+218
-10
lines changed

7 files changed

+218
-10
lines changed

paddle/fluid/operators/cast_op.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
2727
AddOutput("Out", "The output tensor of cast op");
2828
AddAttr<int>("out_dtype", "output data type");
2929
AddAttr<int>("in_dtype", "input data type");
30+
AddAttr<bool>("use_mkldnn",
31+
"(bool, default false) Only used in mkldnn kernel")
32+
.SetDefault(false);
3033
AddComment(R"DOC(
3134
Cast Operator.
3235
@@ -50,6 +53,7 @@ class CastOpGradMaker : public framework::SingleGradOpMaker<T> {
5053
grad->SetOutput("Out", this->InputGrad("X"));
5154
grad->SetAttr("out_dtype", this->GetAttr("in_dtype"));
5255
grad->SetAttr("in_dtype", this->GetAttr("out_dtype"));
56+
grad->SetAttr("use_mkldnn", this->GetAttr("use_mkldnn"));
5357
}
5458
};
5559

@@ -77,6 +81,28 @@ class CastOp : public framework::OperatorWithKernel {
7781
if (platform::is_cuda_pinned_place(tensor_place)) {
7882
return framework::OpKernelType(tensor->type(), ctx.device_context());
7983
}
84+
85+
#ifdef PADDLE_WITH_MKLDNN
86+
int in_dtype = ctx.Attr<int>("in_dtype");
87+
int out_dtype = ctx.Attr<int>("out_dtype");
88+
89+
auto MKLDNNSupportsCast = [&]() -> bool {
90+
int dtype_fp32 = static_cast<int>(framework::proto::VarType::FP32);
91+
int dtype_bf16 = static_cast<int>(framework::proto::VarType::BF16);
92+
93+
if ((in_dtype != dtype_fp32 && in_dtype != dtype_bf16) ||
94+
(out_dtype != dtype_fp32 && out_dtype != dtype_bf16))
95+
return false;
96+
97+
return true;
98+
};
99+
100+
if (this->CanMKLDNNBeUsed(ctx, tensor->type()) && MKLDNNSupportsCast()) {
101+
return framework::OpKernelType(tensor->type(), ctx.GetPlace(),
102+
framework::DataLayout::kMKLDNN,
103+
framework::LibraryType::kMKLDNN);
104+
}
105+
#endif
80106
return framework::OpKernelType(tensor->type(), tensor_place);
81107
}
82108
};
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/platform/mkldnn_reuse.h"
13+
14+
namespace paddle {
15+
namespace operators {
16+
17+
using paddle::framework::Tensor;
18+
19+
template <typename T>
20+
class CastMKLDNNKernel : public framework::OpKernel<T> {
21+
public:
22+
void Compute(const framework::ExecutionContext& ctx) const override {
23+
this->RunKernel(ctx);
24+
}
25+
26+
void RunKernel(const framework::ExecutionContext& ctx) const {
27+
const auto& dev_ctx =
28+
ctx.template device_context<platform::MKLDNNDeviceContext>();
29+
30+
auto* x = ctx.Input<Tensor>("X");
31+
auto* out = ctx.Output<Tensor>("Out");
32+
33+
int in_dtype = ctx.Attr<int>("in_dtype");
34+
int out_dtype = ctx.Attr<int>("out_dtype");
35+
36+
auto x_paddle_type = framework::proto::VarType::Type(in_dtype);
37+
auto out_paddle_type = framework::proto::VarType::Type(out_dtype);
38+
39+
mkldnn::memory::data_type x_type =
40+
framework::ToMKLDNNDataType(x_paddle_type);
41+
mkldnn::memory::data_type out_type =
42+
framework::ToMKLDNNDataType(out_paddle_type);
43+
44+
auto x_tz = framework::vectorize(x->dims());
45+
46+
std::string key =
47+
platform::CreateKey(dev_ctx, x_tz, x->format(), x->format(), x_type);
48+
platform::ReorderMKLDNNHandler reorder_handler(
49+
x_tz, x_paddle_type, x_type, out_paddle_type, out_type, dev_ctx,
50+
dev_ctx.GetEngine(), key);
51+
52+
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
53+
x->format(), platform::to_void_cast(x->data<T>()));
54+
auto reorder_dst_memory_p =
55+
reorder_handler.AcquireDstMemory(out, x->format(), dev_ctx.GetPlace());
56+
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
57+
reorder_src_memory_p);
58+
59+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
60+
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
61+
astream.wait();
62+
63+
out->set_layout(framework::DataLayout::kMKLDNN);
64+
out->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
65+
}
66+
};
67+
} // namespace operators
68+
} // namespace paddle
69+
70+
namespace ops = paddle::operators;
71+
REGISTER_OP_KERNEL(cast, MKLDNN, paddle::platform::CPUPlace,
72+
ops::CastMKLDNNKernel<float>,
73+
ops::CastMKLDNNKernel<paddle::platform::bfloat16>);

paddle/fluid/operators/unity_build_rule.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ register_unity_group(cc
3030
bmm_op.cc
3131
bpr_loss_op.cc
3232
cast_op.cc
33+
mkldnn/cast_mkldnn_op.cc
3334
cholesky_op.cc
3435
chunk_eval_op.cc
3536
clip_by_norm_op.cc

paddle/fluid/platform/mkldnn_reuse.h

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,23 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
926926
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
927927
dims_(dims),
928928
vtype_(vtype),
929-
dtype_(dtype) {}
929+
vtype_dst_(vtype),
930+
dtype_(dtype),
931+
dtype_dst_(dtype) {}
932+
933+
ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
934+
framework::proto::VarType::Type vtype,
935+
mkldnn::memory::data_type dtype,
936+
framework::proto::VarType::Type vtype_dst,
937+
mkldnn::memory::data_type dtype_dst,
938+
const platform::MKLDNNDeviceContext& dev_ctx,
939+
mkldnn::engine engine, const std::string& base_key)
940+
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
941+
dims_(dims),
942+
vtype_(vtype),
943+
vtype_dst_(vtype_dst),
944+
dtype_(dtype),
945+
dtype_dst_(dtype_dst) {}
930946

931947
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
932948
const MKLDNNMemoryFormat& fmt, void* ptr) {
@@ -940,15 +956,16 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
940956
auto mem_p =
941957
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
942958
if (mem_p == nullptr) {
943-
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
944-
auto dst_data = output->mutable_data(place, vtype_, dst_md.get_size());
959+
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_dst_, fmt);
960+
auto dst_data =
961+
output->mutable_data(place, vtype_dst_, dst_md.get_size());
945962

946963
mem_p = std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
947964
dev_ctx_.SetBlob(local_key, mem_p);
948965
} else {
949966
// Even if memory object exists , we may be using it for diffrent tensor
950967
auto dst_data =
951-
output->mutable_data(place, vtype_, mem_p->get_desc().get_size());
968+
output->mutable_data(place, vtype_dst_, mem_p->get_desc().get_size());
952969
mem_p->set_data_handle(dst_data);
953970
}
954971
return mem_p;
@@ -970,8 +987,8 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
970987

971988
private:
972989
std::vector<int64_t> dims_;
973-
framework::proto::VarType::Type vtype_;
974-
mkldnn::memory::data_type dtype_;
990+
framework::proto::VarType::Type vtype_, vtype_dst_;
991+
mkldnn::memory::data_type dtype_, dtype_dst_;
975992
};
976993

977994
template <typename T>
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
20+
import paddle
21+
import paddle.fluid.core as core
22+
import paddle.fluid as fluid
23+
from paddle.fluid import compiler, Program, program_guard
24+
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
25+
26+
27+
@unittest.skipIf(not core.supports_bfloat16(),
28+
"place does not support BF16 evaluation")
29+
class TestCastBF16ToFP32MKLDNNOp(OpTest):
30+
def init_data(self):
31+
self.out = np.random.random(size=[10, 10]).astype("float32")
32+
self.x = convert_float_to_uint16(self.out)
33+
34+
def setUp(self):
35+
self.init_data()
36+
self.inputs = {'X': self.x}
37+
self.outputs = {'Out': self.out}
38+
prepare_dtype = lambda x: int(core.VarDesc.VarType.BF16 if x.dtype != np.float32 else core.VarDesc.VarType.FP32)
39+
self.attrs = {
40+
'in_dtype': prepare_dtype(self.x),
41+
'out_dtype': prepare_dtype(self.out),
42+
'use_mkldnn': True
43+
}
44+
self.op_type = 'cast'
45+
46+
def test_check_output(self):
47+
self.check_output(check_dygraph=False)
48+
49+
def test_check_grad(self):
50+
self.check_grad_with_place(
51+
core.CPUPlace(), ["X"],
52+
"Out",
53+
check_dygraph=False,
54+
user_defined_grads=[self.inputs['X']],
55+
user_defined_grad_outputs=[self.outputs['Out']])
56+
57+
58+
class TestCastFP32ToBF16MKLDNNOp(TestCastBF16ToFP32MKLDNNOp):
59+
def init_data(self):
60+
self.x = np.random.random(size=[2, 6]).astype("float32")
61+
self.out = convert_float_to_uint16(self.x)
62+
63+
64+
class TestCastBF16ToBF16MKLDNNOp(TestCastBF16ToFP32MKLDNNOp):
65+
def init_data(self):
66+
self.x = np.random.random(size=[6, 13]).astype("uint16")
67+
self.out = self.x
68+
69+
70+
class TestCastFP32ToFP32MKLDNNOp(TestCastBF16ToFP32MKLDNNOp):
71+
def init_data(self):
72+
self.x = np.random.random(size=[7, 15]).astype("float32")
73+
self.out = self.x
74+
75+
76+
if __name__ == '__main__':
77+
paddle.enable_static()
78+
unittest.main()

python/paddle/fluid/tests/unittests/op_test.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,8 +1191,12 @@ def find_actual(target_name, fetch_list):
11911191
np.float32, np.float64
11921192
]:
11931193
actual_t = convert_uint16_to_float(actual_t)
1194-
atol = 0.03
1194+
atol = max(atol, 0.03)
11951195

1196+
if expect_t.dtype == np.uint16 and actual_t.dtype == np.uint16:
1197+
expect_t = convert_uint16_to_float(expect_t)
1198+
actual_t = convert_uint16_to_float(actual_t)
1199+
atol = max(atol, 0.03)
11961200
# NOTE(zhiqiu): np.allclose([], [1.]) returns True
11971201
# see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
11981202
if expect_t.size == 0:
@@ -1501,13 +1505,21 @@ def check_grad_with_place(self,
15011505

15021506
# comparison of bf16 results will happen as fp32
15031507
# loop over list of grads and convert bf16 to fp32
1504-
fp32_grads = []
1508+
fp32_analytic_grads = []
15051509
for grad in analytic_grads:
15061510
if grad.dtype == np.uint16:
15071511
grad = convert_uint16_to_float(grad)
15081512
max_relative_error = 0.03
1509-
fp32_grads.append(grad)
1510-
analytic_grads = fp32_grads
1513+
fp32_analytic_grads.append(grad)
1514+
analytic_grads = fp32_analytic_grads
1515+
1516+
fp32_numeric_grads = []
1517+
for grad in numeric_grads:
1518+
if grad.dtype == np.uint16:
1519+
grad = convert_uint16_to_float(grad)
1520+
max_relative_error = 0.03
1521+
fp32_numeric_grads.append(grad)
1522+
numeric_grads = fp32_numeric_grads
15111523

15121524
self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check,
15131525
max_relative_error,

tools/static_mode_white_list.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,7 @@
589589
'test_matmul_op_with_head',
590590
'test_var_conv_2d',
591591
'test_batch_norm_mkldnn_op',
592+
'test_cast_mkldnn_op',
592593
'test_concat_int8_mkldnn_op',
593594
'test_concat_bf16_mkldnn_op',
594595
'test_concat_mkldnn_op',

0 commit comments

Comments
 (0)