diff --git a/pympipool/__init__.py b/pympipool/__init__.py index be5a4750..fa9a3568 100644 --- a/pympipool/__init__.py +++ b/pympipool/__init__.py @@ -7,6 +7,7 @@ ) from pympipool.external_interfaces.executor import Executor, PoolExecutor from pympipool.external_interfaces.pool import Pool, MPISpawnPool +from pympipool.external_interfaces.thread import RaisingThread from pympipool.shared_functions.external_interfaces import cancel_items_in_queue from ._version import get_versions diff --git a/pympipool/external_interfaces/executor.py b/pympipool/external_interfaces/executor.py index 91639211..615065f8 100644 --- a/pympipool/external_interfaces/executor.py +++ b/pympipool/external_interfaces/executor.py @@ -1,8 +1,8 @@ from abc import ABC from concurrent.futures import Executor as FutureExecutor, Future from queue import Queue -from threading import Thread +from pympipool.external_interfaces.thread import RaisingThread from pympipool.shared_functions.external_interfaces import ( execute_parallel_tasks, execute_serial_tasks, @@ -110,7 +110,7 @@ def __init__( queue_adapter_kwargs=None, ): super().__init__() - self._process = Thread( + self._process = RaisingThread( target=execute_parallel_tasks, kwargs={ "future_queue": self._future_queue, @@ -177,7 +177,7 @@ def __init__( queue_adapter_kwargs=None, ): super().__init__() - self._process = Thread( + self._process = RaisingThread( target=execute_serial_tasks, kwargs={ "future_queue": self._future_queue, diff --git a/pympipool/external_interfaces/thread.py b/pympipool/external_interfaces/thread.py new file mode 100644 index 00000000..9f2c2c79 --- /dev/null +++ b/pympipool/external_interfaces/thread.py @@ -0,0 +1,31 @@ +from threading import Thread + + +class RaisingThread(Thread): + """ + Based on https://stackoverflow.com/questions/2829329/catch-a-threads-exception-in-the-caller-thread + """ + + def __init__( + self, group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None + ): + super().__init__( + group=group, + target=target, + name=name, + args=args, + kwargs=kwargs, + daemon=daemon, + ) + self._exception = None + + def run(self): + try: + super().run() + except Exception as e: + self._exception = e + + def join(self, timeout=None): + super().join(timeout=timeout) + if self._exception: + raise self._exception diff --git a/tests/test_worker.py b/tests/test_worker.py index 17afcced..ee6df90a 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -24,6 +24,10 @@ def mpi_funct(i): return i, size, rank +def raise_error(): + raise RuntimeError + + class TestFuturePool(unittest.TestCase): def test_pool_serial(self): with Executor(cores=1) as p: @@ -53,6 +57,11 @@ def test_pool_serial_map(self): output = p.map(calc, [1, 2, 3]) self.assertEqual(list(output), [np.array(1), np.array(4), np.array(9)]) + def test_executor_exception(self): + with self.assertRaises(RuntimeError): + with Executor(cores=1) as p: + p.submit(raise_error) + def test_pool_multi_core(self): with Executor(cores=2) as p: output = p.submit(mpi_funct, i=2)