Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lite/backends/host/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ lite_cc_library(math_host SRCS
inverse.cc
reverse.cc
topk.cc
temporal_shift.cc
DEPS core)
86 changes: 86 additions & 0 deletions lite/backends/host/math/temporal_shift.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/backends/host/math/temporal_shift.h"
#include <algorithm>

namespace paddle {
namespace lite {
namespace host {
namespace math {

template <>
void temporalshiftNCHW_func<float>(const float* input,
float* output,
const int ntchw,
const int tchw,
const int chw,
const int hw,
const int t,
const int c1,
const int c2) {
int src_it = 0;
for (int i = 0; i < ntchw; i++) {
int it = (i % tchw) / chw;
int ic = (i % chw) / hw;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[i] = 0;
} else {
output[i] = input[i + (src_it - it) * chw];
}
}
}

template <>
void temporalshiftNHWC_func<float>(const float* input,
float* output,
const int nthwc,
const int thwc,
const int hwc,
const int t,
const int c,
const int c1,
const int c2) {
int src_it = 0;
for (int i = 0; i < nthwc; i++) {
int it = (i % thwc) / hwc;
int ic = i % c;

if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}

if (src_it < 0 || src_it >= t) {
output[i] = 0;
} else {
output[i] = input[i + (src_it - it) * hwc];
}
}
}

} // namespace math
} // namespace host
} // namespace lite
} // namespace paddle
46 changes: 46 additions & 0 deletions lite/backends/host/math/temporal_shift.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

namespace paddle {
namespace lite {
namespace host {
namespace math {

template <typename InType>
void temporalshiftNCHW_func(const InType* input,
InType* output,
const int ntchw,
const int tchw,
const int chw,
const int hw,
const int t,
const int c1,
const int c2);

template <typename InType>
void temporalshiftNHWC_func(const InType* input,
InType* output,
const int nthwc,
const int thwc,
const int hwc,
const int t,
const int c,
const int c1,
const int c2);
} // namespace math
} // namespace host
} // namespace lite
} // namespace paddle
2 changes: 2 additions & 0 deletions lite/kernels/host/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ add_kernel(roll_compute Host extra SRCS roll_compute.cc)
add_kernel(set_value Host extra SRCS set_value_compute.cc)
add_kernel(share_data_compute_host Host extra SRCS share_data_compute.cc)
add_kernel(round_compute_host Host extra SRCS round_compute.cc)
add_kernel(temporal_shift_compute_host Host extra SRCS temporal_shift_compute.cc)


if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc)
Expand Down
93 changes: 93 additions & 0 deletions lite/kernels/host/temporal_shift_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/host/temporal_shift_compute.h"
#include <string>
#include "lite/backends/host/math/temporal_shift.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace host {

template <>
void TemporalShiftCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = Param<operators::TemporalShiftParam>();
const lite::Tensor* input = param.X;
lite::Tensor* output = param.Out;
int t = param.seg_num;
float shift_ratio = param.shift_ratio;
DataLayoutType data_layout;
if (param.data_format == "NCHW") {
data_layout = DATALAYOUT(kNCHW);
} else if (param.data_format == "NHWC") {
data_layout = DATALAYOUT(kNHWC);
} else {
LOG(FATAL) << "Unknown datalayout";
}

auto input_dims = input->dims();
const int nt = input_dims[0];
const int c =
data_layout == DATALAYOUT(kNCHW) ? input_dims[1] : input_dims[3];
const int h =
data_layout == DATALAYOUT(kNCHW) ? input_dims[2] : input_dims[1];
const int w =
data_layout == DATALAYOUT(kNCHW) ? input_dims[3] : input_dims[2];

const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const int ntchw = nt * chw;

const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);

DDim out_dims;
if (data_layout == DATALAYOUT(kNCHW)) {
out_dims.ConstructFrom({nt, c, h, w});
} else {
out_dims.ConstructFrom({nt, h, w, c});
}

const float* input_data = input->data<float>();
output->Resize(out_dims);
float* output_data = output->mutable_data<float>();

if (data_layout == DATALAYOUT(kNCHW)) {
lite::host::math::temporalshiftNCHW_func(
input_data, output_data, ntchw, tchw, chw, hw, t, c1, c2);
} else {
lite::host::math::temporalshiftNHWC_func(
input_data, output_data, ntchw, tchw, chw, t, c, c1, c2);
}
return;
}

} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
typedef paddle::lite::kernels::host::TemporalShiftCompute<PRECISION(kFloat),
PRECISION(kFloat)>
TSfp32;

REGISTER_LITE_KERNEL(temporal_shift, kHost, kFloat, kNCHW, TSfp32, fp32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))})
.Finalize();
49 changes: 49 additions & 0 deletions lite/kernels/host/temporal_shift_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "lite/core/kernel.h"
#include "lite/operators/temporal_shift_op.h"
#ifdef LITE_WITH_PROFILE
#include <string>
#include "lite/core/profile/profiler.h"
#endif

namespace paddle {
namespace lite {
namespace kernels {
namespace host {

template <PrecisionType PType, PrecisionType OutType>
class TemporalShiftCompute : public KernelLite<TARGET(kHost), PType> {
public:
using param_t = operators::TemporalShiftParam;

void Run() override;

virtual ~TemporalShiftCompute() = default;

#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
std::string kernel_func_name_{"Temporal Shift"};
#endif
};

} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
4 changes: 4 additions & 0 deletions lite/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,13 @@ add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_o
add_operator(search_fc_op basic SRCS search_fc_op.cc)
add_operator(lstm_op extra SRCS lstm_op.cc)
add_operator(topk_pooling_op extra SRCS topk_pooling_op.cc)

# for deformable-convNet
add_operator(deformable_conv_op extra SRCS deformable_conv_op.cc)

# for tsm model
add_operator(temporal_shift_op extra SRCS temporal_shift_op.cc)

# 4. training op
add_operator(mean_op extra SRCS mean_op.cc)

Expand Down
8 changes: 8 additions & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -2510,6 +2510,14 @@ struct FusionUnifiedDecodingParam : ParamBase {
int32_t min_length_{};
};

struct TemporalShiftParam : ParamBase {
const lite::Tensor* X{};
lite::Tensor* Out{};
int seg_num;
float shift_ratio{0.25f};
std::string data_format{"NCHW"};
};

} // namespace operators
} // namespace lite
} // namespace paddle
59 changes: 59 additions & 0 deletions lite/operators/temporal_shift_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/operators/temporal_shift_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace operators {

bool TemporalShiftOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
// seg_num must > 0
CHECK_OR_FALSE(param_.seg_num > 0);
// shift_radio must in [0, 0.5]
CHECK_OR_FALSE(param_.shift_ratio >= 0.0f && param_.shift_ratio <= 0.5f);
CHECK(param_.data_format == "NCHW" || param_.data_format == "NHWC")
<< "Invilid data format.";
return true;
}

bool TemporalShiftOpLite::InferShapeImpl() const { return true; }

bool TemporalShiftOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
param_.X = scope->FindVar(op_desc.Input("X").front())->GetMutable<Tensor>();
param_.Out =
scope->FindVar(op_desc.Output("Out").front())->GetMutable<Tensor>();

if (op_desc.HasAttr("seg_num")) {
param_.seg_num = op_desc.GetAttr<int>("seg_num");
}
if (op_desc.HasAttr("shift_radio")) {
param_.shift_ratio = op_desc.GetAttr<float>("shift_ratio");
}
if (op_desc.HasAttr("data_format")) {
param_.data_format = op_desc.GetAttr<std::string>("data_format");
}
return true;
}

} // namespace operators
} // namespace lite
} // namespace paddle

REGISTER_LITE_OP(temporal_shift, paddle::lite::operators::TemporalShiftOpLite);
Loading