Skip to content

Commit cd674a3

Browse files
authored
[data_files] Only match separated split names (#4633)
* only match separated split names * docs * add space separator * fix win * add testing * add evaluation * suggestions in doc * use list comprehension + support numbers * update tests * remove unnecessary patching and context manager * style
1 parent 4fb3ed0 commit cd674a3

File tree

3 files changed

+129
-20
lines changed

3 files changed

+129
-20
lines changed

docs/source/repository_structure.mdx

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ my_dataset_repository/
2222

2323
## Splits and file names
2424

25-
🤗 Datasets automatically infer a dataset's train, validation, and test splits from the file names. Files that contain *train* in their names are considered part of the train split. The same idea applies to the test and validation split:
25+
🤗 Datasets automatically infer a dataset's train, validation, and test splits from the file names.
2626

27-
- All the files that contain *test* in their names are considered part of the test split.
28-
- All the files that contain *valid* in their names are considered part of the validation split.
27+
All the files that contain a split name in their names (delimited by non-word characters, see below) are considered part of that split:
28+
29+
- train split: `train.csv`, `my_train_file.csv`, `train1.csv`
30+
- validation split: `validation.csv`, `my_validation_file.csv`, `validation1.csv`
31+
- test split: `test.csv`, `my_test_file.csv`, `test1.csv`
2932

3033
Here is an example where all the files are placed into a directory named `data`:
3134

@@ -35,9 +38,13 @@ my_dataset_repository/
3538
└── data/
3639
├── train.csv
3740
├── test.csv
38-
└── valid.csv
41+
└── validation.csv
3942
```
4043

44+
Note that if a file contains *test* but is embedded in another word (e.g. `testfile.csv`), it's not counted as a test file.
45+
It must be delimited by non-word characters, e.g. `test_file.csv`.
46+
Supported delimiters are underscores, dashes, spaces, dots and numbers.
47+
4148
## Multiple files per split
4249

4350
If one of your splits comprises several files, 🤗 Datasets can still infer whether it is the train, validation, and test split from the file name.
@@ -58,7 +65,8 @@ Make sure all the files of your `train` set have *train* in their names (same fo
5865
Even if you add a prefix or suffix to `train` in the file name (like `my_train_file_00001.csv` for example),
5966
🤗 Datasets can still infer the appropriate split.
6067

61-
For convenience, you can also place your data files into different directories. In this case, the split name is inferred from the directory name.
68+
For convenience, you can also place your data files into different directories.
69+
In this case, the split name is inferred from the directory name.
6270

6371
```
6472
my_dataset_repository/
@@ -80,6 +88,28 @@ Eventually, you'll also be able to structure your repository to specify differen
8088

8189
</Tip>
8290

91+
## Split names keywords
92+
93+
Validation splits are sometimes called "dev", and test splits are called "eval".
94+
These other names are also supported.
95+
In particular, these keywords are equivalent:
96+
97+
- train, training
98+
- validation, valid, dev
99+
- test, testing, eval, evaluation
100+
101+
Therefore this is also a valid repository:
102+
103+
```
104+
my_dataset_repository/
105+
├── README.md
106+
└── data/
107+
├── training.csv
108+
├── eval.csv
109+
└── valid.csv
110+
```
111+
112+
83113
## Custom split names
84114

85115
If you have other data files in addition to the traditional train, validation, and test sets, you must use a different structure.

src/datasets/data_files.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,47 @@ class Url(str):
2626

2727
SPLIT_PATTERN_SHARDED = "data/{split}-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"
2828

29+
TRAIN_KEYWORDS = ["train", "training"]
30+
TEST_KEYWORDS = ["test", "testing", "eval", "evaluation"]
31+
VALIDATION_KEYWORDS = ["validation", "valid", "dev"]
32+
NON_WORDS_CHARS = "-._ 0-9"
33+
KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**[{sep}/]{keyword}[{sep}]*", "{keyword}[{sep}]*"]
34+
KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = ["{keyword}[{sep}/]**", "**[{sep}/]{keyword}[{sep}/]**"]
35+
2936
DEFAULT_PATTERNS_SPLIT_IN_FILENAME = {
30-
str(Split.TRAIN): ["**train*"],
31-
str(Split.TEST): ["**test*", "**eval*"],
32-
str(Split.VALIDATION): ["**dev*", "**valid*"],
37+
str(Split.TRAIN): [
38+
pattern.format(keyword=keyword, sep=NON_WORDS_CHARS)
39+
for keyword in TRAIN_KEYWORDS
40+
for pattern in KEYWORDS_IN_FILENAME_BASE_PATTERNS
41+
],
42+
str(Split.TEST): [
43+
pattern.format(keyword=keyword, sep=NON_WORDS_CHARS)
44+
for keyword in TEST_KEYWORDS
45+
for pattern in KEYWORDS_IN_FILENAME_BASE_PATTERNS
46+
],
47+
str(Split.VALIDATION): [
48+
pattern.format(keyword=keyword, sep=NON_WORDS_CHARS)
49+
for keyword in VALIDATION_KEYWORDS
50+
for pattern in KEYWORDS_IN_FILENAME_BASE_PATTERNS
51+
],
3352
}
3453

3554
DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME = {
36-
str(Split.TRAIN): ["**train*/**"],
37-
str(Split.TEST): ["**test*/**", "**eval*/**"],
38-
str(Split.VALIDATION): ["**dev*/**", "**valid*/**"],
55+
str(Split.TRAIN): [
56+
pattern.format(keyword=keyword, sep=NON_WORDS_CHARS)
57+
for keyword in TRAIN_KEYWORDS
58+
for pattern in KEYWORDS_IN_DIR_NAME_BASE_PATTERNS
59+
],
60+
str(Split.TEST): [
61+
pattern.format(keyword=keyword, sep=NON_WORDS_CHARS)
62+
for keyword in TEST_KEYWORDS
63+
for pattern in KEYWORDS_IN_DIR_NAME_BASE_PATTERNS
64+
],
65+
str(Split.VALIDATION): [
66+
pattern.format(keyword=keyword, sep=NON_WORDS_CHARS)
67+
for keyword in VALIDATION_KEYWORDS
68+
for pattern in KEYWORDS_IN_DIR_NAME_BASE_PATTERNS
69+
],
3970
}
4071

4172
DEFAULT_PATTERNS_ALL = {

tests/test_data_files.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import os
2-
from itertools import chain
32
from pathlib import Path, PurePath
3+
from typing import List
44
from unittest.mock import patch
55

66
import fsspec
77
import pytest
8+
from fsspec.spec import AbstractFileSystem
89
from huggingface_hub.hf_api import DatasetInfo
910

1011
from datasets.data_files import (
@@ -491,6 +492,47 @@ def test_DataFilesDict_from_hf_local_or_remote_hashing(text_file):
491492
assert Hasher.hash(data_files1) != Hasher.hash(data_files2)
492493

493494

495+
def mock_fs(file_paths: List[str]):
496+
"""
497+
Set up a mock filesystem for fsspec containing the provided files
498+
499+
Example:
500+
501+
```py
502+
>>> fs = mock_fs(["data/train.txt", "data.test.txt"])
503+
>>> assert fsspec.get_filesystem_class("mock").__name__ == "DummyTestFS"
504+
>>> assert type(fs).__name__ == "DummyTestFS"
505+
>>> print(fs.glob("**"))
506+
["data", "data/train.txt", "data.test.txt"]
507+
```
508+
"""
509+
510+
dir_paths = {file_path.rsplit("/")[0] for file_path in file_paths if "/" in file_path}
511+
fs_contents = [{"name": dir_path, "type": "directory"} for dir_path in dir_paths] + [
512+
{"name": file_path, "type": "file", "size": 10} for file_path in file_paths
513+
]
514+
515+
class DummyTestFS(AbstractFileSystem):
516+
protocol = "mock"
517+
_fs_contents = fs_contents
518+
519+
def ls(self, path, detail=True, refresh=True, **kwargs):
520+
if kwargs.pop("strip_proto", True):
521+
path = self._strip_protocol(path)
522+
523+
files = not refresh and self._ls_from_cache(path)
524+
if not files:
525+
files = [file for file in self._fs_contents if path == self._parent(file["name"])]
526+
files.sort(key=lambda file: file["name"])
527+
self.dircache[path.rstrip("/")] = files
528+
529+
if detail:
530+
return files
531+
return [file["name"] for file in files]
532+
533+
return DummyTestFS()
534+
535+
494536
@pytest.mark.parametrize(
495537
"data_file_per_split",
496538
[
@@ -541,24 +583,30 @@ def test_DataFilesDict_from_hf_local_or_remote_hashing(text_file):
541583
{"validation": "dev/dataset.txt"},
542584
# With other extensions
543585
{"train": "train.parquet", "test": "test.parquet", "validation": "valid.parquet"},
586+
# With "dev" or "eval" without separators
587+
{"train": "developers_list.txt"},
588+
{"train": "data/seqeval_results.txt"},
589+
{"train": "contest.txt"},
590+
# With supported separators
591+
{"test": "my.test.file.txt"},
592+
{"test": "my-test-file.txt"},
593+
{"test": "my_test_file.txt"},
594+
{"test": "my test file.txt"},
595+
{"test": "test00001.txt"},
544596
],
545597
)
546598
def test_get_data_files_patterns(data_file_per_split):
547599
data_file_per_split = {k: v if isinstance(v, list) else [v] for k, v in data_file_per_split.items()}
600+
file_paths = [file_path for split_file_paths in data_file_per_split.values() for file_path in split_file_paths]
601+
fs = mock_fs(file_paths)
548602

549603
def resolver(pattern):
550-
return [PurePath(path) for path in chain(*data_file_per_split.values()) if PurePath(path).match(pattern)]
604+
return [PurePath(file_path) for file_path in fs.glob(pattern) if fs.isfile(file_path)]
551605

552606
patterns_per_split = _get_data_files_patterns(resolver)
553607
assert sorted(patterns_per_split.keys()) == sorted(data_file_per_split.keys())
554608
for split, patterns in patterns_per_split.items():
555-
matched = [
556-
path
557-
for path in chain(*data_file_per_split.values())
558-
for pattern in patterns
559-
if PurePath(path).match(pattern)
560-
]
561-
assert len(matched) == len(data_file_per_split[split])
609+
matched = [file_path.as_posix() for pattern in patterns for file_path in resolver(pattern)]
562610
assert matched == data_file_per_split[split]
563611

564612

0 commit comments

Comments
 (0)