Skip to content

Commit c7a3ee4

Browse files
authored
[Cleanup][C-14]Replace Program.random_seed (#61526)
1 parent 1291468 commit c7a3ee4

5 files changed

Lines changed: 16 additions & 32 deletions

test/legacy_test/test_multiprocess_dataloader_dataset.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def __iter__(self):
5959

6060
class TestTensorDataset(unittest.TestCase):
6161
def run_main(self, num_workers, places):
62-
paddle.static.default_startup_program().random_seed = 1
63-
paddle.static.default_main_program().random_seed = 1
62+
paddle.seed(1)
6463
place = paddle.CPUPlace()
6564
with base.dygraph.guard(place):
6665
input_np = np.random.random([16, 3, 4]).astype('float32')
@@ -98,8 +97,7 @@ def test_main(self):
9897

9998
class TestComposeDataset(unittest.TestCase):
10099
def test_main(self):
101-
paddle.static.default_startup_program().random_seed = 1
102-
paddle.static.default_main_program().random_seed = 1
100+
paddle.seed(1)
103101

104102
dataset1 = RandomDataset(10)
105103
dataset2 = RandomDataset(10)
@@ -118,8 +116,7 @@ def test_main(self):
118116

119117
class TestRandomSplitApi(unittest.TestCase):
120118
def test_main(self):
121-
paddle.static.default_startup_program().random_seed = 1
122-
paddle.static.default_main_program().random_seed = 1
119+
paddle.seed(1)
123120

124121
dataset1, dataset2 = paddle.io.random_split(range(5), [1, 4])
125122

@@ -139,8 +136,7 @@ def test_main(self):
139136

140137
class TestRandomSplitError(unittest.TestCase):
141138
def test_errors(self):
142-
paddle.static.default_startup_program().random_seed = 1
143-
paddle.static.default_main_program().random_seed = 1
139+
paddle.seed(1)
144140

145141
self.assertRaises(ValueError, paddle.io.random_split, range(5), [3, 8])
146142
self.assertRaises(ValueError, paddle.io.random_split, range(5), [8])
@@ -149,8 +145,7 @@ def test_errors(self):
149145

150146
class TestSubsetDataset(unittest.TestCase):
151147
def run_main(self, num_workers, places):
152-
paddle.static.default_startup_program().random_seed = 1
153-
paddle.static.default_main_program().random_seed = 1
148+
paddle.seed(1)
154149

155150
input_np = np.random.random([5, 3, 4]).astype('float32')
156151
input = paddle.to_tensor(input_np)
@@ -201,8 +196,7 @@ def assert_basic(input, label):
201196
self.assertEqual(odd_list, elements_list)
202197

203198
def test_main(self):
204-
paddle.static.default_startup_program().random_seed = 1
205-
paddle.static.default_main_program().random_seed = 1
199+
paddle.seed(1)
206200

207201
places = [paddle.CPUPlace()]
208202
if paddle.is_compiled_with_cuda():
@@ -213,8 +207,7 @@ def test_main(self):
213207

214208
class TestChainDataset(unittest.TestCase):
215209
def run_main(self, num_workers, places):
216-
paddle.static.default_startup_program().random_seed = 1
217-
paddle.static.default_main_program().random_seed = 1
210+
paddle.seed(1)
218211

219212
dataset1 = RandomIterableDataset(10)
220213
dataset2 = RandomIterableDataset(10)
@@ -259,8 +252,7 @@ def __getitem__(self, idx):
259252

260253
class TestNumpyMixTensorDataset(TestTensorDataset):
261254
def run_main(self, num_workers, places):
262-
paddle.static.default_startup_program().random_seed = 1
263-
paddle.static.default_main_program().random_seed = 1
255+
paddle.seed(1)
264256
place = paddle.CPUPlace()
265257
with base.dygraph.guard(place):
266258
dataset = NumpyMixTensorDataset(16)
@@ -304,8 +296,7 @@ def __getitem__(self, idx):
304296

305297
class TestComplextDataset(unittest.TestCase):
306298
def run_main(self, num_workers):
307-
paddle.static.default_startup_program().random_seed = 1
308-
paddle.static.default_main_program().random_seed = 1
299+
paddle.seed(1)
309300
place = paddle.CPUPlace()
310301
with base.dygraph.guard(place):
311302
dataset = ComplextDataset(16)
@@ -360,8 +351,7 @@ def init_dataset(self):
360351
self.dataset = SingleFieldDataset(self.sample_num)
361352

362353
def run_main(self, num_workers):
363-
paddle.static.default_startup_program().random_seed = 1
364-
paddle.static.default_main_program().random_seed = 1
354+
paddle.seed(1)
365355
place = paddle.CPUPlace()
366356
with base.dygraph.guard(place):
367357
self.init_dataset()

test/legacy_test/test_multiprocess_dataloader_dynamic.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ def run_main(
9595
collate_fn,
9696
use_shared_memory,
9797
):
98-
base.default_startup_program().random_seed = 1
99-
base.default_main_program().random_seed = 1
98+
paddle.seed(1)
10099
with base.dygraph.guard(places[0]):
101100
fc_net = SimpleFCNet()
102101
optimizer = paddle.optimizer.Adam(parameters=fc_net.parameters())
@@ -176,8 +175,7 @@ def run_main(
176175
collate_fn,
177176
use_shared_memory,
178177
):
179-
base.default_startup_program().random_seed = 1
180-
base.default_main_program().random_seed = 1
178+
paddle.seed(1)
181179
with base.dygraph.guard(places[0]):
182180
fc_net = SimpleFCNet()
183181
optimizer = paddle.optimizer.Adam(parameters=fc_net.parameters())

test/legacy_test/test_multiprocess_dataloader_iterable_dataset_dynamic.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ def forward(self, image):
7777

7878
class TestDygraphDataLoader(unittest.TestCase):
7979
def run_main(self, num_workers, places, persistent_workers):
80-
base.default_startup_program().random_seed = 1
81-
base.default_main_program().random_seed = 1
80+
paddle.seed(1)
8281
with base.dygraph.guard(places[0]):
8382
fc_net = SimpleFCNet()
8483
optimizer = paddle.optimizer.Adam(parameters=fc_net.parameters())
@@ -146,8 +145,7 @@ def test_main(self):
146145

147146
class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
148147
def run_main(self, num_workers, places, persistent_workers):
149-
base.default_startup_program().random_seed = 1
150-
base.default_main_program().random_seed = 1
148+
paddle.seed(1)
151149
with base.dygraph.guard(places[0]):
152150
fc_net = SimpleFCNet()
153151
optimizer = paddle.optimizer.Adam(parameters=fc_net.parameters())

test/legacy_test/test_multiprocess_dataloader_iterable_dataset_static.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ def __iter__(self):
4747
def simple_fc_net_static():
4848
startup_prog = base.Program()
4949
main_prog = base.Program()
50-
startup_prog.random_seed = 1
51-
main_prog.random_seed = 1
50+
paddle.seed(1)
5251

5352
with base.unique_name.guard():
5453
with base.program_guard(main_prog, startup_prog):

test/legacy_test/test_multiprocess_dataloader_static.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ def __len__(self):
4747
def simple_fc_net_static():
4848
startup_prog = base.Program()
4949
main_prog = base.Program()
50-
startup_prog.random_seed = 1
51-
main_prog.random_seed = 1
50+
paddle.seed(1)
5251

5352
with base.unique_name.guard():
5453
with base.program_guard(main_prog, startup_prog):

0 commit comments

Comments
 (0)