Skip to content

Commit d26abad

Browse files
authored
Fix parallel downloads for datasets without scripts (#6551)
* fix multiprocessing download * comment
1 parent e23a59e commit d26abad

File tree

3 files changed

+44
-9
lines changed

3 files changed

+44
-9
lines changed

src/datasets/download/download_manager.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,14 @@
3030
from .. import config
3131
from ..utils import tqdm as hf_tqdm
3232
from ..utils.deprecation_utils import DeprecatedEnum, deprecated
33-
from ..utils.file_utils import cached_path, get_from_cache, hash_url_to_filename, is_relative_path, url_or_path_join
33+
from ..utils.file_utils import (
34+
cached_path,
35+
get_from_cache,
36+
hash_url_to_filename,
37+
is_relative_path,
38+
stack_multiprocessing_download_progress_bars,
39+
url_or_path_join,
40+
)
3441
from ..utils.info_utils import get_size_checksum_dict
3542
from ..utils.logging import get_logger
3643
from ..utils.py_utils import NestedDataStructure, map_nested, size_str
@@ -423,13 +430,14 @@ def download(self, url_or_urls):
423430
download_func = partial(self._download, download_config=download_config)
424431

425432
start_time = datetime.now()
426-
downloaded_path_or_paths = map_nested(
427-
download_func,
428-
url_or_urls,
429-
map_tuple=True,
430-
num_proc=download_config.num_proc,
431-
desc="Downloading data files",
432-
)
433+
with stack_multiprocessing_download_progress_bars():
434+
downloaded_path_or_paths = map_nested(
435+
download_func,
436+
url_or_urls,
437+
map_tuple=True,
438+
num_proc=download_config.num_proc,
439+
desc="Downloading data files",
440+
)
433441
duration = datetime.now() - start_time
434442
logger.info(f"Downloading took {duration.total_seconds() // 60} min")
435443
url_or_urls = NestedDataStructure(url_or_urls)

src/datasets/utils/file_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import copy
88
import io
99
import json
10+
import multiprocessing
1011
import os
1112
import posixpath
1213
import re
@@ -19,6 +20,7 @@
1920
from functools import partial
2021
from pathlib import Path
2122
from typing import Optional, TypeVar, Union
23+
from unittest.mock import patch
2224
from urllib.parse import urljoin, urlparse
2325

2426
import fsspec
@@ -319,6 +321,12 @@ def fsspec_head(url, storage_options=None):
319321
return fs.info(paths[0])
320322

321323

324+
def stack_multiprocessing_download_progress_bars():
325+
# Stack downloads progress bars automatically using HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS=1
326+
# We use environment variables since the download may happen in a subprocess
327+
return patch.dict(os.environ, {"HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS": "1"})
328+
329+
322330
class TqdmCallback(fsspec.callbacks.TqdmCallback):
323331
def __init__(self, tqdm_kwargs=None, *args, **kwargs):
324332
super().__init__(tqdm_kwargs, *args, **kwargs)
@@ -335,6 +343,10 @@ def fsspec_get(url, temp_file, storage_options=None, desc=None):
335343
"desc": desc or "Downloading",
336344
"unit": "B",
337345
"unit_scale": True,
346+
"position": multiprocessing.current_process()._identity[-1] # contains the ranks of subprocesses
347+
if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
348+
and multiprocessing.current_process()._identity
349+
else None,
338350
}
339351
)
340352
fs.get_file(paths[0], temp_file.name, callback=callback)
@@ -389,6 +401,10 @@ def http_get(
389401
total=total,
390402
initial=resume_size,
391403
desc=desc or "Downloading",
404+
position=multiprocessing.current_process()._identity[-1] # contains the ranks of subprocesses
405+
if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
406+
and multiprocessing.current_process()._identity
407+
else None,
392408
) as progress:
393409
for chunk in response.iter_content(chunk_size=1024):
394410
progress.update(len(chunk))

src/datasets/utils/py_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,18 @@ def map_nested(
462462

463463
if num_proc is None:
464464
num_proc = 1
465-
if num_proc != -1 and num_proc <= 1 or len(iterable) < parallel_min_length:
465+
if any(isinstance(v, types) and len(v) > len(iterable) for v in iterable):
466+
mapped = [
467+
map_nested(
468+
function=function,
469+
data_struct=obj,
470+
num_proc=num_proc,
471+
parallel_min_length=parallel_min_length,
472+
types=types,
473+
)
474+
for obj in iterable
475+
]
476+
elif num_proc != -1 and num_proc <= 1 or len(iterable) < parallel_min_length:
466477
mapped = [
467478
_single_map_nested((function, obj, types, None, True, None))
468479
for obj in hf_tqdm(iterable, disable=disable_tqdm, desc=desc)

0 commit comments

Comments
 (0)