Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lite/kernels/host/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ add_kernel(share_data_compute_host Host extra SRCS share_data_compute.cc)
add_kernel(round_compute_host Host extra SRCS round_compute.cc)
add_kernel(temporal_shift_compute_host Host extra SRCS temporal_shift_compute.cc)
add_kernel(bitwise_compute_host Host extra SRCS bitwise_compute.cc)
add_kernel(empty_compute_host Host extra SRCS empty_compute.cc)


if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc)
Expand Down
16 changes: 16 additions & 0 deletions lite/kernels/host/activation_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ void FloorCompute::Run() {
}
}

void CeilCompute::Run() {
auto& param = this->Param<param_t>();
CHECK(param.X);
auto x_dims = param.X->dims();
auto x_data = param.X->data<float>();
auto output_data = param.Out->mutable_data<float>();
for (int i = 0; i < x_dims.production(); i++) {
output_data[i] = std::ceil(x_data[i]);
}
}

void HardSigmoidCompute::Run() {
auto& param = this->Param<param_t>();
CHECK(param.X);
Expand Down Expand Up @@ -377,6 +388,11 @@ REGISTER_LITE_KERNEL(
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
REGISTER_LITE_KERNEL(
ceil, kHost, kFloat, kNCHW, paddle::lite::kernels::host::CeilCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
REGISTER_LITE_KERNEL(hard_sigmoid,
kHost,
kFloat,
Expand Down
9 changes: 9 additions & 0 deletions lite/kernels/host/activation_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ class FloorCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
virtual ~FloorCompute() = default;
};

class CeilCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;

void Run() override;

virtual ~CeilCompute() = default;
};

class HardSigmoidCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
Expand Down
60 changes: 60 additions & 0 deletions lite/kernels/host/empty_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/host/empty_compute.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace host {

void EmptyCompute::Run() {
auto& param = *param_.get_mutable<param_t>();
auto output = param.Out;
auto output_dims = output->dims();
if (param.dtype == static_cast<int32_t>(lite::core::FluidType::BOOL)) {
output->set_precision(PRECISION(kBool));
output->template mutable_data<bool>();
} else if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
output->set_precision(PRECISION(kFloat));
output->template mutable_data<float>();
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT32)) {
output->set_precision(PRECISION(kInt32));
output->template mutable_data<int32_t>();
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT64)) {
output->set_precision(PRECISION(kInt64));
output->template mutable_data<int64_t>();
} else {
output->set_precision(PRECISION(kInt32));
output->template mutable_data<int32_t>();
}

return;
}

} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(
empty, kHost, kAny, kNCHW, paddle::lite::kernels::host::EmptyCompute, def)
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.BindInput("ShapeTensorList",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.Finalize();
36 changes: 36 additions & 0 deletions lite/kernels/host/empty_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace host {

class EmptyCompute : public KernelLite<TARGET(kHost), PRECISION(kAny)> {
public:
using param_t = operators::EmptyParam;

void Run() override;

virtual ~EmptyCompute() = default;
};

} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
3 changes: 3 additions & 0 deletions lite/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ add_operator(pow_op extra SRCS pow_op.cc)
add_operator(sign_op extra SRCS sign_op.cc)
add_operator(rnn_op extra SRCS rnn_op.cc)

add_operator(empty_op extra SRCS empty_op.cc)

# 2.basic ops not used in basic models
add_operator(negative_op extra SRCS negative_op.cc)
add_operator(crop_op extra SRCS crop_op.cc)
Expand Down Expand Up @@ -212,6 +214,7 @@ add_operator(unique_with_counts_op extra SRCS unique_with_counts_op.cc)
add_operator(unique_op extra SRCS unique_op.cc)
add_operator(viterbi_decode extra SRCS viterbi_decode_op.cc)


# for content-dnn specific
add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc)
add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc)
Expand Down
1 change: 1 addition & 0 deletions lite/operators/activation_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(abs, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(ceil, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(sqrt, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp);
Expand Down
89 changes: 89 additions & 0 deletions lite/operators/empty_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/operators/empty_op.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace operators {

bool EmptyOp::CheckShape() const {
CHECK_OR_FALSE(param_.Out);
return true;
}

bool EmptyOp::InferShapeImpl() const {
std::vector<int64_t> OutShape;
auto ShapeTensor = param_.ShapeTensor;
auto ShapeTensorList = param_.ShapeTensorList;
if (ShapeTensor != nullptr) {
auto ShapeTensorData = ShapeTensor->data<int>();
for (int i = 0; i < ShapeTensor->numel(); i++) {
OutShape.push_back(ShapeTensorData[i]);
}
} else if (!ShapeTensorList.empty()) {
for (size_t i = 0; i < ShapeTensorList.size(); i++) {
OutShape.push_back(ShapeTensorList[i]->data<int>()[0]);
}
} else if (!param_.shape.empty()) {
OutShape = param_.shape;
} else {
LOG(FATAL) << "no valid out_shape. Must set one of shape_tensor, or "
"shape_tensor_list, or shape.";
}

param_.Out->Resize(OutShape);
return true;
}

bool EmptyOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
if (opdesc.HasInput("ShapeTensor") && !opdesc.Input("ShapeTensor").empty()) {
param_.ShapeTensor =
scope->FindMutableTensor(opdesc.Input("ShapeTensor").front());
}
param_.ShapeTensorList.clear();
if (opdesc.HasInput("ShapeTensorList") &&
!opdesc.Input("ShapeTensorList").empty()) {
for (auto name : opdesc.Input("ShapeTensorList")) {
param_.ShapeTensorList.push_back(
GetMutableVar<lite::Tensor>(scope, name));
}
}
if (opdesc.HasAttr("shape")) {
auto type = opdesc.GetAttrType("shape");
if (type == OpAttrType::INTS) { // paddle1.0 shape type is ints
auto shape = opdesc.GetAttr<std::vector<int32_t>>("shape");
param_.shape.resize(shape.size());
for (int i = 0; i < shape.size(); i++) {
param_.shape[i] = shape[i];
}
} else {
param_.shape = opdesc.GetAttr<std::vector<int64_t>>("shape");
}
}
param_.Out = scope->FindMutableTensor(opdesc.Output("Out").front());
CHECK(param_.Out) << "Output(Out) of EmptyOp should not be null.";
if (opdesc.HasAttr("dtype")) {
param_.dtype = opdesc.GetAttr<int>("dtype");
}

return true;
}

} // namespace operators
} // namespace lite
} // namespace paddle

REGISTER_LITE_OP(empty, paddle::lite::operators::EmptyOp);
44 changes: 44 additions & 0 deletions lite/operators/empty_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"

namespace paddle {
namespace lite {
namespace operators {

class EmptyOp : public OpLite {
public:
EmptyOp() {}
explicit EmptyOp(const std::string &op_type) : OpLite(op_type) {}

bool CheckShape() const override;

bool InferShapeImpl() const override;

bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;

void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "empty"; }

protected:
mutable EmptyParam param_;
};

} // namespace operators
} // namespace lite
} // namespace paddle
8 changes: 8 additions & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -2633,6 +2633,14 @@ struct TemporalShiftParam : ParamBase {
std::string data_format{"NCHW"};
};

struct EmptyParam : ParamBase {
lite::Tensor* ShapeTensor{nullptr};
std::vector<lite::Tensor*> ShapeTensorList{};
std::vector<int64_t> shape{};
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
lite::Tensor* Out{};
};

struct ViterbiDecodeParam : ParamBase {
const lite::Tensor* input{};
const lite::Tensor* length{};
Expand Down
Loading