Skip to content

Commit d963a0e

Browse files
committed
only match separated split names
1 parent de2f6ef commit d963a0e

2 files changed

Lines changed: 101 additions & 22 deletions

File tree

src/datasets/data_files.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,29 @@ class Url(str):
2727
SPLIT_PATTERN_SHARDED = "data/{split}-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"
2828

2929
DEFAULT_PATTERNS_SPLIT_IN_FILENAME = {
30-
str(Split.TRAIN): ["**train*"],
31-
str(Split.TEST): ["**test*", "**eval*"],
32-
str(Split.VALIDATION): ["**dev*", "**valid*"],
30+
str(Split.TRAIN): ["**[-._/]train[-._]*", "train[-._]*", "**[-._/]training[-._]*", "training[-._]*"],
31+
str(Split.TEST): ["**[-._/]test[-._]*", "test[-._]*", "**[-._/]eval[-._]*", "eval[-._]*"],
32+
str(Split.VALIDATION): [
33+
"**[-._/]dev[-._]*",
34+
"dev[-._]*",
35+
"**[-._/]valid[-._]*",
36+
"valid[-._]*",
37+
"**[-._/]validation[-._]*",
38+
"validation[-._]*",
39+
],
3340
}
3441

3542
DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME = {
36-
str(Split.TRAIN): ["**train*/**"],
37-
str(Split.TEST): ["**test*/**", "**eval*/**"],
38-
str(Split.VALIDATION): ["**dev*/**", "**valid*/**"],
43+
str(Split.TRAIN): ["train[-._/]**", "**[-._/]train[-._/]**", "training[-._/]**", "**[-._/]training[-._/]**"],
44+
str(Split.TEST): ["test[-._/]**", "**[-._/]test[-._/]**", "eval[-._/]**", "**[-._/]eval[-._/]**"],
45+
str(Split.VALIDATION): [
46+
"dev[-._/]**",
47+
"**[-._/]dev[-._/]**",
48+
"valid[-._/]**",
49+
"**[-._/]valid[-._/]**",
50+
"validation[-._/]**",
51+
"**[-._/]validation[-._/]**",
52+
],
3953
}
4054

4155
DEFAULT_PATTERNS_ALL = {

tests/test_data_files.py

Lines changed: 81 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import os
2-
from itertools import chain
2+
from contextlib import contextmanager
33
from pathlib import Path, PurePath
4+
from typing import List
45
from unittest.mock import patch
56

67
import fsspec
78
import pytest
9+
from fsspec.spec import AbstractBufferedFile, AbstractFileSystem
810
from huggingface_hub.hf_api import DatasetInfo
911

1012
from datasets.data_files import (
@@ -491,6 +493,69 @@ def test_DataFilesDict_from_hf_local_or_remote_hashing(text_file):
491493
assert Hasher.hash(data_files1) != Hasher.hash(data_files2)
492494

493495

496+
@contextmanager
497+
def mock_fs(file_paths: List[str]):
498+
"""context manager to set up a mock:// filesystem in sfspec containing the provided files"""
499+
500+
dir_paths = {file_path.rsplit("/")[0] for file_path in file_paths if "/" in file_path}
501+
fs_contents = [{"name": dir_path, "type": "directory"} for dir_path in dir_paths] + [
502+
{"name": file_path, "type": "file", "size": 10} for file_path in file_paths
503+
]
504+
505+
class DummyTestFS(AbstractFileSystem):
506+
protocol = "mock"
507+
_file_class = AbstractBufferedFile
508+
_fs_contents = fs_contents
509+
510+
def __getitem__(self, name):
511+
for item in self._fs_contents:
512+
if item["name"] == name:
513+
return item
514+
raise IndexError(f"{name} not found!")
515+
516+
def ls(self, path, detail=True, refresh=True, **kwargs):
517+
if kwargs.pop("strip_proto", True):
518+
path = self._strip_protocol(path)
519+
520+
files = not refresh and self._ls_from_cache(path)
521+
if not files:
522+
files = [file for file in self._fs_contents if path == self._parent(file["name"])]
523+
files.sort(key=lambda file: file["name"])
524+
self.dircache[path.rstrip("/")] = files
525+
526+
if detail:
527+
return files
528+
return [file["name"] for file in files]
529+
530+
@classmethod
531+
def get_test_paths(cls, start_with=""):
532+
"""Helper to return directory and file paths with no details"""
533+
all = [file["name"] for file in cls._fs_contents if file["name"].startswith(start_with)]
534+
return all
535+
536+
def _open(
537+
self,
538+
path,
539+
mode="rb",
540+
block_size=None,
541+
autocommit=True,
542+
cache_options=None,
543+
**kwargs,
544+
):
545+
return self._file_class(
546+
self,
547+
path,
548+
mode,
549+
block_size,
550+
autocommit,
551+
cache_options=cache_options,
552+
**kwargs,
553+
)
554+
555+
with patch.dict(fsspec.registry.target, {"mock": DummyTestFS}):
556+
yield DummyTestFS()
557+
558+
494559
@pytest.mark.parametrize(
495560
"data_file_per_split",
496561
[
@@ -541,25 +606,25 @@ def test_DataFilesDict_from_hf_local_or_remote_hashing(text_file):
541606
{"validation": "dev/dataset.txt"},
542607
# With other extensions
543608
{"train": "train.parquet", "test": "test.parquet", "validation": "valid.parquet"},
609+
# With "dev" or "eval" without separators
610+
{"train": "developers_list.txt"},
611+
{"train": "data/seqeval_results.txt"},
544612
],
545613
)
546614
def test_get_data_files_patterns(data_file_per_split):
547615
data_file_per_split = {k: v if isinstance(v, list) else [v] for k, v in data_file_per_split.items()}
548-
549-
def resolver(pattern):
550-
return [PurePath(path) for path in chain(*data_file_per_split.values()) if PurePath(path).match(pattern)]
551-
552-
patterns_per_split = _get_data_files_patterns(resolver)
553-
assert sorted(patterns_per_split.keys()) == sorted(data_file_per_split.keys())
554-
for split, patterns in patterns_per_split.items():
555-
matched = [
556-
path
557-
for path in chain(*data_file_per_split.values())
558-
for pattern in patterns
559-
if PurePath(path).match(pattern)
560-
]
561-
assert len(matched) == len(data_file_per_split[split])
562-
assert matched == data_file_per_split[split]
616+
with mock_fs(
617+
[file_path for split_file_paths in data_file_per_split.values() for file_path in split_file_paths]
618+
) as fs:
619+
620+
def resolver(pattern):
621+
return [PurePath(file_path) for file_path in fs.glob(pattern) if fs.isfile(file_path)]
622+
623+
patterns_per_split = _get_data_files_patterns(resolver)
624+
assert sorted(patterns_per_split.keys()) == sorted(data_file_per_split.keys())
625+
for split, patterns in patterns_per_split.items():
626+
matched = [str(file_path) for pattern in patterns for file_path in resolver(pattern)]
627+
assert matched == data_file_per_split[split]
563628

564629

565630
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)