@@ -27,7 +27,7 @@ class TRTReduceMeanTest(InferencePassTest):
2727 def setUp (self ):
2828 with fluid .program_guard (self .main_program , self .startup_program ):
2929 data = fluid .data (
30- name = "data" , shape = [- 1 , 3 , 224 , 224 ], dtype = "float32" )
30+ name = "data" , shape = [- 1 , 3 , - 1 , - 1 ], dtype = "float32" )
3131 reduce_mean = fluid .layers .reduce_mean (
3232 data , dim = [2 , - 1 ], keep_dim = True )
3333 out = fluid .layers .batch_norm (reduce_mean , is_test = True )
@@ -40,7 +40,35 @@ def setUp(self):
4040 1 << 30 , 32 , 1 , AnalysisConfig .Precision .Float32 , False , False )
4141 self .fetch_list = [out ]
4242 self .dynamic_shape_params = TRTReduceMeanTest .DynamicShapeParam ({
43- 'data' : [1 , 3 , 224 , 224 ]
43+ 'data' : [1 , 3 , 64 , 64 ]
44+ }, {'data' : [3 , 3 , 224 , 224 ]}, {'data' : [3 , 3 , 224 , 224 ]}, False )
45+
46+ def test_check_output (self ):
47+ if core .is_compiled_with_cuda ():
48+ use_gpu = True
49+ self .check_output_with_option (use_gpu , flatten = True )
50+ self .assertTrue (
51+ PassVersionChecker .IsCompatible ('tensorrt_subgraph_pass' ))
52+
53+
54+ class TRTReduceMeanTestFP16 (InferencePassTest ):
55+ def setUp (self ):
56+ with fluid .program_guard (self .main_program , self .startup_program ):
57+ data = fluid .data (
58+ name = "data" , shape = [- 1 , 3 , - 1 , - 1 ], dtype = "float32" )
59+ reduce_mean = fluid .layers .reduce_mean (
60+ data , dim = [2 , - 1 ], keep_dim = True )
61+ out = fluid .layers .batch_norm (reduce_mean , is_test = True )
62+
63+ self .feeds = {
64+ "data" : np .random .random ([3 , 3 , 224 , 224 ]).astype ("float32" ),
65+ }
66+ self .enable_trt = True
67+ self .trt_parameters = TRTReduceMeanTestFP16 .TensorRTParam (
68+ 1 << 30 , 32 , 1 , AnalysisConfig .Precision .Half , False , False )
69+ self .fetch_list = [out ]
70+ self .dynamic_shape_params = TRTReduceMeanTestFP16 .DynamicShapeParam ({
71+ 'data' : [1 , 3 , 64 , 64 ]
4472 }, {'data' : [3 , 3 , 224 , 224 ]}, {'data' : [3 , 3 , 224 , 224 ]}, False )
4573
4674 def test_check_output (self ):
@@ -78,5 +106,102 @@ def test_check_output(self):
78106 PassVersionChecker .IsCompatible ('tensorrt_subgraph_pass' ))
79107
80108
109+ class TRTReduceMeanTestStatic (InferencePassTest ):
110+ def setUp (self ):
111+ with fluid .program_guard (self .main_program , self .startup_program ):
112+ data = fluid .data (
113+ name = "data" , shape = [3 , 3 , 224 , 224 ], dtype = "float32" )
114+ reduce_mean = fluid .layers .reduce_mean (
115+ data , dim = [2 , - 1 ], keep_dim = True )
116+ out = fluid .layers .batch_norm (reduce_mean , is_test = True )
117+
118+ self .feeds = {
119+ "data" : np .random .random ([3 , 3 , 224 , 224 ]).astype ("float32" ),
120+ }
121+ self .enable_trt = True
122+ self .trt_parameters = TRTReduceMeanTestStatic .TensorRTParam (
123+ 1 << 30 , 32 , 1 , AnalysisConfig .Precision .Float32 , False , False )
124+ self .fetch_list = [out ]
125+
126+ def test_check_output (self ):
127+ if core .is_compiled_with_cuda ():
128+ use_gpu = True
129+ self .check_output_with_option (use_gpu , flatten = True )
130+ self .assertTrue (
131+ PassVersionChecker .IsCompatible ('tensorrt_subgraph_pass' ))
132+
133+
134+ class TRTReduceMeanStaticAllTest (InferencePassTest ):
135+ def setUp (self ):
136+ with fluid .program_guard (self .main_program , self .startup_program ):
137+ data = fluid .data (
138+ name = "data" , shape = [4 , 3 , 224 , 224 ], dtype = "float32" )
139+ reduce_mean = fluid .layers .reduce_mean (data , keep_dim = True )
140+ out = fluid .layers .batch_norm (reduce_mean , is_test = True )
141+
142+ self .feeds = {
143+ "data" : np .random .random ([4 , 3 , 224 , 224 ]).astype ("float32" ),
144+ }
145+ self .enable_trt = True
146+ self .trt_parameters = TRTReduceMeanStaticAllTest .TensorRTParam (
147+ 1 << 30 , 32 , 1 , AnalysisConfig .Precision .Float32 , False , False )
148+ self .fetch_list = [out ]
149+
150+ def test_check_output (self ):
151+ if core .is_compiled_with_cuda ():
152+ use_gpu = True
153+ self .check_output_with_option (use_gpu , flatten = True )
154+ self .assertTrue (
155+ PassVersionChecker .IsCompatible ('tensorrt_subgraph_pass' ))
156+
157+
158+ class TRTReduceMeanStaticFP16 (InferencePassTest ):
159+ def setUp (self ):
160+ with fluid .program_guard (self .main_program , self .startup_program ):
161+ data = fluid .data (
162+ name = "data" , shape = [4 , 3 , 224 , 224 ], dtype = "float32" )
163+ reduce_mean = fluid .layers .reduce_mean (data , keep_dim = True )
164+ out = fluid .layers .batch_norm (reduce_mean , is_test = True )
165+
166+ self .feeds = {
167+ "data" : np .random .random ([4 , 3 , 224 , 224 ]).astype ("float32" ),
168+ }
169+ self .enable_trt = True
170+ self .trt_parameters = TRTReduceMeanStaticFP16 .TensorRTParam (
171+ 1 << 30 , 32 , 1 , AnalysisConfig .Precision .Half , False , False )
172+ self .fetch_list = [out ]
173+
174+ def test_check_output (self ):
175+ if core .is_compiled_with_cuda ():
176+ use_gpu = True
177+ self .check_output_with_option (use_gpu , flatten = True )
178+ self .assertTrue (
179+ PassVersionChecker .IsCompatible ('tensorrt_subgraph_pass' ))
180+
181+
182+ class TRTReduceMeanFP16Static (InferencePassTest ):
183+ def setUp (self ):
184+ with fluid .program_guard (self .main_program , self .startup_program ):
185+ data = fluid .data (
186+ name = "data" , shape = [4 , 3 , 224 , 224 ], dtype = "float32" )
187+ reduce_mean = fluid .layers .reduce_mean (data , keep_dim = True )
188+ out = fluid .layers .batch_norm (reduce_mean , is_test = True )
189+
190+ self .feeds = {
191+ "data" : np .random .random ([4 , 3 , 224 , 224 ]).astype ("float32" ),
192+ }
193+ self .enable_trt = True
194+ self .trt_parameters = TRTReduceMeanFP16Static .TensorRTParam (
195+ 1 << 30 , 32 , 1 , AnalysisConfig .Precision .Half , True , False )
196+ self .fetch_list = [out ]
197+
198+ def test_check_output (self ):
199+ if core .is_compiled_with_cuda ():
200+ use_gpu = True
201+ self .check_output_with_option (use_gpu , flatten = True )
202+ self .assertTrue (
203+ PassVersionChecker .IsCompatible ('tensorrt_subgraph_pass' ))
204+
205+
81206if __name__ == "__main__" :
82207 unittest .main ()
0 commit comments