Skip to content

Commit aca4cdc

Browse files
pappacenalhoestq
andauthored
Avoid stuck map operation when subprocesses crashes (#5976)
* Adding simple test * Checking killed subprocs * Code format * Check pool of async results * Style * Using pool._pool instead of async_result pool * Avoid running on Windows a Linux specific test implementation Co-authored-by: Quentin Lhoest <[email protected]> --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 5d15950 commit aca4cdc

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

src/datasets/utils/py_utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from queue import Empty
3333
from shutil import disk_usage
3434
from types import CodeType, FunctionType
35-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
35+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
3636
from urllib.parse import urlparse
3737

3838
import dill
@@ -1330,12 +1330,18 @@ def _write_generator_to_queue(queue: queue.Queue, func: Callable[..., Iterable[Y
13301330
return i
13311331

13321332

1333+
def _get_pool_pid(pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool]) -> Set[int]:
1334+
return {f.pid for f in pool._pool}
1335+
1336+
13331337
def iflatmap_unordered(
13341338
pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool],
13351339
func: Callable[..., Iterable[Y]],
13361340
*,
13371341
kwargs_iterable: Iterable[dict],
13381342
) -> Iterable[Y]:
1343+
initial_pool_pid = _get_pool_pid(pool)
1344+
pool_changed = False
13391345
manager_cls = Manager if isinstance(pool, multiprocessing.pool.Pool) else multiprocess.Manager
13401346
with manager_cls() as manager:
13411347
queue = manager.Queue()
@@ -1349,6 +1355,14 @@ def iflatmap_unordered(
13491355
except Empty:
13501356
if all(async_result.ready() for async_result in async_results) and queue.empty():
13511357
break
1358+
if _get_pool_pid(pool) != initial_pool_pid:
1359+
pool_changed = True
1360+
# One of the subprocesses has died. We should not wait forever.
1361+
raise RuntimeError(
1362+
"One of the subprocesses has abruptly died during map operation."
1363+
"To debug the error, disable multiprocessing."
1364+
)
13521365
finally:
1353-
# we get the result in case there's an error to raise
1354-
[async_result.get(timeout=0.05) for async_result in async_results]
1366+
if not pool_changed:
1367+
# we get the result in case there's an error to raise
1368+
[async_result.get(timeout=0.05) for async_result in async_results]

tests/test_arrow_dataset.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,6 +1632,26 @@ def __call__(self, example):
16321632
dset.map(ex_cnt)
16331633
self.assertEqual(ex_cnt.cnt, len(dset))
16341634

1635+
@require_not_windows
1636+
def test_map_crash_subprocess(self, in_memory):
1637+
# be sure that a crash in one of the subprocess will not
1638+
# hang dataset.map() call forever
1639+
1640+
def do_crash(row):
1641+
import os
1642+
1643+
os.kill(os.getpid(), 9)
1644+
return row
1645+
1646+
with tempfile.TemporaryDirectory() as tmp_dir:
1647+
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1648+
with pytest.raises(RuntimeError) as excinfo:
1649+
dset.map(do_crash, num_proc=2)
1650+
assert str(excinfo.value) == (
1651+
"One of the subprocesses has abruptly died during map operation."
1652+
"To debug the error, disable multiprocessing."
1653+
)
1654+
16351655
def test_filter(self, in_memory):
16361656
# keep only first five examples
16371657

0 commit comments

Comments
 (0)