Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/package_reference/builder_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Two main classes are mostly used during the dataset building process.

.. autoclass:: datasets.NamedSplit

.. autoclass:: datasets.ReadInstruction

.. autoclass:: datasets.utils::DownloadConfig

.. autoclass:: datasets.utils::Version
42 changes: 36 additions & 6 deletions docs/source/splits.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ Slicing API
Slicing instructions are specified in :obj:`datasets.load_dataset` or :obj:`datasets.DatasetBuilder.as_dataset`.

Instructions can be provided as either strings or :obj:`ReadInstruction`. Strings
are more compact and readable for simple cases, while :obj:`ReadInstruction` provide
more options and might be easier to use with variable slicing parameters.
are more compact and readable for simple cases, while :obj:`ReadInstruction`
might be easier to use with variable slicing parameters.

Examples
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -35,10 +35,10 @@ Examples using the string API:
# From record 10 (included) to record 20 (excluded) of `train` split.
train_10_20_ds = datasets.load_dataset('bookcorpus', split='train[10:20]')

# The first 10% of train split.
# The first 10% of `train` split.
train_10pct_ds = datasets.load_dataset('bookcorpus', split='train[:10%]')

# The first 10% of train + the last 80% of train.
# The first 10% of `train` + the last 80% of `train`.
train_10_80pct_ds = datasets.load_dataset('bookcorpus', split='train[:10%]+train[-80%:]')

# 10-fold cross-validation (see also next section on rounding behavior):
Expand Down Expand Up @@ -77,11 +77,11 @@ Examples using the ``ReadInstruction`` API (equivalent as above):
train_10_20_ds = datasets.load_dataset('bookcorpus', split=datasets.ReadInstruction(
'train', from_=10, to=20, unit='abs'))

# The first 10% of train split.
# The first 10% of `train` split.
train_10_20_ds = datasets.load_dataset('bookcorpus', split=datasets.ReadInstruction(
'train', to=10, unit='%'))

# The first 10% of train + the last 80% of train.
# The first 10% of `train` + the last 80% of `train`.
ri = (datasets.ReadInstruction('train', to=10, unit='%') +
datasets.ReadInstruction('train', from_=-80, unit='%'))
train_10_80pct_ds = datasets.load_dataset('bookcorpus', split=ri)
Expand All @@ -100,3 +100,33 @@ Examples using the ``ReadInstruction`` API (equivalent as above):
(datasets.ReadInstruction('train', to=k, unit='%') +
datasets.ReadInstruction('train', from_=k+10, unit='%'))
for k in range(0, 100, 10)])

Percent slicing and rounding
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

If a slice of a split is requested using the percent (``%``) unit, and the
requested slice boundaries do not divide evenly by 100, then the default
behaviour is to round boundaries to the nearest integer (``closest``). This means
that some slices may contain more examples than others. For example:

.. code-block::

# Assuming `train` split contains 999 records.
# 989 records, from 0 (included) to 989 (excluded).
train_99_ds = datasets.load_dataset('bookcorpus', split='train[:99%]')
# 19 records, from 490 (included) to 509 (excluded).
train_49_51_ds = datasets.load_dataset('bookcorpus', split='train[49%:51%]')

Alternatively, the ``pct1_dropremainder`` rounding can be used, so specified
percentage boundaries are treated as multiples of 1%. This option should be used
when consistency is needed (eg: ``len(5%) == 5 * len(1%)``). This means the last
examples may be truncated if ``info.splits[split_name].num_examples % 100 != 0``.

.. code-block::

# Records 0 (included) to 891 (excluded).
train_99pct1_ds = datasets.load_dataset('bookcorpus', split=datasets.ReadInstruction(
'train', to=99, unit='%', rounding='pct1_dropremainder'))
# Or equivalently:
train_99pct1_ds = datasets.load_dataset('bookcorpus', split='train[:99%](pct1_dropremainder)')

89 changes: 54 additions & 35 deletions src/datasets/arrow_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _read_files(self, files, in_memory=False) -> Table:
pa_tables = []
files = copy.deepcopy(files)
for f in files:
f.update(filename=os.path.join(self._path, f["filename"]))
f["filename"] = os.path.join(self._path, f["filename"])
for f_dict in files:
pa_table: Table = self._get_table_from_filename(f_dict, in_memory=in_memory)
pa_tables.append(pa_table)
Expand Down Expand Up @@ -380,10 +380,14 @@ class _RelativeInstruction:
def __post_init__(self):
assert self.unit is None or self.unit in ["%", "abs"]
assert self.rounding is None or self.rounding in ["closest", "pct1_dropremainder"]
if self.unit != "%" and self.rounding is not None:
raise AssertionError("It is forbidden to specify rounding if not using percent slicing.")
if self.unit == "%" and self.from_ is not None and abs(self.from_) > 100:
raise AssertionError("Percent slice boundaries must be > -100 and < 100.")
if self.unit == "%" and self.to is not None and abs(self.to) > 100:
raise AssertionError("Percent slice boundaries must be > -100 and < 100.")
# Update via __dict__ due to instance being "frozen"
self.__dict__["rounding"] = "closest" if self.rounding is None and self.unit == "%" else self.rounding


def _str_to_read_instruction(spec):
Expand All @@ -394,7 +398,7 @@ def _str_to_read_instruction(spec):
unit = "%" if res.group("from_pct") or res.group("to_pct") else "abs"
return ReadInstruction(
split_name=res.group("split"),
rounding=res.group("rounding") if res.group("rounding") else "closest",
rounding=res.group("rounding"),
from_=int(res.group("from")) if res.group("from") else None,
to=int(res.group("to")) if res.group("to") else None,
unit=unit,
Expand Down Expand Up @@ -455,32 +459,39 @@ class ReadInstruction:

Examples of usage:

```
# The following lines are equivalent:
ds = datasets.load_dataset('mnist', split='test[:33%]')
ds = datasets.load_dataset('mnist', split=ReadInstruction.from_spec('test[:33%]'))
ds = datasets.load_dataset('mnist', split=ReadInstruction('test', to=33, unit='%'))
ds = datasets.load_dataset('mnist', split=ReadInstruction(
'test', from_=0, to=33, unit='%'))

# The following lines are equivalent:
ds = datasets.load_dataset('mnist', split='test[:33%]+train[1:-1]')
ds = datasets.load_dataset('mnist', split=ReadInstruction.from_spec(
'test[:33%]+train[1:-1]'))
ds = datasets.load_dataset('mnist', split=(
ReadInstruction('test', to=33, unit='%') +
ReadInstruction('train', from_=1, to=-1, unit='abs')))

# 10-fold validation:
tests = datasets.load_dataset(
'mnist',
[ReadInstruction('train', from_=k, to=k+10, unit='%')
for k in range(0, 100, 10)])
trains = datasets.load_dataset(
'mnist',
[ReadInstruction('train', to=k, unit='%') + ReadInstruction('train', from_=k+10, unit='%')
for k in range(0, 100, 10)])
```
.. code:: python

# The following lines are equivalent:
ds = datasets.load_dataset('mnist', split='test[:33%]')
ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction.from_spec('test[:33%]'))
ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction('test', to=33, unit='%'))
ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction(
'test', from_=0, to=33, unit='%'))

# The following lines are equivalent:
ds = datasets.load_dataset('mnist', split='test[:33%]+train[1:-1]')
ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction.from_spec(
'test[:33%]+train[1:-1]'))
ds = datasets.load_dataset('mnist', split=(
datasets.ReadInstruction('test', to=33, unit='%') +
datasets.ReadInstruction('train', from_=1, to=-1, unit='abs')))

# The following lines are equivalent:
ds = datasets.load_dataset('mnist', split='test[:33%](pct1_dropremainder)')
ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction.from_spec(
'test[:33%](pct1_dropremainder)'))
ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction(
'test', from_=0, to=33, unit='%', rounding="pct1_dropremainder"))

# 10-fold validation:
tests = datasets.load_dataset(
'mnist',
[datasets.ReadInstruction('train', from_=k, to=k+10, unit='%')
for k in range(0, 100, 10)])
trains = datasets.load_dataset(
'mnist',
[datasets.ReadInstruction('train', to=k, unit='%') + datasets.ReadInstruction('train', from_=k+10, unit='%')
for k in range(0, 100, 10)])

"""

Expand All @@ -496,12 +507,12 @@ def _read_instruction_from_relative_instructions(cls, relative_instructions):
result._init(relative_instructions) # pylint: disable=protected-access
return result

def __init__(self, split_name, rounding="closest", from_=None, to=None, unit=None):
def __init__(self, split_name, rounding=None, from_=None, to=None, unit=None):
"""Initialize ReadInstruction.

Args:
split_name (str): name of the split to read. Eg: 'train'.
rounding (str): The rounding behaviour to use when percent slicing is
rounding (str, optional): The rounding behaviour to use when percent slicing is
used. Ignored when slicing with absolute indices.
Possible values:
- 'closest' (default): The specified percentages are rounded to the
Expand Down Expand Up @@ -530,14 +541,17 @@ def from_spec(cls, spec):
"""Creates a ReadInstruction instance out of a string spec.

Args:
spec (str): split(s) + optional slice(s) to read. A slice can be
specified, using absolute numbers (int) or percentages (int). E.g.
spec (str): split(s) + optional slice(s) to read + optional rounding
if percents are used as the slicing unit. A slice can be specified,
using absolute numbers (int) or percentages (int). E.g.
`test`: test split.
`test + validation`: test split + validation split.
`test[10:]`: test split, minus its first 10 records.
`test[:10%]`: first 10% records of test split.
`test[:20%](pct1_dropremainder)`: first 10% records, rounded with
the `pct1_dropremainder` rounding.
`test[:-5%]+train[40%:60%]`: first 95% of test + middle 20% of
train.
train.

Returns:
ReadInstruction instance.
Expand Down Expand Up @@ -574,10 +588,15 @@ def __add__(self, other):
if not isinstance(other, ReadInstruction):
msg = "ReadInstruction can only be added to another ReadInstruction obj."
raise AssertionError(msg)
self_ris = self._relative_instructions
other_ris = other._relative_instructions # pylint: disable=protected-access
if self._relative_instructions[0].rounding != other_ris[0].rounding:
if (
self_ris[0].unit != "abs"
and other_ris[0].unit != "abs"
and self._relative_instructions[0].rounding != other_ris[0].rounding
):
raise AssertionError("It is forbidden to sum ReadInstruction instances with different rounding values.")
return self._read_instruction_from_relative_instructions(self._relative_instructions + other_ris)
return self._read_instruction_from_relative_instructions(self_ris + other_ris)

def __str__(self):
return self.to_spec()
Expand Down