Skip to content
1 change: 1 addition & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# Hub
HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
HUB_DATASETS_URL = HF_ENDPOINT + "/datasets/{repo_id}/resolve/{revision}/{path}"
HUB_DATASETS_HFFS_URL = "hf://datasets/{repo_id}@{revision}/{path}"
HUB_DEFAULT_VERSION = "main"

PY_VERSION = version.parse(platform.python_version())
Expand Down
8 changes: 8 additions & 0 deletions src/datasets/download/download_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,11 @@ def __post_init__(self, use_auth_token):

def copy(self) -> "DownloadConfig":
return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()})

def __setattr__(self, name, value):
if name == "token" and getattr(self, "storage_options", None) is not None:
if "hf" not in self.storage_options:
self.storage_options["hf"] = {"token": value, "endpoint": config.HF_ENDPOINT}
elif getattr(self.storage_options["hf"], "token", None) is None:
self.storage_options["hf"]["token"] = value
super().__setattr__(name, value)
5 changes: 4 additions & 1 deletion src/datasets/features/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,11 @@ def decode_example(
if file is None:
token_per_repo_id = token_per_repo_id or {}
source_url = path.split("::")[-1]
pattern = (
config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL
)
try:
repo_id = string_to_dict(source_url, config.HUB_DATASETS_URL)["repo_id"]
repo_id = string_to_dict(source_url, pattern)["repo_id"]
token = token_per_repo_id[repo_id]
except (ValueError, KeyError):
token = None
Expand Down
7 changes: 6 additions & 1 deletion src/datasets/features/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,13 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "PIL.Image.Imag
image = PIL.Image.open(path)
else:
source_url = path.split("::")[-1]
pattern = (
config.HUB_DATASETS_URL
if source_url.startswith(config.HF_ENDPOINT)
else config.HUB_DATASETS_HFFS_URL
)
try:
repo_id = string_to_dict(source_url, config.HUB_DATASETS_URL)["repo_id"]
repo_id = string_to_dict(source_url, pattern)["repo_id"]
token = token_per_repo_id.get(repo_id)
except ValueError:
token = None
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def create_builder_configs_from_metadata_configs(
base_path: Optional[str] = None,
default_builder_kwargs: Dict[str, Any] = None,
allowed_extensions: Optional[List[str]] = None,
download_config: Optional[DownloadConfig] = None,
) -> Tuple[List[BuilderConfig], str]:
builder_cls = import_main_class(module_path)
builder_config_cls = builder_cls.BUILDER_CONFIG_CLASS
Expand All @@ -560,6 +561,7 @@ def create_builder_configs_from_metadata_configs(
config_patterns,
base_path=config_base_path,
allowed_extensions=ALL_ALLOWED_EXTENSIONS,
download_config=download_config,
)
except EmptyDatasetError as e:
raise EmptyDatasetError(
Expand Down Expand Up @@ -1070,6 +1072,7 @@ def get_module(self) -> DatasetModule:
base_path=base_path,
supports_metadata=supports_metadata,
default_builder_kwargs=default_builder_kwargs,
download_config=self.download_config,
)
else:
builder_configs, default_config_name = None, None
Expand Down
6 changes: 1 addition & 5 deletions tests/fixtures/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,8 @@ def hf_api():


@pytest.fixture(scope="session")
def hf_token(hf_api: HfApi):
previous_token = HfFolder.get_token()
HfFolder.save_token(CI_HUB_USER_TOKEN)
def hf_token():
yield CI_HUB_USER_TOKEN
if previous_token is not None:
HfFolder.save_token(previous_token)


@pytest.fixture
Expand Down
5 changes: 5 additions & 0 deletions tests/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def test_get_dataset_config_info(path, config_name, expected_splits):
assert list(info.splits.keys()) == expected_splits


def test_get_dataset_config_info_private(hf_token, hf_private_dataset_repo_txt_data):
info = get_dataset_config_info(hf_private_dataset_repo_txt_data, config_name="default", token=hf_token)
assert list(info.splits.keys()) == ["train"]


@pytest.mark.parametrize(
"path, config_name, expected_exception",
[
Expand Down
19 changes: 11 additions & 8 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
PackagedDatasetModuleFactory,
infer_module_for_data_files_list,
infer_module_for_data_files_list_in_archives,
load_dataset_builder,
)
from datasets.packaged_modules.audiofolder.audiofolder import AudioFolder, AudioFolderConfig
from datasets.packaged_modules.imagefolder.imagefolder import ImageFolder, ImageFolderConfig
Expand Down Expand Up @@ -1223,13 +1224,19 @@ def assert_auth(method, url, *args, headers, **kwargs):

@pytest.mark.integration
def test_load_streaming_private_dataset(hf_token, hf_private_dataset_repo_txt_data):
ds = load_dataset(hf_private_dataset_repo_txt_data, streaming=True)
ds = load_dataset(hf_private_dataset_repo_txt_data, streaming=True, token=hf_token)
assert next(iter(ds)) is not None


@pytest.mark.integration
def test_load_dataset_builder_private_dataset(hf_token, hf_private_dataset_repo_txt_data):
builder = load_dataset_builder(hf_private_dataset_repo_txt_data, token=hf_token)
assert isinstance(builder, DatasetBuilder)


@pytest.mark.integration
def test_load_streaming_private_dataset_with_zipped_data(hf_token, hf_private_dataset_repo_zipped_txt_data):
ds = load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True)
ds = load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True, token=hf_token)
assert next(iter(ds)) is not None


Expand Down Expand Up @@ -1309,13 +1316,9 @@ def test_load_hub_dataset_without_script_with_metadata_config_in_parallel():

@require_pil
@pytest.mark.integration
@pytest.mark.parametrize("implicit_token", [True])
@pytest.mark.parametrize("streaming", [True])
def test_load_dataset_private_zipped_images(
hf_private_dataset_repo_zipped_img_data, hf_token, streaming, implicit_token
):
token = None if implicit_token else hf_token
ds = load_dataset(hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, token=token)
def test_load_dataset_private_zipped_images(hf_private_dataset_repo_zipped_img_data, hf_token, streaming):
ds = load_dataset(hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, token=hf_token)
assert isinstance(ds, IterableDataset if streaming else Dataset)
ds_items = list(ds)
assert len(ds_items) == 2
Expand Down
6 changes: 4 additions & 2 deletions tests/test_upstream_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@


@for_all_test_methods(xfail_if_500_502_http_error)
@pytest.mark.usefixtures("set_ci_hub_access_token", "ci_hfh_hf_hub_url")
@pytest.mark.usefixtures("ci_hub_config", "ci_hfh_hf_hub_url")
class TestPushToHub:
_api = HfApi(endpoint=CI_HUB_ENDPOINT)
_token = CI_HUB_USER_TOKEN

def test_push_dataset_dict_to_hub_no_token(self, temporary_repo):
def test_push_dataset_dict_to_hub_no_token(self, temporary_repo, set_ci_hub_access_token):
ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})

local_ds = DatasetDict({"train": ds})
Expand Down Expand Up @@ -778,6 +778,7 @@ def test_push_dataset_to_hub_with_config_no_metadata_configs(self, temporary_rep
path_in_repo="data/train-00000-of-00001.parquet",
repo_id=ds_name,
repo_type="dataset",
token=self._token,
)
ds_another_config.push_to_hub(ds_name, "another_config", token=self._token)
ds_builder = load_dataset_builder(ds_name, download_mode="force_redownload")
Expand Down Expand Up @@ -811,6 +812,7 @@ def test_push_dataset_dict_to_hub_with_config_no_metadata_configs(self, temporar
path_in_repo="data/random-00000-of-00001.parquet",
repo_id=ds_name,
repo_type="dataset",
token=self._token,
)
local_ds_another_config.push_to_hub(ds_name, "another_config", token=self._token)
ds_builder = load_dataset_builder(ds_name, download_mode="force_redownload")
Expand Down