Skip to content

Commit 0f1f27c

Browse files
authored
Multithreaded downloads (#6794)
* multithreaded downloads * fix * fix again * fix tests * fix * fix 16 workers * enable multithreading only for small files * pin uv * fix tests * unpin uv * add HF_DATASETS_MULTITHREADING_MAX_WORKERS * rename _download_single * minor
1 parent 91b07b9 commit 0f1f27c

File tree

8 files changed

+156
-23
lines changed

8 files changed

+156
-23
lines changed

src/datasets/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,9 @@
182182
os.environ.get("HF_UPDATE_DOWNLOAD_COUNTS", "AUTO").upper() in ENV_VARS_TRUE_AND_AUTO_VALUES
183183
)
184184

185+
# For downloads and to check remote files metadata
186+
HF_DATASETS_MULTITHREADING_MAX_WORKERS = 16
187+
185188
# Remote dataset scripts support
186189
__HF_DATASETS_TRUST_REMOTE_CODE = os.environ.get("HF_DATASETS_TRUST_REMOTE_CODE", "1")
187190
HF_DATASETS_TRUST_REMOTE_CODE: Optional[bool] = (

src/datasets/data_files.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,9 +544,10 @@ def _get_single_origin_metadata(
544544

545545
def _get_origin_metadata(
546546
data_files: List[str],
547-
max_workers=64,
548547
download_config: Optional[DownloadConfig] = None,
548+
max_workers: Optional[int] = None,
549549
) -> Tuple[str]:
550+
max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS
550551
return thread_map(
551552
partial(_get_single_origin_metadata, download_config=download_config),
552553
data_files,

src/datasets/download/download_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class DownloadConfig:
5959
Key/value pairs to be passed on to the dataset file-system backend, if any.
6060
download_desc (`str`, *optional*):
6161
A description to be displayed alongside with the progress bar while downloading the files.
62+
disable_tqdm (`bool`, defaults to `False`):
63+
Whether to disable the individual files download progress bar
6264
"""
6365

6466
cache_dir: Optional[Union[str, Path]] = None
@@ -78,6 +80,7 @@ class DownloadConfig:
7880
ignore_url_params: bool = False
7981
storage_options: Dict[str, Any] = field(default_factory=dict)
8082
download_desc: Optional[str] = None
83+
disable_tqdm: bool = False
8184

8285
def __post_init__(self, use_auth_token):
8386
if use_auth_token != "deprecated":

src/datasets/download/download_manager.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import enum
1919
import io
20+
import multiprocessing
2021
import os
2122
import posixpath
2223
import tarfile
@@ -27,6 +28,10 @@
2728
from itertools import chain
2829
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
2930

31+
import fsspec
32+
from fsspec.core import url_to_fs
33+
from tqdm.contrib.concurrent import thread_map
34+
3035
from .. import config
3136
from ..utils import tqdm as hf_tqdm
3237
from ..utils.deprecation_utils import DeprecatedEnum, deprecated
@@ -39,7 +44,7 @@
3944
url_or_path_join,
4045
)
4146
from ..utils.info_utils import get_size_checksum_dict
42-
from ..utils.logging import get_logger
47+
from ..utils.logging import get_logger, tqdm
4348
from ..utils.py_utils import NestedDataStructure, map_nested, size_str
4449
from ..utils.track import TrackedIterable, tracked_str
4550
from .download_config import DownloadConfig
@@ -427,7 +432,7 @@ def download(self, url_or_urls):
427432
if download_config.download_desc is None:
428433
download_config.download_desc = "Downloading data"
429434

430-
download_func = partial(self._download, download_config=download_config)
435+
download_func = partial(self._download_batched, download_config=download_config)
431436

432437
start_time = datetime.now()
433438
with stack_multiprocessing_download_progress_bars():
@@ -437,6 +442,8 @@ def download(self, url_or_urls):
437442
map_tuple=True,
438443
num_proc=download_config.num_proc,
439444
desc="Downloading data files",
445+
batched=True,
446+
batch_size=-1,
440447
)
441448
duration = datetime.now() - start_time
442449
logger.info(f"Downloading took {duration.total_seconds() // 60} min")
@@ -451,7 +458,46 @@ def download(self, url_or_urls):
451458

452459
return downloaded_path_or_paths.data
453460

454-
def _download(self, url_or_filename: str, download_config: DownloadConfig) -> str:
461+
def _download_batched(
462+
self,
463+
url_or_filenames: List[str],
464+
download_config: DownloadConfig,
465+
) -> List[str]:
466+
if len(url_or_filenames) >= 16:
467+
download_config = download_config.copy()
468+
download_config.disable_tqdm = True
469+
download_func = partial(self._download_single, download_config=download_config)
470+
471+
fs: fsspec.AbstractFileSystem
472+
fs, path = url_to_fs(url_or_filenames[0], **download_config.storage_options)
473+
size = 0
474+
try:
475+
size = fs.info(path).get("size", 0)
476+
except Exception:
477+
pass
478+
max_workers = (
479+
config.HF_DATASETS_MULTITHREADING_MAX_WORKERS if size < (20 << 20) else 1
480+
) # enable multithreading if files are small
481+
482+
return thread_map(
483+
download_func,
484+
url_or_filenames,
485+
desc=download_config.download_desc or "Downloading",
486+
unit="files",
487+
position=multiprocessing.current_process()._identity[-1] # contains the ranks of subprocesses
488+
if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
489+
and multiprocessing.current_process()._identity
490+
else None,
491+
max_workers=max_workers,
492+
tqdm_class=tqdm,
493+
)
494+
else:
495+
return [
496+
self._download_single(url_or_filename, download_config=download_config)
497+
for url_or_filename in url_or_filenames
498+
]
499+
500+
def _download_single(self, url_or_filename: str, download_config: DownloadConfig) -> str:
455501
url_or_filename = str(url_or_filename)
456502
if is_relative_path(url_or_filename):
457503
# append the relative path to the base_path
@@ -539,7 +585,7 @@ def extract(self, path_or_paths, num_proc="deprecated"):
539585
)
540586
download_config = self.download_config.copy()
541587
download_config.extract_compressed_file = True
542-
extract_func = partial(self._download, download_config=download_config)
588+
extract_func = partial(self._download_single, download_config=download_config)
543589
extracted_paths = map_nested(
544590
extract_func,
545591
path_or_paths,

src/datasets/download/streaming_download_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,10 +1002,10 @@ def download(self, url_or_urls):
10021002
>>> downloaded_files = dl_manager.download('https://storage.googleapis.com/seldon-datasets/sentence_polarity_v1/rt-polaritydata.tar.gz')
10031003
```
10041004
"""
1005-
url_or_urls = map_nested(self._download, url_or_urls, map_tuple=True)
1005+
url_or_urls = map_nested(self._download_single, url_or_urls, map_tuple=True)
10061006
return url_or_urls
10071007

1008-
def _download(self, urlpath: str) -> str:
1008+
def _download_single(self, urlpath: str) -> str:
10091009
urlpath = str(urlpath)
10101010
if is_relative_path(urlpath):
10111011
# append the relative path to the base_path

src/datasets/parallel/parallel.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class ParallelBackendConfig:
1414

1515

1616
@experimental
17-
def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
17+
def parallel_map(function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func):
1818
"""
1919
**Experimental.** Apply a function to iterable elements in parallel, where the implementation uses either
2020
multiprocessing.Pool or joblib for parallelization.
@@ -32,21 +32,25 @@ def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single
3232
"""
3333
if ParallelBackendConfig.backend_name is None:
3434
return _map_with_multiprocessing_pool(
35-
function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func
35+
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
3636
)
3737

38-
return _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func)
38+
return _map_with_joblib(
39+
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
40+
)
3941

4042

41-
def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
43+
def _map_with_multiprocessing_pool(
44+
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
45+
):
4246
num_proc = num_proc if num_proc <= len(iterable) else len(iterable)
4347
split_kwds = [] # We organize the splits ourselve (contiguous splits)
4448
for index in range(num_proc):
4549
div = len(iterable) // num_proc
4650
mod = len(iterable) % num_proc
4751
start = div * index + min(index, mod)
4852
end = start + div + (1 if index < mod else 0)
49-
split_kwds.append((function, iterable[start:end], types, index, disable_tqdm, desc))
53+
split_kwds.append((function, iterable[start:end], batched, batch_size, types, index, disable_tqdm, desc))
5054

5155
if len(iterable) != sum(len(i[1]) for i in split_kwds):
5256
raise ValueError(
@@ -70,14 +74,17 @@ def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_
7074
return mapped
7175

7276

73-
def _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
77+
def _map_with_joblib(
78+
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
79+
):
7480
# progress bar is not yet supported for _map_with_joblib, because tqdm couldn't accurately be applied to joblib,
7581
# and it requires monkey-patching joblib internal classes which is subject to change
7682
import joblib
7783

7884
with joblib.parallel_backend(ParallelBackendConfig.backend_name, n_jobs=num_proc):
7985
return joblib.Parallel()(
80-
joblib.delayed(single_map_nested_func)((function, obj, types, None, True, None)) for obj in iterable
86+
joblib.delayed(single_map_nested_func)((function, obj, batched, batch_size, types, None, True, None))
87+
for obj in iterable
8188
)
8289

8390

src/datasets/utils/file_utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def cached_path(
201201
ignore_url_params=download_config.ignore_url_params,
202202
storage_options=download_config.storage_options,
203203
download_desc=download_config.download_desc,
204+
disable_tqdm=download_config.disable_tqdm,
204205
)
205206
elif os.path.exists(url_or_filename):
206207
# File, and it exists.
@@ -335,7 +336,7 @@ def __init__(self, tqdm_kwargs=None, *args, **kwargs):
335336
super().__init__(tqdm_kwargs, *args, **kwargs)
336337

337338

338-
def fsspec_get(url, temp_file, storage_options=None, desc=None):
339+
def fsspec_get(url, temp_file, storage_options=None, desc=None, disable_tqdm=False):
339340
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
340341
fs, path = url_to_fs(url, **(storage_options or {}))
341342
callback = TqdmCallback(
@@ -347,6 +348,7 @@ def fsspec_get(url, temp_file, storage_options=None, desc=None):
347348
if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
348349
and multiprocessing.current_process()._identity
349350
else None,
351+
"disable": disable_tqdm,
350352
}
351353
)
352354
fs.get_file(path, temp_file.name, callback=callback)
@@ -373,7 +375,16 @@ def ftp_get(url, temp_file, timeout=10.0):
373375

374376

375377
def http_get(
376-
url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=100.0, max_retries=0, desc=None
378+
url,
379+
temp_file,
380+
proxies=None,
381+
resume_size=0,
382+
headers=None,
383+
cookies=None,
384+
timeout=100.0,
385+
max_retries=0,
386+
desc=None,
387+
disable_tqdm=False,
377388
) -> Optional[requests.Response]:
378389
headers = dict(headers) if headers is not None else {}
379390
headers["user-agent"] = get_datasets_user_agent(user_agent=headers.get("user-agent"))
@@ -405,6 +416,7 @@ def http_get(
405416
if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
406417
and multiprocessing.current_process()._identity
407418
else None,
419+
disable=disable_tqdm,
408420
) as progress:
409421
for chunk in response.iter_content(chunk_size=1024):
410422
progress.update(len(chunk))
@@ -464,6 +476,7 @@ def get_from_cache(
464476
ignore_url_params=False,
465477
storage_options=None,
466478
download_desc=None,
479+
disable_tqdm=False,
467480
) -> str:
468481
"""
469482
Given a URL, look for the corresponding file in the local cache.
@@ -629,7 +642,9 @@ def temp_file_manager(mode="w+b"):
629642
if scheme == "ftp":
630643
ftp_get(url, temp_file)
631644
elif scheme not in ("http", "https"):
632-
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
645+
fsspec_get(
646+
url, temp_file, storage_options=storage_options, desc=download_desc, disable_tqdm=disable_tqdm
647+
)
633648
else:
634649
http_get(
635650
url,
@@ -640,6 +655,7 @@ def temp_file_manager(mode="w+b"):
640655
cookies=cookies,
641656
max_retries=max_retries,
642657
desc=download_desc,
658+
disable_tqdm=disable_tqdm,
643659
)
644660

645661
logger.info(f"storing {url} in cache at {cache_path}")

0 commit comments

Comments
 (0)