Skip to content

Commit 3c19812

Browse files
Align signature of list_repo_files with latest hfh (#5063)
* Rename token to use_auth_token in list_repo_files * Remove warning * Fix warning * Update type hint
1 parent e0dd33c commit 3c19812

File tree

4 files changed

+14
-12
lines changed

4 files changed

+14
-12
lines changed

src/datasets/arrow_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4394,7 +4394,7 @@ def shards_with_embedded_external_files(shards):
43944394

43954395
shards = shards_with_embedded_external_files(shards)
43964396

4397-
files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, token=token)
4397+
files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, use_auth_token=token)
43984398
data_files = [file for file in files if file.startswith("data/")]
43994399

44004400
def path_in_repo(_index, shard):

src/datasets/dataset_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1375,7 +1375,7 @@ def push_to_hub(
13751375
info_to_dump.size_in_bytes = total_uploaded_size + total_dataset_nbytes
13761376

13771377
api = HfApi(endpoint=config.HF_ENDPOINT)
1378-
repo_files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, token=token)
1378+
repo_files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, use_auth_token=token)
13791379

13801380
# push to the deprecated dataset_infos.json
13811381
if config.DATASETDICT_INFOS_FILENAME in repo_files:

src/datasets/utils/_hf_hub_fixes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,18 @@ def list_repo_files(
161161
repo_id: str,
162162
revision: Optional[str] = None,
163163
repo_type: Optional[str] = None,
164-
token: Optional[str] = None,
164+
use_auth_token: Optional[Union[bool, str]] = None,
165165
timeout: Optional[float] = None,
166166
) -> List[str]:
167167
"""
168168
The huggingface_hub.HfApi.list_repo_files parameters changed in 0.10.0 and some of them were deprecated.
169169
This function checks the huggingface_hub version to call the right parameters.
170170
"""
171171
if version.parse(huggingface_hub.__version__) < version.parse("0.10.0"):
172-
return hf_api.list_repo_files(repo_id, revision=revision, repo_type=repo_type, token=token, timeout=timeout)
172+
return hf_api.list_repo_files(
173+
repo_id, revision=revision, repo_type=repo_type, token=use_auth_token, timeout=timeout
174+
)
173175
else: # the `token` parameter is deprecated in huggingface_hub>=0.10.0
174176
return hf_api.list_repo_files(
175-
repo_id, revision=revision, repo_type=repo_type, use_auth_token=token, timeout=timeout
177+
repo_id, revision=revision, repo_type=repo_type, use_auth_token=use_auth_token, timeout=timeout
176178
)

tests/test_upstream_hub.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_push_dataset_dict_to_hub_private(self, temporary_repo):
9797
assert local_ds["train"].features == hub_ds["train"].features
9898

9999
# Ensure that there is a single file on the repository that has the correct name
100-
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
100+
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
101101
assert all(
102102
fnmatch.fnmatch(file, expected_file)
103103
for file, expected_file in zip(
@@ -119,7 +119,7 @@ def test_push_dataset_dict_to_hub(self, temporary_repo):
119119
assert local_ds["train"].features == hub_ds["train"].features
120120

121121
# Ensure that there is a single file on the repository that has the correct name
122-
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
122+
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
123123
assert all(
124124
fnmatch.fnmatch(file, expected_file)
125125
for file, expected_file in zip(
@@ -142,7 +142,7 @@ def test_push_dataset_dict_to_hub_multiple_files(self, temporary_repo):
142142
assert local_ds["train"].features == hub_ds["train"].features
143143

144144
# Ensure that there are two files on the repository that have the correct name
145-
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
145+
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
146146
assert all(
147147
fnmatch.fnmatch(file, expected_file)
148148
for file, expected_file in zip(
@@ -170,7 +170,7 @@ def test_push_dataset_dict_to_hub_multiple_files_with_max_shard_size(self, tempo
170170
assert local_ds["train"].features == hub_ds["train"].features
171171

172172
# Ensure that there are two files on the repository that have the correct name
173-
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
173+
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
174174
assert all(
175175
fnmatch.fnmatch(file, expected_file)
176176
for file, expected_file in zip(
@@ -214,7 +214,7 @@ def test_push_dataset_dict_to_hub_overwrite_files(self, temporary_repo):
214214
local_ds.push_to_hub(ds_name, token=self._token, max_shard_size=500 << 5)
215215

216216
# Ensure that there are two files on the repository that have the correct name
217-
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
217+
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
218218

219219
assert all(
220220
fnmatch.fnmatch(file, expected_file)
@@ -261,7 +261,7 @@ def test_push_dataset_dict_to_hub_overwrite_files(self, temporary_repo):
261261
local_ds.push_to_hub(ds_name, token=self._token)
262262

263263
# Ensure that there are two files on the repository that have the correct name
264-
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", token=self._token))
264+
files = sorted(list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token))
265265

266266
assert all(
267267
fnmatch.fnmatch(file, expected_file)
@@ -418,7 +418,7 @@ def test_push_dataset_to_hub_skip_identical_files(self, temporary_repo):
418418
mock_hf_api.reset_mock()
419419

420420
# Remove a data file
421-
files = self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)
421+
files = list_repo_files(self._api, ds_name, repo_type="dataset", use_auth_token=self._token)
422422
data_files = [f for f in files if f.startswith("data/")]
423423
assert len(data_files) > 1
424424
self._api.delete_file(data_files[0], repo_id=ds_name, repo_type="dataset", token=self._token)

0 commit comments

Comments
 (0)