diff --git a/docs/source/package_reference/main_classes.rst b/docs/source/package_reference/main_classes.rst index 84f1cc67d2c..951d55de79f 100644 --- a/docs/source/package_reference/main_classes.rst +++ b/docs/source/package_reference/main_classes.rst @@ -38,6 +38,8 @@ The base class :class:`datasets.Dataset` implements a Dataset backed by an Apach .. autofunction:: datasets.concatenate_datasets +.. autofunction:: datasets.interleave_datasets + .. autofunction:: datasets.set_caching_enabled .. autofunction:: datasets.is_caching_enabled diff --git a/docs/source/processing.rst b/docs/source/processing.rst index bd178b3161b..fd72dd8ca5a 100644 --- a/docs/source/processing.rst +++ b/docs/source/processing.rst @@ -582,6 +582,8 @@ When you have several :obj:`datasets.Dataset` objects that share the same column >>> assert bookcorpus.features.type == wiki.features.type >>> bert_dataset = concatenate_datasets([bookcorpus, wiki]) +If you want to interleave the datasets instead of concatenating them, you can use :func:`datasets.interleave_datasets`. + Saving a processed dataset on disk and reload it ------------------------------------------------ diff --git a/src/datasets/combine.py b/src/datasets/combine.py index f0aad71063a..7c6b7e84a35 100644 --- a/src/datasets/combine.py +++ b/src/datasets/combine.py @@ -1,28 +1,184 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional, TypeVar + +import numpy as np from .info import DatasetInfo +from .utils import logging + + +logger = logging.get_logger(__name__) if TYPE_CHECKING: + from .arrow_dataset import Dataset from .iterable_dataset import IterableDataset +DatasetType = TypeVar("DatasetType", "Dataset", "IterableDataset") + + def interleave_datasets( - datasets: List["IterableDataset"], probabilities: Optional[List[float]] = None, seed: Optional[int] = None + datasets: List[DatasetType], probabilities: Optional[List[float]] = None, seed: Optional[int] = None +) -> DatasetType: + """ + Interleave several datasets (sources) into a single dataset. + The new dataset is constructed by alternating between the sources to get the examples. + + You can use this function on a list of :class:`Dataset` objects, or on a list of :class:`IterableDataset` objects. + + If ``probabilities`` is ``None`` (default) the new dataset is constructed by cycling between each source to get the examples. + If ``probabilities`` is not ``None``, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities. + + The resulting dataset ends when one of the source datasets runs out of examples. + + Examples: + + For regular datasets (map-style): + + >>> from datasets import Dataset, interleave_datasets + >>> d1 = Dataset.from_dict({"a": [0, 1, 2]}) + >>> d2 = Dataset.from_dict({"a": [10, 11, 12]}) + >>> d3 = Dataset.from_dict({"a": [20, 21, 22]}) + >>> dataset = interleave_datasets([d1, d2, d3]) + >>> dataset["a"] + [0, 10, 20, 1, 11, 21, 2, 12, 22] + >>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42) + >>> dataset["a"] + [10, 0, 11, 1, 2, 20, 12] + + For datasets in streaming mode (iterable): + + >>> from datasets import load_dataset, interleave_datasets + >>> d1 = load_dataset("oscar", "unshuffled_deduplicated_en", split="train", streaming=True) + >>> d2 = load_dataset("oscar", "unshuffled_deduplicated_fr", split="train", streaming=True) + >>> dataset = interleave_datasets([d1, d2]) + >>> iterator = iter(dataset) + >>> next(iterator) + {'text': 'Mtendere Village was inspired by the vision... + >>> next(iterator) + {'text': "Média de débat d'idées, de culture... + + Args: + datasets (:obj:`List[Dataset]` or :obj:`List[IterableDataset]`): list of datasets to interleave + probabilities (:obj:`List[float]`, optional, default None): If specified, the new dataset is constructued by sampling + examples from one source at a time according to these probabilities. + seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example. + **kwargs: For map-style datasets: + Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets. + + Output: + :class:`datasets.Dataset` if the input is a list of :class:`datasets.Dataset` + or :class:`datasets.IterableDataset` if the input is a list of :class:`datasets.IterableDataset` + """ + from .arrow_dataset import Dataset + from .iterable_dataset import IterableDataset + + if not datasets: + raise ValueError("Unable to interleave an empty list of datasets.") + iterable = isinstance(datasets[0], IterableDataset) + map_style = isinstance(datasets[0], Dataset) + if not (iterable ^ map_style): + raise ValueError( + f"Expected a list Dataset objects or a list of IterableDataset objects, but first element is a {type(datasets[0])}" + ) + for dataset in datasets[1:]: + if (map_style and not isinstance(dataset, Dataset)) or (iterable and not isinstance(dataset, IterableDataset)): + raise ValueError( + f"Unable to interleave a {type(datasets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects." + ) + if map_style: + return _interleave_map_style_datasets(datasets, probabilities, seed) + else: + return _interleave_iterable_datasets(datasets, probabilities, seed) + + +def _interleave_map_style_datasets( + datasets: List["Dataset"], + probabilities: Optional[List[float]] = None, + seed: Optional[int] = None, + info: Optional[Any] = None, + split: Optional[Any] = None, + **kwargs, +) -> "Dataset": + """ + Interleave several map-style datasets (sources) into a single map-style dataset. + The new dataset is constructed by alternating between the sources to get the examples. + If `probabilities = None` (default) the new dataset is constructed by cycling between each source to get the examples. + If `probabilities` is not `None, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities. + + Args: + datasets (:obj:`List[Dataset]`): list of datasets to interleave + probabilities (:obj:`List[float]`, optional, default None): If specified, the new dataset is constructued by sampling + examples from one source at a time according to these probabilities. + seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example. + **kwargs: Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets. + + Output: + :class:`datasets.Dataset` + """ + from .arrow_dataset import concatenate_datasets + + if not all([dset.features.type == datasets[0].features.type for dset in datasets]): + raise ValueError("Features must match for all datasets") + + # Find common format or reset format + format = datasets[0].format + if any(dset.format != format for dset in datasets): + format = {} + logger.info("Some of the datasets have disparate format. Resetting the format of the interleaved dataset.") + + # To interleave the datasets, we concatenate them and then we re-order the indices + concatenated_datasets = concatenate_datasets(datasets, info=info, split=split) + + # Let's now build the indices to pass to .select() + lengths = [len(dset) for dset in datasets] + offsets = np.cumsum([0] + lengths[:-1]) + if probabilities is None: + # Example: If lengths of the datasets are [3, 4, 5] + # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9] + # Note that we only have 3 examples per dataset since the first dataset ran out of examples + indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist() + else: + + def iter_random_indices(): + """Get an infinite iterator that randomly samples the index of the source to pick examples from.""" + rng = np.random.default_rng(seed) + while True: + yield from (int(i) for i in rng.choice(len(datasets), size=1000, p=probabilities)) + + current_index = [0] * len(datasets) + indices = [] + for source_idx in iter_random_indices(): + # we ran out of examples, let's stop + if current_index[source_idx] >= lengths[source_idx]: + break + # let's add the example at the current index of the `source_idx`-th dataset + indices.append(current_index[source_idx] + offsets[source_idx]) + current_index[source_idx] += 1 + return concatenated_datasets.select(indices, **kwargs) + + +def _interleave_iterable_datasets( + datasets: List["IterableDataset"], + probabilities: Optional[List[float]] = None, + seed: Optional[int] = None, + info: Optional[Any] = None, + split: Optional[Any] = None, ) -> "IterableDataset": """ Interleave several iterable datasets (sources) into a single iterable dataset. The new iterable dataset alternates between the sources to yield examples. - If `probabilities = None` (default) the iterable dataset will cycles through the sources - in order for each next example in the iteration. - If `probabilities` is not `None, the iterable dataset will sample a random source according - to the provided probabilities for each next examples in the iteration. + If `probabilities = None` (default) the iterable dataset will cycles through the sources in order for each next example in the iteration. + If `probabilities` is not `None, the iterable dataset will sample a random source according to the provided probabilities for each next examples in the iteration. Args: - datasets (:obj:`List[IterableDataset]`): list of datasets to merge - probabilities (:obj:`List[float]`, optional, default None): If specified, the new iterable datasets will sample + datasets (:obj:`List[IterableDataset]`): list of datasets to interleave + probabilities (:obj:`List[float]`, optional, default None): If specified, the new iterable dataset samples examples from one source at a time according to these probabilities. - seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example to yield. + seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example. + + Output: + :class:`datasets.IterableDataset` """ from .iterable_dataset import ( CyclingMultiSourcesExamplesIterable, @@ -42,7 +198,8 @@ def interleave_datasets( else: ex_iterable = RandomlyCyclingMultiSourcesExamplesIterable(ex_iterables, seed=seed, probabilities=probabilities) # Set new info - we reset the features - info = DatasetInfo.from_merge([d.info for d in datasets]) - info.features = None + if info is None: + info = DatasetInfo.from_merge([d.info for d in datasets]) + info.features = None # Return new daset - return iterable_dataset(ex_iterable=ex_iterable, info=info) + return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 5d4cdb53860..54e7b85177c 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1,4 +1,5 @@ import copy +import itertools import json import os import pickle @@ -13,7 +14,7 @@ from absl.testing import parameterized import datasets.arrow_dataset -from datasets import concatenate_datasets, load_from_disk, temp_seed +from datasets import concatenate_datasets, interleave_datasets, load_from_disk, temp_seed from datasets.arrow_dataset import Dataset, transmit_format, update_metadata_with_features from datasets.dataset_dict import DatasetDict from datasets.features import Array2D, Array3D, ClassLabel, Features, Sequence, Value @@ -2047,6 +2048,36 @@ def test_concatenate_datasets_duplicate_columns(dataset): assert "duplicated" in str(excinfo.value) +def test_interleave_datasets(): + d1 = Dataset.from_dict({"a": [0, 1, 2]}) + d2 = Dataset.from_dict({"a": [10, 11, 12, 13]}) + d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0]) + dataset = interleave_datasets([d1, d2, d3]) + expected_length = 3 * min(len(d1), len(d2), len(d3)) + expected_values = [x["a"] for x in itertools.chain(*zip(d1, d2, d3))] + assert isinstance(dataset, Dataset) + assert len(dataset) == expected_length + assert dataset["a"] == expected_values + assert dataset._fingerprint == interleave_datasets([d1, d2, d3])._fingerprint + + +def test_interleave_datasets_probabilities(): + seed = 42 + probabilities = [0.3, 0.5, 0.2] + d1 = Dataset.from_dict({"a": [0, 1, 2]}) + d2 = Dataset.from_dict({"a": [10, 11, 12, 13]}) + d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0]) + dataset = interleave_datasets([d1, d2, d3], probabilities=probabilities, seed=seed) + expected_length = 7 # hardcoded + expected_values = [10, 11, 20, 12, 0, 21, 13] # hardcoded + assert isinstance(dataset, Dataset) + assert len(dataset) == expected_length + assert dataset["a"] == expected_values + assert ( + dataset._fingerprint == interleave_datasets([d1, d2, d3], probabilities=probabilities, seed=seed)._fingerprint + ) + + @pytest.mark.parametrize( "column, expected_dtype", [(["a", "b", "c", "d"], "string"), ([1, 2, 3, 4], "int64"), ([1.0, 2.0, 3.0, 4.0], "float64")],