Skip to content

Commit ec422ea

Browse files
authored
[NPU] add masked_select_op_npu (#35649)
1 parent 5fa9cf7 commit ec422ea

File tree

2 files changed

+361
-0
lines changed

2 files changed

+361
-0
lines changed
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/masked_select_op.h"
16+
#include "paddle/fluid/operators/npu_op_runner.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
template <typename T>
22+
class MaskedSelectedNPUKernel : public framework::OpKernel<T> {
23+
public:
24+
void Compute(const framework::ExecutionContext& ctx) const override {
25+
auto input = ctx.Input<framework::Tensor>("X");
26+
auto mask = ctx.Input<framework::Tensor>("Mask");
27+
auto out = ctx.Output<framework::Tensor>("Y");
28+
29+
auto input_dim = input->dims();
30+
auto mask_dim = mask->dims();
31+
PADDLE_ENFORCE_EQ(
32+
input_dim, mask_dim,
33+
platform::errors::InvalidArgument(
34+
"The dim size of input and mask in OP(masked_selected) "
35+
"must be equal, but got input dim:(%ld), mask dim: "
36+
"(%ld). Please check input "
37+
"value.",
38+
input_dim, mask_dim));
39+
40+
auto& dev_ctx =
41+
ctx.template device_context<paddle::platform::NPUDeviceContext>();
42+
auto stream = dev_ctx.stream();
43+
44+
Tensor mask_int32, out_size;
45+
std::vector<int32_t> out_size_vec;
46+
mask_int32.mutable_data<int32_t>(mask->dims(), ctx.GetPlace());
47+
out_size.mutable_data<int32_t>({1}, ctx.GetPlace());
48+
{
49+
const auto& cast_runner =
50+
NpuOpRunner("Cast", {*mask}, {mask_int32},
51+
{{"dst_type", static_cast<int32_t>(ConvertToNpuDtype(
52+
framework::proto::VarType::INT32))}});
53+
cast_runner.Run(stream);
54+
55+
mask_int32.Resize({mask_int32.numel()});
56+
NpuOpRunner sum_runner;
57+
sum_runner.SetType("ReduceSum");
58+
sum_runner.AddInput(mask_int32);
59+
sum_runner.AddInput(std::vector<int32_t>({0}));
60+
sum_runner.AddOutput(out_size);
61+
sum_runner.AddAttr("keep_dims", false);
62+
sum_runner.Run(stream);
63+
TensorToVector(out_size, dev_ctx, &out_size_vec);
64+
}
65+
66+
out->Resize({out_size_vec[0]});
67+
out->mutable_data<T>(ctx.GetPlace());
68+
69+
Tensor topkv2_out, indices;
70+
topkv2_out.mutable_data<int32_t>({out_size_vec[0]}, ctx.GetPlace());
71+
indices.mutable_data<int32_t>({out_size_vec[0]}, ctx.GetPlace());
72+
{
73+
NpuOpRunner topkv2_runner;
74+
topkv2_runner.SetType("TopKV2")
75+
.AddInput(mask_int32)
76+
.AddInput(out_size)
77+
.AddOutput(topkv2_out)
78+
.AddOutput(indices)
79+
.AddAttr("sorted", false)
80+
.AddAttr("dim", 0)
81+
.AddAttr("largest", true)
82+
.Run(stream);
83+
// TopKV2 may be unstable
84+
NpuOpRunner topkv2_runner2;
85+
topkv2_runner2.SetType("TopKV2")
86+
.AddInput(indices)
87+
.AddInput(out_size)
88+
.AddOutput(topkv2_out)
89+
.AddOutput(indices)
90+
.AddAttr("sorted", true)
91+
.AddAttr("dim", 0)
92+
.AddAttr("largest", false)
93+
.Run(stream);
94+
95+
Tensor input_tmp;
96+
input_tmp.ShareDataWith(*input);
97+
input_tmp.Resize({input->numel()});
98+
const auto& gather_runner = NpuOpRunner(
99+
"GatherV2D", {input_tmp, topkv2_out}, {*out}, {{"axis", 0}});
100+
gather_runner.Run(stream);
101+
}
102+
}
103+
};
104+
105+
template <typename T>
106+
class MaskedSelectedGradNPUKernel : public framework::OpKernel<T> {
107+
public:
108+
void Compute(const framework::ExecutionContext& ctx) const override {
109+
auto mask = ctx.Input<framework::Tensor>("Mask");
110+
auto y_grad = ctx.Input<framework::Tensor>(framework::GradVarName("Y"));
111+
auto x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
112+
113+
x_grad->mutable_data<T>(ctx.GetPlace());
114+
115+
auto& dev_ctx =
116+
ctx.template device_context<paddle::platform::NPUDeviceContext>();
117+
auto stream = dev_ctx.stream();
118+
119+
Tensor mask_int32, out_size;
120+
std::vector<int32_t> out_size_vec;
121+
mask_int32.mutable_data<int32_t>(mask->dims(), ctx.GetPlace());
122+
out_size.mutable_data<int32_t>({1}, ctx.GetPlace());
123+
{
124+
const auto& cast_runner =
125+
NpuOpRunner("Cast", {*mask}, {mask_int32},
126+
{{"dst_type", static_cast<int32_t>(ConvertToNpuDtype(
127+
framework::proto::VarType::INT32))}});
128+
cast_runner.Run(stream);
129+
130+
mask_int32.Resize({mask_int32.numel()});
131+
NpuOpRunner sum_runner;
132+
sum_runner.SetType("ReduceSum");
133+
sum_runner.AddInput(mask_int32);
134+
sum_runner.AddInput(std::vector<int32_t>({0}));
135+
sum_runner.AddOutput(out_size);
136+
sum_runner.AddAttr("keep_dims", false);
137+
sum_runner.Run(stream);
138+
TensorToVector(out_size, dev_ctx, &out_size_vec);
139+
}
140+
141+
Tensor topkv2_out, indices;
142+
topkv2_out.mutable_data<int32_t>({out_size_vec[0]}, ctx.GetPlace());
143+
indices.mutable_data<int32_t>({out_size_vec[0]}, ctx.GetPlace());
144+
{
145+
NpuOpRunner topkv2_runner;
146+
topkv2_runner.SetType("TopKV2")
147+
.AddInput(mask_int32)
148+
.AddInput(out_size)
149+
.AddOutput(topkv2_out)
150+
.AddOutput(indices)
151+
.AddAttr("sorted", false)
152+
.AddAttr("dim", 0)
153+
.AddAttr("largest", true)
154+
.Run(stream);
155+
156+
NpuOpRunner topkv2_runner2;
157+
topkv2_runner2.SetType("TopKV2")
158+
.AddInput(indices)
159+
.AddInput(out_size)
160+
.AddOutput(topkv2_out)
161+
.AddOutput(indices)
162+
.AddAttr("sorted", true)
163+
.AddAttr("dim", 0)
164+
.AddAttr("largest", false)
165+
.Run(stream);
166+
167+
topkv2_out.Resize({out_size_vec[0], 1});
168+
x_grad->Resize({x_grad->numel()});
169+
NpuOpRunner scatter_runner;
170+
scatter_runner.SetType("ScatterNd");
171+
scatter_runner.AddInput(topkv2_out);
172+
scatter_runner.AddInput(*y_grad);
173+
scatter_runner.AddInput(
174+
std::vector<int32_t>({static_cast<int32_t>(x_grad->numel())}));
175+
scatter_runner.AddOutput(*x_grad);
176+
scatter_runner.Run(stream);
177+
x_grad->Resize(mask->dims());
178+
}
179+
}
180+
};
181+
} // namespace operators
182+
} // namespace paddle
183+
184+
namespace ops = paddle::operators;
185+
namespace plat = paddle::platform;
186+
REGISTER_OP_NPU_KERNEL(masked_select,
187+
ops::MaskedSelectedNPUKernel<plat::float16>,
188+
ops::MaskedSelectedNPUKernel<float>,
189+
ops::MaskedSelectedNPUKernel<int>,
190+
ops::MaskedSelectedNPUKernel<int64_t>);
191+
REGISTER_OP_NPU_KERNEL(masked_select_grad,
192+
ops::MaskedSelectedGradNPUKernel<plat::float16>,
193+
ops::MaskedSelectedGradNPUKernel<float>,
194+
ops::MaskedSelectedGradNPUKernel<int>,
195+
ops::MaskedSelectedGradNPUKernel<int64_t>);
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import numpy as np
18+
import unittest
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest, skip_check_grad_ci
22+
import paddle
23+
import paddle.fluid as fluid
24+
25+
paddle.enable_static()
26+
27+
28+
def np_masked_select(x, mask):
29+
result = np.empty(shape=(0), dtype=x.dtype)
30+
for ele, ma in zip(np.nditer(x), np.nditer(mask)):
31+
if ma:
32+
result = np.append(result, ele)
33+
return result.flatten()
34+
35+
36+
class TestMaskedSelectOp(OpTest):
37+
def set_npu(self):
38+
self.__class__.use_npu = True
39+
40+
def setUp(self):
41+
self.set_npu()
42+
self.init()
43+
self.init_dtype()
44+
self.place = paddle.NPUPlace(0)
45+
self.op_type = "masked_select"
46+
x = np.random.random(self.shape).astype(self.dtype)
47+
mask = np.array(np.random.randint(2, size=self.shape, dtype=bool))
48+
out = np_masked_select(x, mask)
49+
self.inputs = {'X': x, 'Mask': mask}
50+
self.outputs = {'Y': out}
51+
52+
def test_check_output(self):
53+
self.check_output_with_place(self.place)
54+
55+
def test_check_grad(self):
56+
self.check_grad_with_place(self.place, ['X'], 'Y')
57+
58+
def init(self):
59+
self.shape = (50, 3)
60+
61+
def init_dtype(self):
62+
self.dtype = np.float32
63+
64+
65+
class TestMaskedSelectOp1(TestMaskedSelectOp):
66+
def init(self):
67+
self.shape = (6, 8, 9, 18)
68+
69+
70+
class TestMaskedSelectOp2(TestMaskedSelectOp):
71+
def init(self):
72+
self.shape = (168, )
73+
74+
75+
class TestMaskedSelectOpFp16(TestMaskedSelectOp):
76+
def init_dtype(self):
77+
self.dtype = np.float16
78+
79+
def test_check_grad(self):
80+
x_grad = self.inputs['Mask'].astype(self.dtype)
81+
x_grad = x_grad * (1 / x_grad.sum())
82+
self.check_grad_with_place(
83+
self.place, ['X'], 'Y', user_defined_grads=[x_grad])
84+
85+
86+
@skip_check_grad_ci(reason="get_numeric_gradient not support int32")
87+
class TestMaskedSelectOpInt32(TestMaskedSelectOp):
88+
def init_dtype(self):
89+
self.dtype = np.int32
90+
91+
def test_check_grad(self):
92+
pass
93+
94+
95+
@skip_check_grad_ci(reason="get_numeric_gradient not support int64")
96+
class TestMaskedSelectOpInt64(TestMaskedSelectOp):
97+
def init_dtype(self):
98+
self.dtype = np.int64
99+
100+
def test_check_grad(self):
101+
pass
102+
103+
104+
class TestMaskedSelectAPI(unittest.TestCase):
105+
def test_imperative_mode(self):
106+
paddle.disable_static(paddle.NPUPlace(0))
107+
shape = (88, 6, 8)
108+
np_x = np.random.random(shape).astype('float32')
109+
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))
110+
x = paddle.to_tensor(np_x)
111+
mask = paddle.to_tensor(np_mask)
112+
out = paddle.masked_select(x, mask)
113+
np_out = np_masked_select(np_x, np_mask)
114+
self.assertEqual(np.allclose(out.numpy(), np_out), True)
115+
paddle.enable_static()
116+
117+
def test_static_mode(self):
118+
shape = [8, 9, 6]
119+
x = paddle.fluid.data(shape=shape, dtype='float32', name='x')
120+
mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask')
121+
np_x = np.random.random(shape).astype('float32')
122+
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))
123+
124+
out = paddle.masked_select(x, mask)
125+
np_out = np_masked_select(np_x, np_mask)
126+
127+
exe = paddle.static.Executor(place=paddle.NPUPlace(0))
128+
129+
res = exe.run(paddle.static.default_main_program(),
130+
feed={"x": np_x,
131+
"mask": np_mask},
132+
fetch_list=[out])
133+
self.assertEqual(np.allclose(res, np_out), True)
134+
135+
136+
class TestMaskedSelectError(unittest.TestCase):
137+
def test_error(self):
138+
with paddle.static.program_guard(paddle.static.Program(),
139+
paddle.static.Program()):
140+
141+
shape = [8, 9, 6]
142+
x = paddle.fluid.data(shape=shape, dtype='float32', name='x')
143+
mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask')
144+
mask_float = paddle.fluid.data(
145+
shape=shape, dtype='float32', name='mask_float')
146+
np_x = np.random.random(shape).astype('float32')
147+
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))
148+
149+
def test_x_type():
150+
paddle.masked_select(np_x, mask)
151+
152+
self.assertRaises(TypeError, test_x_type)
153+
154+
def test_mask_type():
155+
paddle.masked_select(x, np_mask)
156+
157+
self.assertRaises(TypeError, test_mask_type)
158+
159+
def test_mask_dtype():
160+
paddle.masked_select(x, mask_float)
161+
162+
self.assertRaises(TypeError, test_mask_dtype)
163+
164+
165+
if __name__ == '__main__':
166+
unittest.main()

0 commit comments

Comments
 (0)