Skip to content

Commit a27cd1d

Browse files
committed
[xpu]:support equal int64 and transpose int64;test=develop
1 parent 8654144 commit a27cd1d

File tree

4 files changed

+105
-8
lines changed

4 files changed

+105
-8
lines changed

lite/kernels/x86/transpose_compute.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,25 @@ REGISTER_LITE_KERNEL(transpose2,
3434
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
3535
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))})
3636
.Finalize();
37+
38+
REGISTER_LITE_KERNEL(transpose,
39+
kX86,
40+
kFloat,
41+
kNCHW,
42+
paddle::lite::kernels::x86::TransposeCompute<int64_t>,
43+
int64)
44+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
45+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
46+
.Finalize();
47+
48+
REGISTER_LITE_KERNEL(transpose2,
49+
kX86,
50+
kFloat,
51+
kNCHW,
52+
paddle::lite::kernels::x86::Transpose2Compute<int64_t>,
53+
int64)
54+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
55+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
56+
.BindOutput("XShape",
57+
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
58+
.Finalize();

lite/kernels/xpu/compare_compute.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,18 @@ struct LessThanFunctor {
3434
}
3535
};
3636

37+
template <typename T>
38+
struct EqualFunctor {
39+
inline int operator()(xdnn::Context* ctx,
40+
const T* x,
41+
const T* y,
42+
bool* z,
43+
const std::vector<int>& xshape,
44+
const std::vector<int>& yshape) const {
45+
return xdnn::broadcast_equal<T>(ctx, x, y, z, xshape, yshape);
46+
}
47+
};
48+
3749
template <PrecisionType PType, class T, class Functor>
3850
void CompareCompute<PType, T, Functor>::Run() {
3951
auto& param = this->template Param<operators::CompareParam>();
@@ -152,3 +164,63 @@ REGISTER_LITE_KERNEL(less_than, kXPU, kFloat, kAny, less_than_int64, int64)
152164
DATALAYOUT(kAny))})
153165
.BindPaddleOpVersion("less_than", 1)
154166
.Finalize();
167+
168+
using equal_float = paddle::lite::kernels::xpu::CompareCompute<
169+
PRECISION(kFloat),
170+
float,
171+
paddle::lite::kernels::xpu::EqualFunctor<float>>;
172+
REGISTER_LITE_KERNEL(equal, kXPU, kFloat, kAny, equal_float, def)
173+
.BindInput("X",
174+
{LiteType::GetTensorTy(TARGET(kXPU),
175+
PRECISION(kFloat),
176+
DATALAYOUT(kAny))})
177+
.BindInput("Y",
178+
{LiteType::GetTensorTy(TARGET(kXPU),
179+
PRECISION(kFloat),
180+
DATALAYOUT(kAny))})
181+
.BindOutput("Out",
182+
{LiteType::GetTensorTy(TARGET(kXPU),
183+
PRECISION(kBool),
184+
DATALAYOUT(kAny))})
185+
.BindPaddleOpVersion("equal", 1)
186+
.Finalize();
187+
188+
using equal_int32 = paddle::lite::kernels::xpu::CompareCompute<
189+
PRECISION(kFloat),
190+
int,
191+
paddle::lite::kernels::xpu::EqualFunctor<int>>;
192+
REGISTER_LITE_KERNEL(equal, kXPU, kFloat, kAny, equal_int32, int32)
193+
.BindInput("X",
194+
{LiteType::GetTensorTy(TARGET(kXPU),
195+
PRECISION(kInt32),
196+
DATALAYOUT(kAny))})
197+
.BindInput("Y",
198+
{LiteType::GetTensorTy(TARGET(kXPU),
199+
PRECISION(kInt32),
200+
DATALAYOUT(kAny))})
201+
.BindOutput("Out",
202+
{LiteType::GetTensorTy(TARGET(kXPU),
203+
PRECISION(kBool),
204+
DATALAYOUT(kAny))})
205+
.BindPaddleOpVersion("equal", 1)
206+
.Finalize();
207+
208+
using euqal_int64 = paddle::lite::kernels::xpu::CompareCompute<
209+
PRECISION(kFloat),
210+
int64_t,
211+
paddle::lite::kernels::xpu::EqualFunctor<int64_t>>;
212+
REGISTER_LITE_KERNEL(equal, kXPU, kFloat, kAny, euqal_int64, int64)
213+
.BindInput("X",
214+
{LiteType::GetTensorTy(TARGET(kXPU),
215+
PRECISION(kInt64),
216+
DATALAYOUT(kAny))})
217+
.BindInput("Y",
218+
{LiteType::GetTensorTy(TARGET(kXPU),
219+
PRECISION(kInt64),
220+
DATALAYOUT(kAny))})
221+
.BindOutput("Out",
222+
{LiteType::GetTensorTy(TARGET(kXPU),
223+
PRECISION(kBool),
224+
DATALAYOUT(kAny))})
225+
.BindPaddleOpVersion("equal", 1)
226+
.Finalize();

lite/kernels/xpu/transpose_compute.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ namespace lite {
2222
namespace kernels {
2323
namespace xpu {
2424

25-
void TransposeCompute::Run() {
25+
template <class T>
26+
void TransposeCompute<T>::Run() {
2627
auto& param = this->Param<param_t>();
2728
auto& ctx = this->ctx_->As<XPUContext>();
2829
auto x = param.x;
@@ -38,10 +39,10 @@ void TransposeCompute::Run() {
3839
for (int i = 0; i < ndims; ++i) {
3940
x_shape_host[i] = x_dims[i];
4041
}
41-
int r =
42-
xdnn::transpose<float>(ctx.GetRawContext(),
43-
x->data<float>(),
44-
param.output->mutable_data<float>(TARGET(kXPU)),
42+
43+
int r = xdnn::transpose<T>(ctx.GetRawContext(),
44+
x->data<T>(),
45+
param.output->mutable_data<T>(TARGET(kXPU)),
4546
x_shape_host,
4647
axis);
4748
CHECK_EQ(r, 0);
@@ -56,7 +57,7 @@ REGISTER_LITE_KERNEL(transpose,
5657
kXPU,
5758
kFloat,
5859
kNCHW,
59-
paddle::lite::kernels::xpu::TransposeCompute,
60+
paddle::lite::kernels::xpu::TransposeCompute<float>,
6061
def)
6162
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
6263
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
@@ -66,17 +67,18 @@ REGISTER_LITE_KERNEL(transpose2,
6667
kXPU,
6768
kFloat,
6869
kNCHW,
69-
paddle::lite::kernels::xpu::TransposeCompute,
70+
paddle::lite::kernels::xpu::TransposeCompute<float>,
7071
def)
7172
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
7273
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
7374
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))})
7475
.Finalize();
76+
7577
REGISTER_LITE_KERNEL(transpose2,
7678
kXPU,
7779
kFloat,
7880
kNCHW,
79-
paddle::lite::kernels::xpu::TransposeCompute,
81+
paddle::lite::kernels::xpu::TransposeCompute<int64_t>,
8082
def_int64)
8183
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
8284
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})

lite/kernels/xpu/transpose_compute.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace lite {
2121
namespace kernels {
2222
namespace xpu {
2323

24+
template <class T>
2425
class TransposeCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
2526
public:
2627
using param_t = operators::TransposeParam;

0 commit comments

Comments
 (0)