Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from .splits import Split, SplitDict, SplitGenerator
from .utils import logging
from .utils.download_manager import DownloadManager, GenerateMode
from .utils.file_utils import DownloadConfig, is_remote_url, request_etag
from .utils.file_utils import DownloadConfig, is_remote_url, request_etags
from .utils.filelock import FileLock
from .utils.info_utils import get_size_checksum_dict, verify_checksums, verify_splits

Expand Down Expand Up @@ -152,12 +152,16 @@ def create_config_id(
}
else:
raise ValueError("Please provide a valid `data_files` in `DatasetBuilder`")
remote_urls = [
data_file for key in data_files for data_file in data_files[key] if is_remote_url(data_file)
]
etags = dict(zip(remote_urls, request_etags(remote_urls, tqdm_kwargs={"desc": "Check remote data files"})))
for key in sorted(data_files.keys()):
m.update(key)
for data_file in data_files[key]:
if is_remote_url(data_file):
m.update(data_file)
m.update(str(request_etag(data_file, use_auth_token=use_auth_token)))
m.update(etags[data_file])
else:
m.update(os.path.abspath(data_file))
m.update(str(os.path.getmtime(data_file)))
Expand Down
29 changes: 25 additions & 4 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
from functools import partial
from hashlib import sha256
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union
from urllib.parse import urlparse

import numpy as np
import posixpath
import requests
from tqdm.contrib.concurrent import thread_map

from .. import __version__, config, utils
from . import logging
from .extract import ExtractManager
from .filelock import FileLock
from .tqdm_utils import tqdm


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -383,7 +385,7 @@ def _request_with_retry(
try:
response = requests.request(method=method.upper(), url=url, timeout=timeout, **params)
success = True
except requests.exceptions.ConnectTimeout as err:
except (requests.exceptions.ConnectTimeout, requests.exceptions.ConnectionError) as err:
if tries > max_retries:
raise err
else:
Expand Down Expand Up @@ -465,13 +467,32 @@ def http_head(
return response


def request_etag(url: str, use_auth_token: Optional[Union[str, bool]] = None):
def request_etag(url: str, use_auth_token: Optional[Union[str, bool]] = None) -> Optional[str]:
headers = get_authentication_headers_for_url(url, use_auth_token=use_auth_token)
response = http_head(url, headers=headers)
response = http_head(url, headers=headers, max_retries=3)
response.raise_for_status()
etag = response.headers.get("ETag") if response.ok else None
return etag


def request_etags(
urls: List[str],
use_auth_token: Optional[Union[str, bool]] = None,
max_workers=64,
tqdm_kwargs: Optional[dict] = None,
) -> List[Optional[str]]:
tqdm_kwargs = tqdm_kwargs if tqdm_kwargs is not None else {}
tqdm_kwargs["desc"] = tqdm_kwargs.get("desc", "Get ETags")
tqdm_kwargs["disable"] = tqdm_kwargs.get("disable", len(urls) <= 16 or logging.get_verbosity() == logging.NOTSET)
return thread_map(
partial(request_etag, use_auth_token=use_auth_token),
urls,
max_workers=max_workers,
tqdm_class=tqdm,
**tqdm_kwargs,
)


def get_from_cache(
url,
cache_dir=None,
Expand Down
23 changes: 18 additions & 5 deletions src/datasets/utils/tqdm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,24 @@ def __exit__(self, type_, value, traceback):
_active = True


def tqdm(*args, **kwargs):
if _active:
return tqdm_lib.tqdm(*args, **kwargs)
else:
return EmptyTqdm(*args, **kwargs)
class _tqdm_cls:
def __call__(self, *args, **kwargs):
if _active:
return tqdm_lib.tqdm(*args, **kwargs)
else:
return EmptyTqdm(*args, **kwargs)

def set_lock(self, *args, **kwargs):
self._lock = None
if _active:
return tqdm_lib.tqdm.set_lock(*args, **kwargs)

def get_lock(self):
if _active:
return tqdm_lib.tqdm.get_lock()


tqdm = _tqdm_cls()


def async_tqdm(*args, **kwargs):
Expand Down