Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from .splits import Split, SplitDict, SplitGenerator
from .utils import logging
from .utils.download_manager import DownloadManager, GenerateMode
from .utils.file_utils import DownloadConfig, is_remote_url
from .utils.file_utils import DownloadConfig, is_remote_url, request_etag
from .utils.filelock import FileLock
from .utils.info_utils import get_size_checksum_dict, verify_checksums, verify_splits

Expand Down Expand Up @@ -95,7 +95,12 @@ def __eq__(self, o):
return False
return all((k, getattr(self, k)) == (k, getattr(o, k)) for k in self.__dict__.keys())

def create_config_id(self, config_kwargs: dict, custom_features: Optional[Features] = None) -> str:
def create_config_id(
self,
config_kwargs: dict,
custom_features: Optional[Features] = None,
use_auth_token: Optional[Union[bool, str]] = None,
) -> str:
"""
The config id is used to build the cache directory.
By default it is equal to the config name.
Expand Down Expand Up @@ -152,6 +157,7 @@ def create_config_id(self, config_kwargs: dict, custom_features: Optional[Featur
for data_file in data_files[key]:
if is_remote_url(data_file):
m.update(data_file)
m.update(str(request_etag(data_file, use_auth_token=use_auth_token)))
else:
m.update(os.path.abspath(data_file))
m.update(str(os.path.getmtime(data_file)))
Expand Down Expand Up @@ -209,6 +215,7 @@ def __init__(
hash: Optional[str] = None,
base_path: Optional[str] = None,
features: Optional[Features] = None,
use_auth_token: Optional[Union[bool, str]] = None,
**config_kwargs,
):
"""Constructs a DatasetBuilder.
Expand All @@ -226,13 +233,16 @@ def __init__(
base_path: `str`, base path for relative paths that are used to download files. This can be a remote url.
features: `Features`, optional features that will be used to read/write the dataset
It can be used to changed the :obj:`datasets.Features` description of a dataset for example.
use_auth_token (:obj:`str` or :obj:`bool`, optional): Optional string or boolean to use as Bearer token
for remote files on the Datasets Hub. If True, will get token from ``"~/.huggingface"``.
config_kwargs: will override the defaults kwargs in config

"""
# DatasetBuilder name
self.name: str = camelcase_to_snakecase(self.__class__.__name__)
self.hash: Optional[str] = hash
self.base_path = base_path
self.use_auth_token = use_auth_token

# Prepare config: DatasetConfig contains name, version and description but can be extended by each dataset
config_kwargs = {key: value for key, value in config_kwargs.items() if value is not None}
Expand Down Expand Up @@ -355,7 +365,9 @@ def _create_builder_config(self, name=None, custom_features=None, **config_kwarg
raise ValueError("BuilderConfig must have a name, got %s" % builder_config.name)

# compute the config id that is going to be used for caching
config_id = builder_config.create_config_id(config_kwargs, custom_features=custom_features)
config_id = builder_config.create_config_id(
config_kwargs, custom_features=custom_features, use_auth_token=self.use_auth_token
)
is_custom = config_id not in self.builder_configs
if is_custom:
logger.warning("Using custom data configuration %s", config_id)
Expand Down
1 change: 1 addition & 0 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ def load_dataset_builder(
hash=hash,
base_path=base_path,
features=features,
use_auth_token=use_auth_token,
**config_kwargs,
)

Expand Down
7 changes: 7 additions & 0 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,13 @@ def http_head(
return response


def request_etag(url: str, use_auth_token: Optional[Union[str, bool]] = None):
headers = get_authentication_headers_for_url(url, use_auth_token=use_auth_token)
response = http_head(url, headers=headers)
etag = response.headers.get("ETag") if response.ok else None
return etag


def get_from_cache(
url,
cache_dir=None,
Expand Down