Skip to content

Commit 8cf4a5d

Browse files
committed
[NPU] add meshgrid, test=develop
1 parent e7dcdb7 commit 8cf4a5d

File tree

2 files changed

+300
-0
lines changed

2 files changed

+300
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the Licnse. */
14+
15+
#include "paddle/fluid/operators/meshgrid_op.h"
16+
#include "paddle/fluid/operators/npu_op_runner.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
template <typename DeviceContext, typename T>
22+
class MeshgridNPUKernel : public framework::OpKernel<T> {
23+
public:
24+
void Compute(const framework::ExecutionContext& context) const override {
25+
auto ins = context.MultiInput<framework::Tensor>("X");
26+
auto outs = context.MultiOutput<framework::Tensor>("Out");
27+
PADDLE_ENFORCE_EQ(
28+
(ins.size() > 1) && (ins.size() < 7), true,
29+
platform::errors::InvalidArgument(
30+
"Excepted Tensor numbers between 2 and 6, but only received d% .",
31+
ins.size()));
32+
33+
int64_t size = ins.size();
34+
std::vector<int64_t> shape(size);
35+
36+
for (int64_t i = 0; i < size; i++) {
37+
switch (ins[i]->dims().size()) {
38+
case 0:
39+
shape[i] = 1;
40+
break;
41+
case 1:
42+
shape[i] = ins[i]->dims()[0];
43+
break;
44+
default:
45+
PADDLE_THROW(platform::errors::InvalidArgument(
46+
"Expected scalar or 1D tensor in the tensor list but got tensor "
47+
"%d: ",
48+
i));
49+
}
50+
}
51+
52+
for (int64_t i = 0; i < size; i++) {
53+
std::vector<int64_t> view_shape(size, 1);
54+
view_shape[i] = shape[i];
55+
56+
framework::DDim out_dims_reshape = framework::make_ddim(view_shape);
57+
framework::Tensor reshape_ins_tensor(ins[i]->type());
58+
reshape_ins_tensor.ShareDataWith(*ins[i]);
59+
reshape_ins_tensor.Resize(out_dims_reshape);
60+
61+
framework::DDim out_dims = framework::make_ddim(shape);
62+
outs[i]->Resize(out_dims);
63+
outs[i]->mutable_data<T>(context.GetPlace());
64+
65+
auto stream =
66+
context.template device_context<paddle::platform::NPUDeviceContext>()
67+
.stream();
68+
const auto& runner = NpuOpRunner("BroadcastToD", {reshape_ins_tensor},
69+
{*(outs[i])}, {{"shape", shape}});
70+
runner.Run(stream);
71+
}
72+
}
73+
};
74+
75+
} // namespace operators
76+
} // namespace paddle
77+
78+
namespace ops = paddle::operators;
79+
namespace plat = paddle::platform;
80+
81+
REGISTER_OP_NPU_KERNEL(
82+
meshgrid, ops::MeshgridNPUKernel<plat::NPUDeviceContext, float>,
83+
ops::MeshgridNPUKernel<plat::NPUDeviceContext, plat::float16>,
84+
ops::MeshgridNPUKernel<plat::NPUDeviceContext, int32_t>);
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest, skip_check_grad_ci
22+
import paddle.fluid as fluid
23+
import paddle
24+
from paddle.fluid import compiler, Program, program_guard, core
25+
26+
paddle.enable_static()
27+
28+
29+
class TestMeshgridOp(OpTest):
30+
def setUp(self):
31+
self.set_npu()
32+
self.op_type = "meshgrid"
33+
self.dtype = self.get_dtype()
34+
ins, outs = self.init_test_data()
35+
self.inputs = {'X': [('x%d' % i, ins[i]) for i in range(len(ins))]}
36+
self.outputs = {
37+
'Out': [('out%d' % i, outs[i]) for i in range(len(outs))]
38+
}
39+
40+
def set_npu(self):
41+
self.__class__.use_npu = True
42+
self.place = paddle.NPUPlace(0)
43+
44+
def get_dtype(self):
45+
return "float32"
46+
47+
def test_check_output(self):
48+
self.check_output_with_place(self.place)
49+
50+
def test_check_grad(self):
51+
pass
52+
53+
def init_test_data(self):
54+
self.shape = self.get_x_shape()
55+
ins = []
56+
outs = []
57+
for i in range(len(self.shape)):
58+
ins.append(np.random.random((self.shape[i], )).astype(self.dtype))
59+
60+
for i in range(len(self.shape)):
61+
out_reshape = [1] * len(self.shape)
62+
out_reshape[i] = self.shape[i]
63+
out_temp = np.reshape(ins[i], out_reshape)
64+
outs.append(np.broadcast_to(out_temp, self.shape))
65+
return ins, outs
66+
67+
def get_x_shape(self):
68+
return [100, 200]
69+
70+
71+
@skip_check_grad_ci(
72+
reason="The backward test is not supported for float16 type on NPU.")
73+
class TestMeshgridOpFP16(TestMeshgridOp):
74+
def get_dtype(self):
75+
return "float16"
76+
77+
78+
class TestMeshgridOp2(TestMeshgridOp):
79+
def get_x_shape(self):
80+
return [100, 300]
81+
82+
83+
class TestMeshgridOp3(unittest.TestCase):
84+
def test_api(self):
85+
x = fluid.data(shape=[100], dtype='int32', name='x')
86+
y = fluid.data(shape=[200], dtype='int32', name='y')
87+
88+
input_1 = np.random.randint(0, 100, [100, ]).astype('int32')
89+
input_2 = np.random.randint(0, 100, [200, ]).astype('int32')
90+
91+
out_1 = np.reshape(input_1, [100, 1])
92+
out_1 = np.broadcast_to(out_1, [100, 200])
93+
out_2 = np.reshape(input_2, [1, 200])
94+
out_2 = np.broadcast_to(out_2, [100, 200])
95+
96+
exe = fluid.Executor(place=fluid.NPUPlace(0))
97+
grid_x, grid_y = paddle.tensor.meshgrid(x, y)
98+
res_1, res_2 = exe.run(fluid.default_main_program(),
99+
feed={'x': input_1,
100+
'y': input_2},
101+
fetch_list=[grid_x, grid_y])
102+
103+
self.assertTrue(np.allclose(res_1, out_1))
104+
self.assertTrue(np.allclose(res_2, out_2))
105+
106+
107+
class TestMeshgridOp4(unittest.TestCase):
108+
def test_list_input(self):
109+
x = fluid.data(shape=[100], dtype='int32', name='x')
110+
y = fluid.data(shape=[200], dtype='int32', name='y')
111+
112+
input_1 = np.random.randint(0, 100, [100, ]).astype('int32')
113+
input_2 = np.random.randint(0, 100, [200, ]).astype('int32')
114+
115+
out_1 = np.reshape(input_1, [100, 1])
116+
out_1 = np.broadcast_to(out_1, [100, 200])
117+
out_2 = np.reshape(input_2, [1, 200])
118+
out_2 = np.broadcast_to(out_2, [100, 200])
119+
120+
exe = fluid.Executor(place=fluid.NPUPlace(0))
121+
grid_x, grid_y = paddle.tensor.meshgrid([x, y])
122+
res_1, res_2 = exe.run(fluid.default_main_program(),
123+
feed={'x': input_1,
124+
'y': input_2},
125+
fetch_list=[grid_x, grid_y])
126+
127+
self.assertTrue(np.allclose(res_1, out_1))
128+
self.assertTrue(np.allclose(res_2, out_2))
129+
130+
131+
class TestMeshgridOp5(unittest.TestCase):
132+
def test_tuple_input(self):
133+
x = fluid.data(shape=[100], dtype='int32', name='x')
134+
y = fluid.data(shape=[200], dtype='int32', name='y')
135+
136+
input_1 = np.random.randint(0, 100, [100, ]).astype('int32')
137+
input_2 = np.random.randint(0, 100, [200, ]).astype('int32')
138+
139+
out_1 = np.reshape(input_1, [100, 1])
140+
out_1 = np.broadcast_to(out_1, [100, 200])
141+
out_2 = np.reshape(input_2, [1, 200])
142+
out_2 = np.broadcast_to(out_2, [100, 200])
143+
144+
exe = fluid.Executor(place=fluid.NPUPlace(0))
145+
grid_x, grid_y = paddle.tensor.meshgrid((x, y))
146+
res_1, res_2 = exe.run(fluid.default_main_program(),
147+
feed={'x': input_1,
148+
'y': input_2},
149+
fetch_list=[grid_x, grid_y])
150+
151+
self.assertTrue(np.allclose(res_1, out_1))
152+
self.assertTrue(np.allclose(res_2, out_2))
153+
154+
155+
class TestMeshgridOp6(unittest.TestCase):
156+
def test_api_with_dygraph(self):
157+
paddle.disable_static(paddle.NPUPlace(0))
158+
input_3 = np.random.randint(0, 100, [100, ]).astype('int32')
159+
input_4 = np.random.randint(0, 100, [200, ]).astype('int32')
160+
161+
out_3 = np.reshape(input_3, [100, 1])
162+
out_3 = np.broadcast_to(out_3, [100, 200])
163+
out_4 = np.reshape(input_4, [1, 200])
164+
out_4 = np.broadcast_to(out_4, [100, 200])
165+
166+
tensor_3 = paddle.to_tensor(input_3)
167+
tensor_4 = paddle.to_tensor(input_4)
168+
res_3, res_4 = paddle.tensor.meshgrid(tensor_3, tensor_4)
169+
170+
self.assertTrue(np.allclose(res_3.numpy(), out_3))
171+
self.assertTrue(np.allclose(res_4.numpy(), out_4))
172+
paddle.enable_static()
173+
174+
175+
class TestMeshgridOp7(unittest.TestCase):
176+
def test_api_with_dygraph_list_input(self):
177+
paddle.disable_static(paddle.NPUPlace(0))
178+
input_3 = np.random.randint(0, 100, [100, ]).astype('int32')
179+
input_4 = np.random.randint(0, 100, [200, ]).astype('int32')
180+
181+
out_3 = np.reshape(input_3, [100, 1])
182+
out_3 = np.broadcast_to(out_3, [100, 200])
183+
out_4 = np.reshape(input_4, [1, 200])
184+
out_4 = np.broadcast_to(out_4, [100, 200])
185+
186+
tensor_3 = paddle.to_tensor(input_3)
187+
tensor_4 = paddle.to_tensor(input_4)
188+
res_3, res_4 = paddle.meshgrid([tensor_3, tensor_4])
189+
190+
self.assertTrue(np.allclose(res_3.numpy(), out_3))
191+
self.assertTrue(np.allclose(res_4.numpy(), out_4))
192+
paddle.enable_static()
193+
194+
195+
class TestMeshgridOp8(unittest.TestCase):
196+
def test_api_with_dygraph_tuple_input(self):
197+
paddle.disable_static(paddle.NPUPlace(0))
198+
input_3 = np.random.randint(0, 100, [100, ]).astype('int32')
199+
input_4 = np.random.randint(0, 100, [200, ]).astype('int32')
200+
201+
out_3 = np.reshape(input_3, [100, 1])
202+
out_3 = np.broadcast_to(out_3, [100, 200])
203+
out_4 = np.reshape(input_4, [1, 200])
204+
out_4 = np.broadcast_to(out_4, [100, 200])
205+
206+
tensor_3 = paddle.to_tensor(input_3)
207+
tensor_4 = paddle.to_tensor(input_4)
208+
res_3, res_4 = paddle.tensor.meshgrid((tensor_3, tensor_4))
209+
210+
self.assertTrue(np.allclose(res_3.numpy(), out_3))
211+
self.assertTrue(np.allclose(res_4.numpy(), out_4))
212+
paddle.enable_static()
213+
214+
215+
if __name__ == '__main__':
216+
unittest.main()

0 commit comments

Comments
 (0)