Skip to content

Commit 45a5ee0

Browse files
committed
Creating a process Pool subclass to track PID changes
1 parent 082d3df commit 45a5ee0

File tree

4 files changed

+34
-17
lines changed

4 files changed

+34
-17
lines changed

src/datasets/arrow_dataset.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
import pyarrow as pa
5959
import pyarrow.compute as pc
6060
from huggingface_hub import HfApi, HfFolder
61-
from multiprocess import Pool
6261
from requests import HTTPError
6362

6463
from . import config
@@ -113,7 +112,7 @@
113112
from .utils.hub import hf_hub_url
114113
from .utils.info_utils import is_small_dataset
115114
from .utils.metadata import DatasetMetadata
116-
from .utils.py_utils import Literal, asdict, convert_file_size_to_int, iflatmap_unordered, unique_values
115+
from .utils.py_utils import Literal, ProcessPool, asdict, convert_file_size_to_int, iflatmap_unordered, unique_values
117116
from .utils.stratify import stratified_shuffle_split_generate_indices
118117
from .utils.tf_utils import dataset_to_tf, minimal_tf_collate_fn, multiprocess_dataset_to_tf
119118
from .utils.typing import ListLike, PathLike
@@ -1505,7 +1504,7 @@ def save_to_disk(
15051504
shard_lengths = [None] * num_shards
15061505
shard_sizes = [None] * num_shards
15071506
if num_proc > 1:
1508-
with Pool(num_proc) as pool:
1507+
with ProcessPool(num_proc) as pool:
15091508
with pbar:
15101509
for job_id, done, content in iflatmap_unordered(
15111510
pool, Dataset._save_to_disk_single, kwargs_iterable=kwargs_per_job
@@ -3167,7 +3166,7 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
31673166
logger.info(
31683167
f"Reprocessing {len(kwargs_per_job)}/{num_shards} shards because some of them were missing from the cache."
31693168
)
3170-
with Pool(len(kwargs_per_job)) as pool:
3169+
with ProcessPool(len(kwargs_per_job)) as pool:
31713170
os.environ = prev_env
31723171
logger.info(f"Spawning {num_proc} processes")
31733172
with logging.tqdm(

src/datasets/builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
import fsspec
3535
import pyarrow as pa
36-
from multiprocess import Pool
3736
from tqdm.contrib.concurrent import thread_map
3837

3938
from . import config, utils
@@ -69,6 +68,7 @@
6968
from .utils.filelock import FileLock
7069
from .utils.info_utils import VerificationMode, get_size_checksum_dict, verify_checksums, verify_splits
7170
from .utils.py_utils import (
71+
ProcessPool,
7272
classproperty,
7373
convert_file_size_to_int,
7474
has_sufficient_disk_space,
@@ -1532,7 +1532,7 @@ def _prepare_split(
15321532
shards_per_job = [None] * num_jobs
15331533
shard_lengths_per_job = [None] * num_jobs
15341534

1535-
with Pool(num_proc) as pool:
1535+
with ProcessPool(num_proc) as pool:
15361536
with pbar:
15371537
for job_id, done, content in iflatmap_unordered(
15381538
pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job
@@ -1791,7 +1791,7 @@ def _prepare_split(
17911791
shards_per_job = [None] * num_jobs
17921792
shard_lengths_per_job = [None] * num_jobs
17931793

1794-
with Pool(num_proc) as pool:
1794+
with ProcessPool(num_proc) as pool:
17951795
with pbar:
17961796
for job_id, done, content in iflatmap_unordered(
17971797
pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job

src/datasets/utils/py_utils.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,39 +1330,57 @@ def _write_generator_to_queue(queue: queue.Queue, func: Callable[..., Iterable[Y
13301330
return i
13311331

13321332

1333-
def _get_pool_pid(pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool]) -> Set[int]:
1334-
return {f.pid for f in pool._pool}
1333+
class ProcessPool(multiprocess.pool.Pool):
1334+
"""
1335+
A multiprocess.pool.Pool implementation that keeps track of child process' PIDs,
1336+
and can detect if a child process has been restarted.
1337+
"""
1338+
1339+
def __init__(self, *args, **kwargs) -> None:
1340+
super().__init__(*args, **kwargs)
1341+
self._last_pool_pids = self._get_current_pool_pids()
1342+
self._has_restarted_subprocess = False
1343+
1344+
def _get_current_pool_pids(self) -> Set[int]:
1345+
return {f.pid for f in self._pool}
1346+
1347+
def has_restarted_subprocess(self) -> bool:
1348+
if self._has_restarted_subprocess:
1349+
# If the pool ever restarted a subprocess,
1350+
# we don't check the PIDs again.
1351+
return True
1352+
current_pids = self._get_current_pool_pids()
1353+
self._has_restarted_subprocess = current_pids != self._last_pool_pids
1354+
self._last_pool_pids = current_pids
1355+
return self._has_restarted_subprocess
13351356

13361357

13371358
def iflatmap_unordered(
1338-
pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool],
1359+
pool: ProcessPool,
13391360
func: Callable[..., Iterable[Y]],
13401361
*,
13411362
kwargs_iterable: Iterable[dict],
13421363
) -> Iterable[Y]:
1343-
initial_pool_pid = _get_pool_pid(pool)
13441364
manager_cls = Manager if isinstance(pool, multiprocessing.pool.Pool) else multiprocess.Manager
13451365
with manager_cls() as manager:
13461366
queue = manager.Queue()
13471367
async_results = [
13481368
pool.apply_async(_write_generator_to_queue, (queue, func, kwargs)) for kwargs in kwargs_iterable
13491369
]
1350-
subproc_killed = False
13511370
try:
13521371
while True:
13531372
try:
13541373
yield queue.get(timeout=0.05)
13551374
except Empty:
13561375
if all(async_result.ready() for async_result in async_results) and queue.empty():
13571376
break
1358-
if _get_pool_pid(pool) != initial_pool_pid:
1359-
subproc_killed = True
1377+
if pool.has_restarted_subprocess():
13601378
# One of the subprocesses has died. We should not wait forever.
13611379
raise RuntimeError(
13621380
"One of the subprocesses has abruptly died during map operation."
13631381
"To debug the error, disable multiprocessing."
13641382
)
13651383
finally:
1366-
if not subproc_killed:
1384+
if not pool.has_restarted_subprocess():
13671385
# we get the result in case there's an error to raise
13681386
[async_result.get(timeout=0.05) for async_result in async_results]

tests/test_arrow_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,9 +1343,9 @@ def test_map_caching(self, in_memory):
13431343
with self._caplog.at_level(WARNING):
13441344
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
13451345
with patch(
1346-
"datasets.arrow_dataset.Pool",
1346+
"datasets.arrow_dataset.ProcessPool",
13471347
new_callable=PickableMagicMock,
1348-
side_effect=datasets.arrow_dataset.Pool,
1348+
side_effect=datasets.arrow_dataset.ProcessPool,
13491349
) as mock_pool:
13501350
with dset.map(lambda x: {"foo": "bar"}, num_proc=2) as dset_test1:
13511351
dset_test1_data_files = list(dset_test1.cache_files)

0 commit comments

Comments
 (0)