Skip to content

Commit a7b13d3

Browse files
authored
Support test_imperative container_sequential and signal_handler with eager_guard (PaddlePaddle#38614)
1 parent 30be931 commit a7b13d3

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

python/paddle/fluid/tests/unittests/test_imperative_container_sequential.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
import unittest
1818
import paddle.fluid as fluid
1919
import numpy as np
20+
from paddle.fluid.framework import _test_eager_guard
2021

2122

2223
class TestImperativeContainerSequential(unittest.TestCase):
23-
def test_sequential(self):
24+
def func_sequential(self):
2425
data = np.random.uniform(-1, 1, [5, 10]).astype('float32')
2526
with fluid.dygraph.guard():
2627
data = fluid.dygraph.to_variable(data)
@@ -55,7 +56,12 @@ def test_sequential(self):
5556
loss2 = fluid.layers.reduce_mean(res2)
5657
loss2.backward()
5758

58-
def test_sequential_list_params(self):
59+
def test_sequential(self):
60+
with _test_eager_guard():
61+
self.func_sequential()
62+
self.func_sequential()
63+
64+
def func_sequential_list_params(self):
5965
data = np.random.uniform(-1, 1, [5, 10]).astype('float32')
6066
with fluid.dygraph.guard():
6167
data = fluid.dygraph.to_variable(data)
@@ -90,6 +96,11 @@ def test_sequential_list_params(self):
9096
loss2 = fluid.layers.reduce_mean(res2)
9197
loss2.backward()
9298

99+
def test_sequential_list_params(self):
100+
with _test_eager_guard():
101+
self.func_sequential_list_params()
102+
self.func_sequential_list_params()
103+
93104

94105
if __name__ == '__main__':
95106
unittest.main()

python/paddle/fluid/tests/unittests/test_imperative_signal_handler.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import paddle.compat as cpt
2323
from paddle.fluid import core
24+
from paddle.fluid.framework import _test_eager_guard
2425

2526

2627
def set_child_signal_handler(self, child_pid):
@@ -37,8 +38,8 @@ def __handler__(signum, frame):
3738
signal.signal(signal.SIGCHLD, __handler__)
3839

3940

40-
class TestDygraphDataLoaderSingalHandler(unittest.TestCase):
41-
def test_child_process_exit_with_error(self):
41+
class DygraphDataLoaderSingalHandler(unittest.TestCase):
42+
def func_child_process_exit_with_error(self):
4243
def __test_process__():
4344
core._set_process_signal_handler()
4445
sys.exit(1)
@@ -65,7 +66,12 @@ def try_except_exit():
6566

6667
self.assertIsNotNone(exception)
6768

68-
def test_child_process_killed_by_sigsegv(self):
69+
def test_child_process_exit_with_error(self):
70+
with _test_eager_guard():
71+
self.func_child_process_exit_with_error()
72+
self.func_child_process_exit_with_error()
73+
74+
def func_child_process_killed_by_sigsegv(self):
6975
def __test_process__():
7076
core._set_process_signal_handler()
7177
os.kill(os.getpid(), signal.SIGSEGV)
@@ -93,7 +99,12 @@ def try_except_exit():
9399

94100
self.assertIsNotNone(exception)
95101

96-
def test_child_process_killed_by_sigbus(self):
102+
def test_child_process_killed_by_sigsegv(self):
103+
with _test_eager_guard():
104+
self.func_child_process_killed_by_sigsegv()
105+
self.func_child_process_killed_by_sigsegv()
106+
107+
def func_child_process_killed_by_sigbus(self):
97108
def __test_process__():
98109
core._set_process_signal_handler()
99110
os.kill(os.getpid(), signal.SIGBUS)
@@ -120,7 +131,12 @@ def try_except_exit():
120131

121132
self.assertIsNotNone(exception)
122133

123-
def test_child_process_killed_by_sigterm(self):
134+
def test_child_process_killed_by_sigbus(self):
135+
with _test_eager_guard():
136+
self.func_child_process_killed_by_sigbus()
137+
self.func_child_process_killed_by_sigbus()
138+
139+
def func_child_process_killed_by_sigterm(self):
124140
def __test_process__():
125141
core._set_process_signal_handler()
126142
time.sleep(10)
@@ -132,6 +148,11 @@ def __test_process__():
132148
set_child_signal_handler(id(self), test_process.pid)
133149
time.sleep(1)
134150

151+
def test_child_process_killed_by_sigterm(self):
152+
with _test_eager_guard():
153+
self.func_child_process_killed_by_sigterm()
154+
self.func_child_process_killed_by_sigterm()
155+
135156

136157
if __name__ == '__main__':
137158
unittest.main()

0 commit comments

Comments
 (0)