Skip to content

Commit 7c8106d

Browse files
authored
Skip identical files in push_to_hub instead of overwriting (#4402)
* Resume download instead of pushing identical files * Update tests * Update glob * Add test * Use fnmatch in tests * Add warning when resuming upload
1 parent 7103969 commit 7c8106d

File tree

3 files changed

+158
-74
lines changed

3 files changed

+158
-74
lines changed

src/datasets/arrow_dataset.py

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import contextlib
1919
import copy
20+
import itertools
2021
import json
2122
import os
2223
import shutil
@@ -3949,60 +3950,69 @@ def shards_with_embedded_external_files(shards):
39493950
shards = shards_with_embedded_external_files(shards)
39503951

39513952
files = api.list_repo_files(repo_id, repo_type="dataset", revision=branch, token=token)
3952-
files = [file for file in files if file.startswith("data/")]
3953+
data_files = [file for file in files if file.startswith("data/")]
39533954

3954-
def path_in_repo(_index):
3955-
return f"data/{split}-{_index:05d}-of-{num_shards:05d}.parquet"
3955+
def path_in_repo(_index, shard):
3956+
return f"data/{split}-{_index:05d}-of-{num_shards:05d}-{shard._fingerprint}.parquet"
39563957

3957-
# Only delete file shards that don't currently exist. Others will be overwritten if the content is different
3958-
# or will be left intact is the content is identical.
3959-
def should_delete_file(file_name):
3960-
file_to_overwrite = file_name in [path_in_repo(i) for i in range(num_shards)]
3961-
file_from_same_split = file_name.startswith(f"data/{split}-")
3958+
shards_iter = iter(shards)
3959+
first_shard = next(shards_iter)
3960+
first_shard_path_in_repo = path_in_repo(0, first_shard)
3961+
if first_shard_path_in_repo in data_files:
3962+
logger.warning("Resuming upload of dataset shards")
39623963

3963-
return file_from_same_split and not file_to_overwrite
3964+
uploaded_size = 0
3965+
shards_path_in_repo = []
3966+
for index, shard in logging.tqdm(
3967+
enumerate(itertools.chain([first_shard], shards_iter)),
3968+
desc="Pushing dataset shards to the dataset hub",
3969+
total=num_shards,
3970+
disable=not logging.is_progress_bar_enabled(),
3971+
):
3972+
shard_path_in_repo = path_in_repo(index, shard)
3973+
# Upload a shard only if it doesn't already exist in the repository
3974+
if shard_path_in_repo not in data_files:
3975+
buffer = BytesIO()
3976+
shard.to_parquet(buffer)
3977+
uploaded_size += buffer.tell()
3978+
_retry(
3979+
api.upload_file,
3980+
func_kwargs=dict(
3981+
path_or_fileobj=buffer.getvalue(),
3982+
path_in_repo=shard_path_in_repo,
3983+
repo_id=repo_id,
3984+
token=token,
3985+
repo_type="dataset",
3986+
revision=branch,
3987+
identical_ok=False,
3988+
),
3989+
exceptions=HTTPError,
3990+
status_codes=[504],
3991+
base_wait_time=2.0,
3992+
max_retries=5,
3993+
max_wait_time=20.0,
3994+
)
3995+
shards_path_in_repo.append(shard_path_in_repo)
39643996

3965-
file_shards_to_delete = [file for file in files if should_delete_file(file)]
3997+
# Cleanup to remove unused files
3998+
data_files_to_delete = [
3999+
data_file
4000+
for data_file in data_files
4001+
if data_file.startswith(f"data/{split}-") and data_file not in shards_path_in_repo
4002+
]
39664003

39674004
def delete_file(file):
39684005
api.delete_file(file, repo_id=repo_id, token=token, repo_type="dataset", revision=branch)
39694006

3970-
if len(file_shards_to_delete):
3971-
for file in logging.tqdm(
3972-
file_shards_to_delete,
4007+
if len(data_files_to_delete):
4008+
for data_file in logging.tqdm(
4009+
data_files_to_delete,
39734010
desc="Deleting unused files from dataset repository",
3974-
total=len(file_shards_to_delete),
4011+
total=len(data_files_to_delete),
39754012
disable=not logging.is_progress_bar_enabled(),
39764013
):
3977-
delete_file(file)
4014+
delete_file(data_file)
39784015

3979-
uploaded_size = 0
3980-
for index, shard in logging.tqdm(
3981-
enumerate(shards),
3982-
desc="Pushing dataset shards to the dataset hub",
3983-
total=num_shards,
3984-
disable=not logging.is_progress_bar_enabled(),
3985-
):
3986-
buffer = BytesIO()
3987-
shard.to_parquet(buffer)
3988-
uploaded_size += buffer.tell()
3989-
_retry(
3990-
api.upload_file,
3991-
func_kwargs=dict(
3992-
path_or_fileobj=buffer.getvalue(),
3993-
path_in_repo=path_in_repo(index),
3994-
repo_id=repo_id,
3995-
token=token,
3996-
repo_type="dataset",
3997-
revision=branch,
3998-
identical_ok=True,
3999-
),
4000-
exceptions=HTTPError,
4001-
status_codes=[504],
4002-
base_wait_time=2.0,
4003-
max_retries=5,
4004-
max_wait_time=20.0,
4005-
)
40064016
return repo_id, split, uploaded_size, dataset_nbytes
40074017

40084018
def push_to_hub(

src/datasets/data_files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class Url(str):
2424
pass
2525

2626

27-
SPLIT_PATTERN_SHARDED = "data/{split}-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9].*"
27+
SPLIT_PATTERN_SHARDED = "data/{split}-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"
2828

2929
DEFAULT_PATTERNS_SPLIT_IN_FILENAME = {
3030
str(Split.TRAIN): ["**train*"],

tests/test_upstream_hub.py

Lines changed: 105 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import fnmatch
12
import os
23
import tempfile
34
import time
@@ -77,7 +78,14 @@ def test_push_dataset_dict_to_hub_no_token(self):
7778

7879
# Ensure that there is a single file on the repository that has the correct name
7980
files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset"))
80-
self.assertListEqual(files, [".gitattributes", "data/train-00000-of-00001.parquet", "dataset_infos.json"])
81+
self.assertTrue(
82+
all(
83+
fnmatch.fnmatch(file, expected_file)
84+
for file, expected_file in zip(
85+
files, [".gitattributes", "data/train-00000-of-00001-*.parquet", "dataset_infos.json"]
86+
)
87+
)
88+
)
8189
finally:
8290
self.cleanup_repo(ds_name)
8391

@@ -97,7 +105,14 @@ def test_push_dataset_dict_to_hub_name_without_namespace(self):
97105

98106
# Ensure that there is a single file on the repository that has the correct name
99107
files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset"))
100-
self.assertListEqual(files, [".gitattributes", "data/train-00000-of-00001.parquet", "dataset_infos.json"])
108+
self.assertTrue(
109+
all(
110+
fnmatch.fnmatch(file, expected_file)
111+
for file, expected_file in zip(
112+
files, [".gitattributes", "data/train-00000-of-00001-*.parquet", "dataset_infos.json"]
113+
)
114+
)
115+
)
101116
finally:
102117
self.cleanup_repo(ds_name)
103118

@@ -131,7 +146,14 @@ def test_push_dataset_dict_to_hub_private(self):
131146

132147
# Ensure that there is a single file on the repository that has the correct name
133148
files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token))
134-
self.assertListEqual(files, [".gitattributes", "data/train-00000-of-00001.parquet", "dataset_infos.json"])
149+
self.assertTrue(
150+
all(
151+
fnmatch.fnmatch(file, expected_file)
152+
for file, expected_file in zip(
153+
files, [".gitattributes", "data/train-00000-of-00001-*.parquet", "dataset_infos.json"]
154+
)
155+
)
156+
)
135157
finally:
136158
self.cleanup_repo(ds_name)
137159

@@ -151,7 +173,14 @@ def test_push_dataset_dict_to_hub(self):
151173

152174
# Ensure that there is a single file on the repository that has the correct name
153175
files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token))
154-
self.assertListEqual(files, [".gitattributes", "data/train-00000-of-00001.parquet", "dataset_infos.json"])
176+
self.assertTrue(
177+
all(
178+
fnmatch.fnmatch(file, expected_file)
179+
for file, expected_file in zip(
180+
files, [".gitattributes", "data/train-00000-of-00001-*.parquet", "dataset_infos.json"]
181+
)
182+
)
183+
)
155184
finally:
156185
self.cleanup_repo(ds_name)
157186

@@ -171,14 +200,19 @@ def test_push_dataset_dict_to_hub_multiple_files(self):
171200

172201
# Ensure that there are two files on the repository that have the correct name
173202
files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token))
174-
self.assertListEqual(
175-
files,
176-
[
177-
".gitattributes",
178-
"data/train-00000-of-00002.parquet",
179-
"data/train-00001-of-00002.parquet",
180-
"dataset_infos.json",
181-
],
203+
self.assertTrue(
204+
all(
205+
fnmatch.fnmatch(file, expected_file)
206+
for file, expected_file in zip(
207+
files,
208+
[
209+
".gitattributes",
210+
"data/train-00000-of-00002-*.parquet",
211+
"data/train-00001-of-00002-*.parquet",
212+
"dataset_infos.json",
213+
],
214+
)
215+
)
182216
)
183217
finally:
184218
self.cleanup_repo(ds_name)
@@ -214,16 +248,22 @@ def test_push_dataset_dict_to_hub_overwrite_files(self):
214248

215249
# Ensure that there are two files on the repository that have the correct name
216250
files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token))
217-
self.assertListEqual(
218-
files,
219-
[
220-
".gitattributes",
221-
"data/random-00000-of-00001.parquet",
222-
"data/train-00000-of-00002.parquet",
223-
"data/train-00001-of-00002.parquet",
224-
"datafile.txt",
225-
"dataset_infos.json",
226-
],
251+
252+
self.assertTrue(
253+
all(
254+
fnmatch.fnmatch(file, expected_file)
255+
for file, expected_file in zip(
256+
files,
257+
[
258+
".gitattributes",
259+
"data/random-00000-of-00001-*.parquet",
260+
"data/train-00000-of-00002-*.parquet",
261+
"data/train-00001-of-00002-*.parquet",
262+
"datafile.txt",
263+
"dataset_infos.json",
264+
],
265+
)
266+
)
227267
)
228268

229269
self._api.delete_file("datafile.txt", repo_id=ds_name, repo_type="dataset", token=self._token)
@@ -260,15 +300,21 @@ def test_push_dataset_dict_to_hub_overwrite_files(self):
260300

261301
# Ensure that there are two files on the repository that have the correct name
262302
files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token))
263-
self.assertListEqual(
264-
files,
265-
[
266-
".gitattributes",
267-
"data/random-00000-of-00001.parquet",
268-
"data/train-00000-of-00001.parquet",
269-
"datafile.txt",
270-
"dataset_infos.json",
271-
],
303+
304+
self.assertTrue(
305+
all(
306+
fnmatch.fnmatch(file, expected_file)
307+
for file, expected_file in zip(
308+
files,
309+
[
310+
".gitattributes",
311+
"data/random-00000-of-00001-*.parquet",
312+
"data/train-00000-of-00001-*.parquet",
313+
"datafile.txt",
314+
"dataset_infos.json",
315+
],
316+
)
317+
)
272318
)
273319

274320
# Keeping the "datafile.txt" breaks the load_dataset to think it's a text-based dataset
@@ -403,6 +449,34 @@ def test_push_dataset_to_hub_custom_splits(self):
403449
finally:
404450
self.cleanup_repo(ds_name)
405451

452+
def test_push_to_dataset_skip_identical_files(self):
453+
ds = Dataset.from_dict({"x": list(range(1000)), "y": list(range(1000))})
454+
ds_name = f"{USER}/test-{int(time.time() * 10e3)}"
455+
try:
456+
with patch("datasets.arrow_dataset.HfApi.upload_file", side_effect=self._api.upload_file) as mock_hf_api:
457+
# Initial push
458+
ds.push_to_hub(ds_name, token=self._token, max_shard_size="1KB")
459+
call_count_old = mock_hf_api.call_count
460+
mock_hf_api.reset_mock()
461+
462+
# Remove a data file
463+
files = self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)
464+
data_files = [f for f in files if f.startswith("data/")]
465+
self.assertGreater(len(data_files), 1)
466+
self._api.delete_file(data_files[0], repo_id=ds_name, repo_type="dataset", token=self._token)
467+
468+
# "Resume" push - push missing files
469+
ds.push_to_hub(ds_name, token=self._token, max_shard_size="1KB")
470+
call_count_new = mock_hf_api.call_count
471+
self.assertGreater(call_count_old, call_count_new)
472+
473+
hub_ds = load_dataset(ds_name, split="train", download_mode="force_redownload")
474+
self.assertListEqual(ds.column_names, hub_ds.column_names)
475+
self.assertListEqual(list(ds.features.keys()), list(hub_ds.features.keys()))
476+
self.assertDictEqual(ds.features, hub_ds.features)
477+
finally:
478+
self.cleanup_repo(ds_name)
479+
406480
def test_push_dataset_dict_to_hub_custom_splits(self):
407481
ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})
408482

0 commit comments

Comments
 (0)