Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
if: ${{ matrix.os == 'ubuntu-latest' }}
run: echo "installing pinned version of setuptools-scm to fix seqeval installation on 3.7" && pip install "setuptools-scm==6.4.2"
- name: Install uv
run: pip install --upgrade uv
run: pip install uv==0.1.29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would remove the pin to be consistent with huggingface_hub and diffusers:

Suggested change
run: pip install uv==0.1.29

(we don't use uv's advanced/experimental features, so a breaking change here is unlikely)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had pinned it because 0.1.30 had bugs - I'll see if 0.1.31 has fixed them

Copy link
Collaborator

@mariosasko mariosasko Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been fixed in 0.1.31 (issue in uv: astral-sh/uv#2941) :)

- name: Install dependencies
run: |
uv pip install --system "datasets[tests,metrics-tests] @ ."
Expand Down Expand Up @@ -89,7 +89,7 @@ jobs:
- name: Upgrade pip
run: python -m pip install --upgrade pip
- name: Install uv
run: pip install --upgrade uv
run: pip install uv==0.1.29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here:

Suggested change
run: pip install uv==0.1.29

- name: Install dependencies
run: uv pip install --system "datasets[tests] @ ."
- name: Test with pytest
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,8 @@ def _get_single_origin_metadata(

def _get_origin_metadata(
data_files: List[str],
max_workers=64,
download_config: Optional[DownloadConfig] = None,
max_workers: int = 16,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be a config variable (which we would also use in DownloadManager)

) -> Tuple[str]:
return thread_map(
partial(_get_single_origin_metadata, download_config=download_config),
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/download/download_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class DownloadConfig:
Key/value pairs to be passed on to the dataset file-system backend, if any.
download_desc (`str`, *optional*):
A description to be displayed alongside with the progress bar while downloading the files.
disable_tqdm (`bool`, defaults to `False`):
Whether to disable the individual files download progress bar
"""

cache_dir: Optional[Union[str, Path]] = None
Expand All @@ -78,6 +80,7 @@ class DownloadConfig:
ignore_url_params: bool = False
storage_options: Dict[str, Any] = field(default_factory=dict)
download_desc: Optional[str] = None
disable_tqdm: bool = False

def __post_init__(self, use_auth_token):
if use_auth_token != "deprecated":
Expand Down
48 changes: 46 additions & 2 deletions src/datasets/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import enum
import io
import multiprocessing
import os
import posixpath
import tarfile
Expand All @@ -27,6 +28,10 @@
from itertools import chain
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union

import fsspec
from fsspec.core import url_to_fs
from tqdm.contrib.concurrent import thread_map

from .. import config
from ..utils import tqdm as hf_tqdm
from ..utils.deprecation_utils import DeprecatedEnum, deprecated
Expand All @@ -39,7 +44,7 @@
url_or_path_join,
)
from ..utils.info_utils import get_size_checksum_dict
from ..utils.logging import get_logger
from ..utils.logging import get_logger, tqdm
from ..utils.py_utils import NestedDataStructure, map_nested, size_str
from ..utils.track import TrackedIterable, tracked_str
from .download_config import DownloadConfig
Expand Down Expand Up @@ -427,7 +432,7 @@ def download(self, url_or_urls):
if download_config.download_desc is None:
download_config.download_desc = "Downloading data"

download_func = partial(self._download, download_config=download_config)
download_func = partial(self._download_batched, download_config=download_config)

start_time = datetime.now()
with stack_multiprocessing_download_progress_bars():
Expand All @@ -437,6 +442,8 @@ def download(self, url_or_urls):
map_tuple=True,
num_proc=download_config.num_proc,
desc="Downloading data files",
batched=True,
batch_size=-1,
)
duration = datetime.now() - start_time
logger.info(f"Downloading took {duration.total_seconds() // 60} min")
Expand All @@ -451,6 +458,43 @@ def download(self, url_or_urls):

return downloaded_path_or_paths.data

def _download_batched(
self,
url_or_filenames: List[str],
download_config: DownloadConfig,
) -> List[str]:
if len(url_or_filenames) >= 16:
download_config = download_config.copy()
download_config.disable_tqdm = True
download_func = partial(self._download, download_config=download_config)

fs: fsspec.AbstractFileSystem
fs, path = url_to_fs(url_or_filenames[0], **download_config.storage_options)
size = 0
try:
size = fs.info(path).get("size", 0)
except Exception:
pass
max_workers = 16 if size < (20 << 20) else 1 # enable multithreading if files are small

return thread_map(
download_func,
url_or_filenames,
desc=download_config.download_desc or "Downloading",
unit="files",
position=multiprocessing.current_process()._identity[-1] # contains the ranks of subprocesses
if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
and multiprocessing.current_process()._identity
else None,
max_workers=max_workers,
tqdm_class=tqdm,
)
else:
return [
self._download(url_or_filename, download_config=download_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would rename this method to _download_single

for url_or_filename in url_or_filenames
]

def _download(self, url_or_filename: str, download_config: DownloadConfig) -> str:
url_or_filename = str(url_or_filename)
if is_relative_path(url_or_filename):
Expand Down
21 changes: 14 additions & 7 deletions src/datasets/parallel/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ParallelBackendConfig:


@experimental
def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
def parallel_map(function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func):
"""
**Experimental.** Apply a function to iterable elements in parallel, where the implementation uses either
multiprocessing.Pool or joblib for parallelization.
Expand All @@ -32,21 +32,25 @@ def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single
"""
if ParallelBackendConfig.backend_name is None:
return _map_with_multiprocessing_pool(
function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
)

return _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func)
return _map_with_joblib(
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
)


def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
def _map_with_multiprocessing_pool(
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
):
num_proc = num_proc if num_proc <= len(iterable) else len(iterable)
split_kwds = [] # We organize the splits ourselve (contiguous splits)
for index in range(num_proc):
div = len(iterable) // num_proc
mod = len(iterable) % num_proc
start = div * index + min(index, mod)
end = start + div + (1 if index < mod else 0)
split_kwds.append((function, iterable[start:end], types, index, disable_tqdm, desc))
split_kwds.append((function, iterable[start:end], batched, batch_size, types, index, disable_tqdm, desc))

if len(iterable) != sum(len(i[1]) for i in split_kwds):
raise ValueError(
Expand All @@ -70,14 +74,17 @@ def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_
return mapped


def _map_with_joblib(function, iterable, num_proc, types, disable_tqdm, desc, single_map_nested_func):
def _map_with_joblib(
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func
):
# progress bar is not yet supported for _map_with_joblib, because tqdm couldn't accurately be applied to joblib,
# and it requires monkey-patching joblib internal classes which is subject to change
import joblib

with joblib.parallel_backend(ParallelBackendConfig.backend_name, n_jobs=num_proc):
return joblib.Parallel()(
joblib.delayed(single_map_nested_func)((function, obj, types, None, True, None)) for obj in iterable
joblib.delayed(single_map_nested_func)((function, obj, batched, batch_size, types, None, True, None))
for obj in iterable
)


Expand Down
22 changes: 19 additions & 3 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def cached_path(
ignore_url_params=download_config.ignore_url_params,
storage_options=download_config.storage_options,
download_desc=download_config.download_desc,
disable_tqdm=download_config.disable_tqdm,
)
elif os.path.exists(url_or_filename):
# File, and it exists.
Expand Down Expand Up @@ -335,7 +336,7 @@ def __init__(self, tqdm_kwargs=None, *args, **kwargs):
super().__init__(tqdm_kwargs, *args, **kwargs)


def fsspec_get(url, temp_file, storage_options=None, desc=None):
def fsspec_get(url, temp_file, storage_options=None, desc=None, disable_tqdm=False):
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
fs, path = url_to_fs(url, **(storage_options or {}))
callback = TqdmCallback(
Expand All @@ -347,6 +348,7 @@ def fsspec_get(url, temp_file, storage_options=None, desc=None):
if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
and multiprocessing.current_process()._identity
else None,
"disable": disable_tqdm,
}
)
fs.get_file(path, temp_file.name, callback=callback)
Expand All @@ -373,7 +375,16 @@ def ftp_get(url, temp_file, timeout=10.0):


def http_get(
url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=100.0, max_retries=0, desc=None
url,
temp_file,
proxies=None,
resume_size=0,
headers=None,
cookies=None,
timeout=100.0,
max_retries=0,
desc=None,
disable_tqdm=False,
) -> Optional[requests.Response]:
headers = dict(headers) if headers is not None else {}
headers["user-agent"] = get_datasets_user_agent(user_agent=headers.get("user-agent"))
Expand Down Expand Up @@ -405,6 +416,7 @@ def http_get(
if os.environ.get("HF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS") == "1"
and multiprocessing.current_process()._identity
else None,
disable=disable_tqdm,
) as progress:
for chunk in response.iter_content(chunk_size=1024):
progress.update(len(chunk))
Expand Down Expand Up @@ -464,6 +476,7 @@ def get_from_cache(
ignore_url_params=False,
storage_options=None,
download_desc=None,
disable_tqdm=False,
) -> str:
"""
Given a URL, look for the corresponding file in the local cache.
Expand Down Expand Up @@ -629,7 +642,9 @@ def temp_file_manager(mode="w+b"):
if scheme == "ftp":
ftp_get(url, temp_file)
elif scheme not in ("http", "https"):
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
fsspec_get(
url, temp_file, storage_options=storage_options, desc=download_desc, disable_tqdm=disable_tqdm
)
else:
http_get(
url,
Expand All @@ -640,6 +655,7 @@ def temp_file_manager(mode="w+b"):
cookies=cookies,
max_retries=max_retries,
desc=download_desc,
disable_tqdm=disable_tqdm,
)

logger.info(f"storing {url} in cache at {cache_path}")
Expand Down
Loading