Skip to content
Merged
9 changes: 5 additions & 4 deletions src/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
del pyarrow
del version

from .arrow_dataset import Dataset, concatenate_datasets
from .arrow_dataset import Dataset
from .arrow_reader import ReadInstruction
from .builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
from .combine import interleave_datasets
from .combine import concatenate_datasets, interleave_datasets
from .dataset_dict import DatasetDict, IterableDatasetDict
from .download import *
from .features import *
Expand Down Expand Up @@ -73,15 +73,16 @@


# deprecated modules
from datasets import arrow_dataset as _arrow_dataset # isort:skip
from datasets import utils as _utils # isort:skip
from datasets.utils import download_manager as _deprecated_download_manager # isort:skip


_arrow_dataset.concatenate_datasets = concatenate_datasets
_utils.DownloadConfig = DownloadConfig
_utils.DownloadManager = DownloadManager
_utils.DownloadMode = DownloadMode
_deprecated_download_manager.DownloadConfig = DownloadConfig
_deprecated_download_manager.DownloadMode = DownloadMode
_deprecated_download_manager.DownloadManager = DownloadManager

del _utils, _deprecated_download_manager
del _arrow_dataset, _utils, _deprecated_download_manager
75 changes: 67 additions & 8 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,8 @@ def from_dict(
cls,
mapping: dict,
features: Optional[Features] = None,
info: Optional[Any] = None,
split: Optional[Any] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
) -> "Dataset":
"""
Convert :obj:`dict` to a :obj:`pyarrow.Table` to create a :class:`Dataset`.
Expand Down Expand Up @@ -2493,7 +2493,7 @@ def catch_non_existent_error(func, kwargs):
), "All shards have to be defined Datasets, none should still be missing."

logger.info(f"Concatenating {num_proc} shards")
result = concatenate_datasets(transformed_shards)
result = _concatenate_map_style_datasets(transformed_shards)
if new_fingerprint is not None:
result._fingerprint = new_fingerprint
return result
Expand Down Expand Up @@ -4719,14 +4719,15 @@ def process_label_ids(batch):
return self.map(process_label_ids, features=features, batched=True, desc="Aligning the labels")


def concatenate_datasets(
def _concatenate_map_style_datasets(
dsets: List[Dataset],
info: Optional[Any] = None,
split: Optional[Any] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
axis: int = 0,
):
"""
Converts a list of :class:`Dataset` with the same schema into a single :class:`Dataset`.
When you concatenate on axis 0, missing data are filled with None values.

Args:
dsets (:obj:`List[datasets.Dataset]`): List of Datasets to concatenate.
Expand All @@ -4741,7 +4742,7 @@ def concatenate_datasets(
Example:

```py
>>> ds3 = concatenate_datasets([ds1, ds2])
>>> ds3 = _concatenate_map_style_datasets([ds1, ds2])
```
"""
# Ignore datasets with no rows
Expand Down Expand Up @@ -4817,7 +4818,7 @@ def apply_offset_to_indices_table(table, offset):
if info is None:
info = DatasetInfo.from_merge([dset.info for dset in dsets])
fingerprint = update_fingerprint(
"".join(dset._fingerprint for dset in dsets), concatenate_datasets, {"info": info, "split": split}
"".join(dset._fingerprint for dset in dsets), _concatenate_map_style_datasets, {"info": info, "split": split}
)

# Make final concatenated dataset
Expand All @@ -4832,6 +4833,64 @@ def apply_offset_to_indices_table(table, offset):
return concatenated_dataset


def _interleave_map_style_datasets(
datasets: List["Dataset"],
probabilities: Optional[List[float]] = None,
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
**kwargs,
) -> "Dataset":
"""
Interleave several map-style datasets (sources) into a single map-style dataset.
The new dataset is constructed by alternating between the sources to get the examples.
If `probabilities = None` (default) the new dataset is constructed by cycling between each source to get the examples.
If `probabilities` is not `None, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities.

Args:
datasets (:obj:`List[Dataset]`): list of datasets to interleave
probabilities (:obj:`List[float]`, optional, default None): If specified, the new dataset is constructued by sampling
examples from one source at a time according to these probabilities.
seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example.
info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc.
split (:class:`NamedSplit`, optional): Name of the dataset split.
**kwargs (additional keyword arguments): Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets.

Output:
:class:`datasets.Dataset`
"""

# To interleave the datasets, we concatenate them and then we re-order the indices
concatenated_datasets = _concatenate_map_style_datasets(datasets, info=info, split=split)

# Let's now build the indices to pass to .select()
lengths = [len(dset) for dset in datasets]
offsets = np.cumsum([0] + lengths[:-1])
if probabilities is None:
# Example:: If lengths of the datasets are [3, 4, 5]
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9]
# Note that we only have 3 examples per dataset since the first dataset ran out of examples
indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist()
else:

def iter_random_indices():
"""Get an infinite iterator that randomly samples the index of the source to pick examples from."""
rng = np.random.default_rng(seed)
while True:
yield from (int(i) for i in rng.choice(len(datasets), size=1000, p=probabilities))

current_index = [0] * len(datasets)
indices = []
for source_idx in iter_random_indices():
# we ran out of examples, let's stop
if current_index[source_idx] >= lengths[source_idx]:
break
# let's add the example at the current index of the `source_idx`-th dataset
indices.append(current_index[source_idx] + offsets[source_idx])
current_index[source_idx] += 1
return concatenated_datasets.select(indices, **kwargs)


# This is outside Dataset.filter as it needs to be picklable for multiprocessing


Expand Down
161 changes: 47 additions & 114 deletions src/datasets/combine.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from typing import TYPE_CHECKING, Any, List, Optional, TypeVar

import numpy as np
from typing import List, Optional, TypeVar

from .arrow_dataset import Dataset, _concatenate_map_style_datasets, _interleave_map_style_datasets
from .info import DatasetInfo
from .iterable_dataset import IterableDataset, _concatenate_iterable_datasets, _interleave_iterable_datasets
from .splits import NamedSplit
from .utils import logging


logger = logging.get_logger(__name__)


if TYPE_CHECKING:
from .arrow_dataset import Dataset
from .iterable_dataset import IterableDataset


DatasetType = TypeVar("DatasetType", "Dataset", "IterableDataset")


def interleave_datasets(
datasets: List[DatasetType], probabilities: Optional[List[float]] = None, seed: Optional[int] = None
datasets: List[DatasetType],
probabilities: Optional[List[float]] = None,
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
) -> DatasetType:
"""
Interleave several datasets (sources) into a single dataset.
Expand Down Expand Up @@ -78,126 +78,59 @@ def interleave_datasets(
map_style = isinstance(datasets[0], Dataset)
if not (iterable ^ map_style):
raise ValueError(
f"Expected a list Dataset objects or a list of IterableDataset objects, but first element is a {type(datasets[0])}"
f"Expected a list of Dataset objects or a list of IterableDataset objects, but first element is a {type(datasets[0])}"
)
for dataset in datasets[1:]:
if (map_style and not isinstance(dataset, Dataset)) or (iterable and not isinstance(dataset, IterableDataset)):
raise ValueError(
f"Unable to interleave a {type(datasets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects."
)
if map_style:
return _interleave_map_style_datasets(datasets, probabilities, seed)
return _interleave_map_style_datasets(datasets, probabilities, seed, info=info, split=split)
else:
return _interleave_iterable_datasets(datasets, probabilities, seed)
return _interleave_iterable_datasets(datasets, probabilities, seed, info=info, split=split)


def _interleave_map_style_datasets(
datasets: List["Dataset"],
probabilities: Optional[List[float]] = None,
seed: Optional[int] = None,
info: Optional[Any] = None,
split: Optional[Any] = None,
**kwargs,
) -> "Dataset":
def concatenate_datasets(
dsets: List[Dataset],
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
axis: int = 0,
):
"""
Interleave several map-style datasets (sources) into a single map-style dataset.
The new dataset is constructed by alternating between the sources to get the examples.
If `probabilities = None` (default) the new dataset is constructed by cycling between each source to get the examples.
If `probabilities` is not `None, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities.
Converts a list of :class:`Dataset` with the same schema into a single :class:`Dataset`.

Args:
datasets (:obj:`List[Dataset]`): list of datasets to interleave
probabilities (:obj:`List[float]`, optional, default None): If specified, the new dataset is constructued by sampling
examples from one source at a time according to these probabilities.
seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example.
**kwargs (additional keyword arguments): Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets.
dsets (:obj:`List[datasets.Dataset]`): List of Datasets to concatenate.
info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc.
split (:class:`NamedSplit`, optional): Name of the dataset split.
axis (``{0, 1}``, default ``0``, meaning over rows):
Axis to concatenate over, where ``0`` means over rows (vertically) and ``1`` means over columns
(horizontally).

Output:
:class:`datasets.Dataset`
"""
from .arrow_dataset import concatenate_datasets

# To interleave the datasets, we concatenate them and then we re-order the indices
concatenated_datasets = concatenate_datasets(datasets, info=info, split=split)

# Let's now build the indices to pass to .select()
lengths = [len(dset) for dset in datasets]
offsets = np.cumsum([0] + lengths[:-1])
if probabilities is None:
# Example:: If lengths of the datasets are [3, 4, 5]
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9]
# Note that we only have 3 examples per dataset since the first dataset ran out of examples
indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist()
else:
*New in version 1.6.0*

def iter_random_indices():
"""Get an infinite iterator that randomly samples the index of the source to pick examples from."""
rng = np.random.default_rng(seed)
while True:
yield from (int(i) for i in rng.choice(len(datasets), size=1000, p=probabilities))

current_index = [0] * len(datasets)
indices = []
for source_idx in iter_random_indices():
# we ran out of examples, let's stop
if current_index[source_idx] >= lengths[source_idx]:
break
# let's add the example at the current index of the `source_idx`-th dataset
indices.append(current_index[source_idx] + offsets[source_idx])
current_index[source_idx] += 1
return concatenated_datasets.select(indices, **kwargs)


def _interleave_iterable_datasets(
datasets: List["IterableDataset"],
probabilities: Optional[List[float]] = None,
seed: Optional[int] = None,
info: Optional[Any] = None,
split: Optional[Any] = None,
) -> "IterableDataset":
"""
Interleave several iterable datasets (sources) into a single iterable dataset.
The new iterable dataset alternates between the sources to yield examples.
If `probabilities = None` (default) the iterable dataset will cycles through the sources in order for each next example in the iteration.
If `probabilities` is not `None, the iterable dataset will sample a random source according to the provided probabilities for each next examples in the iteration.

Args:
datasets (:obj:`List[IterableDataset]`): list of datasets to interleave
probabilities (:obj:`List[float]`, optional, default None): If specified, the new iterable dataset samples
examples from one source at a time according to these probabilities.
seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example.
Example:

Output:
:class:`datasets.IterableDataset`
```py
>>> ds3 = concatenate_datasets([ds1, ds2])
```
"""
from .iterable_dataset import (
CyclingMultiSourcesExamplesIterable,
RandomlyCyclingMultiSourcesExamplesIterable,
TypedExamplesIterable,
iterable_dataset,
)

ex_iterables = [
TypedExamplesIterable(d._ex_iterable, d.features)
if not isinstance(d._ex_iterable, TypedExamplesIterable) and d.features is not None
else d._ex_iterable
for d in datasets
]
# Use cycling or random cycling or sources
if probabilities is None:
ex_iterable = CyclingMultiSourcesExamplesIterable(ex_iterables)
else:
generator = np.random.default_rng(seed)
ex_iterable = RandomlyCyclingMultiSourcesExamplesIterable(
ex_iterables, generator=generator, probabilities=probabilities

if not dsets:
raise ValueError("Unable to concatenate an empty list of datasets.")
iterable = isinstance(dsets[0], IterableDataset)
map_style = isinstance(dsets[0], Dataset)
if not (iterable ^ map_style):
raise ValueError(
f"Expected a list of Dataset objects or a list of IterableDataset objects, but first element is a {type(dsets[0])}"
)
# Set new info - we reset the features
if info is None:
info = DatasetInfo.from_merge([d.info for d in datasets])
info.features = None
# Get all the auth tokens per repository - in case the datasets come from different private repositories
token_per_repo_id = {
repo_id: token for dataset in datasets for repo_id, token in dataset._token_per_repo_id.items()
}
# Return new daset
return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)
for dataset in dsets[1:]:
if (map_style and not isinstance(dataset, Dataset)) or (iterable and not isinstance(dataset, IterableDataset)):
raise ValueError(
f"Unable to concatenate a {type(dsets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects."
)
if map_style:
return _concatenate_map_style_datasets(dsets, info=info, split=split, axis=axis)
else:
return _concatenate_iterable_datasets(dsets, info=info, split=split, axis=axis)
Loading