Skip to content

Commit 38419f5

Browse files
[NPU] Support NPU kernel cast op (PaddlePaddle#31635)
Co-authored-by: frankwhzhang <[email protected]>
1 parent 2d5cd1e commit 38419f5

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_ASCEND_CL
16+
#include <memory>
17+
#include <string>
18+
19+
#include "paddle/fluid/operators/cast_op.h"
20+
#include "paddle/fluid/operators/npu_op_runner.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
25+
static std::map<framework::proto::VarType::Type, aclDataType>
26+
DTYPE_2_ACL_DTYPE = {
27+
{framework::proto::VarType::BOOL, ACL_BOOL},
28+
{framework::proto::VarType::INT16, ACL_INT16},
29+
{framework::proto::VarType::INT32, ACL_INT32},
30+
{framework::proto::VarType::INT64, ACL_INT64},
31+
{framework::proto::VarType::FP16, ACL_FLOAT16},
32+
{framework::proto::VarType::FP32, ACL_FLOAT},
33+
{framework::proto::VarType::FP64, ACL_DOUBLE},
34+
};
35+
36+
using Tensor = framework::Tensor;
37+
38+
template <typename DeviceContext, typename T>
39+
class CastNPUKernel : public framework::OpKernel<T> {
40+
public:
41+
void Compute(const framework::ExecutionContext& ctx) const override {
42+
auto* x = ctx.Input<Tensor>("X");
43+
int dtype = ctx.Attr<int>("out_dtype");
44+
45+
auto* out = ctx.Output<Tensor>("Out");
46+
47+
auto place = ctx.GetPlace();
48+
49+
auto iter = DTYPE_2_ACL_DTYPE.find(static_cast<framework::proto::VarType::Type>(dtype));
50+
int aclDtype = iter->second;
51+
52+
if (dtype == framework::proto::VarType::FP32) {
53+
out->mutable_data<float>(place);
54+
} else if (dtype == framework::proto::VarType::FP16) {
55+
out->mutable_data<paddle::platform::float16>(place);
56+
} else if (dtype == framework::proto::VarType::INT16) {
57+
out->mutable_data<int16_t>(place);
58+
} else if (dtype == framework::proto::VarType::INT32) {
59+
out->mutable_data<int32_t>(place);
60+
} else if (dtype == framework::proto::VarType::INT64) {
61+
out->mutable_data<int64_t>(place);
62+
} else if (dtype == framework::proto::VarType::FP64) {
63+
out->mutable_data<double>(place);
64+
} else if (dtype == framework::proto::VarType::BOOL) {
65+
out->mutable_data<bool>(place);
66+
}
67+
68+
auto stream =
69+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
70+
.stream();
71+
72+
auto runner = NpuOpRunner("Cast", {*x}, {*out}, {{"dst_type", static_cast<int32_t>(aclDtype)}});
73+
runner.Run(stream);
74+
}
75+
};
76+
} // namespace operators
77+
} // namespace paddleaclDtype
78+
79+
namespace ops = paddle::operators;
80+
81+
REGISTER_OP_NPU_KERNEL(
82+
cast,
83+
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int16_t>,
84+
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int32_t>,
85+
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
86+
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, bool>,
87+
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, double>,
88+
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, float>,
89+
ops::CastNPUKernel<paddle::platform::NPUDeviceContext,
90+
paddle::platform::float16>);
91+
#endif
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
import paddle.fluid.core as core
25+
26+
paddle.enable_static()
27+
SEED = 2021
28+
29+
30+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
31+
"core is not compiled with NPU")
32+
class TestCast1(OpTest):
33+
def setUp(self):
34+
self.set_npu()
35+
self.op_type = "cast"
36+
self.place = paddle.NPUPlace(0)
37+
38+
ipt = np.random.random(size=[10, 10]) + 1
39+
self.inputs = {'X': ipt.astype('float32')}
40+
self.outputs = {'Out': ipt.astype('float16')}
41+
42+
self.attrs = {
43+
'in_dtype': int(core.VarDesc.VarType.FP32),
44+
'out_dtype': int(core.VarDesc.VarType.FP16)
45+
}
46+
47+
def set_npu(self):
48+
self.__class__.use_npu = True
49+
50+
def test_check_output(self):
51+
self.check_output_with_place(self.place, check_dygraph=False)
52+
53+
class TestCast2(OpTest):
54+
def setUp(self):
55+
self.set_npu()
56+
self.op_type = "cast"
57+
self.place = paddle.NPUPlace(0)
58+
59+
ipt = np.random.random(size=[10, 10]) + 1
60+
self.inputs = {'X': ipt.astype('float16')}
61+
self.outputs = {'Out': ipt.astype('float32')}
62+
63+
self.attrs = {
64+
'in_dtype': int(core.VarDesc.VarType.FP16),
65+
'out_dtype': int(core.VarDesc.VarType.FP32)
66+
}
67+
68+
def set_npu(self):
69+
self.__class__.use_npu = True
70+
71+
def test_check_output(self):
72+
self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3)
73+
74+
if __name__ == '__main__':
75+
unittest.main()

0 commit comments

Comments
 (0)