|
1 | | -from concurrent.futures import Future, as_completed |
2 | | -from concurrent.futures.thread import ThreadPoolExecutor |
3 | | -from typing import Optional, List |
| 1 | +import multiprocessing.pool |
| 2 | +from multiprocessing.pool import ThreadPool |
| 3 | +from typing import List |
4 | 4 |
|
5 | 5 |
|
6 | | -class SafeThreadPoolExecutor(ThreadPoolExecutor): |
7 | | - """An enhanced thread pool executor |
| 6 | +class SafeThreadPoolExecutor: |
| 7 | + """ |
| 8 | + A thread pool executor that collects all AsyncResult objects and waits for their completion. |
| 9 | +
|
| 10 | + Example Usage: |
| 11 | +
|
| 12 | + with SafeThreadPoolExecutor(max_workers=len(duthosts)) as executor: |
| 13 | + for duthost in duthosts: |
| 14 | + executor.submit(example_func, duthost, localhost) |
8 | 15 |
|
9 | | - Everytime we submit a task, it will store the feature in self.features |
10 | | - On the __exit__ function, it will wait all the tasks to be finished, |
11 | | - And check any exceptions that are raised during the task executing |
| 16 | + Behavior Summary: |
| 17 | + 1. On instantiation, starts `max_workers` threads via ThreadPool. |
| 18 | + 2. Each thread runs the submitted function (e.g., `example_func(arg1, arg2)`) in parallel. |
| 19 | + 3. When the `with` block scope ends, execution moves to `__exit__`, where it blocks on each `AsyncResult.get()` |
| 20 | + in turn to wait for all tasks to finish. |
| 21 | + 4. If all threads succeed without raising, the pool is shut down cleanly. |
| 22 | + 5. If any thread raises an exception, `.get()` re-raises that exception in the main thread. |
12 | 23 | """ |
13 | | - def __init__(self, *args, **kwargs): |
14 | | - super().__init__(*args, **kwargs) |
15 | | - self.features: Optional[List[Future]] = [] |
16 | 24 |
|
17 | | - def submit(self, __fn, *args, **kwargs): |
18 | | - f = super().submit(__fn, *args, **kwargs) |
19 | | - self.features.append(f) |
20 | | - return f |
| 25 | + def __init__(self, max_workers, *args, **kwargs): |
| 26 | + """ |
| 27 | + Create a ThreadPool with `max_workers` threads and initialize an empty list to collect results. |
| 28 | +
|
| 29 | + Args: |
| 30 | + max_workers: number of worker threads (maps to ThreadPool's `processes` parameter). |
| 31 | + *args, **kwargs: ignored (only here to match ThreadPoolExecutor signature). |
| 32 | + """ |
| 33 | + self._pool = ThreadPool(processes=max_workers) |
| 34 | + self._results: List["multiprocessing.pool.ApplyResult"] = [] |
| 35 | + |
| 36 | + def submit(self, fn, *args, **kwargs): |
| 37 | + """ |
| 38 | + Schedule fn(*args, **kwargs) to run in a worker thread. |
| 39 | + Returns an ApplyResult object whose .get() will return the result or re-raise any exception from the worker. |
| 40 | + """ |
| 41 | + # Wrap the user‐provided fn in a wrapper to catch any BaseException, and convert that BaseException into |
| 42 | + # a regular RuntimeError so ThreadPool's "except Exception" block will catch and enqueue it. |
| 43 | + def _wrapper(*fn_args, **fn_kwargs): |
| 44 | + try: |
| 45 | + return fn(*fn_args, **fn_kwargs) |
| 46 | + except BaseException as be: |
| 47 | + raise RuntimeError("Thread worker aborted: " + repr(be)) |
| 48 | + |
| 49 | + async_res = self._pool.apply_async(_wrapper, args, kwargs) |
| 50 | + self._results.append(async_res) |
| 51 | + return async_res |
| 52 | + |
| 53 | + def shutdown(self, wait=True): |
| 54 | + """ |
| 55 | + Stop accepting new tasks and optionally wait for running ones to finish. |
| 56 | + """ |
| 57 | + # Prevent new tasks |
| 58 | + self._pool.close() |
| 59 | + if wait: |
| 60 | + # Wait for all tasks to finish |
| 61 | + self._pool.join() |
| 62 | + |
| 63 | + def __enter__(self): |
| 64 | + """ |
| 65 | + Support the "with" statement. |
| 66 | + """ |
| 67 | + return self |
21 | 68 |
|
22 | 69 | def __exit__(self, exc_type, exc_val, exc_tb): |
23 | | - for future in as_completed(self.features): |
24 | | - # if exception caught in the sub-thread, .result() will raise it in the main thread |
25 | | - _ = future.result() |
| 70 | + """ |
| 71 | + Wait for each submitted task to complete and surface exceptions. |
| 72 | + """ |
| 73 | + for async_res in self._results: |
| 74 | + # .get() will block until the task finishes, and re-raise any exception to the main thread. |
| 75 | + async_res.get() |
| 76 | + |
| 77 | + # Shut down the pool by close + join. |
26 | 78 | self.shutdown(wait=True) |
| 79 | + # Returning False to ensure that any exception in the "with" statement is not suppressed. |
27 | 80 | return False |
0 commit comments