2222import multiprocessing
2323import numpy as np
2424
25+ import paddle
2526import paddle .fluid as fluid
2627from paddle .io import Dataset , BatchSampler , DataLoader
2728
29+ paddle .enable_static ()
30+
2831EPOCH_NUM = 3
2932BATCH_SIZE = 8
3033IMAGE_SIZE = 32
@@ -84,24 +87,31 @@ def simple_fc_net_static():
8487 return startup_prog , main_prog , image , label , loss
8588
8689
87- def prepare_places (with_data_parallel , with_cpu = False , with_gpu = True ):
90+ def prepare_places (with_data_parallel ,
91+ with_cpu = False ,
92+ with_gpu = False ,
93+ with_npu = False ):
8894 places = []
8995 if with_cpu :
9096 places .append ([fluid .CPUPlace ()])
9197 if with_data_parallel :
9298 places .append ([fluid .CPUPlace ()] * 2 )
9399
94- if with_gpu and fluid .core .is_compiled_with_cuda ():
100+ elif with_gpu and fluid .core .is_compiled_with_cuda ():
95101 tmp = fluid .cuda_places ()[:2 ]
96102 assert len (tmp ) > 0 , "no gpu detected"
97103 if with_data_parallel :
98104 places .append (tmp )
99105 places .append ([tmp [0 ]])
106+
107+ elif with_npu and paddle .is_compiled_with_npu ():
108+ places .append ([paddle .NPUPlace (0 )])
109+
100110 return places
101111
102112
103113class TestStaticDataLoader (unittest .TestCase ):
104- def run_main (self , num_workers , places ):
114+ def run_main (self , num_workers , places , use_pe = True ):
105115 scope = fluid .Scope ()
106116 with fluid .scope_guard (scope ):
107117 startup_prog , main_prog , image , label , loss = simple_fc_net_static ()
@@ -120,10 +130,13 @@ def run_main(self, num_workers, places):
120130 exe = fluid .Executor (place = places [0 ])
121131 exe .run (startup_prog )
122132
123- prog = fluid .CompiledProgram (main_prog )
124- if len (places ) > 1 :
125- prog = prog .with_data_parallel (
126- loss_name = loss .name , places = places )
133+ if use_pe :
134+ prog = fluid .CompiledProgram (main_prog )
135+ if len (places ) > 1 :
136+ prog = prog .with_data_parallel (
137+ loss_name = loss .name , places = places )
138+ else :
139+ prog = main_prog
127140
128141 step_list = []
129142 loss_list = []
@@ -157,19 +170,30 @@ def run_main(self, num_workers, places):
157170 print ("time cost" , ret ['time' ], 'step_list' , ret ['step' ])
158171 return ret
159172
160- def test_main (self ):
161- for p in prepare_places (True ):
162- results = []
173+ def _check_with_place (self , with_cpu = False , with_gpu = False , with_npu = False ):
174+ results = []
175+ for place in prepare_places (
176+ with_data_parallel = True ,
177+ with_cpu = with_cpu ,
178+ with_gpu = with_gpu ,
179+ with_npu = with_npu ):
163180 for num_workers in [0 , 2 ]:
164- print (self .__class__ .__name__ , p , num_workers )
181+ print (self .__class__ .__name__ , place , num_workers )
165182 sys .stdout .flush ()
166- ret = self .run_main (num_workers = num_workers , places = p )
183+ ret = self .run_main (
184+ num_workers = num_workers , places = place , use_pe = not with_npu )
185+
167186 results .append (ret )
168187 diff = np .max (
169188 np .abs (results [0 ]['loss' ] - results [1 ]['loss' ]) /
170189 np .abs (results [0 ]['loss' ]))
171190 self .assertLess (diff , 1e-2 )
172191
192+ def test_main (self ):
193+ self ._check_with_place (with_cpu = True )
194+ self ._check_with_place (with_gpu = True )
195+ self ._check_with_place (with_npu = True )
196+
173197
174198class TestStaticDataLoaderReturnList (unittest .TestCase ):
175199 def test_single_place (self ):
0 commit comments