Skip to content

Commit c47cc14

Browse files
authored
Fix imagefolder with one image (#6556)
* fix imagefolder with one image * better typing * fix flaky test
1 parent d26abad commit c47cc14

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

src/datasets/load.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def _load_importable_file(
491491

492492
def infer_module_for_data_files_list(
493493
data_files_list: DataFilesList, download_config: Optional[DownloadConfig] = None
494-
) -> Optional[Tuple[str, str]]:
494+
) -> Tuple[Optional[str], dict]:
495495
"""Infer module (and builder kwargs) from list of data files.
496496
497497
It picks the module based on the most common file extension.
@@ -507,18 +507,18 @@ def infer_module_for_data_files_list(
507507
- dict of builder kwargs
508508
"""
509509
extensions_counter = Counter(
510-
"." + suffix.lower()
510+
("." + suffix.lower(), xbasename(filepath) in ("metadata.jsonl", "metadata.csv"))
511511
for filepath in data_files_list[: config.DATA_FILES_MAX_NUMBER_FOR_MODULE_INFERENCE]
512512
for suffix in xbasename(filepath).split(".")[1:]
513513
)
514514
if extensions_counter:
515515

516-
def sort_key(ext_count: Tuple[str, int]) -> Tuple[int, bool]:
517-
"""Sort by count and set ".parquet" as the favorite in case of a draw"""
518-
ext, count = ext_count
519-
return (count, ext == ".parquet", ext)
516+
def sort_key(ext_count: Tuple[Tuple[str, bool], int]) -> Tuple[int, bool]:
517+
"""Sort by count and set ".parquet" as the favorite in case of a draw, and ignore metadata files"""
518+
(ext, is_metadata), count = ext_count
519+
return (not is_metadata, count, ext == ".parquet", ext)
520520

521-
for ext, _ in sorted(extensions_counter.items(), key=sort_key, reverse=True):
521+
for (ext, _), _ in sorted(extensions_counter.items(), key=sort_key, reverse=True):
522522
if ext in _EXTENSION_TO_MODULE:
523523
return _EXTENSION_TO_MODULE[ext]
524524
elif ext == ".zip":
@@ -528,7 +528,7 @@ def sort_key(ext_count: Tuple[str, int]) -> Tuple[int, bool]:
528528

529529
def infer_module_for_data_files_list_in_archives(
530530
data_files_list: DataFilesList, download_config: Optional[DownloadConfig] = None
531-
) -> Optional[Tuple[str, str]]:
531+
) -> Tuple[Optional[str], dict]:
532532
"""Infer module (and builder kwargs) from list of archive data files.
533533
534534
Args:

src/datasets/packaged_modules/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22
import re
3-
from typing import Dict, List
3+
from typing import Dict, List, Tuple
44

55
from huggingface_hub.utils import insecure_hashlib
66

@@ -44,7 +44,7 @@ def _hash_python_lines(lines: List[str]) -> str:
4444
}
4545

4646
# Used to infer the module to use based on the data files extensions
47-
_EXTENSION_TO_MODULE = {
47+
_EXTENSION_TO_MODULE: Dict[str, Tuple[str, dict]] = {
4848
".csv": ("csv", {}),
4949
".tsv": ("csv", {"sep": "\t"}),
5050
".json": ("json", {}),

tests/test_arrow_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3995,8 +3995,8 @@ def test_build_local_temp_path(uri_or_path):
39953995
path_relative_to_tmp_dir = Path(local_temp_path).relative_to(Path(tempfile.gettempdir())).as_posix()
39963996

39973997
assert (
3998-
"hdfs" not in path_relative_to_tmp_dir
3999-
and "s3" not in path_relative_to_tmp_dir
3998+
"hdfs://" not in path_relative_to_tmp_dir
3999+
and "s3://" not in path_relative_to_tmp_dir
40004000
and not local_temp_path.startswith(extracted_path_without_anchor)
40014001
and local_temp_path.endswith(extracted_path_without_anchor)
40024002
), f"Local temp path: {local_temp_path}"

0 commit comments

Comments
 (0)