Skip to content

Commit 9dcb717

Browse files
Do not filter out .zip extensions from no-script datasets (#6208)
* Rename zip_csv_path fixture dirname and filename * Test load no-script dataset with ZIP file * Fix style * Avoid filtering out .zip extension
1 parent 3794f77 commit 9dcb717

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

src/datasets/packaged_modules/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,5 @@ def _hash_python_lines(lines: List[str]) -> str:
6060
for _ext, (_module, _) in _EXTENSION_TO_MODULE.items():
6161
_MODULE_TO_EXTENSIONS.setdefault(_module, []).append(_ext)
6262

63-
_MODULE_TO_EXTENSIONS["imagefolder"].append(".zip")
64-
_MODULE_TO_EXTENSIONS["audiofolder"].append(".zip")
63+
for _module in _MODULE_TO_EXTENSIONS:
64+
_MODULE_TO_EXTENSIONS[_module].append(".zip")

tests/fixtures/files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def bz2_csv_path(csv_path, tmp_path_factory):
289289

290290
@pytest.fixture(scope="session")
291291
def zip_csv_path(csv_path, csv2_path, tmp_path_factory):
292-
path = tmp_path_factory.mktemp("data") / "dataset.csv.zip"
292+
path = tmp_path_factory.mktemp("zip_csv_path") / "csv-dataset.zip"
293293
with zipfile.ZipFile(path, "w") as f:
294294
f.write(csv_path, arcname=os.path.basename(csv_path))
295295
f.write(csv2_path, arcname=os.path.basename(csv2_path))

tests/test_load.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,3 +1458,12 @@ def test_load_dataset_with_storage_options_with_decoding(mockfs, image_file):
14581458
ds = load_dataset("imagefolder", data_files=data_files, storage_options=mockfs.storage_options)
14591459
assert len(ds["train"]) == 1
14601460
assert isinstance(ds["train"][0]["image"], PIL.Image.Image)
1461+
1462+
1463+
def test_load_dataset_without_script_with_zip(zip_csv_path):
1464+
path = str(zip_csv_path.parent)
1465+
ds = load_dataset(path)
1466+
assert list(ds.keys()) == ["train"]
1467+
assert ds["train"].column_names == ["col_1", "col_2", "col_3"]
1468+
assert ds["train"].num_rows == 8
1469+
assert ds["train"][0] == {"col_1": 0, "col_2": 0, "col_3": 0.0}

0 commit comments

Comments
 (0)