2626
2727import numpy as np
2828from functools import partial
29+ from functools import reduce
2930
3031
3132class TestReshapeOp (AutoScanTest ):
@@ -66,12 +67,17 @@ def sample_program_configs(self, draw):
6667 st .lists (
6768 st .integers (
6869 min_value = 1 , max_value = 10 ), min_size = 4 , max_size = 4 ))
70+
6971 attr_shape = draw (
7072 st .lists (
7173 st .integers (
72- min_value = 0 , max_value = 4 ),
73- min_size = len ( in_shape ) ,
74+ min_value = 1 , max_value = max ( in_shape ) ),
75+ min_size = 1 ,
7476 max_size = len (in_shape )))
77+ assume (
78+ reduce (lambda x , y : x * y , attr_shape ) == reduce (
79+ lambda x , y : x * y , in_shape ))
80+
7581 with_shape = draw (st .sampled_from ([True , False ]))
7682
7783 def generate_input (* args , ** kwargs ):
@@ -81,7 +87,7 @@ def generate_input(*args, **kwargs):
8187 type = "reshape" ,
8288 inputs = {"X" : ["input_data" ], },
8389 outputs = {"Out" : ["output_data" ], },
84- attrs = {"shape" : in_shape , })
90+ attrs = {"shape" : attr_shape , })
8591 program_config = ProgramConfig (
8692 ops = [build_ops ],
8793 weights = {},
@@ -105,7 +111,7 @@ def _teller1(program_config, predictor_config):
105111 )
106112
107113 def test (self , * args , ** kwargs ):
108- self .run_and_statis (quant = False , max_examples = 25 )
114+ self .run_and_statis (quant = False , max_examples = 200 )
109115
110116
111117if __name__ == "__main__" :
0 commit comments