Skip to content

Commit f13371e

Browse files
committed
test=develop
1 parent f4bfc95 commit f13371e

3 files changed

Lines changed: 217 additions & 16 deletions

File tree

lite/kernels/host/sequence_mask_compute.cc

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,32 +19,62 @@ namespace lite {
1919
namespace kernels {
2020
namespace host {
2121

22+
template <class Tx, class Ty>
23+
void SequenceMask(const Tx* x, Ty* y, const int x_size, const int max_len) {
24+
memset(y, 0, sizeof(Ty) * x_size * max_len);
25+
for (int i = 0; i < x_size; i++) {
26+
for (int j = 0; j < max_len; j++) {
27+
y[j] = static_cast<Ty>(static_cast<Tx>(j) < x[i] ? 1 : 0);
28+
}
29+
y += max_len;
30+
}
31+
}
32+
2233
template <class T>
2334
void SequenceMaskCompute<T>::Run() {
2435
auto& param = this->template Param<param_t>();
2536
auto* x = param.X;
26-
auto* y = parm.Y;
27-
int maxlen = param.maxlen;
37+
auto* y = param.Y;
38+
int max_len = param.maxlen;
2839
auto* max_len_tensor = param.MaxLenTensor;
2940
if (max_len_tensor != nullptr) {
30-
maxlen = max_len_tensor->template data<int>()[0];
31-
CHECK_GT(maxlen, 0) << "Input(MaxLenTensor)'s value should be greater than "
32-
"0. But received maxlen: "
33-
<< maxlen;
41+
max_len = max_len_tensor->template data<int>()[0];
42+
CHECK_GT(max_len, 0)
43+
<< "Input(MaxLenTensor)'s value should be greater than "
44+
"0. But received maxlen: "
45+
<< max_len;
3446
}
3547

3648
auto* x_data = x->template data<T>();
37-
auto x_size = x->numel();
38-
if (maxlen < 0) {
39-
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_size));
49+
int x_size = static_cast<int>(x->numel());
50+
if (max_len < 0) {
51+
max_len = static_cast<int>(*std::max_element(x_data, x_data + x_size));
4052
}
4153

4254
auto y_shape = x->dims().Vectorize();
43-
y_shape.push_back(static_cast<int64_t>(maxlen));
55+
y_shape.push_back(static_cast<int64_t>(max_len));
4456
y->Resize(y_shape);
4557
y->set_lod(x->lod());
4658

4759
int out_type = param.out_dtype;
60+
switch (lite::core::FluidType(out_type)) {
61+
case lite::core::FluidType::FP32: {
62+
SequenceMask(x_data, y->template mutable_data<float>(), x_size, max_len);
63+
break;
64+
}
65+
case lite::core::FluidType::INT32: {
66+
SequenceMask(x_data, y->template mutable_data<int>(), x_size, max_len);
67+
break;
68+
}
69+
case lite::core::FluidType::INT64: {
70+
SequenceMask(
71+
x_data, y->template mutable_data<int64_t>(), x_size, max_len);
72+
break;
73+
}
74+
default:
75+
LOG(FATAL) << "unsupported out data type: " << out_type;
76+
break;
77+
}
4878
}
4979

5080
} // namespace host
@@ -61,8 +91,7 @@ REGISTER_LITE_KERNEL(sequence_mask,
6191
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))})
6292
.BindInput("MaxLenTensor",
6393
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
64-
.BindOutput("Output",
65-
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
94+
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
6695
.Finalize();
6796

6897
REGISTER_LITE_KERNEL(sequence_mask,
@@ -74,8 +103,7 @@ REGISTER_LITE_KERNEL(sequence_mask,
74103
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
75104
.BindInput("MaxLenTensor",
76105
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
77-
.BindOutput("Output",
78-
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
106+
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
79107
.Finalize();
80108

81109
REGISTER_LITE_KERNEL(sequence_mask,
@@ -87,6 +115,5 @@ REGISTER_LITE_KERNEL(sequence_mask,
87115
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
88116
.BindInput("MaxLenTensor",
89117
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
90-
.BindOutput("Output",
91-
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
118+
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
92119
.Finalize();

lite/tests/kernels/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ if((NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LITE_WITH_MLU) AND (LITE_WIT
6565
if(LITE_BUILD_EXTRA)
6666
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS ${test_kernel_deps})
6767
lite_cc_test(test_sequence_pad SRCS sequence_pad_test.cc DEPS ${test_kernel_deps})
68+
lite_cc_test(test_sequence_mask SRCS sequence_mask_test.cc DEPS ${test_kernel_deps})
6869
lite_cc_test(test_correlation SRCS correlation_test.cc DEPS ${test_kernel_deps})
6970
#lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${test_kernel_deps})
7071
lite_cc_test(test_kernel_sequence_conv_compute SRCS sequence_conv_compute_test.cc DEPS ${test_kernel_deps})
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <gtest/gtest.h>
16+
#include <cmath>
17+
#include "lite/api/paddle_use_kernels.h"
18+
#include "lite/api/paddle_use_ops.h"
19+
#include "lite/core/arena/framework.h"
20+
#include "lite/tests/utils/fill_data.h"
21+
22+
namespace paddle {
23+
namespace lite {
24+
25+
template <class Tx, class Ty>
26+
void SequenceMask(const Tx* x, Ty* y, const int x_size, const int max_len) {
27+
memset(y, 0, sizeof(Ty) * x_size * max_len);
28+
for (int i = 0; i < x_size; i++) {
29+
int step = static_cast<int>(std::ceil(static_cast<float>(x[i])));
30+
for (int j = 0; j < step; j++) {
31+
y[j] = static_cast<Ty>(1);
32+
}
33+
y += max_len;
34+
}
35+
}
36+
37+
template <class T>
38+
class SequenceMaskTester : public arena::TestCase {
39+
protected:
40+
std::string x_ = "x";
41+
std::string max_len_tensor_;
42+
std::string y_ = "y";
43+
int max_len_{-1};
44+
int out_type_{5};
45+
DDim x_dims_{{2, 3, 4}};
46+
47+
public:
48+
SequenceMaskTester(const Place& place,
49+
const std::string& alias,
50+
const int max_len = 5,
51+
const int out_type = 5,
52+
const bool use_max_len_tensor = false)
53+
: TestCase(place, alias), max_len_(max_len), out_type_(out_type) {
54+
if (use_max_len_tensor) {
55+
max_len_tensor_ = std::string("max_len_tensor");
56+
}
57+
}
58+
59+
void RunBaseline(Scope* scope) override {
60+
auto* y = scope->NewTensor(y_);
61+
auto y_shape = x_dims_.Vectorize();
62+
y_shape.push_back(static_cast<int64_t>(max_len_));
63+
y->Resize(y_shape);
64+
65+
auto* x = scope->FindTensor(x_);
66+
auto* x_data = x->template data<T>();
67+
int x_size = static_cast<int>(x->numel());
68+
69+
switch (out_type_) {
70+
case 5: {
71+
SequenceMask(
72+
x_data, y->template mutable_data<float>(), x_size, max_len_);
73+
break;
74+
}
75+
case 2: {
76+
SequenceMask(x_data, y->template mutable_data<int>(), x_size, max_len_);
77+
break;
78+
}
79+
case 3: {
80+
SequenceMask(
81+
x_data, y->template mutable_data<int64_t>(), x_size, max_len_);
82+
break;
83+
}
84+
default:
85+
LOG(FATAL) << "unsupported out data type: " << out_type_;
86+
break;
87+
}
88+
}
89+
90+
void PrepareOpDesc(cpp::OpDesc* op_desc) {
91+
op_desc->SetType("sequence_mask");
92+
op_desc->SetInput("X", {x_});
93+
if (!max_len_tensor_.empty()) {
94+
op_desc->SetInput("MaxLenTensor", {max_len_tensor_});
95+
op_desc->SetAttr("maxlen", -1);
96+
} else {
97+
op_desc->SetAttr("maxlen", max_len_);
98+
}
99+
op_desc->SetOutput("Y", {y_});
100+
op_desc->SetAttr("out_dtype", out_type_);
101+
}
102+
103+
void PrepareData() override {
104+
std::vector<T> x_data(x_dims_.production());
105+
fill_data_rand<T>(x_data.data(), 0, 4, x_dims_.production());
106+
SetCommonTensor(x_, x_dims_, x_data.data());
107+
108+
if (!max_len_tensor_.empty()) {
109+
std::vector<int> max_len_tensor_data{max_len_};
110+
SetCommonTensor(max_len_tensor_, DDim{{1}}, max_len_tensor_data.data());
111+
}
112+
}
113+
};
114+
115+
template <class T>
116+
void TestSequenceMaskHelper(const Place place,
117+
const float abs_error,
118+
const int max_len = 5,
119+
const int out_type = 5,
120+
const bool use_max_len_tensor = false) {
121+
std::string alias;
122+
auto precision = lite_api::PrecisionTypeTrait<T>::Type();
123+
switch (precision) {
124+
case PRECISION(kFloat):
125+
alias = std::string("def");
126+
break;
127+
case PRECISION(kInt32):
128+
alias = std::string("int32");
129+
break;
130+
case PRECISION(kInt64):
131+
alias = std::string("int64");
132+
break;
133+
default:
134+
LOG(FATAL) << "unsupported input data type: "
135+
<< lite_api::PrecisionToStr(precision);
136+
break;
137+
}
138+
std::unique_ptr<arena::TestCase> tester(new SequenceMaskTester<T>(
139+
place, alias, max_len, out_type, use_max_len_tensor));
140+
arena::Arena arena(std::move(tester), place, abs_error);
141+
arena.TestPrecision();
142+
}
143+
144+
template <class T>
145+
void TestSequenceMask(const Place place, const float abs_error) {
146+
// test max_len
147+
for (int max_len : {6}) {
148+
TestSequenceMaskHelper<T>(place, abs_error, max_len);
149+
}
150+
// test out_type
151+
for (int out_type : {2, 3, 5}) {
152+
TestSequenceMaskHelper<T>(place, abs_error, 5, out_type);
153+
}
154+
// test max_len_tensor
155+
TestSequenceMaskHelper<T>(place, abs_error, 5, 5, true);
156+
}
157+
158+
TEST(sequence_mask, precision) {
159+
Place place;
160+
float abs_error = 1e-5;
161+
#if defined(LITE_WITH_ARM) || defined(LITE_WITH_X86)
162+
place = TARGET(kHost);
163+
#else
164+
return;
165+
#endif
166+
167+
TestSequenceMask<float>(place, abs_error);
168+
TestSequenceMask<int>(place, abs_error);
169+
TestSequenceMask<int64_t>(place, abs_error);
170+
}
171+
172+
} // namespace lite
173+
} // namespace paddle

0 commit comments

Comments
 (0)