Skip to content

Commit 2c5feaa

Browse files
[arm_ut_py] fix reduce_max compute error and add reduce_* ut_py (#8649)
1 parent ac56e50 commit 2c5feaa

File tree

8 files changed

+286
-97
lines changed

8 files changed

+286
-97
lines changed

lite/backends/arm/math/reduce_max.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ inline void reduce_third_of_three(
8383
const T* src, T* dst, int first_in, int second_in, int third_in) {
8484
for (int i = 0; i < first_in; i++) {
8585
for (int j = 0; j < second_in; j++) {
86-
dst[i * second_in + j] = src[i * second_in * third_in + j * second_in];
86+
dst[i * second_in + j] = src[i * second_in * third_in + j * third_in];
8787
for (int k = 0; k < third_in; k++) {
8888
dst[i * second_in + j] =
8989
src[i * second_in * third_in + j * third_in + k] >

lite/backends/arm/math/reduce_max_min.cc

Lines changed: 9 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,78 +18,30 @@ namespace paddle {
1818
namespace lite {
1919
namespace arm {
2020
namespace math {
21-
2221
template <>
2322
void reduce_second_of_two<float>(const float* src,
2423
float* dst,
2524
int first_in,
2625
int second_in,
27-
MaxMinType max_min_selector) {
28-
// max_min_selector == true, do reduce max; else do reduce min
29-
for (int j = 0; j < second_in; j++) {
30-
dst[j * first_in] = src[j * first_in];
31-
for (int k = 1; k < first_in; k++) {
32-
dst[j * first_in] = (src[j * first_in + k] <= dst[j * first_in]) ^
33-
static_cast<bool>(max_min_selector)
34-
? src[j * first_in + k]
35-
: dst[j * first_in];
36-
}
37-
}
38-
}
39-
40-
template <>
41-
void reduce_first_of_two<float>(const float* src,
42-
float* dst,
43-
int first_in,
44-
int second_in,
45-
MaxMinType max_min_selector) {
46-
// max_min_selector == true, do reduce max; else do reduce min
47-
for (int j = 0; j < first_in; j++) {
48-
dst[j] = src[j];
49-
for (int k = 1; k < second_in; k++) {
50-
dst[j] = (src[j + k * first_in] <= dst[j]) ^
51-
static_cast<bool>(max_min_selector)
52-
? src[j + k * first_in]
53-
: dst[j];
54-
}
55-
}
56-
}
57-
26+
MaxMinType max_min_selector);
5827
template <>
5928
void reduce_second_of_two<int64_t>(const int64_t* src,
6029
int64_t* dst,
6130
int first_in,
6231
int second_in,
63-
MaxMinType max_min_selector) {
64-
// max_min_selector == true, do reduce max; else do reduce min
65-
for (int j = 0; j < second_in; j++) {
66-
dst[j * first_in] = src[j * first_in];
67-
for (int k = 1; k < first_in; k++) {
68-
dst[j * first_in] = (src[j * first_in + k] <= dst[j * first_in]) ^
69-
static_cast<bool>(max_min_selector)
70-
? src[j * first_in + k]
71-
: dst[j * first_in];
72-
}
73-
}
74-
}
75-
32+
MaxMinType max_min_selector);
33+
template <>
34+
void reduce_first_of_two<float>(const float* src,
35+
float* dst,
36+
int first_in,
37+
int second_in,
38+
MaxMinType max_min_selector);
7639
template <>
7740
void reduce_first_of_two<int64_t>(const int64_t* src,
7841
int64_t* dst,
7942
int first_in,
8043
int second_in,
81-
MaxMinType max_min_selector) {
82-
// max_min_selector == true, do reduce max; else do reduce min
83-
for (int j = 0; j < first_in; j++) {
84-
dst[j] = src[j];
85-
for (int k = 1; k < second_in; k++) {
86-
dst[j] = (src[j + k * first_in] <= dst[j]) ^
87-
static_cast<bool>(max_min_selector)
88-
? src[j + k * first_in]
89-
: dst[j];
90-
}
91-
}
92-
}
44+
MaxMinType max_min_selector);
9345

9446
} // namespace math
9547
} // namespace arm

lite/backends/arm/math/reduce_max_min.h

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,40 @@ inline void reduce_one_line_min(const DataType* src, DataType* dst, int size) {
4444
*dst = tmp;
4545
}
4646

47-
template <typename DataType>
48-
void reduce_first_of_two(const DataType* src,
49-
DataType* dst,
50-
int first_in,
51-
int second_in,
52-
MaxMinType compare_functor);
53-
5447
template <typename DataType>
5548
void reduce_second_of_two(const DataType* src,
5649
DataType* dst,
5750
int first_in,
5851
int second_in,
59-
MaxMinType max_min_selector);
52+
MaxMinType max_min_selector) {
53+
// max_min_selector == true, do reduce max; else do reduce min
54+
for (int j = 0; j < first_in; j++) {
55+
dst[j] = src[j * second_in];
56+
for (int k = 1; k < second_in; k++) {
57+
dst[j] = (src[j * second_in + k] <= dst[j]) ^
58+
static_cast<bool>(max_min_selector)
59+
? src[j * second_in + k]
60+
: dst[j];
61+
}
62+
}
63+
}
64+
template <typename DataType>
65+
void reduce_first_of_two(const DataType* src,
66+
DataType* dst,
67+
int first_in,
68+
int second_in,
69+
MaxMinType max_min_selector) {
70+
// max_min_selector == true, do reduce max; else do reduce min
71+
for (int j = 0; j < second_in; j++) {
72+
dst[j] = src[j];
73+
for (int k = 1; k < first_in; k++) {
74+
dst[j] = (src[k * second_in + j] <= dst[j]) ^
75+
static_cast<bool>(max_min_selector)
76+
? src[k * second_in + j]
77+
: dst[j];
78+
}
79+
}
80+
}
6081

6182
} // namespace math
6283
} // namespace arm

lite/tests/unittest_py/op/test_reduce_max_op.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,16 @@
3131
class TestReduceMaxOp(AutoScanTest):
3232
def __init__(self, *args, **kwargs):
3333
AutoScanTest.__init__(self, *args, **kwargs)
34+
self.enable_testing_on_place(
35+
TargetType.ARM,
36+
PrecisionType.FP32,
37+
DataLayoutType.NCHW,
38+
thread=[1, 4])
3439
self.enable_testing_on_place(
3540
TargetType.X86,
3641
PrecisionType.FP32,
3742
DataLayoutType.NCHW,
38-
thread=[1, 2])
43+
thread=[1, 4])
3944
opencl_places = [
4045
Place(TargetType.OpenCL, PrecisionType.FP16,
4146
DataLayoutType.ImageDefault), Place(
@@ -69,14 +74,22 @@ def sample_program_configs(self, draw):
6974
in_shape = draw(
7075
st.lists(
7176
st.integers(
72-
min_value=1, max_value=10), min_size=4, max_size=4))
77+
min_value=1, max_value=10), min_size=1, max_size=4))
7378
keep_dim = draw(st.booleans())
74-
axis = draw(st.integers(min_value=-1, max_value=3))
75-
assume(axis < len(in_shape))
79+
axis_list = [
80+
draw(st.integers(
81+
min_value=-1, max_value=len(in_shape) - 1))
82+
]
83+
84+
if len(in_shape) == 2:
85+
axis_list = draw(st.sampled_from([[0], [1]]))
86+
elif len(in_shape) == 3:
87+
axis_list = draw(st.sampled_from([[0], [1], [2]]))
88+
elif len(in_shape) == 4:
89+
axis_list = draw(
90+
st.sampled_from([[0], [1], [2], [3], [0, 1], [1, 2], [2, 3]]))
7691

77-
if isinstance(axis, int):
78-
axis = [axis]
79-
reduce_all_data = True if axis == None or axis == [] else False
92+
reduce_all_data = True if axis_list == None or axis_list == [] else False
8093

8194
def generate_input(*args, **kwargs):
8295
return np.random.random(in_shape).astype(np.float32)
@@ -86,7 +99,7 @@ def generate_input(*args, **kwargs):
8699
inputs={"X": ["input_data"], },
87100
outputs={"Out": ["output_data"], },
88101
attrs={
89-
"dim": axis,
102+
"dim": axis_list,
90103
"keep_dim": keep_dim,
91104
"reduce_all": reduce_all_data,
92105
})
@@ -114,7 +127,8 @@ def _teller2(program_config, predictor_config):
114127
axis = program_config.ops[0].attrs["dim"]
115128
keep_dim = program_config.ops[0].attrs["keep_dim"]
116129
if target_type == TargetType.Metal:
117-
if keep_dim == False or axis[0] != 1 or in_shape[0] != 1:
130+
if keep_dim == False or axis[0] != 1 or in_shape[
131+
0] != 1 or len(in_shape) < 4 or len(axis) > 1:
118132
return True
119133

120134
self.add_ignore_check_case(
@@ -124,7 +138,7 @@ def _teller2(program_config, predictor_config):
124138

125139
def test(self, *args, **kwargs):
126140
target_str = self.get_target()
127-
max_examples = 100
141+
max_examples = 300
128142
if target_str == "Metal":
129143
# Make sure to generate enough valid cases for Metal
130144
max_examples = 3000

lite/tests/unittest_py/op/test_reduce_mean_op.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def __init__(self, *args, **kwargs):
3333
AutoScanTest.__init__(self, *args, **kwargs)
3434
self.enable_testing_on_place(TargetType.X86, PrecisionType.FP32,
3535
DataLayoutType.NCHW)
36+
self.enable_testing_on_place(
37+
TargetType.ARM,
38+
PrecisionType.FP32,
39+
DataLayoutType.NCHW,
40+
thread=[1, 4])
3641
opencl_places = [
3742
Place(TargetType.OpenCL, PrecisionType.FP16,
3843
DataLayoutType.ImageDefault), Place(
@@ -61,18 +66,10 @@ def sample_program_configs(self, draw):
6166
st.integers(
6267
min_value=1, max_value=10), min_size=4, max_size=4))
6368
keep_dim = draw(st.booleans())
64-
axis_type = draw(st.sampled_from(["int", "list"]))
65-
axis_int = draw(st.integers(min_value=-1, max_value=3))
6669
axis_list = draw(
67-
st.sampled_from([[2, 3], [1, 2], [0, 1], [1, 2, 3], [0, 1, 2]]))
70+
st.sampled_from([[0], [1], [2], [3], [0, 1], [1, 2], [2, 3]]))
6871

69-
if axis_type == "int":
70-
axis = axis_int
71-
else:
72-
axis = axis_list
73-
if isinstance(axis, int):
74-
axis = [axis]
75-
reduce_all_data = True if axis == None or axis == [] else False
72+
reduce_all_data = True if axis_list == None or axis_list == [] else False
7673

7774
def generate_input(*args, **kwargs):
7875
return np.random.random(in_shape).astype(np.float32)
@@ -82,7 +79,7 @@ def generate_input(*args, **kwargs):
8279
inputs={"X": ["input_data"], },
8380
outputs={"Out": ["output_data"], },
8481
attrs={
85-
"dim": axis,
82+
"dim": axis_list,
8683
"keep_dim": keep_dim,
8784
"reduce_all": reduce_all_data,
8885
})
@@ -99,7 +96,18 @@ def sample_predictor_configs(self):
9996
return self.get_predictor_configs(), ["reduce_mean"], (1e-2, 1e-2)
10097

10198
def add_ignore_pass_case(self):
102-
pass
99+
def _teller1(program_config, predictor_config):
100+
target_type = predictor_config.target()
101+
in_shape = list(program_config.inputs["input_data"].shape)
102+
axis = program_config.ops[0].attrs["dim"]
103+
if target_type == TargetType.OpenCL:
104+
if len(axis) == 1 and len(in_shape) == 4:
105+
return True
106+
107+
self.add_ignore_check_case(
108+
_teller1, IgnoreReasons.ACCURACY_ERROR,
109+
"The op output has diff in a specific case on opencl. We need to fix it as soon as possible."
110+
)
103111

104112
def test(self, *args, **kwargs):
105113
self.run_and_statis(quant=False, max_examples=100)

lite/tests/unittest_py/op/test_reduce_min_op.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,16 @@
3131
class TestReduceMinOp(AutoScanTest):
3232
def __init__(self, *args, **kwargs):
3333
AutoScanTest.__init__(self, *args, **kwargs)
34+
self.enable_testing_on_place(
35+
TargetType.ARM,
36+
PrecisionType.FP32,
37+
DataLayoutType.NCHW,
38+
thread=[1, 4])
3439
self.enable_testing_on_place(
3540
TargetType.X86,
3641
PrecisionType.FP32,
3742
DataLayoutType.NCHW,
38-
thread=[1, 2])
43+
thread=[1, 4])
3944

4045
def is_program_valid(self,
4146
program_config: ProgramConfig,
@@ -46,14 +51,22 @@ def sample_program_configs(self, draw):
4651
in_shape = draw(
4752
st.lists(
4853
st.integers(
49-
min_value=1, max_value=10), min_size=4, max_size=4))
54+
min_value=1, max_value=10), min_size=1, max_size=4))
5055
keep_dim = draw(st.booleans())
51-
axis = draw(st.integers(min_value=-1, max_value=3))
52-
assume(axis < len(in_shape))
56+
axis_list = [
57+
draw(st.integers(
58+
min_value=-1, max_value=len(in_shape) - 1))
59+
]
60+
61+
if len(in_shape) == 2:
62+
axis_list = draw(st.sampled_from([[0], [1]]))
63+
elif len(in_shape) == 3:
64+
axis_list = draw(st.sampled_from([[0], [1], [2]]))
65+
elif len(in_shape) == 4:
66+
axis_list = draw(
67+
st.sampled_from([[0], [1], [2], [3], [0, 1], [1, 2], [2, 3]]))
5368

54-
if isinstance(axis, int):
55-
axis = [axis]
56-
reduce_all_data = True if axis == None or axis == [] else False
69+
reduce_all_data = True if axis_list == None or axis_list == [] else False
5770

5871
def generate_input(*args, **kwargs):
5972
return np.random.random(in_shape).astype(np.float32)
@@ -63,7 +76,7 @@ def generate_input(*args, **kwargs):
6376
inputs={"X": ["input_data"], },
6477
outputs={"Out": ["output_data"], },
6578
attrs={
66-
"dim": axis,
79+
"dim": axis_list,
6780
"keep_dim": keep_dim,
6881
"reduce_all": reduce_all_data,
6982
})
@@ -83,7 +96,7 @@ def add_ignore_pass_case(self):
8396
pass
8497

8598
def test(self, *args, **kwargs):
86-
self.run_and_statis(quant=False, max_examples=25)
99+
self.run_and_statis(quant=False, max_examples=250)
87100

88101

89102
if __name__ == "__main__":

0 commit comments

Comments
 (0)