Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4391,7 +4391,7 @@ def shards_with_embedded_external_files(shards):

shards = shards_with_embedded_external_files(shards)

files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, token=token)
files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, use_auth_token=token)
data_files = [file for file in files if file.startswith("data/")]

def path_in_repo(_index, shard):
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ def push_to_hub(
info_to_dump.size_in_bytes = total_uploaded_size + total_dataset_nbytes

api = HfApi(endpoint=config.HF_ENDPOINT)
repo_files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, token=token)
repo_files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, use_auth_token=token)

# push to the deprecated dataset_infos.json
if config.DATASETDICT_INFOS_FILENAME in repo_files:
Expand Down
8 changes: 5 additions & 3 deletions src/datasets/utils/_hf_hub_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,18 @@ def list_repo_files(
repo_id: str,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
token: Optional[str] = None,
use_auth_token: Optional[Union[bool, str]] = None,
timeout: Optional[float] = None,
) -> List[str]:
"""
The huggingface_hub.HfApi.list_repo_files parameters changed in 0.10.0 and some of them were deprecated.
This function checks the huggingface_hub version to call the right parameters.
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.10.0"):
return hf_api.list_repo_files(repo_id, revision=revision, repo_type=repo_type, token=token, timeout=timeout)
return hf_api.list_repo_files(
repo_id, revision=revision, repo_type=repo_type, token=use_auth_token, timeout=timeout
)
else: # the `token` parameter is deprecated in huggingface_hub>=0.10.0
return hf_api.list_repo_files(
repo_id, revision=revision, repo_type=repo_type, use_auth_token=token, timeout=timeout
repo_id, revision=revision, repo_type=repo_type, use_auth_token=use_auth_token, timeout=timeout
)
14 changes: 7 additions & 7 deletions tests/test_upstream_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_push_dataset_dict_to_hub_private(self, temporary_repo):
assert local_ds["train"].features == hub_ds["train"].features

# Ensure that there is a single file on the repository that has the correct name
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
assert all(
fnmatch.fnmatch(file, expected_file)
for file, expected_file in zip(
Expand All @@ -119,7 +119,7 @@ def test_push_dataset_dict_to_hub(self, temporary_repo):
assert local_ds["train"].features == hub_ds["train"].features

# Ensure that there is a single file on the repository that has the correct name
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
assert all(
fnmatch.fnmatch(file, expected_file)
for file, expected_file in zip(
Expand All @@ -142,7 +142,7 @@ def test_push_dataset_dict_to_hub_multiple_files(self, temporary_repo):
assert local_ds["train"].features == hub_ds["train"].features

# Ensure that there are two files on the repository that have the correct name
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
assert all(
fnmatch.fnmatch(file, expected_file)
for file, expected_file in zip(
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_push_dataset_dict_to_hub_multiple_files_with_max_shard_size(self, tempo
assert local_ds["train"].features == hub_ds["train"].features

# Ensure that there are two files on the repository that have the correct name
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
assert all(
fnmatch.fnmatch(file, expected_file)
for file, expected_file in zip(
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_push_dataset_dict_to_hub_overwrite_files(self, temporary_repo):
local_ds.push_to_hub(ds_name, token=self._token, max_shard_size=500 << 5)

# Ensure that there are two files on the repository that have the correct name
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))

assert all(
fnmatch.fnmatch(file, expected_file)
Expand Down Expand Up @@ -261,7 +261,7 @@ def test_push_dataset_dict_to_hub_overwrite_files(self, temporary_repo):
local_ds.push_to_hub(ds_name, token=self._token)

# Ensure that there are two files on the repository that have the correct name
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))

assert all(
fnmatch.fnmatch(file, expected_file)
Expand Down Expand Up @@ -418,7 +418,7 @@ def test_push_dataset_to_hub_skip_identical_files(self, temporary_repo):
mock_hf_api.reset_mock()

# Remove a data file
files = self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)
files = list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token)
data_files = [f for f in files if f.startswith("data/")]
assert len(data_files) > 1
self._api.delete_file(data_files[0], repo_id=ds_name, repo_type="dataset", token=self._token)
Expand Down