diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 7f10ece15e1..7f038fb6164 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -42,7 +42,7 @@ from .splits import Split, SplitDict, SplitGenerator from .utils import logging from .utils.download_manager import DownloadManager, GenerateMode -from .utils.file_utils import DownloadConfig, is_remote_url +from .utils.file_utils import DownloadConfig, is_remote_url, request_etag from .utils.filelock import FileLock from .utils.info_utils import get_size_checksum_dict, verify_checksums, verify_splits @@ -95,7 +95,12 @@ def __eq__(self, o): return False return all((k, getattr(self, k)) == (k, getattr(o, k)) for k in self.__dict__.keys()) - def create_config_id(self, config_kwargs: dict, custom_features: Optional[Features] = None) -> str: + def create_config_id( + self, + config_kwargs: dict, + custom_features: Optional[Features] = None, + use_auth_token: Optional[Union[bool, str]] = None, + ) -> str: """ The config id is used to build the cache directory. By default it is equal to the config name. @@ -152,6 +157,7 @@ def create_config_id(self, config_kwargs: dict, custom_features: Optional[Featur for data_file in data_files[key]: if is_remote_url(data_file): m.update(data_file) + m.update(str(request_etag(data_file, use_auth_token=use_auth_token))) else: m.update(os.path.abspath(data_file)) m.update(str(os.path.getmtime(data_file))) @@ -209,6 +215,7 @@ def __init__( hash: Optional[str] = None, base_path: Optional[str] = None, features: Optional[Features] = None, + use_auth_token: Optional[Union[bool, str]] = None, **config_kwargs, ): """Constructs a DatasetBuilder. @@ -226,6 +233,8 @@ def __init__( base_path: `str`, base path for relative paths that are used to download files. This can be a remote url. features: `Features`, optional features that will be used to read/write the dataset It can be used to changed the :obj:`datasets.Features` description of a dataset for example. + use_auth_token (:obj:`str` or :obj:`bool`, optional): Optional string or boolean to use as Bearer token + for remote files on the Datasets Hub. If True, will get token from ``"~/.huggingface"``. config_kwargs: will override the defaults kwargs in config """ @@ -233,6 +242,7 @@ def __init__( self.name: str = camelcase_to_snakecase(self.__class__.__name__) self.hash: Optional[str] = hash self.base_path = base_path + self.use_auth_token = use_auth_token # Prepare config: DatasetConfig contains name, version and description but can be extended by each dataset config_kwargs = {key: value for key, value in config_kwargs.items() if value is not None} @@ -355,7 +365,9 @@ def _create_builder_config(self, name=None, custom_features=None, **config_kwarg raise ValueError("BuilderConfig must have a name, got %s" % builder_config.name) # compute the config id that is going to be used for caching - config_id = builder_config.create_config_id(config_kwargs, custom_features=custom_features) + config_id = builder_config.create_config_id( + config_kwargs, custom_features=custom_features, use_auth_token=self.use_auth_token + ) is_custom = config_id not in self.builder_configs if is_custom: logger.warning("Using custom data configuration %s", config_id) diff --git a/src/datasets/load.py b/src/datasets/load.py index 46dece0e3ba..aaf01fc7109 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -701,6 +701,7 @@ def load_dataset_builder( hash=hash, base_path=base_path, features=features, + use_auth_token=use_auth_token, **config_kwargs, ) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 8b2dbdf5528..f6ab0392a36 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -464,6 +464,13 @@ def http_head( return response +def request_etag(url: str, use_auth_token: Optional[Union[str, bool]] = None): + headers = get_authentication_headers_for_url(url, use_auth_token=use_auth_token) + response = http_head(url, headers=headers) + etag = response.headers.get("ETag") if response.ok else None + return etag + + def get_from_cache( url, cache_dir=None,