Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4391,7 +4391,7 @@ def shards_with_embedded_external_files(shards):

shards = shards_with_embedded_external_files(shards)

files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, token=token)
files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, use_auth_token=token)
data_files = [file for file in files if file.startswith("data/")]

def path_in_repo(_index, shard):
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ def push_to_hub(
info_to_dump.size_in_bytes = total_uploaded_size + total_dataset_nbytes

api = HfApi(endpoint=config.HF_ENDPOINT)
repo_files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, token=token)
repo_files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, use_auth_token=token)

# push to the deprecated dataset_infos.json
if config.DATASETDICT_INFOS_FILENAME in repo_files:
Expand Down
8 changes: 5 additions & 3 deletions src/datasets/utils/_hf_hub_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,18 @@ def list_repo_files(
repo_id: str,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
token: Optional[str] = None,
use_auth_token: Optional[str] = None,
timeout: Optional[float] = None,
) -> List[str]:
"""
The huggingface_hub.HfApi.list_repo_files parameters changed in 0.10.0 and some of them were deprecated.
This function checks the huggingface_hub version to call the right parameters.
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.10.0"):
return hf_api.list_repo_files(repo_id, revision=revision, repo_type=repo_type, token=token, timeout=timeout)
return hf_api.list_repo_files(
repo_id, revision=revision, repo_type=repo_type, token=use_auth_token, timeout=timeout
)
else: # the `token` parameter is deprecated in huggingface_hub>=0.10.0
return hf_api.list_repo_files(
repo_id, revision=revision, repo_type=repo_type, use_auth_token=token, timeout=timeout
repo_id, revision=revision, repo_type=repo_type, use_auth_token=use_auth_token, timeout=timeout
)
14 changes: 7 additions & 7 deletions tests/test_upstream_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_push_dataset_dict_to_hub_private(self, temporary_repo):
assert local_ds["train"].features == hub_ds["train"].features

# Ensure that there is a single file on the repository that has the correct name
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
assert all(
fnmatch.fnmatch(file, expected_file)
for file, expected_file in zip(
Expand All @@ -119,7 +119,7 @@ def test_push_dataset_dict_to_hub(self, temporary_repo):
assert local_ds["train"].features == hub_ds["train"].features

# Ensure that there is a single file on the repository that has the correct name
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
assert all(
fnmatch.fnmatch(file, expected_file)
for file, expected_file in zip(
Expand All @@ -142,7 +142,7 @@ def test_push_dataset_dict_to_hub_multiple_files(self, temporary_repo):
assert local_ds["train"].features == hub_ds["train"].features

# Ensure that there are two files on the repository that have the correct name
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
assert all(
fnmatch.fnmatch(file, expected_file)
for file, expected_file in zip(
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_push_dataset_dict_to_hub_multiple_files_with_max_shard_size(self, tempo
assert local_ds["train"].features == hub_ds["train"].features

# Ensure that there are two files on the repository that have the correct name
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
assert all(
fnmatch.fnmatch(file, expected_file)
for file, expected_file in zip(
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_push_dataset_dict_to_hub_overwrite_files(self, temporary_repo):
local_ds.push_to_hub(ds_name, token=self._token, max_shard_size=500 << 5)

# Ensure that there are two files on the repository that have the correct name
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))

assert all(
fnmatch.fnmatch(file, expected_file)
Expand Down Expand Up @@ -261,7 +261,7 @@ def test_push_dataset_dict_to_hub_overwrite_files(self, temporary_repo):
local_ds.push_to_hub(ds_name, token=self._token)

# Ensure that there are two files on the repository that have the correct name
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))

assert all(
fnmatch.fnmatch(file, expected_file)
Expand Down Expand Up @@ -418,7 +418,7 @@ def test_push_dataset_to_hub_skip_identical_files(self, temporary_repo):
mock_hf_api.reset_mock()

# Remove a data file
files = self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)
files = list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token)
data_files = [f for f in files if f.startswith("data/")]
assert len(data_files) > 1
self._api.delete_file(data_files[0], repo_id=ds_name, repo_type="dataset", token=self._token)
Expand Down