Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def _load_importable_file(

def infer_module_for_data_files_list(
data_files_list: DataFilesList, download_config: Optional[DownloadConfig] = None
) -> Optional[Tuple[str, str]]:
) -> Tuple[Optional[str], dict]:
"""Infer module (and builder kwargs) from list of data files.

It picks the module based on the most common file extension.
Expand All @@ -507,18 +507,18 @@ def infer_module_for_data_files_list(
- dict of builder kwargs
"""
extensions_counter = Counter(
"." + suffix.lower()
("." + suffix.lower(), xbasename(filepath) in ("metadata.jsonl", "metadata.csv"))
for filepath in data_files_list[: config.DATA_FILES_MAX_NUMBER_FOR_MODULE_INFERENCE]
for suffix in xbasename(filepath).split(".")[1:]
)
if extensions_counter:

def sort_key(ext_count: Tuple[str, int]) -> Tuple[int, bool]:
"""Sort by count and set ".parquet" as the favorite in case of a draw"""
ext, count = ext_count
return (count, ext == ".parquet", ext)
def sort_key(ext_count: Tuple[Tuple[str, bool], int]) -> Tuple[int, bool]:
"""Sort by count and set ".parquet" as the favorite in case of a draw, and ignore metadata files"""
(ext, is_metadata), count = ext_count
return (not is_metadata, count, ext == ".parquet", ext)

for ext, _ in sorted(extensions_counter.items(), key=sort_key, reverse=True):
for (ext, _), _ in sorted(extensions_counter.items(), key=sort_key, reverse=True):
if ext in _EXTENSION_TO_MODULE:
return _EXTENSION_TO_MODULE[ext]
elif ext == ".zip":
Expand All @@ -528,7 +528,7 @@ def sort_key(ext_count: Tuple[str, int]) -> Tuple[int, bool]:

def infer_module_for_data_files_list_in_archives(
data_files_list: DataFilesList, download_config: Optional[DownloadConfig] = None
) -> Optional[Tuple[str, str]]:
) -> Tuple[Optional[str], dict]:
"""Infer module (and builder kwargs) from list of archive data files.

Args:
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/packaged_modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import re
from typing import Dict, List
from typing import Dict, List, Tuple

from huggingface_hub.utils import insecure_hashlib

Expand Down Expand Up @@ -44,7 +44,7 @@ def _hash_python_lines(lines: List[str]) -> str:
}

# Used to infer the module to use based on the data files extensions
_EXTENSION_TO_MODULE = {
_EXTENSION_TO_MODULE: Dict[str, Tuple[str, dict]] = {
".csv": ("csv", {}),
".tsv": ("csv", {"sep": "\t"}),
".json": ("json", {}),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3995,8 +3995,8 @@ def test_build_local_temp_path(uri_or_path):
path_relative_to_tmp_dir = Path(local_temp_path).relative_to(Path(tempfile.gettempdir())).as_posix()

assert (
"hdfs" not in path_relative_to_tmp_dir
and "s3" not in path_relative_to_tmp_dir
"hdfs://" not in path_relative_to_tmp_dir
and "s3://" not in path_relative_to_tmp_dir
and not local_temp_path.startswith(extracted_path_without_anchor)
and local_temp_path.endswith(extracted_path_without_anchor)
), f"Local temp path: {local_temp_path}"
Expand Down