Skip to content

Commit b679e66

Browse files
committed
squeeze and unsqueeze op for ascend
1 parent 83f81eb commit b679e66

File tree

4 files changed

+246
-0
lines changed

4 files changed

+246
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#ifdef PADDLE_WITH_ASCEND_CL
13+
#include <memory>
14+
#include <string>
15+
16+
#include "paddle/fluid/operators/squeeze_op.h"
17+
#include "paddle/fluid/operators/npu_op_runner.h"
18+
19+
namespace ops = paddle::operators;
20+
namespace plat = paddle::platform;
21+
22+
REGISTER_OP_NPU_KERNEL(
23+
squeeze,
24+
ops::SqueezeKernel<plat::NPUDeviceContext, float>,
25+
ops::SqueezeKernel<plat::NPUDeviceContext, double>,
26+
ops::SqueezeKernel<plat::NPUDeviceContext, plat::float16>,
27+
ops::SqueezeKernel<plat::NPUDeviceContext, bool>,
28+
ops::SqueezeKernel<plat::NPUDeviceContext, int>,
29+
ops::SqueezeKernel<plat::NPUDeviceContext, uint8_t>,
30+
ops::SqueezeKernel<plat::NPUDeviceContext, int8_t>,
31+
ops::SqueezeKernel<plat::NPUDeviceContext, int64_t>);
32+
REGISTER_OP_NPU_KERNEL(
33+
squeeze2,
34+
ops::SqueezeKernel<plat::NPUDeviceContext, float>,
35+
ops::SqueezeKernel<plat::NPUDeviceContext, double>,
36+
ops::SqueezeKernel<plat::NPUDeviceContext, plat::float16>,
37+
ops::SqueezeKernel<plat::NPUDeviceContext, bool>,
38+
ops::SqueezeKernel<plat::NPUDeviceContext, int>,
39+
ops::SqueezeKernel<plat::NPUDeviceContext, uint8_t>,
40+
ops::SqueezeKernel<plat::NPUDeviceContext, int8_t>,
41+
ops::SqueezeKernel<plat::NPUDeviceContext, int64_t>);
42+
#endif
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#ifndef _WIN32
13+
#include <unistd.h>
14+
#endif
15+
16+
#include <string>
17+
#include <thread> // NOLINT
18+
#include <vector>
19+
20+
#include "gtest/gtest.h"
21+
#include "paddle/fluid/framework/op_registry.h"
22+
#include "paddle/fluid/framework/operator.h"
23+
#include "paddle/fluid/framework/program_desc.h"
24+
#include "paddle/fluid/operators/dropout_op.h"
25+
#include "paddle/fluid/operators/math/math_function.h"
26+
#include "paddle/fluid/string/printf.h"
27+
28+
namespace f = paddle::framework;
29+
namespace p = paddle::platform;
30+
namespace m = paddle::operators::math;
31+
32+
USE_OP(squeeze);
33+
USE_OP_DEVICE_KERNEL(squeeze, NPU);
34+
35+
template <typename T>
36+
void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
37+
// init
38+
auto x = scope->Var("X");
39+
auto tensor_x = x->GetMutable<f::LoDTensor>();
40+
41+
int dim0 = 1;
42+
int dim1 = 10;
43+
int dim2 = 1;
44+
45+
std::vector<T> init;
46+
for (int64_t i = 0; i < dim0 * dim1 * dim2; ++i) {
47+
init.push_back(static_cast<T>(0.1));
48+
}
49+
50+
TensorFromVector(init, ctx, tensor_x);
51+
tensor_x->Resize({dim0, dim1, dim2});
52+
53+
ctx.Wait();
54+
55+
// run
56+
auto place = ctx.GetPlace();
57+
auto out = scope->Var("Out");
58+
auto tensor_out = out->GetMutable<f::LoDTensor>();
59+
60+
std::vector<int> axis;
61+
axis.push_back(2);
62+
f::AttributeMap attrs = {{"axes", axis}};
63+
64+
auto op =
65+
f::OpRegistry::CreateOp("squeeze", {{"X", {"X"}}},
66+
{{"Out", {"Out"}}}, attrs);
67+
68+
op->Run(*scope, place);
69+
ctx.Wait();
70+
71+
EXPECT_EQ((uint32_t)tensor_out->dims().size(), uint32_t(2));
72+
EXPECT_EQ((uint32_t)tensor_out->dims()[0], uint32_t(dim0));
73+
EXPECT_EQ((uint32_t)tensor_out->dims()[1], uint32_t(dim1));
74+
75+
ctx.Wait();
76+
}
77+
78+
TEST(squeeze, NPU_fp32) {
79+
f::Scope scope;
80+
p::NPUDeviceContext ctx(p::NPUPlace(0));
81+
Compare<float>(&scope, ctx);
82+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#ifdef PADDLE_WITH_ASCEND_CL
13+
#include <memory>
14+
#include <string>
15+
16+
#include "paddle/fluid/operators/unsqueeze_op.h"
17+
#include "paddle/fluid/operators/npu_op_runner.h"
18+
19+
namespace ops = paddle::operators;
20+
namespace plat = paddle::platform;
21+
22+
REGISTER_OP_NPU_KERNEL(
23+
unsqueeze,
24+
ops::UnsqueezeKernel<plat::NPUDeviceContext, float>,
25+
ops::UnsqueezeKernel<plat::NPUDeviceContext, double>,
26+
ops::UnsqueezeKernel<plat::NPUDeviceContext, plat::float16>,
27+
ops::UnsqueezeKernel<plat::NPUDeviceContext, bool>,
28+
ops::UnsqueezeKernel<plat::NPUDeviceContext, int>,
29+
ops::UnsqueezeKernel<plat::NPUDeviceContext, int8_t>,
30+
ops::UnsqueezeKernel<plat::NPUDeviceContext, int64_t>);
31+
REGISTER_OP_NPU_KERNEL(
32+
unsqueeze2,
33+
ops::UnsqueezeKernel<plat::NPUDeviceContext, float>,
34+
ops::UnsqueezeKernel<plat::NPUDeviceContext, double>,
35+
ops::UnsqueezeKernel<plat::NPUDeviceContext, plat::float16>,
36+
ops::UnsqueezeKernel<plat::NPUDeviceContext, bool>,
37+
ops::UnsqueezeKernel<plat::NPUDeviceContext, int>,
38+
ops::UnsqueezeKernel<plat::NPUDeviceContext, int8_t>,
39+
ops::UnsqueezeKernel<plat::NPUDeviceContext, int64_t>);
40+
#endif
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#ifndef _WIN32
13+
#include <unistd.h>
14+
#endif
15+
16+
#include <string>
17+
#include <thread> // NOLINT
18+
#include <vector>
19+
20+
#include "gtest/gtest.h"
21+
#include "paddle/fluid/framework/op_registry.h"
22+
#include "paddle/fluid/framework/operator.h"
23+
#include "paddle/fluid/framework/program_desc.h"
24+
#include "paddle/fluid/operators/dropout_op.h"
25+
#include "paddle/fluid/operators/math/math_function.h"
26+
#include "paddle/fluid/string/printf.h"
27+
28+
namespace f = paddle::framework;
29+
namespace p = paddle::platform;
30+
namespace m = paddle::operators::math;
31+
32+
USE_OP(unsqueeze);
33+
USE_OP_DEVICE_KERNEL(unsqueeze, NPU);
34+
35+
template <typename T>
36+
void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
37+
// init
38+
auto x = scope->Var("X");
39+
auto tensor_x = x->GetMutable<f::LoDTensor>();
40+
41+
int dim0 = 5;
42+
int dim1 = 10;
43+
44+
std::vector<T> init;
45+
for (int64_t i = 0; i < dim0 * dim1; ++i) {
46+
init.push_back(static_cast<T>(10.0));
47+
}
48+
49+
TensorFromVector(init, ctx, tensor_x);
50+
tensor_x->Resize({dim0, dim1});
51+
52+
ctx.Wait();
53+
54+
// run
55+
auto place = ctx.GetPlace();
56+
auto out = scope->Var("Out");
57+
auto tensor_out = out->GetMutable<f::LoDTensor>();
58+
59+
std::vector<int> axis;
60+
axis.push_back(1);
61+
f::AttributeMap attrs = {{"axes", axis}};
62+
63+
auto op =
64+
f::OpRegistry::CreateOp("unsqueeze", {{"X", {"X"}}},
65+
{{"Out", {"Out"}}}, attrs);
66+
67+
op->Run(*scope, place);
68+
ctx.Wait();
69+
70+
EXPECT_EQ((uint32_t)tensor_out->dims().size(), uint32_t(3));
71+
EXPECT_EQ((uint32_t)tensor_out->dims()[0], uint32_t(5));
72+
EXPECT_EQ((uint32_t)tensor_out->dims()[1], uint32_t(1));
73+
EXPECT_EQ((uint32_t)tensor_out->dims()[2], uint32_t(10));
74+
75+
ctx.Wait();
76+
}
77+
78+
TEST(unsqueeze, NPU_fp32) {
79+
f::Scope scope;
80+
p::NPUDeviceContext ctx(p::NPUPlace(0));
81+
Compare<float>(&scope, ctx);
82+
}

0 commit comments

Comments
 (0)