diff --git a/src/datasets/config.py b/src/datasets/config.py
index ce749aea4ff..9668dfbd91e 100644
--- a/src/datasets/config.py
+++ b/src/datasets/config.py
@@ -182,6 +182,9 @@
os.environ.get("HF_UPDATE_DOWNLOAD_COUNTS", "AUTO").upper() in ENV_VARS_TRUE_AND_AUTO_VALUES
)
+# For downloads and to check remote files metadata
+HF_DATASETS_MULTITHREADING_MAX_WORKERS = 16
+
# Remote dataset scripts support
__HF_DATASETS_TRUST_REMOTE_CODE = os.environ.get("HF_DATASETS_TRUST_REMOTE_CODE", "1")
HF_DATASETS_TRUST_REMOTE_CODE: Optional[bool] = (
diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py
index 4a918595408..60dc491fe06 100644
--- a/src/datasets/data_files.py
+++ b/src/datasets/data_files.py
@@ -544,9 +544,10 @@ def _get_single_origin_metadata(
def _get_origin_metadata(
data_files: List[str],
- max_workers=64,
download_config: Optional[DownloadConfig] = None,
+ max_workers: Optional[int] = None,
) -> Tuple[str]:
+ max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS
return thread_map(
partial(_get_single_origin_metadata, download_config=download_config),
data_files,
diff --git a/src/datasets/download/download_config.py b/src/datasets/download/download_config.py
index 8ba032f75ba..21019b903f1 100644
--- a/src/datasets/download/download_config.py
+++ b/src/datasets/download/download_config.py
@@ -59,6 +59,8 @@ class DownloadConfig:
Key/value pairs to be passed on to the dataset file-system backend, if any.
download_desc (`str`, *optional*):
A description to be displayed alongside with the progress bar while downloading the files.
+ disable_tqdm (`bool`, defaults to `False`):
+ Whether to disable the individual files download progress bar
"""
cache_dir: Optional[Union[str, Path]] = None
@@ -78,6 +80,7 @@ class DownloadConfig:
ignore_url_params: bool = False
storage_options: Dict[str, Any] = field(default_factory=dict)
download_desc: Optional[str] = None
+ disable_tqdm: bool = False
def __post_init__(self, use_auth_token):
if use_auth_token != "deprecated":
diff --git a/src/datasets/download/download_manager.py b/src/datasets/download/download_manager.py
index 6c838753b99..fe24fc5277e 100644
--- a/src/datasets/download/download_manager.py
+++ b/src/datasets/download/download_manager.py
@@ -17,6 +17,7 @@
import enum
import io
+import multiprocessing
import os
import posixpath
import tarfile
@@ -27,6 +28,10 @@
from itertools import chain
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
+import fsspec
+from fsspec.core import url_to_fs
+from tqdm.contrib.concurrent import thread_map
+
from .. import config
from ..utils import tqdm as hf_tqdm
from ..utils.deprecation_utils import DeprecatedEnum, deprecated
@@ -39,7 +44,7 @@
url_or_path_join,
)
from ..utils.info_utils import get_size_checksum_dict
-from ..utils.logging import get_logger
+from ..utils.logging import get_logger, tqdm
from ..utils.py_utils import NestedDataStructure, map_nested, size_str
from ..utils.track import TrackedIterable, tracked_str
from .download_config import DownloadConfig
@@ -427,7 +432,7 @@ def download(self, url_or_urls):
if download_config.download_desc is None:
download_config.download_desc = "Downloading data"
- download_func = partial(self._download, download_config=download_config)
+ download_func = partial(self._download_batched, download_config=download_config)
start_time = datetime.now()
with stack_multiprocessing_download_progress_bars():
@@ -437,6 +442,8 @@ def download(self, url_or_urls):
map_tuple=True,
num_proc=download_config.num_proc,
desc="Downloading data files",
+ batched=True,
+ batch_size=-1,
)
duration = datetime.now() - start_time
logger.info(f"Downloading took {duration.total_seconds() // 60} min")
@@ -451,7 +458,46 @@ def download(self, url_or_urls):
return downloaded_path_or_paths.data
- def _download(self, url_or_filename: str, download_config: DownloadConfig) -> str:
+ def _download_batched(
+ self,
+ url_or_filenames: List[str],
+ download_config: DownloadConfig,
+ ) -> List[str]:
+ if len(url_or_filenames) >= 16:
+ download_config = download_config.copy()
+ download_config.disable_tqdm = True
+ download_func = partial(self._download_single, download_config=download_config)
+
+ fs: fsspec.AbstractFileSystem
+ fs, path = url_to_fs(url_or_filenames[0], **download_config.storage_options)
+ size = 0
+ try:
+ size = fs.info(path).get("size", 0)
+ except Exception:
+ pass
+ max_workers = (
+ config.HF_DATASETS_MULTITHREADING_MAX_WORKERS if size < (20 << 20) else 1
+ ) # enable multithreading if files are small
+
+ return thread_map(
+ download_func,
+ url_or_filenames,
+ desc=download_config.download_desc or "Downloading",
+ unit="files",
+ position=multiprocessing.current_process()._identity[-1] # contains the ranks of subprocesses
+ if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
+ and multiprocessing.current_process()._identity
+ else None,
+ max_workers=max_workers,
+ tqdm_class=tqdm,
+ )
+ else:
+ return [
+ self._download_single(url_or_filename, download_config=download_config)
+ for url_or_filename in url_or_filenames
+ ]
+
+ def _download_single(self, url_or_filename: str, download_config: DownloadConfig) -> str:
url_or_filename = str(url_or_filename)
if is_relative_path(url_or_filename):
# append the relative path to the base_path
@@ -539,7 +585,7 @@ def extract(self, path_or_paths, num_proc="deprecated"):
)
download_config = self.download_config.copy()
download_config.extract_compressed_file = True
- extract_func = partial(self._download, download_config=download_config)
+ extract_func = partial(self._download_single, download_config=download_config)
extracted_paths = map_nested(
extract_func,
path_or_paths,
diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py
index 819150756fe..d7653c00da5 100644
--- a/src/datasets/download/streaming_download_manager.py
+++ b/src/datasets/download/streaming_download_manager.py
@@ -1002,10 +1002,10 @@ def download(self, url_or_urls):
>>> downloaded_files = dl_manager.download('https://storage.googleapis.com/seldon-datasets/sentence_polarity_v1/rt-polaritydata.tar.gz')
```
"""
- url_or_urls = map_nested(self._download, url_or_urls, map_tuple=True)
+ url_or_urls = map_nested(self._download_single, url_or_urls, map_tuple=True)
return url_or_urls
- def _download(self, urlpath: str) -> str:
+ def _download_single(self, urlpath: str) -> str:
urlpath = str(urlpath)
if is_relative_path(urlpath):
# append the relative path to the base_path
diff --git a/src/datasets/parallel/parallel.py b/src/datasets/parallel/parallel.py
index 4e1a8546c58..5cad2c48ba2 100644
--- a/src/datasets/parallel/parallel.py
+++ b/src/datasets/parallel/parallel.py
@@ -14,7 +14,7 @@ class ParallelBackendConfig:
@experimental
-def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
+def parallel_map(function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func):
"""
**Experimental.** Apply a function to iterable elements in parallel, where the implementation uses either
multiprocessing.Pool or joblib for parallelization.
@@ -32,13 +32,17 @@ def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single
"""
if ParallelBackendConfig.backend_name is None:
return _map_with_multiprocessing_pool(
- function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func
+ function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
)
- return _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func)
+ return _map_with_joblib(
+ function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
+ )
-def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
+def _map_with_multiprocessing_pool(
+ function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
+):
num_proc = num_proc if num_proc <= len(iterable) else len(iterable)
split_kwds = [] # We organize the splits ourselve (contiguous splits)
for index in range(num_proc):
@@ -46,7 +50,7 @@ def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_
mod = len(iterable) % num_proc
start = div * index + min(index, mod)
end = start + div + (1 if index < mod else 0)
- split_kwds.append((function, iterable[start:end], types, index, disable_tqdm, desc))
+ split_kwds.append((function, iterable[start:end], batched, batch_size, types, index, disable_tqdm, desc))
if len(iterable) != sum(len(i[1]) for i in split_kwds):
raise ValueError(
@@ -70,14 +74,17 @@ def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_
return mapped
-def _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
+def _map_with_joblib(
+ function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
+):
# progress bar is not yet supported for _map_with_joblib, because tqdm couldn't accurately be applied to joblib,
# and it requires monkey-patching joblib internal classes which is subject to change
import joblib
with joblib.parallel_backend(ParallelBackendConfig.backend_name, n_jobs=num_proc):
return joblib.Parallel()(
- joblib.delayed(single_map_nested_func)((function, obj, types, None, True, None)) for obj in iterable
+ joblib.delayed(single_map_nested_func)((function, obj, batched, batch_size, types, None, True, None))
+ for obj in iterable
)
diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py
index 6384c521124..750779a2cf0 100644
--- a/src/datasets/utils/file_utils.py
+++ b/src/datasets/utils/file_utils.py
@@ -201,6 +201,7 @@ def cached_path(
ignore_url_params=download_config.ignore_url_params,
storage_options=download_config.storage_options,
download_desc=download_config.download_desc,
+ disable_tqdm=download_config.disable_tqdm,
)
elif os.path.exists(url_or_filename):
# File, and it exists.
@@ -335,7 +336,7 @@ def __init__(self, tqdm_kwargs=None, *args, **kwargs):
super().__init__(tqdm_kwargs, *args, **kwargs)
-def fsspec_get(url, temp_file, storage_options=None, desc=None):
+def fsspec_get(url, temp_file, storage_options=None, desc=None, disable_tqdm=False):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
fs, path = url_to_fs(url, **(storage_options or {}))
callback = TqdmCallback(
@@ -347,6 +348,7 @@ def fsspec_get(url, temp_file, storage_options=None, desc=None):
if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
and multiprocessing.current_process()._identity
else None,
+ "disable": disable_tqdm,
}
)
fs.get_file(path, temp_file.name, callback=callback)
@@ -373,7 +375,16 @@ def ftp_get(url, temp_file, timeout=10.0):
def http_get(
- url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=100.0, max_retries=0, desc=None
+ url,
+ temp_file,
+ proxies=None,
+ resume_size=0,
+ headers=None,
+ cookies=None,
+ timeout=100.0,
+ max_retries=0,
+ desc=None,
+ disable_tqdm=False,
) -> Optional[requests.Response]:
headers = dict(headers) if headers is not None else {}
headers["user-agent"] = get_datasets_user_agent(user_agent=headers.get("user-agent"))
@@ -405,6 +416,7 @@ def http_get(
if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
and multiprocessing.current_process()._identity
else None,
+ disable=disable_tqdm,
) as progress:
for chunk in response.iter_content(chunk_size=1024):
progress.update(len(chunk))
@@ -464,6 +476,7 @@ def get_from_cache(
ignore_url_params=False,
storage_options=None,
download_desc=None,
+ disable_tqdm=False,
) -> str:
"""
Given a URL, look for the corresponding file in the local cache.
@@ -629,7 +642,9 @@ def temp_file_manager(mode="w+b"):
if scheme == "ftp":
ftp_get(url, temp_file)
elif scheme not in ("http", "https"):
- fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
+ fsspec_get(
+ url, temp_file, storage_options=storage_options, desc=download_desc, disable_tqdm=disable_tqdm
+ )
else:
http_get(
url,
@@ -640,6 +655,7 @@ def temp_file_manager(mode="w+b"):
cookies=cookies,
max_retries=max_retries,
desc=download_desc,
+ disable_tqdm=disable_tqdm,
)
logger.info(f"storing {url} in cache at {cache_path}")
diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py
index 779d7dabc24..1304a971667 100644
--- a/src/datasets/utils/py_utils.py
+++ b/src/datasets/utils/py_utils.py
@@ -363,11 +363,18 @@ def __get__(self, obj, objtype=None):
def _single_map_nested(args):
"""Apply a function recursively to each element of a nested data struct."""
- function, data_struct, types, rank, disable_tqdm, desc = args
+ function, data_struct, batched, batch_size, types, rank, disable_tqdm, desc = args
# Singleton first to spare some computation
if not isinstance(data_struct, dict) and not isinstance(data_struct, types):
return function(data_struct)
+ if (
+ batched
+ and not isinstance(data_struct, dict)
+ and isinstance(data_struct, types)
+ and all(not isinstance(v, types) for v in data_struct)
+ ):
+ return [mapped_item for batch in iter_batched(data_struct, batch_size) for mapped_item in function(batch)]
# Reduce logging to keep things readable in multiprocessing with tqdm
if rank is not None and logging.get_verbosity() < logging.WARNING:
@@ -382,9 +389,11 @@ def _single_map_nested(args):
pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc
with hf_tqdm(pbar_iterable, disable=disable_tqdm, position=rank, unit="obj", desc=pbar_desc) as pbar:
if isinstance(data_struct, dict):
- return {k: _single_map_nested((function, v, types, None, True, None)) for k, v in pbar}
+ return {
+ k: _single_map_nested((function, v, batched, batch_size, types, None, True, None)) for k, v in pbar
+ }
else:
- mapped = [_single_map_nested((function, v, types, None, True, None)) for v in pbar]
+ mapped = [_single_map_nested((function, v, batched, batch_size, types, None, True, None)) for v in pbar]
if isinstance(data_struct, list):
return mapped
elif isinstance(data_struct, tuple):
@@ -402,6 +411,8 @@ def map_nested(
map_numpy: bool = False,
num_proc: Optional[int] = None,
parallel_min_length: int = 2,
+ batched: bool = False,
+ batch_size: Optional[int] = 1000,
types: Optional[tuple] = None,
disable_tqdm: bool = True,
desc: Optional[str] = None,
@@ -432,9 +443,18 @@ def map_nested(
map_numpy (`bool, default `False`): Whether also apply `function` recursively to `numpy.array` elements (besides
`dict` values).
num_proc (`int`, *optional*): Number of processes.
+ The level in the data struct used for multiprocessing is the first level that has smaller sub-structs,
+ starting from the root.
parallel_min_length (`int`, default `2`): Minimum length of `data_struct` required for parallel
processing.
+ batched (`bool`, defaults to `False`):
+ Provide batch of items to `function`.
+
+ batch_size (`int`, *optional*, defaults to `1000`):
+ Number of items per batch provided to `function` if `batched=True`.
+ If `batch_size <= 0` or `batch_size == None`, provide the full iterable as a single batch to `function`.
+
types (`tuple`, *optional*): Additional types (besides `dict` values) to apply `function` recursively to their
elements.
disable_tqdm (`bool`, default `True`): Whether to disable the tqdm progressbar.
@@ -456,7 +476,12 @@ def map_nested(
# Singleton
if not isinstance(data_struct, dict) and not isinstance(data_struct, types):
- return function(data_struct)
+ if batched:
+ data_struct = [data_struct]
+ mapped = function(data_struct)
+ if batched:
+ mapped = mapped[0]
+ return mapped
iterable = list(data_struct.values()) if isinstance(data_struct, dict) else data_struct
@@ -469,15 +494,23 @@ def map_nested(
data_struct=obj,
num_proc=num_proc,
parallel_min_length=parallel_min_length,
+ batched=batched,
+ batch_size=batch_size,
types=types,
)
for obj in iterable
]
elif num_proc != -1 and num_proc <= 1 or len(iterable) < parallel_min_length:
+ if batched:
+ if batch_size is None or batch_size <= 0:
+ batch_size = max(len(iterable) // num_proc + int(len(iterable) % num_proc > 0), 1)
+ iterable = list(iter_batched(iterable, batch_size))
mapped = [
- _single_map_nested((function, obj, types, None, True, None))
+ _single_map_nested((function, obj, batched, batch_size, types, None, True, None))
for obj in hf_tqdm(iterable, disable=disable_tqdm, desc=desc)
]
+ if batched:
+ mapped = [mapped_item for mapped_batch in mapped for mapped_item in mapped_batch]
else:
with warnings.catch_warnings():
warnings.filterwarnings(
@@ -485,7 +518,15 @@ def map_nested(
message=".* is experimental and might be subject to breaking changes in the future\\.$",
category=UserWarning,
)
- mapped = parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, _single_map_nested)
+ if batched:
+ if batch_size is None or batch_size <= 0:
+ batch_size = len(iterable) // num_proc + int(len(iterable) % num_proc > 0)
+ iterable = list(iter_batched(iterable, batch_size))
+ mapped = parallel_map(
+ function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, _single_map_nested
+ )
+ if batched:
+ mapped = [mapped_item for mapped_batch in mapped for mapped_item in mapped_batch]
if isinstance(data_struct, dict):
return dict(zip(data_struct.keys(), mapped))
@@ -672,3 +713,19 @@ def iflatmap_unordered(
if not pool_changed:
# we get the result in case there's an error to raise
[async_result.get(timeout=0.05) for async_result in async_results]
+
+
+T = TypeVar("T")
+
+
+def iter_batched(iterable: Iterable[T], n: int) -> Iterable[List[T]]:
+ if n < 1:
+ raise ValueError(f"Invalid batch size {n}")
+ batch = []
+ for item in iterable:
+ batch.append(item)
+ if len(batch) == n:
+ yield batch
+ batch = []
+ if batch:
+ yield batch
diff --git a/tests/test_load.py b/tests/test_load.py
index 6b85984d25b..c59ce7d5e6c 100644
--- a/tests/test_load.py
+++ b/tests/test_load.py
@@ -842,10 +842,10 @@ def test_HubDatasetModuleFactoryWithParquetExport(self):
)
assert module_factory_result.builder_configs_parameters.builder_configs[0].data_files == {
"train": [
- "hf://datasets/hf-internal-testing/dataset_with_script@8f965694d611974ef8661618ada1b5aeb1072915/default/train/0000.parquet"
+ "hf://datasets/hf-internal-testing/dataset_with_script@0c97cd1168f31e683059ddcf0703e3f45d9007c4/default/train/0000.parquet"
],
"validation": [
- "hf://datasets/hf-internal-testing/dataset_with_script@8f965694d611974ef8661618ada1b5aeb1072915/default/validation/0000.parquet"
+ "hf://datasets/hf-internal-testing/dataset_with_script@0c97cd1168f31e683059ddcf0703e3f45d9007c4/default/validation/0000.parquet"
],
}