Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/package_reference/builder_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Two main classes are mostly used during the dataset building process.

.. autoclass:: datasets.NamedSplit

.. autoclass:: datasets.NamedSplitAll

.. autoclass:: datasets.utils::DownloadConfig

.. autoclass:: datasets.utils::Version
12 changes: 11 additions & 1 deletion src/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,17 @@
)
from .load import import_main_class, load_dataset, load_from_disk, load_metric, prepare_module
from .metric import Metric
from .splits import NamedSplit, Split, SplitBase, SplitDict, SplitGenerator, SplitInfo, SubSplitInfo, percent
from .splits import (
NamedSplit,
NamedSplitAll,
Split,
SplitBase,
SplitDict,
SplitGenerator,
SplitInfo,
SubSplitInfo,
percent,
)
from .utils import *
from .utils.tqdm_utils import disable_progress_bar

Expand Down
4 changes: 2 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from .formatting import format_table, get_format_type_from_alias, get_formatter, query_table
from .info import DATASET_INFO_FILENAME, DatasetInfo
from .search import IndexableMixin
from .splits import NamedSplit
from .splits import NamedSplit, Split
from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables, list_table_cache_files
from .utils import map_nested
from .utils.deprecation_utils import deprecated
Expand Down Expand Up @@ -674,7 +674,7 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] =
indices_table = None

split = state["_split"]
split = NamedSplit(split) if split is not None else split
split = Split(split) if split is not None else split

return Dataset(
arrow_table=arrow_table,
Expand Down
2 changes: 2 additions & 0 deletions src/datasets/arrow_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def make_file_instructions(name, split_infos, instruction, filetype_suffix=None)
"""
name2len = {info.name: info.num_examples for info in split_infos}
if not isinstance(instruction, ReadInstruction):
if str(instruction) == "all":
instruction = "+".join(name2len.keys())
instruction = ReadInstruction.from_spec(instruction)
# Create the absolute instruction (per split)
absolute_instructions = instruction.to_absolute(name2len)
Expand Down
6 changes: 4 additions & 2 deletions src/datasets/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def __init__(self):
super(NamedSplitAll, self).__init__("all")

def __repr__(self):
return f"NamedSplitAll({self._name!r})"
return "NamedSplitAll()"

def get_read_instruction(self, split_dict):
# Merge all dataset split together
Expand All @@ -398,6 +398,7 @@ class Split:
model architecture, etc.).
- `TEST`: the testing data. This is the data to report metrics on. Typically
you do not want to use this during model iteration as you may overfit to it.
- `ALL`: the union of all defined dataset splits.

Note: All splits, including compositions inherit from `datasets.SplitBase`

Expand All @@ -407,10 +408,11 @@ class Split:
TRAIN = NamedSplit("train")
TEST = NamedSplit("test")
VALIDATION = NamedSplit("validation")
ALL = NamedSplitAll()

def __new__(cls, name):
"""Create a custom split with datasets.Split('custom_name')."""
return NamedSplit(name)
return NamedSplitAll() if name == "all" else NamedSplit(name)


# Similar to SplitInfo, but contain an additional slice info
Expand Down
7 changes: 6 additions & 1 deletion tests/test_arrow_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datasets.arrow_dataset import Dataset
from datasets.arrow_reader import ArrowReader, BaseReader, ReadInstruction
from datasets.info import DatasetInfo
from datasets.splits import NamedSplit, Split, SplitDict, SplitInfo
from datasets.splits import NamedSplit, NamedSplitAll, Split, SplitDict, SplitInfo

from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases

Expand Down Expand Up @@ -71,6 +71,11 @@ def test_read(self):
self.assertEqual(str(test_dset.split), "test[:33%]")
del train_dset, test_dset

instructions = "all"
dset = Dataset(**reader.read(name, instructions, split_infos))
self.assertEqual(dset.num_rows, train_info.num_examples + test_info.num_examples)
self.assertIsInstance(dset.split, NamedSplitAll)

def test_read_files(self):
train_info = SplitInfo(name="train", num_examples=100)
test_info = SplitInfo(name="test", num_examples=100)
Expand Down