Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/package_reference/builder_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Two main classes are mostly used during the dataset building process.

.. autoclass:: datasets.NamedSplit

.. autoclass:: datasets.NamedSplitAll

.. autoclass:: datasets.ReadInstruction

.. autoclass:: datasets.utils::DownloadConfig
Expand Down
12 changes: 11 additions & 1 deletion src/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,17 @@
from .keyhash import KeyHasher
from .load import import_main_class, load_dataset, load_from_disk, load_metric, prepare_module
from .metric import Metric
from .splits import NamedSplit, Split, SplitBase, SplitDict, SplitGenerator, SplitInfo, SubSplitInfo, percent
from .splits import (
NamedSplit,
NamedSplitAll,
Split,
SplitBase,
SplitDict,
SplitGenerator,
SplitInfo,
SubSplitInfo,
percent,
)
from .utils import *
from .utils.tqdm_utils import disable_progress_bar

Expand Down
4 changes: 2 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from .formatting import format_table, get_format_type_from_alias, get_formatter, query_table
from .info import DatasetInfo
from .search import IndexableMixin
from .splits import NamedSplit
from .splits import NamedSplit, Split
from .table import (
ConcatenationTable,
InMemoryTable,
Expand Down Expand Up @@ -718,7 +718,7 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] =
indices_table = None

split = state["_split"]
split = NamedSplit(split) if split is not None else split
split = Split(split) if split is not None else split

return Dataset(
arrow_table=arrow_table,
Expand Down
5 changes: 4 additions & 1 deletion src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,10 @@ def _build_single_dataset(
):
"""as_dataset for a single split."""
verify_infos = not ignore_verifications
if isinstance(split, str):
if not isinstance(split, ReadInstruction):
split = str(split)
if split == "all":
split = "+".join(self.info.splits.keys())
split = Split(split)

# Build base dataset
Expand Down
6 changes: 4 additions & 2 deletions src/datasets/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def __init__(self):
super(NamedSplitAll, self).__init__("all")

def __repr__(self):
return f"NamedSplitAll({self._name!r})"
return "NamedSplitAll()"

def get_read_instruction(self, split_dict):
# Merge all dataset split together
Expand All @@ -398,6 +398,7 @@ class Split:
model architecture, etc.).
- `TEST`: the testing data. This is the data to report metrics on. Typically
you do not want to use this during model iteration as you may overfit to it.
- `ALL`: the union of all defined dataset splits.

Note: All splits, including compositions inherit from `datasets.SplitBase`

Expand All @@ -407,10 +408,11 @@ class Split:
TRAIN = NamedSplit("train")
TEST = NamedSplit("test")
VALIDATION = NamedSplit("validation")
ALL = NamedSplitAll()

def __new__(cls, name):
"""Create a custom split with datasets.Split('custom_name')."""
return NamedSplit(name)
return NamedSplitAll() if name == "all" else NamedSplit(name)


# Similar to SplitInfo, but contain an additional slice info
Expand Down
8 changes: 8 additions & 0 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,14 @@ def _post_processing_resources(self, split):
self.assertListEqual(dset.column_names, ["text", "tokens"])
del dset

dset = dummy_builder.as_dataset("all")
self.assertIsInstance(dset, Dataset)
self.assertEqual(dset.split, "train+test")
self.assertEqual(len(dset), 20)
self.assertDictEqual(dset.features, Features({"text": Value("string"), "tokens": [Value("string")]}))
self.assertListEqual(dset.column_names, ["text", "tokens"])
del dset

def _post_process(self, dataset, resources_paths):
return dataset.select([0, 1], keep_in_memory=True)

Expand Down