Skip to content

Commit e6a19f3

Browse files
committed
Remove ProcessPool subclass
1 parent 285603d commit e6a19f3

File tree

4 files changed

+10
-33
lines changed

4 files changed

+10
-33
lines changed

src/datasets/arrow_dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import pyarrow as pa
5959
import pyarrow.compute as pc
6060
from huggingface_hub import HfApi, HfFolder
61+
from multiprocess import Pool
6162
from requests import HTTPError
6263

6364
from . import config
@@ -112,7 +113,7 @@
112113
from .utils.hub import hf_hub_url
113114
from .utils.info_utils import is_small_dataset
114115
from .utils.metadata import DatasetMetadata
115-
from .utils.py_utils import Literal, ProcessPool, asdict, convert_file_size_to_int, iflatmap_unordered, unique_values
116+
from .utils.py_utils import Literal, asdict, convert_file_size_to_int, iflatmap_unordered, unique_values
116117
from .utils.stratify import stratified_shuffle_split_generate_indices
117118
from .utils.tf_utils import dataset_to_tf, minimal_tf_collate_fn, multiprocess_dataset_to_tf
118119
from .utils.typing import ListLike, PathLike
@@ -1504,7 +1505,7 @@ def save_to_disk(
15041505
shard_lengths = [None] * num_shards
15051506
shard_sizes = [None] * num_shards
15061507
if num_proc > 1:
1507-
with ProcessPool(num_proc) as pool:
1508+
with Pool(num_proc) as pool:
15081509
with pbar:
15091510
for job_id, done, content in iflatmap_unordered(
15101511
pool, Dataset._save_to_disk_single, kwargs_iterable=kwargs_per_job
@@ -3166,7 +3167,7 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
31663167
logger.info(
31673168
f"Reprocessing {len(kwargs_per_job)}/{num_shards} shards because some of them were missing from the cache."
31683169
)
3169-
with ProcessPool(len(kwargs_per_job)) as pool:
3170+
with Pool(len(kwargs_per_job)) as pool:
31703171
os.environ = prev_env
31713172
logger.info(f"Spawning {num_proc} processes")
31723173
with logging.tqdm(

src/datasets/builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import fsspec
3535
import pyarrow as pa
36+
from multiprocess import Pool
3637
from tqdm.contrib.concurrent import thread_map
3738

3839
from . import config, utils
@@ -68,7 +69,6 @@
6869
from .utils.filelock import FileLock
6970
from .utils.info_utils import VerificationMode, get_size_checksum_dict, verify_checksums, verify_splits
7071
from .utils.py_utils import (
71-
ProcessPool,
7272
classproperty,
7373
convert_file_size_to_int,
7474
has_sufficient_disk_space,
@@ -1543,7 +1543,7 @@ def _prepare_split(
15431543
shards_per_job = [None] * num_jobs
15441544
shard_lengths_per_job = [None] * num_jobs
15451545

1546-
with ProcessPool(num_proc) as pool:
1546+
with Pool(num_proc) as pool:
15471547
with pbar:
15481548
for job_id, done, content in iflatmap_unordered(
15491549
pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job
@@ -1802,7 +1802,7 @@ def _prepare_split(
18021802
shards_per_job = [None] * num_jobs
18031803
shard_lengths_per_job = [None] * num_jobs
18041804

1805-
with ProcessPool(num_proc) as pool:
1805+
with Pool(num_proc) as pool:
18061806
with pbar:
18071807
for job_id, done, content in iflatmap_unordered(
18081808
pool, self._prepare_split_single, kwargs_iterable=kwargs_per_job

src/datasets/utils/py_utils.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,33 +1330,9 @@ def _write_generator_to_queue(queue: queue.Queue, func: Callable[..., Iterable[Y
13301330
return i
13311331

13321332

1333-
class ProcessPool(multiprocess.pool.Pool):
1334-
"""
1335-
A multiprocess.pool.Pool implementation that keeps track of child process' PIDs,
1336-
and can detect if a child process has been restarted.
1337-
"""
1338-
1339-
def __init__(self, *args, **kwargs) -> None:
1340-
super().__init__(*args, **kwargs)
1341-
self._last_pool_pids = self._get_current_pool_pids()
1342-
self._has_restarted_subprocess = False
1343-
1344-
def _get_current_pool_pids(self) -> Set[int]:
1345-
return {f.pid for f in self._pool}
1346-
1347-
def has_restarted_subprocess(self) -> bool:
1348-
if self._has_restarted_subprocess:
1349-
# If the pool ever restarted a subprocess,
1350-
# we don't check the PIDs again.
1351-
return True
1352-
current_pids = self._get_current_pool_pids()
1353-
self._has_restarted_subprocess = current_pids != self._last_pool_pids
1354-
self._last_pool_pids = current_pids
1355-
return self._has_restarted_subprocess
1356-
13571333

13581334
def iflatmap_unordered(
1359-
pool: ProcessPool,
1335+
pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool],
13601336
func: Callable[..., Iterable[Y]],
13611337
*,
13621338
kwargs_iterable: Iterable[dict],

tests/test_arrow_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,9 +1352,9 @@ def test_map_caching(self, in_memory):
13521352
with self._caplog.at_level(WARNING):
13531353
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
13541354
with patch(
1355-
"datasets.arrow_dataset.ProcessPool",
1355+
"datasets.arrow_dataset.Pool",
13561356
new_callable=PickableMagicMock,
1357-
side_effect=datasets.arrow_dataset.ProcessPool,
1357+
side_effect=datasets.arrow_dataset.Pool,
13581358
) as mock_pool:
13591359
with dset.map(lambda x: {"foo": "bar"}, num_proc=2) as dset_test1:
13601360
dset_test1_data_files = list(dset_test1.cache_files)

0 commit comments

Comments
 (0)