Skip to content

Commit 40e0dc3

Browse files
authored
[ARM]add temporal_shift op for tsm model (#10010)
1 parent e79c751 commit 40e0dc3

File tree

11 files changed

+651
-0
lines changed

11 files changed

+651
-0
lines changed

lite/backends/host/math/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ lite_cc_library(math_host SRCS
1414
inverse.cc
1515
reverse.cc
1616
topk.cc
17+
temporal_shift.cc
1718
DEPS core)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/backends/host/math/temporal_shift.h"
16+
#include <algorithm>
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace host {
21+
namespace math {
22+
23+
template <>
24+
void temporalshiftNCHW_func<float>(const float* input,
25+
float* output,
26+
const int ntchw,
27+
const int tchw,
28+
const int chw,
29+
const int hw,
30+
const int t,
31+
const int c1,
32+
const int c2) {
33+
int src_it = 0;
34+
for (int i = 0; i < ntchw; i++) {
35+
int it = (i % tchw) / chw;
36+
int ic = (i % chw) / hw;
37+
if (ic < c1) {
38+
src_it = it - 1;
39+
} else if (ic < c2) {
40+
src_it = it + 1;
41+
} else {
42+
src_it = it;
43+
}
44+
if (src_it < 0 || src_it >= t) {
45+
output[i] = 0;
46+
} else {
47+
output[i] = input[i + (src_it - it) * chw];
48+
}
49+
}
50+
}
51+
52+
template <>
53+
void temporalshiftNHWC_func<float>(const float* input,
54+
float* output,
55+
const int nthwc,
56+
const int thwc,
57+
const int hwc,
58+
const int t,
59+
const int c,
60+
const int c1,
61+
const int c2) {
62+
int src_it = 0;
63+
for (int i = 0; i < nthwc; i++) {
64+
int it = (i % thwc) / hwc;
65+
int ic = i % c;
66+
67+
if (ic < c1) {
68+
src_it = it - 1;
69+
} else if (ic < c2) {
70+
src_it = it + 1;
71+
} else {
72+
src_it = it;
73+
}
74+
75+
if (src_it < 0 || src_it >= t) {
76+
output[i] = 0;
77+
} else {
78+
output[i] = input[i + (src_it - it) * hwc];
79+
}
80+
}
81+
}
82+
83+
} // namespace math
84+
} // namespace host
85+
} // namespace lite
86+
} // namespace paddle
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
namespace paddle {
18+
namespace lite {
19+
namespace host {
20+
namespace math {
21+
22+
template <typename InType>
23+
void temporalshiftNCHW_func(const InType* input,
24+
InType* output,
25+
const int ntchw,
26+
const int tchw,
27+
const int chw,
28+
const int hw,
29+
const int t,
30+
const int c1,
31+
const int c2);
32+
33+
template <typename InType>
34+
void temporalshiftNHWC_func(const InType* input,
35+
InType* output,
36+
const int nthwc,
37+
const int thwc,
38+
const int hwc,
39+
const int t,
40+
const int c,
41+
const int c1,
42+
const int c2);
43+
} // namespace math
44+
} // namespace host
45+
} // namespace lite
46+
} // namespace paddle

lite/kernels/host/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ add_kernel(roll_compute Host extra SRCS roll_compute.cc)
121121
add_kernel(set_value Host extra SRCS set_value_compute.cc)
122122
add_kernel(share_data_compute_host Host extra SRCS share_data_compute.cc)
123123
add_kernel(round_compute_host Host extra SRCS round_compute.cc)
124+
add_kernel(temporal_shift_compute_host Host extra SRCS temporal_shift_compute.cc)
125+
124126

125127
if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
126128
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/host/temporal_shift_compute.h"
16+
#include <string>
17+
#include "lite/backends/host/math/temporal_shift.h"
18+
#include "lite/core/op_registry.h"
19+
#include "lite/core/tensor.h"
20+
#include "lite/core/type_system.h"
21+
22+
namespace paddle {
23+
namespace lite {
24+
namespace kernels {
25+
namespace host {
26+
27+
template <>
28+
void TemporalShiftCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
29+
auto& param = Param<operators::TemporalShiftParam>();
30+
const lite::Tensor* input = param.X;
31+
lite::Tensor* output = param.Out;
32+
int t = param.seg_num;
33+
float shift_ratio = param.shift_ratio;
34+
DataLayoutType data_layout;
35+
if (param.data_format == "NCHW") {
36+
data_layout = DATALAYOUT(kNCHW);
37+
} else if (param.data_format == "NHWC") {
38+
data_layout = DATALAYOUT(kNHWC);
39+
} else {
40+
LOG(FATAL) << "Unknown datalayout";
41+
}
42+
43+
auto input_dims = input->dims();
44+
const int nt = input_dims[0];
45+
const int c =
46+
data_layout == DATALAYOUT(kNCHW) ? input_dims[1] : input_dims[3];
47+
const int h =
48+
data_layout == DATALAYOUT(kNCHW) ? input_dims[2] : input_dims[1];
49+
const int w =
50+
data_layout == DATALAYOUT(kNCHW) ? input_dims[3] : input_dims[2];
51+
52+
const int hw = h * w;
53+
const int chw = c * hw;
54+
const int tchw = t * chw;
55+
const int ntchw = nt * chw;
56+
57+
const int c1 = static_cast<int>(c * shift_ratio);
58+
const int c2 = static_cast<int>(c * 2 * shift_ratio);
59+
60+
DDim out_dims;
61+
if (data_layout == DATALAYOUT(kNCHW)) {
62+
out_dims.ConstructFrom({nt, c, h, w});
63+
} else {
64+
out_dims.ConstructFrom({nt, h, w, c});
65+
}
66+
67+
const float* input_data = input->data<float>();
68+
output->Resize(out_dims);
69+
float* output_data = output->mutable_data<float>();
70+
71+
if (data_layout == DATALAYOUT(kNCHW)) {
72+
lite::host::math::temporalshiftNCHW_func(
73+
input_data, output_data, ntchw, tchw, chw, hw, t, c1, c2);
74+
} else {
75+
lite::host::math::temporalshiftNHWC_func(
76+
input_data, output_data, ntchw, tchw, chw, t, c, c1, c2);
77+
}
78+
return;
79+
}
80+
81+
} // namespace host
82+
} // namespace kernels
83+
} // namespace lite
84+
} // namespace paddle
85+
typedef paddle::lite::kernels::host::TemporalShiftCompute<PRECISION(kFloat),
86+
PRECISION(kFloat)>
87+
TSfp32;
88+
89+
REGISTER_LITE_KERNEL(temporal_shift, kHost, kFloat, kNCHW, TSfp32, fp32)
90+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))})
91+
.BindOutput("Out",
92+
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))})
93+
.Finalize();
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "lite/core/kernel.h"
17+
#include "lite/operators/temporal_shift_op.h"
18+
#ifdef LITE_WITH_PROFILE
19+
#include <string>
20+
#include "lite/core/profile/profiler.h"
21+
#endif
22+
23+
namespace paddle {
24+
namespace lite {
25+
namespace kernels {
26+
namespace host {
27+
28+
template <PrecisionType PType, PrecisionType OutType>
29+
class TemporalShiftCompute : public KernelLite<TARGET(kHost), PType> {
30+
public:
31+
using param_t = operators::TemporalShiftParam;
32+
33+
void Run() override;
34+
35+
virtual ~TemporalShiftCompute() = default;
36+
37+
#ifdef LITE_WITH_PROFILE
38+
virtual void SetProfileRuntimeKernelInfo(
39+
paddle::lite::profile::OpCharacter* ch) {
40+
ch->kernel_func_name = kernel_func_name_;
41+
}
42+
std::string kernel_func_name_{"Temporal Shift"};
43+
#endif
44+
};
45+
46+
} // namespace host
47+
} // namespace kernels
48+
} // namespace lite
49+
} // namespace paddle

lite/operators/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,13 @@ add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_o
217217
add_operator(search_fc_op basic SRCS search_fc_op.cc)
218218
add_operator(lstm_op extra SRCS lstm_op.cc)
219219
add_operator(topk_pooling_op extra SRCS topk_pooling_op.cc)
220+
220221
# for deformable-convNet
221222
add_operator(deformable_conv_op extra SRCS deformable_conv_op.cc)
222223

224+
# for tsm model
225+
add_operator(temporal_shift_op extra SRCS temporal_shift_op.cc)
226+
223227
# 4. training op
224228
add_operator(mean_op extra SRCS mean_op.cc)
225229

lite/operators/op_params.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2510,6 +2510,14 @@ struct FusionUnifiedDecodingParam : ParamBase {
25102510
int32_t min_length_{};
25112511
};
25122512

2513+
struct TemporalShiftParam : ParamBase {
2514+
const lite::Tensor* X{};
2515+
lite::Tensor* Out{};
2516+
int seg_num;
2517+
float shift_ratio{0.25f};
2518+
std::string data_format{"NCHW"};
2519+
};
2520+
25132521
} // namespace operators
25142522
} // namespace lite
25152523
} // namespace paddle
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/operators/temporal_shift_op.h"
16+
#include "lite/core/op_lite.h"
17+
#include "lite/core/op_registry.h"
18+
19+
namespace paddle {
20+
namespace lite {
21+
namespace operators {
22+
23+
bool TemporalShiftOpLite::CheckShape() const {
24+
CHECK_OR_FALSE(param_.X);
25+
CHECK_OR_FALSE(param_.Out);
26+
// seg_num must > 0
27+
CHECK_OR_FALSE(param_.seg_num > 0);
28+
// shift_radio must in [0, 0.5]
29+
CHECK_OR_FALSE(param_.shift_ratio >= 0.0f && param_.shift_ratio <= 0.5f);
30+
CHECK(param_.data_format == "NCHW" || param_.data_format == "NHWC")
31+
<< "Invilid data format.";
32+
return true;
33+
}
34+
35+
bool TemporalShiftOpLite::InferShapeImpl() const { return true; }
36+
37+
bool TemporalShiftOpLite::AttachImpl(const cpp::OpDesc &op_desc,
38+
lite::Scope *scope) {
39+
param_.X = scope->FindVar(op_desc.Input("X").front())->GetMutable<Tensor>();
40+
param_.Out =
41+
scope->FindVar(op_desc.Output("Out").front())->GetMutable<Tensor>();
42+
43+
if (op_desc.HasAttr("seg_num")) {
44+
param_.seg_num = op_desc.GetAttr<int>("seg_num");
45+
}
46+
if (op_desc.HasAttr("shift_radio")) {
47+
param_.shift_ratio = op_desc.GetAttr<float>("shift_ratio");
48+
}
49+
if (op_desc.HasAttr("data_format")) {
50+
param_.data_format = op_desc.GetAttr<std::string>("data_format");
51+
}
52+
return true;
53+
}
54+
55+
} // namespace operators
56+
} // namespace lite
57+
} // namespace paddle
58+
59+
REGISTER_LITE_OP(temporal_shift, paddle::lite::operators::TemporalShiftOpLite);

0 commit comments

Comments
 (0)