Skip to content

Commit ed2641c

Browse files
[NPU] Support op kernel for Fill constant batch size like op (#34721)
* fix npu compile error, test=develop * add fill constant batch size lilke op npu,test=develop Co-authored-by: qili93 <[email protected]>
1 parent cfd49ac commit ed2641c

File tree

2 files changed

+231
-0
lines changed

2 files changed

+231
-0
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/fill_constant_op.h"
16+
#include "paddle/fluid/operators/npu_op_runner.h"
17+
#include "paddle/fluid/operators/utils.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using Tensor = framework::Tensor;
23+
24+
template <typename DeviceContext, typename T>
25+
class FillConstantBatchSizeLikeOpNPUKernel : public framework::OpKernel<T> {
26+
public:
27+
void Compute(const framework::ExecutionContext &ctx) const override {
28+
auto data_type =
29+
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
30+
auto float_value = ctx.Attr<float>("value");
31+
auto str_value = ctx.Attr<std::string>("str_value");
32+
auto force_cpu = ctx.Attr<bool>("force_cpu");
33+
34+
auto *out = ctx.Output<Tensor>("Out");
35+
auto *input = ctx.Input<Tensor>("Input");
36+
if (&ctx.Attr<int>("input_dim_idx") == 0) {
37+
// set the correct batch size.
38+
auto odims = out->dims();
39+
int input_dim_idx = ctx.Attr<int>("input_dim_idx");
40+
int output_dim_idx = ctx.Attr<int>("output_dim_idx");
41+
odims[output_dim_idx] = input->dims()[input_dim_idx];
42+
out->mutable_data<T>(odims, ctx.GetPlace());
43+
}
44+
45+
T value;
46+
if (str_value.empty()) {
47+
value = static_cast<T>(float_value);
48+
} else {
49+
std::stringstream convert_stream(str_value);
50+
if (std::is_same<int64_t, T>::value) {
51+
int64_t tmp_value;
52+
convert_stream >> tmp_value;
53+
value = static_cast<T>(tmp_value);
54+
} else {
55+
double tmp_value;
56+
convert_stream >> tmp_value;
57+
value = static_cast<T>(tmp_value);
58+
}
59+
}
60+
61+
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
62+
auto &dev_ctx = *pool.Get(ctx.GetPlace());
63+
bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace();
64+
if (cpu_place) {
65+
math::SetConstant<platform::CPUDeviceContext, T> functor;
66+
out->mutable_data(platform::CPUPlace(), data_type);
67+
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
68+
out, static_cast<T>(value));
69+
} else {
70+
out->mutable_data(ctx.GetPlace(), data_type);
71+
Tensor tensor_tmp(data_type);
72+
tensor_tmp.mutable_data<T>({1}, ctx.GetPlace());
73+
FillNpuTensorWithConstant<T>(&tensor_tmp, value);
74+
75+
auto stream =
76+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
77+
.stream();
78+
const auto &runner =
79+
NpuOpRunner("FillD", {tensor_tmp}, {*out},
80+
{{"dims", framework::vectorize(out->dims())}});
81+
runner.Run(stream);
82+
}
83+
}
84+
};
85+
} // namespace operators
86+
} // namespace paddle
87+
88+
namespace ops = paddle::operators;
89+
90+
REGISTER_OP_NPU_KERNEL(
91+
fill_constant_batch_size_like,
92+
ops::FillConstantBatchSizeLikeOpNPUKernel<
93+
paddle::platform::NPUDeviceContext, float>,
94+
ops::FillConstantBatchSizeLikeOpNPUKernel<
95+
paddle::platform::NPUDeviceContext, int>,
96+
ops::FillConstantBatchSizeLikeOpNPUKernel<
97+
paddle::platform::NPUDeviceContext, paddle::platform::float16>);
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
import paddle
23+
import paddle.fluid as fluid
24+
from paddle.fluid import core
25+
26+
paddle.enable_static()
27+
SEED = 2021
28+
29+
30+
class TestFillConstantBatchSizeLike(OpTest):
31+
def setUp(self):
32+
self.set_npu()
33+
self.place = paddle.NPUPlace(0)
34+
self.op_type = "fill_constant_batch_size_like"
35+
self.init_shape()
36+
self.init_value()
37+
self.init_dtype()
38+
self.init_force_cpu()
39+
self.init_dim_idx()
40+
41+
self.inputs = {
42+
'Input': np.random.random(self.input_shape).astype("float32")
43+
}
44+
self.attrs = {
45+
'shape': self.shape,
46+
'value': self.value,
47+
'str_value': self.str_value,
48+
'dtype': self.dtype,
49+
'force_cpu': self.force_cpu,
50+
'input_dim_idx': self.input_dim_idx,
51+
'output_dim_idx': self.output_dim_idx
52+
}
53+
self.outputs = {
54+
'Out': np.full(self.output_shape, self.output_value,
55+
self.output_dtype)
56+
}
57+
58+
def set_npu(self):
59+
self.__class__.use_npu = True
60+
61+
def init_shape(self):
62+
self.input_shape = [4, 5]
63+
self.shape = [123, 92]
64+
self.output_shape = (4, 92)
65+
66+
def init_value(self):
67+
self.value = 3.8
68+
self.str_value = ''
69+
self.output_value = 3.8
70+
71+
def init_dtype(self):
72+
self.dtype = core.VarDesc.VarType.FP32
73+
self.output_dtype = np.float32
74+
75+
def init_force_cpu(self):
76+
self.force_cpu = False
77+
78+
def init_dim_idx(self):
79+
self.input_dim_idx = 0
80+
self.output_dim_idx = 0
81+
82+
def test_check_output(self):
83+
self.check_output_with_place(self.place)
84+
85+
86+
class TestFillConstantBatchSizeLike2(TestFillConstantBatchSizeLike):
87+
def init_shape(self):
88+
# test shape
89+
self.input_shape = [4, 5, 6, 7]
90+
self.shape = [10, 123, 92]
91+
self.output_shape = (4, 123, 92)
92+
93+
94+
class TestFillConstantBatchSizeLike3(TestFillConstantBatchSizeLike):
95+
def init_value(self):
96+
# use 'str_value' rather than 'value'
97+
self.value = 3.8
98+
self.str_value = '4.5'
99+
self.output_value = 4.5
100+
101+
102+
class TestFillConstantBatchSizeLike6(TestFillConstantBatchSizeLike):
103+
def init_dtype(self):
104+
self.dtype = core.VarDesc.VarType.FP16
105+
self.output_dtype = np.float16
106+
107+
def test_check_output(self):
108+
self.check_output_with_place(self.place, atol=1e-2)
109+
110+
111+
class TestFillConstantBatchSizeLike7(TestFillConstantBatchSizeLike):
112+
def init_dtype(self):
113+
self.dtype = core.VarDesc.VarType.INT32
114+
self.output_dtype = np.int32
115+
116+
117+
class TestFillConstantBatchSizeLike8(TestFillConstantBatchSizeLike):
118+
def init_force_cpu(self):
119+
self.force_cpu = True
120+
121+
122+
class TestFillConstantBatchSizeLike9(TestFillConstantBatchSizeLike):
123+
def init_shape(self):
124+
self.input_shape = [4, 5]
125+
self.shape = [123, 92]
126+
self.output_shape = (123, 4)
127+
128+
def init_dim_idx(self):
129+
self.input_dim_idx = 0
130+
self.output_dim_idx = 1
131+
132+
133+
if __name__ == '__main__':
134+
unittest.main()

0 commit comments

Comments
 (0)