1+ import fnmatch
12import os
23import tempfile
34import time
@@ -77,7 +78,14 @@ def test_push_dataset_dict_to_hub_no_token(self):
7778
7879 # Ensure that there is a single file on the repository that has the correct name
7980 files = sorted (self ._api .list_repo_files (ds_name , repo_type = "dataset" ))
80- self .assertListEqual (files , [".gitattributes" , "data/train-00000-of-00001.parquet" , "dataset_infos.json" ])
81+ self .assertTrue (
82+ all (
83+ fnmatch .fnmatch (file , expected_file )
84+ for file , expected_file in zip (
85+ files , [".gitattributes" , "data/train-00000-of-00001-*.parquet" , "dataset_infos.json" ]
86+ )
87+ )
88+ )
8189 finally :
8290 self .cleanup_repo (ds_name )
8391
@@ -97,7 +105,14 @@ def test_push_dataset_dict_to_hub_name_without_namespace(self):
97105
98106 # Ensure that there is a single file on the repository that has the correct name
99107 files = sorted (self ._api .list_repo_files (ds_name , repo_type = "dataset" ))
100- self .assertListEqual (files , [".gitattributes" , "data/train-00000-of-00001.parquet" , "dataset_infos.json" ])
108+ self .assertTrue (
109+ all (
110+ fnmatch .fnmatch (file , expected_file )
111+ for file , expected_file in zip (
112+ files , [".gitattributes" , "data/train-00000-of-00001-*.parquet" , "dataset_infos.json" ]
113+ )
114+ )
115+ )
101116 finally :
102117 self .cleanup_repo (ds_name )
103118
@@ -131,7 +146,14 @@ def test_push_dataset_dict_to_hub_private(self):
131146
132147 # Ensure that there is a single file on the repository that has the correct name
133148 files = sorted (self ._api .list_repo_files (ds_name , repo_type = "dataset" , token = self ._token ))
134- self .assertListEqual (files , [".gitattributes" , "data/train-00000-of-00001.parquet" , "dataset_infos.json" ])
149+ self .assertTrue (
150+ all (
151+ fnmatch .fnmatch (file , expected_file )
152+ for file , expected_file in zip (
153+ files , [".gitattributes" , "data/train-00000-of-00001-*.parquet" , "dataset_infos.json" ]
154+ )
155+ )
156+ )
135157 finally :
136158 self .cleanup_repo (ds_name )
137159
@@ -151,7 +173,14 @@ def test_push_dataset_dict_to_hub(self):
151173
152174 # Ensure that there is a single file on the repository that has the correct name
153175 files = sorted (self ._api .list_repo_files (ds_name , repo_type = "dataset" , token = self ._token ))
154- self .assertListEqual (files , [".gitattributes" , "data/train-00000-of-00001.parquet" , "dataset_infos.json" ])
176+ self .assertTrue (
177+ all (
178+ fnmatch .fnmatch (file , expected_file )
179+ for file , expected_file in zip (
180+ files , [".gitattributes" , "data/train-00000-of-00001-*.parquet" , "dataset_infos.json" ]
181+ )
182+ )
183+ )
155184 finally :
156185 self .cleanup_repo (ds_name )
157186
@@ -171,14 +200,19 @@ def test_push_dataset_dict_to_hub_multiple_files(self):
171200
172201 # Ensure that there are two files on the repository that have the correct name
173202 files = sorted (self ._api .list_repo_files (ds_name , repo_type = "dataset" , token = self ._token ))
174- self .assertListEqual (
175- files ,
176- [
177- ".gitattributes" ,
178- "data/train-00000-of-00002.parquet" ,
179- "data/train-00001-of-00002.parquet" ,
180- "dataset_infos.json" ,
181- ],
203+ self .assertTrue (
204+ all (
205+ fnmatch .fnmatch (file , expected_file )
206+ for file , expected_file in zip (
207+ files ,
208+ [
209+ ".gitattributes" ,
210+ "data/train-00000-of-00002-*.parquet" ,
211+ "data/train-00001-of-00002-*.parquet" ,
212+ "dataset_infos.json" ,
213+ ],
214+ )
215+ )
182216 )
183217 finally :
184218 self .cleanup_repo (ds_name )
@@ -214,16 +248,22 @@ def test_push_dataset_dict_to_hub_overwrite_files(self):
214248
215249 # Ensure that there are two files on the repository that have the correct name
216250 files = sorted (self ._api .list_repo_files (ds_name , repo_type = "dataset" , token = self ._token ))
217- self .assertListEqual (
218- files ,
219- [
220- ".gitattributes" ,
221- "data/random-00000-of-00001.parquet" ,
222- "data/train-00000-of-00002.parquet" ,
223- "data/train-00001-of-00002.parquet" ,
224- "datafile.txt" ,
225- "dataset_infos.json" ,
226- ],
251+
252+ self .assertTrue (
253+ all (
254+ fnmatch .fnmatch (file , expected_file )
255+ for file , expected_file in zip (
256+ files ,
257+ [
258+ ".gitattributes" ,
259+ "data/random-00000-of-00001-*.parquet" ,
260+ "data/train-00000-of-00002-*.parquet" ,
261+ "data/train-00001-of-00002-*.parquet" ,
262+ "datafile.txt" ,
263+ "dataset_infos.json" ,
264+ ],
265+ )
266+ )
227267 )
228268
229269 self ._api .delete_file ("datafile.txt" , repo_id = ds_name , repo_type = "dataset" , token = self ._token )
@@ -260,15 +300,21 @@ def test_push_dataset_dict_to_hub_overwrite_files(self):
260300
261301 # Ensure that there are two files on the repository that have the correct name
262302 files = sorted (self ._api .list_repo_files (ds_name , repo_type = "dataset" , token = self ._token ))
263- self .assertListEqual (
264- files ,
265- [
266- ".gitattributes" ,
267- "data/random-00000-of-00001.parquet" ,
268- "data/train-00000-of-00001.parquet" ,
269- "datafile.txt" ,
270- "dataset_infos.json" ,
271- ],
303+
304+ self .assertTrue (
305+ all (
306+ fnmatch .fnmatch (file , expected_file )
307+ for file , expected_file in zip (
308+ files ,
309+ [
310+ ".gitattributes" ,
311+ "data/random-00000-of-00001-*.parquet" ,
312+ "data/train-00000-of-00001-*.parquet" ,
313+ "datafile.txt" ,
314+ "dataset_infos.json" ,
315+ ],
316+ )
317+ )
272318 )
273319
274320 # Keeping the "datafile.txt" breaks the load_dataset to think it's a text-based dataset
@@ -403,6 +449,34 @@ def test_push_dataset_to_hub_custom_splits(self):
403449 finally :
404450 self .cleanup_repo (ds_name )
405451
452+ def test_push_to_dataset_skip_identical_files (self ):
453+ ds = Dataset .from_dict ({"x" : list (range (1000 )), "y" : list (range (1000 ))})
454+ ds_name = f"{ USER } /test-{ int (time .time () * 10e3 )} "
455+ try :
456+ with patch ("datasets.arrow_dataset.HfApi.upload_file" , side_effect = self ._api .upload_file ) as mock_hf_api :
457+ # Initial push
458+ ds .push_to_hub (ds_name , token = self ._token , max_shard_size = "1KB" )
459+ call_count_old = mock_hf_api .call_count
460+ mock_hf_api .reset_mock ()
461+
462+ # Remove a data file
463+ files = self ._api .list_repo_files (ds_name , repo_type = "dataset" , token = self ._token )
464+ data_files = [f for f in files if f .startswith ("data/" )]
465+ self .assertGreater (len (data_files ), 1 )
466+ self ._api .delete_file (data_files [0 ], repo_id = ds_name , repo_type = "dataset" , token = self ._token )
467+
468+ # "Resume" push - push missing files
469+ ds .push_to_hub (ds_name , token = self ._token , max_shard_size = "1KB" )
470+ call_count_new = mock_hf_api .call_count
471+ self .assertGreater (call_count_old , call_count_new )
472+
473+ hub_ds = load_dataset (ds_name , split = "train" , download_mode = "force_redownload" )
474+ self .assertListEqual (ds .column_names , hub_ds .column_names )
475+ self .assertListEqual (list (ds .features .keys ()), list (hub_ds .features .keys ()))
476+ self .assertDictEqual (ds .features , hub_ds .features )
477+ finally :
478+ self .cleanup_repo (ds_name )
479+
406480 def test_push_dataset_dict_to_hub_custom_splits (self ):
407481 ds = Dataset .from_dict ({"x" : [1 , 2 , 3 ], "y" : [4 , 5 , 6 ]})
408482
0 commit comments