Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/datasets/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from queue import Empty
from shutil import disk_usage
from types import CodeType, FunctionType
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
from urllib.parse import urlparse

import dill
Expand Down Expand Up @@ -1330,12 +1330,18 @@ def _write_generator_to_queue(queue: queue.Queue, func: Callable[..., Iterable[Y
return i


def _get_pool_pid(pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool]) -> Set[int]:
return {f.pid for f in pool._pool}


def iflatmap_unordered(
pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool],
func: Callable[..., Iterable[Y]],
*,
kwargs_iterable: Iterable[dict],
) -> Iterable[Y]:
initial_pool_pid = _get_pool_pid(pool)
pool_changed = False
manager_cls = Manager if isinstance(pool, multiprocessing.pool.Pool) else multiprocess.Manager
with manager_cls() as manager:
queue = manager.Queue()
Expand All @@ -1349,6 +1355,14 @@ def iflatmap_unordered(
except Empty:
if all(async_result.ready() for async_result in async_results) and queue.empty():
break
if _get_pool_pid(pool) != initial_pool_pid:
pool_changed = True
# One of the subprocesses has died. We should not wait forever.
raise RuntimeError(
"One of the subprocesses has abruptly died during map operation."
"To debug the error, disable multiprocessing."
)
finally:
# we get the result in case there's an error to raise
[async_result.get(timeout=0.05) for async_result in async_results]
if not pool_changed:
# we get the result in case there's an error to raise
[async_result.get(timeout=0.05) for async_result in async_results]
20 changes: 20 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,26 @@ def __call__(self, example):
dset.map(ex_cnt)
self.assertEqual(ex_cnt.cnt, len(dset))

@require_not_windows
def test_map_crash_subprocess(self, in_memory):
# be sure that a crash in one of the subprocess will not
# hang dataset.map() call forever

def do_crash(row):
import os

os.kill(os.getpid(), 9)
return row

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
with pytest.raises(RuntimeError) as excinfo:
dset.map(do_crash, num_proc=2)
assert str(excinfo.value) == (
"One of the subprocesses has abruptly died during map operation."
"To debug the error, disable multiprocessing."
)

def test_filter(self, in_memory):
# keep only first five examples

Expand Down