diff --git a/src/datasets/config.py b/src/datasets/config.py index ce749aea4ff..9668dfbd91e 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -182,6 +182,9 @@ os.environ.get("HF_UPDATE_DOWNLOAD_COUNTS", "AUTO").upper() in ENV_VARS_TRUE_AND_AUTO_VALUES ) +# For downloads and to check remote files metadata +HF_DATASETS_MULTITHREADING_MAX_WORKERS = 16 + # Remote dataset scripts support __HF_DATASETS_TRUST_REMOTE_CODE = os.environ.get("HF_DATASETS_TRUST_REMOTE_CODE", "1") HF_DATASETS_TRUST_REMOTE_CODE: Optional[bool] = ( diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 4a918595408..60dc491fe06 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -544,9 +544,10 @@ def _get_single_origin_metadata( def _get_origin_metadata( data_files: List[str], - max_workers=64, download_config: Optional[DownloadConfig] = None, + max_workers: Optional[int] = None, ) -> Tuple[str]: + max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS return thread_map( partial(_get_single_origin_metadata, download_config=download_config), data_files, diff --git a/src/datasets/download/download_config.py b/src/datasets/download/download_config.py index 8ba032f75ba..21019b903f1 100644 --- a/src/datasets/download/download_config.py +++ b/src/datasets/download/download_config.py @@ -59,6 +59,8 @@ class DownloadConfig: Key/value pairs to be passed on to the dataset file-system backend, if any. download_desc (`str`, *optional*): A description to be displayed alongside with the progress bar while downloading the files. + disable_tqdm (`bool`, defaults to `False`): + Whether to disable the individual files download progress bar """ cache_dir: Optional[Union[str, Path]] = None @@ -78,6 +80,7 @@ class DownloadConfig: ignore_url_params: bool = False storage_options: Dict[str, Any] = field(default_factory=dict) download_desc: Optional[str] = None + disable_tqdm: bool = False def __post_init__(self, use_auth_token): if use_auth_token != "deprecated": diff --git a/src/datasets/download/download_manager.py b/src/datasets/download/download_manager.py index 6c838753b99..fe24fc5277e 100644 --- a/src/datasets/download/download_manager.py +++ b/src/datasets/download/download_manager.py @@ -17,6 +17,7 @@ import enum import io +import multiprocessing import os import posixpath import tarfile @@ -27,6 +28,10 @@ from itertools import chain from typing import Callable, Dict, Generator, List, Optional, Tuple, Union +import fsspec +from fsspec.core import url_to_fs +from tqdm.contrib.concurrent import thread_map + from .. import config from ..utils import tqdm as hf_tqdm from ..utils.deprecation_utils import DeprecatedEnum, deprecated @@ -39,7 +44,7 @@ url_or_path_join, ) from ..utils.info_utils import get_size_checksum_dict -from ..utils.logging import get_logger +from ..utils.logging import get_logger, tqdm from ..utils.py_utils import NestedDataStructure, map_nested, size_str from ..utils.track import TrackedIterable, tracked_str from .download_config import DownloadConfig @@ -427,7 +432,7 @@ def download(self, url_or_urls): if download_config.download_desc is None: download_config.download_desc = "Downloading data" - download_func = partial(self._download, download_config=download_config) + download_func = partial(self._download_batched, download_config=download_config) start_time = datetime.now() with stack_multiprocessing_download_progress_bars(): @@ -437,6 +442,8 @@ def download(self, url_or_urls): map_tuple=True, num_proc=download_config.num_proc, desc="Downloading data files", + batched=True, + batch_size=-1, ) duration = datetime.now() - start_time logger.info(f"Downloading took {duration.total_seconds() // 60} min") @@ -451,7 +458,46 @@ def download(self, url_or_urls): return downloaded_path_or_paths.data - def _download(self, url_or_filename: str, download_config: DownloadConfig) -> str: + def _download_batched( + self, + url_or_filenames: List[str], + download_config: DownloadConfig, + ) -> List[str]: + if len(url_or_filenames) >= 16: + download_config = download_config.copy() + download_config.disable_tqdm = True + download_func = partial(self._download_single, download_config=download_config) + + fs: fsspec.AbstractFileSystem + fs, path = url_to_fs(url_or_filenames[0], **download_config.storage_options) + size = 0 + try: + size = fs.info(path).get("size", 0) + except Exception: + pass + max_workers = ( + config.HF_DATASETS_MULTITHREADING_MAX_WORKERS if size < (20 << 20) else 1 + ) # enable multithreading if files are small + + return thread_map( + download_func, + url_or_filenames, + desc=download_config.download_desc or "Downloading", + unit="files", + position=multiprocessing.current_process()._identity[-1] # contains the ranks of subprocesses + if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1" + and multiprocessing.current_process()._identity + else None, + max_workers=max_workers, + tqdm_class=tqdm, + ) + else: + return [ + self._download_single(url_or_filename, download_config=download_config) + for url_or_filename in url_or_filenames + ] + + def _download_single(self, url_or_filename: str, download_config: DownloadConfig) -> str: url_or_filename = str(url_or_filename) if is_relative_path(url_or_filename): # append the relative path to the base_path @@ -539,7 +585,7 @@ def extract(self, path_or_paths, num_proc="deprecated"): ) download_config = self.download_config.copy() download_config.extract_compressed_file = True - extract_func = partial(self._download, download_config=download_config) + extract_func = partial(self._download_single, download_config=download_config) extracted_paths = map_nested( extract_func, path_or_paths, diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py index 819150756fe..d7653c00da5 100644 --- a/src/datasets/download/streaming_download_manager.py +++ b/src/datasets/download/streaming_download_manager.py @@ -1002,10 +1002,10 @@ def download(self, url_or_urls): >>> downloaded_files = dl_manager.download('https://storage.googleapis.com/seldon-datasets/sentence_polarity_v1/rt-polaritydata.tar.gz') ``` """ - url_or_urls = map_nested(self._download, url_or_urls, map_tuple=True) + url_or_urls = map_nested(self._download_single, url_or_urls, map_tuple=True) return url_or_urls - def _download(self, urlpath: str) -> str: + def _download_single(self, urlpath: str) -> str: urlpath = str(urlpath) if is_relative_path(urlpath): # append the relative path to the base_path diff --git a/src/datasets/parallel/parallel.py b/src/datasets/parallel/parallel.py index 4e1a8546c58..5cad2c48ba2 100644 --- a/src/datasets/parallel/parallel.py +++ b/src/datasets/parallel/parallel.py @@ -14,7 +14,7 @@ class ParallelBackendConfig: @experimental -def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func): +def parallel_map(function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func): """ **Experimental.** Apply a function to iterable elements in parallel, where the implementation uses either multiprocessing.Pool or joblib for parallelization. @@ -32,13 +32,17 @@ def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single """ if ParallelBackendConfig.backend_name is None: return _map_with_multiprocessing_pool( - function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func + function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func ) - return _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func) + return _map_with_joblib( + function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func + ) -def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func): +def _map_with_multiprocessing_pool( + function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func +): num_proc = num_proc if num_proc <= len(iterable) else len(iterable) split_kwds = [] # We organize the splits ourselve (contiguous splits) for index in range(num_proc): @@ -46,7 +50,7 @@ def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_ mod = len(iterable) % num_proc start = div * index + min(index, mod) end = start + div + (1 if index < mod else 0) - split_kwds.append((function, iterable[start:end], types, index, disable_tqdm, desc)) + split_kwds.append((function, iterable[start:end], batched, batch_size, types, index, disable_tqdm, desc)) if len(iterable) != sum(len(i[1]) for i in split_kwds): raise ValueError( @@ -70,14 +74,17 @@ def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_ return mapped -def _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func): +def _map_with_joblib( + function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func +): # progress bar is not yet supported for _map_with_joblib, because tqdm couldn't accurately be applied to joblib, # and it requires monkey-patching joblib internal classes which is subject to change import joblib with joblib.parallel_backend(ParallelBackendConfig.backend_name, n_jobs=num_proc): return joblib.Parallel()( - joblib.delayed(single_map_nested_func)((function, obj, types, None, True, None)) for obj in iterable + joblib.delayed(single_map_nested_func)((function, obj, batched, batch_size, types, None, True, None)) + for obj in iterable ) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 6384c521124..750779a2cf0 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -201,6 +201,7 @@ def cached_path( ignore_url_params=download_config.ignore_url_params, storage_options=download_config.storage_options, download_desc=download_config.download_desc, + disable_tqdm=download_config.disable_tqdm, ) elif os.path.exists(url_or_filename): # File, and it exists. @@ -335,7 +336,7 @@ def __init__(self, tqdm_kwargs=None, *args, **kwargs): super().__init__(tqdm_kwargs, *args, **kwargs) -def fsspec_get(url, temp_file, storage_options=None, desc=None): +def fsspec_get(url, temp_file, storage_options=None, desc=None, disable_tqdm=False): _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") fs, path = url_to_fs(url, **(storage_options or {})) callback = TqdmCallback( @@ -347,6 +348,7 @@ def fsspec_get(url, temp_file, storage_options=None, desc=None): if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1" and multiprocessing.current_process()._identity else None, + "disable": disable_tqdm, } ) fs.get_file(path, temp_file.name, callback=callback) @@ -373,7 +375,16 @@ def ftp_get(url, temp_file, timeout=10.0): def http_get( - url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=100.0, max_retries=0, desc=None + url, + temp_file, + proxies=None, + resume_size=0, + headers=None, + cookies=None, + timeout=100.0, + max_retries=0, + desc=None, + disable_tqdm=False, ) -> Optional[requests.Response]: headers = dict(headers) if headers is not None else {} headers["user-agent"] = get_datasets_user_agent(user_agent=headers.get("user-agent")) @@ -405,6 +416,7 @@ def http_get( if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1" and multiprocessing.current_process()._identity else None, + disable=disable_tqdm, ) as progress: for chunk in response.iter_content(chunk_size=1024): progress.update(len(chunk)) @@ -464,6 +476,7 @@ def get_from_cache( ignore_url_params=False, storage_options=None, download_desc=None, + disable_tqdm=False, ) -> str: """ Given a URL, look for the corresponding file in the local cache. @@ -629,7 +642,9 @@ def temp_file_manager(mode="w+b"): if scheme == "ftp": ftp_get(url, temp_file) elif scheme not in ("http", "https"): - fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc) + fsspec_get( + url, temp_file, storage_options=storage_options, desc=download_desc, disable_tqdm=disable_tqdm + ) else: http_get( url, @@ -640,6 +655,7 @@ def temp_file_manager(mode="w+b"): cookies=cookies, max_retries=max_retries, desc=download_desc, + disable_tqdm=disable_tqdm, ) logger.info(f"storing {url} in cache at {cache_path}") diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index 779d7dabc24..1304a971667 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -363,11 +363,18 @@ def __get__(self, obj, objtype=None): def _single_map_nested(args): """Apply a function recursively to each element of a nested data struct.""" - function, data_struct, types, rank, disable_tqdm, desc = args + function, data_struct, batched, batch_size, types, rank, disable_tqdm, desc = args # Singleton first to spare some computation if not isinstance(data_struct, dict) and not isinstance(data_struct, types): return function(data_struct) + if ( + batched + and not isinstance(data_struct, dict) + and isinstance(data_struct, types) + and all(not isinstance(v, types) for v in data_struct) + ): + return [mapped_item for batch in iter_batched(data_struct, batch_size) for mapped_item in function(batch)] # Reduce logging to keep things readable in multiprocessing with tqdm if rank is not None and logging.get_verbosity() < logging.WARNING: @@ -382,9 +389,11 @@ def _single_map_nested(args): pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc with hf_tqdm(pbar_iterable, disable=disable_tqdm, position=rank, unit="obj", desc=pbar_desc) as pbar: if isinstance(data_struct, dict): - return {k: _single_map_nested((function, v, types, None, True, None)) for k, v in pbar} + return { + k: _single_map_nested((function, v, batched, batch_size, types, None, True, None)) for k, v in pbar + } else: - mapped = [_single_map_nested((function, v, types, None, True, None)) for v in pbar] + mapped = [_single_map_nested((function, v, batched, batch_size, types, None, True, None)) for v in pbar] if isinstance(data_struct, list): return mapped elif isinstance(data_struct, tuple): @@ -402,6 +411,8 @@ def map_nested( map_numpy: bool = False, num_proc: Optional[int] = None, parallel_min_length: int = 2, + batched: bool = False, + batch_size: Optional[int] = 1000, types: Optional[tuple] = None, disable_tqdm: bool = True, desc: Optional[str] = None, @@ -432,9 +443,18 @@ def map_nested( map_numpy (`bool, default `False`): Whether also apply `function` recursively to `numpy.array` elements (besides `dict` values). num_proc (`int`, *optional*): Number of processes. + The level in the data struct used for multiprocessing is the first level that has smaller sub-structs, + starting from the root. parallel_min_length (`int`, default `2`): Minimum length of `data_struct` required for parallel processing. + batched (`bool`, defaults to `False`): + Provide batch of items to `function`. + + batch_size (`int`, *optional*, defaults to `1000`): + Number of items per batch provided to `function` if `batched=True`. + If `batch_size <= 0` or `batch_size == None`, provide the full iterable as a single batch to `function`. + types (`tuple`, *optional*): Additional types (besides `dict` values) to apply `function` recursively to their elements. disable_tqdm (`bool`, default `True`): Whether to disable the tqdm progressbar. @@ -456,7 +476,12 @@ def map_nested( # Singleton if not isinstance(data_struct, dict) and not isinstance(data_struct, types): - return function(data_struct) + if batched: + data_struct = [data_struct] + mapped = function(data_struct) + if batched: + mapped = mapped[0] + return mapped iterable = list(data_struct.values()) if isinstance(data_struct, dict) else data_struct @@ -469,15 +494,23 @@ def map_nested( data_struct=obj, num_proc=num_proc, parallel_min_length=parallel_min_length, + batched=batched, + batch_size=batch_size, types=types, ) for obj in iterable ] elif num_proc != -1 and num_proc <= 1 or len(iterable) < parallel_min_length: + if batched: + if batch_size is None or batch_size <= 0: + batch_size = max(len(iterable) // num_proc + int(len(iterable) % num_proc > 0), 1) + iterable = list(iter_batched(iterable, batch_size)) mapped = [ - _single_map_nested((function, obj, types, None, True, None)) + _single_map_nested((function, obj, batched, batch_size, types, None, True, None)) for obj in hf_tqdm(iterable, disable=disable_tqdm, desc=desc) ] + if batched: + mapped = [mapped_item for mapped_batch in mapped for mapped_item in mapped_batch] else: with warnings.catch_warnings(): warnings.filterwarnings( @@ -485,7 +518,15 @@ def map_nested( message=".* is experimental and might be subject to breaking changes in the future\\.$", category=UserWarning, ) - mapped = parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, _single_map_nested) + if batched: + if batch_size is None or batch_size <= 0: + batch_size = len(iterable) // num_proc + int(len(iterable) % num_proc > 0) + iterable = list(iter_batched(iterable, batch_size)) + mapped = parallel_map( + function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, _single_map_nested + ) + if batched: + mapped = [mapped_item for mapped_batch in mapped for mapped_item in mapped_batch] if isinstance(data_struct, dict): return dict(zip(data_struct.keys(), mapped)) @@ -672,3 +713,19 @@ def iflatmap_unordered( if not pool_changed: # we get the result in case there's an error to raise [async_result.get(timeout=0.05) for async_result in async_results] + + +T = TypeVar("T") + + +def iter_batched(iterable: Iterable[T], n: int) -> Iterable[List[T]]: + if n < 1: + raise ValueError(f"Invalid batch size {n}") + batch = [] + for item in iterable: + batch.append(item) + if len(batch) == n: + yield batch + batch = [] + if batch: + yield batch diff --git a/tests/test_load.py b/tests/test_load.py index 6b85984d25b..c59ce7d5e6c 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -842,10 +842,10 @@ def test_HubDatasetModuleFactoryWithParquetExport(self): ) assert module_factory_result.builder_configs_parameters.builder_configs[0].data_files == { "train": [ - "hf://datasets/hf-internal-testing/dataset_with_script@8f965694d611974ef8661618ada1b5aeb1072915/default/train/0000.parquet" + "hf://datasets/hf-internal-testing/dataset_with_script@0c97cd1168f31e683059ddcf0703e3f45d9007c4/default/train/0000.parquet" ], "validation": [ - "hf://datasets/hf-internal-testing/dataset_with_script@8f965694d611974ef8661618ada1b5aeb1072915/default/validation/0000.parquet" + "hf://datasets/hf-internal-testing/dataset_with_script@0c97cd1168f31e683059ddcf0703e3f45d9007c4/default/validation/0000.parquet" ], }