1919
2020from paddle .fluid .framework import IrGraph
2121from paddle .fluid .framework import IrNode
22+ from paddle .fluid .tests .unittests .op_test import OpTestTool
2223from paddle .fluid import core
2324import paddle .fluid .layers as layers
2425from paddle .fluid .framework import Program , program_guard , default_startup_program
@@ -66,11 +67,12 @@ def false_func():
6667 # be destructed and the sub_graphs will be empty.
6768 return graph , all_sub_graphs
6869
69- def test_quant_sub_graphs (self ):
70+ def test_quant_sub_graphs (self , use_cuda = False ):
7071 graph , sub_graphs = self .build_graph_with_sub_graph ()
72+ place = fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace ()
7173 transform_pass = QuantizationTransformPass (
7274 scope = fluid .global_scope (),
73- place = fluid . CUDAPlace ( 0 ) ,
75+ place = place ,
7476 activation_quantize_type = 'abs_max' ,
7577 weight_quantize_type = 'range_abs_max' )
7678 Find_inserted_quant_op = False
@@ -81,6 +83,14 @@ def test_quant_sub_graphs(self):
8183 Find_inserted_quant_op = True
8284 self .assertTrue (Find_inserted_quant_op )
8385
86+ def test_quant_sub_graphs_cpu (self ):
87+ self .test_quant_sub_graphs (use_cuda = False )
88+
89+ @OpTestTool .skip_if (not paddle .is_compiled_with_cuda (),
90+ "Not GPU version paddle" )
91+ def test_quant_sub_graphs_gpu (self ):
92+ self .test_quant_sub_graphs (use_cuda = True )
93+
8494
8595if __name__ == '__main__' :
8696 unittest .main ()
0 commit comments