Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lite/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ add_operator(scatter extra SRCS scatter_op.cc)
add_operator(matrix_nms_op_lite extra SRCS matrix_nms_op.cc)
add_operator(sin_op extra SRCS sin_op.cc)
add_operator(cos_op extra SRCS cos_op.cc)
add_operator(tan_op extra SRCS tan_op.cc)
add_operator(cos_sim_op extra SRCS cos_sim_op.cc)
add_operator(asin_op extra SRCS asin_op.cc)
add_operator(acos_op extra SRCS acos_op.cc)
Expand Down
1 change: 1 addition & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -1998,6 +1998,7 @@ struct TrigonometricParam : ParamBase {

using SinParam = TrigonometricParam;
using CosParam = TrigonometricParam;
using TanParam = TrigonometricParam;
using AsinParam = TrigonometricParam;
using AcosParam = TrigonometricParam;
using AtanParam = TrigonometricParam;
Expand Down
49 changes: 49 additions & 0 deletions lite/operators/tan_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/operators/tan_op.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace operators {

bool TanOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
return true;
}

bool TanOpLite::InferShape() {
lite::DDim x_dims = param_.X->dims();
param_.Out->Resize(x_dims);

return true;
}

bool TanOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();

param_.X = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();

return true;
}

} // namespace operators
} // namespace lite
} // namespace paddle

REGISTER_LITE_OP(tan, paddle::lite::operators::TanOpLite);
47 changes: 47 additions & 0 deletions lite/operators/tan_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"

namespace paddle {
namespace lite {
namespace operators {

class TanOpLite : public OpLite {
public:
TanOpLite() {}
explicit TanOpLite(const std::string &op_type) : OpLite(op_type) {}

bool CheckShape() const override;

bool InferShape() override;

bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;

void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }

std::string DebugString() const override { return "tan"; }

private:
mutable TanParam param_;
};

} // namespace operators
} // namespace lite
} // namespace paddle
100 changes: 100 additions & 0 deletions lite/tests/unittest_py/op/test_tan_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
sys.path.append('../')

from auto_scan_test import AutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig, CxxConfig, TargetType, PrecisionType, DataLayoutType, Place
import unittest
from functools import partial
import numpy as np

import hypothesis
from hypothesis import given, settings, seed, example, assume
import hypothesis.strategies as st
import argparse


class TestTanOp(AutoScanTest):
def __init__(self, *args, **kwargs):
AutoScanTest.__init__(self, *args, **kwargs)
opencl_places = [
Place(TargetType.OpenCL, PrecisionType.FP16,
DataLayoutType.ImageDefault), Place(
TargetType.OpenCL, PrecisionType.FP16,
DataLayoutType.ImageFolder),
Place(TargetType.OpenCL, PrecisionType.FP32, DataLayoutType.NCHW),
Place(TargetType.OpenCL, PrecisionType.Any,
DataLayoutType.ImageDefault), Place(
TargetType.OpenCL, PrecisionType.Any,
DataLayoutType.ImageFolder),
Place(TargetType.OpenCL, PrecisionType.Any, DataLayoutType.NCHW),
Place(TargetType.Host, PrecisionType.FP32)
]
self.enable_testing_on_place(places=opencl_places)

def is_program_valid(self,
program_config: ProgramConfig,
predictor_config: CxxConfig) -> bool:
return True

def sample_program_configs(self, draw):
def generate_input(*args, **kwargs):
if kwargs["type"] == "int32":
return np.random.randint(kwargs["low"], kwargs["high"],
kwargs["shape"]).astype(np.int32)
elif kwargs["type"] == "int64":
return np.random.randint(kwargs["low"], kwargs["high"],
kwargs["shape"]).astype(np.int64)
elif kwargs["type"] == "float32":
return (kwargs["high"] - kwargs["low"]) * np.random.random(
kwargs["shape"]).astype(np.float32) + kwargs["low"]

in_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8), min_size=4, max_size=4))

tan_op = OpConfig(
type="tan",
inputs={"X": ["input_data"]},
outputs={"Out": ["output_data"]},
attrs={})
program_config = ProgramConfig(
ops=[tan_op],
weights={},
inputs={
"input_data": TensorConfig(data_gen=partial(
generate_input,
type="float32",
low=-0.9,
high=0.9,
shape=in_shape))
},
outputs=["output_data"])
return program_config

def sample_predictor_configs(self):
return self.get_predictor_configs(), ["tan"], (1e-5, 1e-5)

def add_ignore_pass_case(self):
pass

def test(self, *args, **kwargs):
self.run_and_statis(quant=False, max_examples=50)


if __name__ == "__main__":
unittest.main(argv=[''])