Skip to content

Commit 2a83dec

Browse files
authored
[ut bug fix opencl]opencl fix reshape/reshape2 bug (#8364)
1 parent cd756fd commit 2a83dec

3 files changed

Lines changed: 23 additions & 11 deletions

File tree

lite/backends/opencl/cl_kernel/image/reshape_kernel.cl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ __kernel void reshape(__read_only image2d_t input_image,
5050

5151
int in_n0 = count0 / in_Stride2;
5252
int in_n1 = count1 / in_Stride2;
53-
int in_n2 = count1 / in_Stride2;
54-
int in_n3 = count2 / in_Stride2;
53+
int in_n2 = count2 / in_Stride2;
54+
int in_n3 = count3 / in_Stride2;
5555

5656
count0 = count0 % in_Stride2;
5757
count1 = count1 % in_Stride2;
@@ -132,7 +132,7 @@ __kernel void reshape(__read_only image2d_t input_image,
132132
if (in_c2 % 4 == 0) {
133133
output.z = input2.x;
134134
} else if (in_c2 % 4 == 1) {
135-
output.z = input1.y;
135+
output.z = input2.y;
136136
} else if (in_c2 % 4 == 2) {
137137
output.z = input2.z;
138138
} else {

lite/tests/unittest_py/op/test_reshape2_op.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import numpy as np
2828
from functools import partial
29+
from functools import reduce
2930

3031

3132
class TestReshape2Op(AutoScanTest):
@@ -77,12 +78,17 @@ def sample_program_configs(self, draw):
7778
st.lists(
7879
st.integers(
7980
min_value=1, max_value=10), min_size=4, max_size=4))
81+
8082
attr_shape = draw(
8183
st.lists(
8284
st.integers(
83-
min_value=0, max_value=4),
84-
min_size=len(in_shape),
85+
min_value=1, max_value=max(in_shape)),
86+
min_size=1,
8587
max_size=len(in_shape)))
88+
assume(
89+
reduce(lambda x, y: x * y, attr_shape) == reduce(
90+
lambda x, y: x * y, in_shape))
91+
8692
with_shape = draw(st.sampled_from([True, False]))
8793

8894
def generate_input(*args, **kwargs):
@@ -95,7 +101,7 @@ def generate_input(*args, **kwargs):
95101
"Out": ["output_data"],
96102
"XShape": ["x_shape"],
97103
},
98-
attrs={"shape": in_shape, })
104+
attrs={"shape": attr_shape, })
99105
program_config = ProgramConfig(
100106
ops=[build_ops],
101107
weights={},
@@ -125,7 +131,7 @@ def _teller1(program_config, predictor_config):
125131
)
126132

127133
def test(self, *args, **kwargs):
128-
self.run_and_statis(quant=False, max_examples=25)
134+
self.run_and_statis(quant=False, max_examples=200)
129135

130136

131137
if __name__ == "__main__":

lite/tests/unittest_py/op/test_reshape_op.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import numpy as np
2828
from functools import partial
29+
from functools import reduce
2930

3031

3132
class TestReshapeOp(AutoScanTest):
@@ -66,12 +67,17 @@ def sample_program_configs(self, draw):
6667
st.lists(
6768
st.integers(
6869
min_value=1, max_value=10), min_size=4, max_size=4))
70+
6971
attr_shape = draw(
7072
st.lists(
7173
st.integers(
72-
min_value=0, max_value=4),
73-
min_size=len(in_shape),
74+
min_value=1, max_value=max(in_shape)),
75+
min_size=1,
7476
max_size=len(in_shape)))
77+
assume(
78+
reduce(lambda x, y: x * y, attr_shape) == reduce(
79+
lambda x, y: x * y, in_shape))
80+
7581
with_shape = draw(st.sampled_from([True, False]))
7682

7783
def generate_input(*args, **kwargs):
@@ -81,7 +87,7 @@ def generate_input(*args, **kwargs):
8187
type="reshape",
8288
inputs={"X": ["input_data"], },
8389
outputs={"Out": ["output_data"], },
84-
attrs={"shape": in_shape, })
90+
attrs={"shape": attr_shape, })
8591
program_config = ProgramConfig(
8692
ops=[build_ops],
8793
weights={},
@@ -105,7 +111,7 @@ def _teller1(program_config, predictor_config):
105111
)
106112

107113
def test(self, *args, **kwargs):
108-
self.run_and_statis(quant=False, max_examples=25)
114+
self.run_and_statis(quant=False, max_examples=200)
109115

110116

111117
if __name__ == "__main__":

0 commit comments

Comments
 (0)