Skip to content

Commit 316357b

Browse files
authored
Merge pull request #118 from pyiron/raise_exceptions
Raise exception from threads
2 parents 3066b4f + ca312e4 commit 316357b

File tree

4 files changed

+44
-3
lines changed

4 files changed

+44
-3
lines changed

pympipool/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88
from pympipool.external_interfaces.executor import Executor, PoolExecutor
99
from pympipool.external_interfaces.pool import Pool, MPISpawnPool
10+
from pympipool.external_interfaces.thread import RaisingThread
1011
from pympipool.shared_functions.external_interfaces import cancel_items_in_queue
1112

1213
from ._version import get_versions

pympipool/external_interfaces/executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from abc import ABC
22
from concurrent.futures import Executor as FutureExecutor, Future
33
from queue import Queue
4-
from threading import Thread
54

5+
from pympipool.external_interfaces.thread import RaisingThread
66
from pympipool.shared_functions.external_interfaces import (
77
execute_parallel_tasks,
88
execute_serial_tasks,
@@ -110,7 +110,7 @@ def __init__(
110110
queue_adapter_kwargs=None,
111111
):
112112
super().__init__()
113-
self._process = Thread(
113+
self._process = RaisingThread(
114114
target=execute_parallel_tasks,
115115
kwargs={
116116
"future_queue": self._future_queue,
@@ -177,7 +177,7 @@ def __init__(
177177
queue_adapter_kwargs=None,
178178
):
179179
super().__init__()
180-
self._process = Thread(
180+
self._process = RaisingThread(
181181
target=execute_serial_tasks,
182182
kwargs={
183183
"future_queue": self._future_queue,
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from threading import Thread
2+
3+
4+
class RaisingThread(Thread):
5+
"""
6+
Based on https://stackoverflow.com/questions/2829329/catch-a-threads-exception-in-the-caller-thread
7+
"""
8+
9+
def __init__(
10+
self, group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None
11+
):
12+
super().__init__(
13+
group=group,
14+
target=target,
15+
name=name,
16+
args=args,
17+
kwargs=kwargs,
18+
daemon=daemon,
19+
)
20+
self._exception = None
21+
22+
def run(self):
23+
try:
24+
super().run()
25+
except Exception as e:
26+
self._exception = e
27+
28+
def join(self, timeout=None):
29+
super().join(timeout=timeout)
30+
if self._exception:
31+
raise self._exception

tests/test_worker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ def mpi_funct(i):
2424
return i, size, rank
2525

2626

27+
def raise_error():
28+
raise RuntimeError
29+
30+
2731
class TestFuturePool(unittest.TestCase):
2832
def test_pool_serial(self):
2933
with Executor(cores=1) as p:
@@ -53,6 +57,11 @@ def test_pool_serial_map(self):
5357
output = p.map(calc, [1, 2, 3])
5458
self.assertEqual(list(output), [np.array(1), np.array(4), np.array(9)])
5559

60+
def test_executor_exception(self):
61+
with self.assertRaises(RuntimeError):
62+
with Executor(cores=1) as p:
63+
p.submit(raise_error)
64+
5665
def test_pool_multi_core(self):
5766
with Executor(cores=2) as p:
5867
output = p.submit(mpi_funct, i=2)

0 commit comments

Comments
 (0)