@@ -59,8 +59,7 @@ def __iter__(self):
5959
6060class TestTensorDataset (unittest .TestCase ):
6161 def run_main (self , num_workers , places ):
62- paddle .static .default_startup_program ().random_seed = 1
63- paddle .static .default_main_program ().random_seed = 1
62+ paddle .seed (1 )
6463 place = paddle .CPUPlace ()
6564 with base .dygraph .guard (place ):
6665 input_np = np .random .random ([16 , 3 , 4 ]).astype ('float32' )
@@ -98,8 +97,7 @@ def test_main(self):
9897
9998class TestComposeDataset (unittest .TestCase ):
10099 def test_main (self ):
101- paddle .static .default_startup_program ().random_seed = 1
102- paddle .static .default_main_program ().random_seed = 1
100+ paddle .seed (1 )
103101
104102 dataset1 = RandomDataset (10 )
105103 dataset2 = RandomDataset (10 )
@@ -118,8 +116,7 @@ def test_main(self):
118116
119117class TestRandomSplitApi (unittest .TestCase ):
120118 def test_main (self ):
121- paddle .static .default_startup_program ().random_seed = 1
122- paddle .static .default_main_program ().random_seed = 1
119+ paddle .seed (1 )
123120
124121 dataset1 , dataset2 = paddle .io .random_split (range (5 ), [1 , 4 ])
125122
@@ -139,8 +136,7 @@ def test_main(self):
139136
140137class TestRandomSplitError (unittest .TestCase ):
141138 def test_errors (self ):
142- paddle .static .default_startup_program ().random_seed = 1
143- paddle .static .default_main_program ().random_seed = 1
139+ paddle .seed (1 )
144140
145141 self .assertRaises (ValueError , paddle .io .random_split , range (5 ), [3 , 8 ])
146142 self .assertRaises (ValueError , paddle .io .random_split , range (5 ), [8 ])
@@ -149,8 +145,7 @@ def test_errors(self):
149145
150146class TestSubsetDataset (unittest .TestCase ):
151147 def run_main (self , num_workers , places ):
152- paddle .static .default_startup_program ().random_seed = 1
153- paddle .static .default_main_program ().random_seed = 1
148+ paddle .seed (1 )
154149
155150 input_np = np .random .random ([5 , 3 , 4 ]).astype ('float32' )
156151 input = paddle .to_tensor (input_np )
@@ -201,8 +196,7 @@ def assert_basic(input, label):
201196 self .assertEqual (odd_list , elements_list )
202197
203198 def test_main (self ):
204- paddle .static .default_startup_program ().random_seed = 1
205- paddle .static .default_main_program ().random_seed = 1
199+ paddle .seed (1 )
206200
207201 places = [paddle .CPUPlace ()]
208202 if paddle .is_compiled_with_cuda ():
@@ -213,8 +207,7 @@ def test_main(self):
213207
214208class TestChainDataset (unittest .TestCase ):
215209 def run_main (self , num_workers , places ):
216- paddle .static .default_startup_program ().random_seed = 1
217- paddle .static .default_main_program ().random_seed = 1
210+ paddle .seed (1 )
218211
219212 dataset1 = RandomIterableDataset (10 )
220213 dataset2 = RandomIterableDataset (10 )
@@ -259,8 +252,7 @@ def __getitem__(self, idx):
259252
260253class TestNumpyMixTensorDataset (TestTensorDataset ):
261254 def run_main (self , num_workers , places ):
262- paddle .static .default_startup_program ().random_seed = 1
263- paddle .static .default_main_program ().random_seed = 1
255+ paddle .seed (1 )
264256 place = paddle .CPUPlace ()
265257 with base .dygraph .guard (place ):
266258 dataset = NumpyMixTensorDataset (16 )
@@ -304,8 +296,7 @@ def __getitem__(self, idx):
304296
305297class TestComplextDataset (unittest .TestCase ):
306298 def run_main (self , num_workers ):
307- paddle .static .default_startup_program ().random_seed = 1
308- paddle .static .default_main_program ().random_seed = 1
299+ paddle .seed (1 )
309300 place = paddle .CPUPlace ()
310301 with base .dygraph .guard (place ):
311302 dataset = ComplextDataset (16 )
@@ -360,8 +351,7 @@ def init_dataset(self):
360351 self .dataset = SingleFieldDataset (self .sample_num )
361352
362353 def run_main (self , num_workers ):
363- paddle .static .default_startup_program ().random_seed = 1
364- paddle .static .default_main_program ().random_seed = 1
354+ paddle .seed (1 )
365355 place = paddle .CPUPlace ()
366356 with base .dygraph .guard (place ):
367357 self .init_dataset ()
0 commit comments