diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py index d9ada4cc3fd..6ac94b2ec64 100644 --- a/src/datasets/__init__.py +++ b/src/datasets/__init__.py @@ -34,10 +34,10 @@ del pyarrow del version -from .arrow_dataset import Dataset, concatenate_datasets +from .arrow_dataset import Dataset from .arrow_reader import ReadInstruction from .builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder -from .combine import interleave_datasets +from .combine import concatenate_datasets, interleave_datasets from .dataset_dict import DatasetDict, IterableDatasetDict from .download import * from .features import * @@ -73,10 +73,11 @@ # deprecated modules +from datasets import arrow_dataset as _arrow_dataset # isort:skip from datasets import utils as _utils # isort:skip from datasets.utils import download_manager as _deprecated_download_manager # isort:skip - +_arrow_dataset.concatenate_datasets = concatenate_datasets _utils.DownloadConfig = DownloadConfig _utils.DownloadManager = DownloadManager _utils.DownloadMode = DownloadMode @@ -84,4 +85,4 @@ _deprecated_download_manager.DownloadMode = DownloadMode _deprecated_download_manager.DownloadManager = DownloadManager -del _utils, _deprecated_download_manager +del _arrow_dataset, _utils, _deprecated_download_manager diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index faf165a94ee..2f035ac57b7 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -827,8 +827,8 @@ def from_dict( cls, mapping: dict, features: Optional[Features] = None, - info: Optional[Any] = None, - split: Optional[Any] = None, + info: Optional[DatasetInfo] = None, + split: Optional[NamedSplit] = None, ) -> "Dataset": """ Convert :obj:`dict` to a :obj:`pyarrow.Table` to create a :class:`Dataset`. @@ -2493,7 +2493,7 @@ def catch_non_existent_error(func, kwargs): ), "All shards have to be defined Datasets, none should still be missing." logger.info(f"Concatenating {num_proc} shards") - result = concatenate_datasets(transformed_shards) + result = _concatenate_map_style_datasets(transformed_shards) if new_fingerprint is not None: result._fingerprint = new_fingerprint return result @@ -4725,14 +4725,15 @@ def process_label_ids(batch): return self.map(process_label_ids, features=features, batched=True, desc="Aligning the labels") -def concatenate_datasets( +def _concatenate_map_style_datasets( dsets: List[Dataset], - info: Optional[Any] = None, - split: Optional[Any] = None, + info: Optional[DatasetInfo] = None, + split: Optional[NamedSplit] = None, axis: int = 0, ): """ Converts a list of :class:`Dataset` with the same schema into a single :class:`Dataset`. + When you concatenate on axis 0, missing data are filled with None values. Args: dsets (:obj:`List[datasets.Dataset]`): List of Datasets to concatenate. @@ -4747,7 +4748,7 @@ def concatenate_datasets( Example: ```py - >>> ds3 = concatenate_datasets([ds1, ds2]) + >>> ds3 = _concatenate_map_style_datasets([ds1, ds2]) ``` """ # Ignore datasets with no rows @@ -4823,7 +4824,7 @@ def apply_offset_to_indices_table(table, offset): if info is None: info = DatasetInfo.from_merge([dset.info for dset in dsets]) fingerprint = update_fingerprint( - "".join(dset._fingerprint for dset in dsets), concatenate_datasets, {"info": info, "split": split} + "".join(dset._fingerprint for dset in dsets), _concatenate_map_style_datasets, {"info": info, "split": split} ) # Make final concatenated dataset @@ -4838,6 +4839,64 @@ def apply_offset_to_indices_table(table, offset): return concatenated_dataset +def _interleave_map_style_datasets( + datasets: List["Dataset"], + probabilities: Optional[List[float]] = None, + seed: Optional[int] = None, + info: Optional[DatasetInfo] = None, + split: Optional[NamedSplit] = None, + **kwargs, +) -> "Dataset": + """ + Interleave several map-style datasets (sources) into a single map-style dataset. + The new dataset is constructed by alternating between the sources to get the examples. + If `probabilities = None` (default) the new dataset is constructed by cycling between each source to get the examples. + If `probabilities` is not `None, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities. + + Args: + datasets (:obj:`List[Dataset]`): list of datasets to interleave + probabilities (:obj:`List[float]`, optional, default None): If specified, the new dataset is constructued by sampling + examples from one source at a time according to these probabilities. + seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example. + info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc. + split (:class:`NamedSplit`, optional): Name of the dataset split. + **kwargs (additional keyword arguments): Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets. + + Output: + :class:`datasets.Dataset` + """ + + # To interleave the datasets, we concatenate them and then we re-order the indices + concatenated_datasets = _concatenate_map_style_datasets(datasets, info=info, split=split) + + # Let's now build the indices to pass to .select() + lengths = [len(dset) for dset in datasets] + offsets = np.cumsum([0] + lengths[:-1]) + if probabilities is None: + # Example:: If lengths of the datasets are [3, 4, 5] + # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9] + # Note that we only have 3 examples per dataset since the first dataset ran out of examples + indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist() + else: + + def iter_random_indices(): + """Get an infinite iterator that randomly samples the index of the source to pick examples from.""" + rng = np.random.default_rng(seed) + while True: + yield from (int(i) for i in rng.choice(len(datasets), size=1000, p=probabilities)) + + current_index = [0] * len(datasets) + indices = [] + for source_idx in iter_random_indices(): + # we ran out of examples, let's stop + if current_index[source_idx] >= lengths[source_idx]: + break + # let's add the example at the current index of the `source_idx`-th dataset + indices.append(current_index[source_idx] + offsets[source_idx]) + current_index[source_idx] += 1 + return concatenated_datasets.select(indices, **kwargs) + + # This is outside Dataset.filter as it needs to be picklable for multiprocessing diff --git a/src/datasets/combine.py b/src/datasets/combine.py index 49c8d4ff8a3..4f997877840 100644 --- a/src/datasets/combine.py +++ b/src/datasets/combine.py @@ -1,24 +1,24 @@ -from typing import TYPE_CHECKING, Any, List, Optional, TypeVar - -import numpy as np +from typing import List, Optional, TypeVar +from .arrow_dataset import Dataset, _concatenate_map_style_datasets, _interleave_map_style_datasets from .info import DatasetInfo +from .iterable_dataset import IterableDataset, _concatenate_iterable_datasets, _interleave_iterable_datasets +from .splits import NamedSplit from .utils import logging logger = logging.get_logger(__name__) -if TYPE_CHECKING: - from .arrow_dataset import Dataset - from .iterable_dataset import IterableDataset - - DatasetType = TypeVar("DatasetType", "Dataset", "IterableDataset") def interleave_datasets( - datasets: List[DatasetType], probabilities: Optional[List[float]] = None, seed: Optional[int] = None + datasets: List[DatasetType], + probabilities: Optional[List[float]] = None, + seed: Optional[int] = None, + info: Optional[DatasetInfo] = None, + split: Optional[NamedSplit] = None, ) -> DatasetType: """ Interleave several datasets (sources) into a single dataset. @@ -78,7 +78,7 @@ def interleave_datasets( map_style = isinstance(datasets[0], Dataset) if not (iterable ^ map_style): raise ValueError( - f"Expected a list Dataset objects or a list of IterableDataset objects, but first element is a {type(datasets[0])}" + f"Expected a list of Dataset objects or a list of IterableDataset objects, but first element is a {type(datasets[0])}" ) for dataset in datasets[1:]: if (map_style and not isinstance(dataset, Dataset)) or (iterable and not isinstance(dataset, IterableDataset)): @@ -86,118 +86,51 @@ def interleave_datasets( f"Unable to interleave a {type(datasets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects." ) if map_style: - return _interleave_map_style_datasets(datasets, probabilities, seed) + return _interleave_map_style_datasets(datasets, probabilities, seed, info=info, split=split) else: - return _interleave_iterable_datasets(datasets, probabilities, seed) + return _interleave_iterable_datasets(datasets, probabilities, seed, info=info, split=split) -def _interleave_map_style_datasets( - datasets: List["Dataset"], - probabilities: Optional[List[float]] = None, - seed: Optional[int] = None, - info: Optional[Any] = None, - split: Optional[Any] = None, - **kwargs, -) -> "Dataset": +def concatenate_datasets( + dsets: List[Dataset], + info: Optional[DatasetInfo] = None, + split: Optional[NamedSplit] = None, + axis: int = 0, +): """ - Interleave several map-style datasets (sources) into a single map-style dataset. - The new dataset is constructed by alternating between the sources to get the examples. - If `probabilities = None` (default) the new dataset is constructed by cycling between each source to get the examples. - If `probabilities` is not `None, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities. + Converts a list of :class:`Dataset` with the same schema into a single :class:`Dataset`. Args: - datasets (:obj:`List[Dataset]`): list of datasets to interleave - probabilities (:obj:`List[float]`, optional, default None): If specified, the new dataset is constructued by sampling - examples from one source at a time according to these probabilities. - seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example. - **kwargs (additional keyword arguments): Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets. + dsets (:obj:`List[datasets.Dataset]`): List of Datasets to concatenate. + info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc. + split (:class:`NamedSplit`, optional): Name of the dataset split. + axis (``{0, 1}``, default ``0``, meaning over rows): + Axis to concatenate over, where ``0`` means over rows (vertically) and ``1`` means over columns + (horizontally). - Output: - :class:`datasets.Dataset` - """ - from .arrow_dataset import concatenate_datasets - - # To interleave the datasets, we concatenate them and then we re-order the indices - concatenated_datasets = concatenate_datasets(datasets, info=info, split=split) - - # Let's now build the indices to pass to .select() - lengths = [len(dset) for dset in datasets] - offsets = np.cumsum([0] + lengths[:-1]) - if probabilities is None: - # Example:: If lengths of the datasets are [3, 4, 5] - # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9] - # Note that we only have 3 examples per dataset since the first dataset ran out of examples - indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist() - else: + *New in version 1.6.0* - def iter_random_indices(): - """Get an infinite iterator that randomly samples the index of the source to pick examples from.""" - rng = np.random.default_rng(seed) - while True: - yield from (int(i) for i in rng.choice(len(datasets), size=1000, p=probabilities)) - - current_index = [0] * len(datasets) - indices = [] - for source_idx in iter_random_indices(): - # we ran out of examples, let's stop - if current_index[source_idx] >= lengths[source_idx]: - break - # let's add the example at the current index of the `source_idx`-th dataset - indices.append(current_index[source_idx] + offsets[source_idx]) - current_index[source_idx] += 1 - return concatenated_datasets.select(indices, **kwargs) - - -def _interleave_iterable_datasets( - datasets: List["IterableDataset"], - probabilities: Optional[List[float]] = None, - seed: Optional[int] = None, - info: Optional[Any] = None, - split: Optional[Any] = None, -) -> "IterableDataset": - """ - Interleave several iterable datasets (sources) into a single iterable dataset. - The new iterable dataset alternates between the sources to yield examples. - If `probabilities = None` (default) the iterable dataset will cycles through the sources in order for each next example in the iteration. - If `probabilities` is not `None, the iterable dataset will sample a random source according to the provided probabilities for each next examples in the iteration. - - Args: - datasets (:obj:`List[IterableDataset]`): list of datasets to interleave - probabilities (:obj:`List[float]`, optional, default None): If specified, the new iterable dataset samples - examples from one source at a time according to these probabilities. - seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example. + Example: - Output: - :class:`datasets.IterableDataset` + ```py + >>> ds3 = concatenate_datasets([ds1, ds2]) + ``` """ - from .iterable_dataset import ( - CyclingMultiSourcesExamplesIterable, - RandomlyCyclingMultiSourcesExamplesIterable, - TypedExamplesIterable, - iterable_dataset, - ) - - ex_iterables = [ - TypedExamplesIterable(d._ex_iterable, d.features) - if not isinstance(d._ex_iterable, TypedExamplesIterable) and d.features is not None - else d._ex_iterable - for d in datasets - ] - # Use cycling or random cycling or sources - if probabilities is None: - ex_iterable = CyclingMultiSourcesExamplesIterable(ex_iterables) - else: - generator = np.random.default_rng(seed) - ex_iterable = RandomlyCyclingMultiSourcesExamplesIterable( - ex_iterables, generator=generator, probabilities=probabilities + + if not dsets: + raise ValueError("Unable to concatenate an empty list of datasets.") + iterable = isinstance(dsets[0], IterableDataset) + map_style = isinstance(dsets[0], Dataset) + if not (iterable ^ map_style): + raise ValueError( + f"Expected a list of Dataset objects or a list of IterableDataset objects, but first element is a {type(dsets[0])}" ) - # Set new info - we reset the features - if info is None: - info = DatasetInfo.from_merge([d.info for d in datasets]) - info.features = None - # Get all the auth tokens per repository - in case the datasets come from different private repositories - token_per_repo_id = { - repo_id: token for dataset in datasets for repo_id, token in dataset._token_per_repo_id.items() - } - # Return new daset - return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id) + for dataset in dsets[1:]: + if (map_style and not isinstance(dataset, Dataset)) or (iterable and not isinstance(dataset, IterableDataset)): + raise ValueError( + f"Unable to concatenate a {type(dsets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects." + ) + if map_style: + return _concatenate_map_style_datasets(dsets, info=info, split=split, axis=axis) + else: + return _concatenate_iterable_datasets(dsets, info=info, split=split, axis=axis) diff --git a/src/datasets/features/audio.py b/src/datasets/features/audio.py index 77378a2235b..f3d7da75ef3 100644 --- a/src/datasets/features/audio.py +++ b/src/datasets/features/audio.py @@ -101,7 +101,9 @@ def encode_example(self, value: Union[str, dict]) -> dict: f"An audio sample should have one of 'path' or 'bytes' but they are missing or None in {value}." ) - def decode_example(self, value: dict, token_per_repo_id=None) -> dict: + def decode_example( + self, value: dict, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None + ) -> dict: """Decode example audio file into audio data. Args: @@ -211,7 +213,9 @@ def path_to_bytes(path): storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null()) return array_cast(storage, self.pa_type) - def _decode_non_mp3_path_like(self, path, format=None, token_per_repo_id=None): + def _decode_non_mp3_path_like( + self, path, format=None, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None + ): try: import librosa except ImportError as err: diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 7de00a9da75..71a63a0f9d2 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1222,7 +1222,7 @@ def encode_nested_example(schema, obj, level=0): return obj -def decode_nested_example(schema, obj, token_per_repo_id=None): +def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None): """Decode a nested example. This is used since some features (in particular Audio and Image) have some logic during decoding. @@ -1613,7 +1613,7 @@ def encode_batch(self, batch): encoded_batch[key] = [encode_nested_example(self[key], obj) for obj in column] return encoded_batch - def decode_example(self, example: dict, token_per_repo_id=None): + def decode_example(self, example: dict, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None): """Decode example with custom feature decoding. Args: diff --git a/src/datasets/formatting/dataset_wrappers/torch_iterable_dataset.py b/src/datasets/formatting/dataset_wrappers/torch_iterable_dataset.py index cb874342a3c..33ce181aaad 100644 --- a/src/datasets/formatting/dataset_wrappers/torch_iterable_dataset.py +++ b/src/datasets/formatting/dataset_wrappers/torch_iterable_dataset.py @@ -1,7 +1,7 @@ import fsspec import torch -from ...iterable_dataset import IterableDataset +from ...iterable_dataset import IterableDataset, _apply_feature_types from ...utils.logging import get_logger @@ -46,7 +46,12 @@ def __iter__(self): ) for shard_idx in shards_indices: for key, example in self._iter_shard(shard_idx): - yield self._apply_feature_types(example) + if self.features: + yield _apply_feature_types( + example, self.features, token_per_repo_id=self._token_per_repo_id + ) + else: + yield example logger.debug( f"dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{self.n_shards} shards." ) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 4c51c83c90c..ea6feb2c894 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1,4 +1,6 @@ import copy +import itertools +from collections import Counter from copy import deepcopy from dataclasses import dataclass from itertools import cycle, islice @@ -8,7 +10,7 @@ import pyarrow as pa from .arrow_dataset import DatasetInfoMixin -from .features import Features +from .features import Features, Value from .features.features import FeatureType from .formatting import PythonFormatter from .info import DatasetInfo @@ -170,6 +172,112 @@ def shard_data_sources(self, shard_idx: int) -> "CyclingMultiSourcesExamplesIter raise NotImplementedError("Sharding a CyclingMultiSourcesExamplesIterable is not implemented") +class VerticallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable): + """ + VerticallyConcatenatedMultiSourcesExamplesIterable simply chains the input iterables. + It doesn't require the examples iterables to always yield the same columns. + Instead, this is handled by the `IterableDataset` class or `TypedExamplesIterable`. + + For information, `IterableDataset` merges the features of all the datasets to concatenate into one. + We use `IterableDataset._resolve_features` to obtain the features of all the datasets to concatenate. + + Then for each example, `IterableDataset` and `TypedExamplesIterable` automatically fill missing columns with None. + This is done with `_apply_feature_types`. + """ + + def __init__(self, ex_iterables: List[_BaseExamplesIterable]): + self.ex_iterables = ex_iterables + + def __iter__(self): + for ex_iterable in self.ex_iterables: + yield from ex_iterable + + def shuffle_data_sources( + self, generator: np.random.Generator + ) -> "VerticallyConcatenatedMultiSourcesExamplesIterable": + """Shuffle the list of examples iterable, as well as each underlying examples iterable.""" + rng = deepcopy(generator) + ex_iterables = list(self.ex_iterables) + rng.shuffle(ex_iterables) + ex_iterables = [ex_iterable.shuffle_data_sources(generator) for ex_iterable in ex_iterables] + return VerticallyConcatenatedMultiSourcesExamplesIterable(ex_iterables) + + @property + def n_shards(self) -> int: + return sum(ex_iterable.n_shards for ex_iterable in self.ex_iterables) + + def shard_data_sources(self, shard_idx: int) -> "VerticallyConcatenatedMultiSourcesExamplesIterable": + """Either keep only the requested shard, or propagate the request to the underlying iterable.""" + raise NotImplementedError("Sharding a VerticallyConcatenatedMultiSourcesExamplesIterable is not implemented") + + +def _check_column_names(column_names: List[str]): + """Check the column names to make sure they don't contain duplicates.""" + counter = Counter(column_names) + if not all(count == 1 for count in counter.values()): + duplicated_columns = [col for col in counter if counter[col] > 1] + raise ValueError( + f"The examples iterables can't have duplicated columns but columns {duplicated_columns} are duplicated." + ) + + +class HorizontallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable): + """ + HorizontallyConcatenatedMultiSourcesExamplesIterable merges examples together for the input list of iterables. + It also checks that there are no duplicate columns (otherwise we don't know which one to keep). + This check is done once when yielding the first example. + + However it doesn't fill missing columns with None. + Instead, this is handled by the `IterableDataset` class or `TypedExamplesIterable`. + + For information, `IterableDataset` merges the features of all the datasets to concatenate into one. + We use `IterableDataset._resolve_features` to obtain the features of all the datasets to concatenate. + + Then for each example, `IterableDataset` and `TypedExamplesIterable` automatically fill missing columns with None. + This is done with `_apply_feature_types`. + """ + + def __init__(self, ex_iterables: List[_BaseExamplesIterable]): + self.ex_iterables = ex_iterables + + def __iter__(self): + ex_iterators = [iter(ex_iterable) for ex_iterable in self.ex_iterables] + for i in itertools.count(): + keys = [] + examples = [] + for ex_iterator in list(ex_iterators): + try: + key, example = next(ex_iterator) + keys.append(key) + examples.append(example) + except StopIteration: + ex_iterators.remove(ex_iterator) + if ex_iterators: + if i == 0: + _check_column_names([column_name for example in examples for column_name in example]) + new_example = {} + for example in examples: + new_example.update(example) + new_key = "_".join(str(key) for key in keys) + yield new_key, new_example + else: + break + + def shuffle_data_sources( + self, generator: np.random.Generator + ) -> "HorizontallyConcatenatedMultiSourcesExamplesIterable": + """Doesn't shuffle the wrapped examples iterable since it would break the alignment between them.""" + return self + + @property + def n_shards(self) -> int: + return 1 + + def shard_data_sources(self, shard_idx: int) -> "HorizontallyConcatenatedMultiSourcesExamplesIterable": + """Either keep only the requested shard, or propagate the request to the underlying iterable.""" + raise NotImplementedError("Sharding a HorizontallyConcatenatedMultiSourcesExamplesIterable is not implemented") + + class RandomlyCyclingMultiSourcesExamplesIterable(CyclingMultiSourcesExamplesIterable): def __init__(self, ex_iterables, generator: np.random.Generator, probabilities: Optional[List[float]] = None): super().__init__(ex_iterables) @@ -484,24 +592,44 @@ def n_shards(self) -> int: return self.ex_iterable.n_shards +def _apply_feature_types( + example: dict, features: Features, token_per_repo_id: Dict[str, Union[str, bool, None]] +) -> dict: + example = dict(example) + # add missing columns + for column_name in features: + if column_name not in example: + example[column_name] = None + # we encode the example for ClassLabel feature types for example + encoded_example = features.encode_example(example) + # Decode example for Audio feature, e.g. + decoded_example = features.decode_example(encoded_example, token_per_repo_id=token_per_repo_id) + return decoded_example + + class TypedExamplesIterable(_BaseExamplesIterable): - def __init__(self, ex_iterable: _BaseExamplesIterable, features: Features): + def __init__( + self, + ex_iterable: _BaseExamplesIterable, + features: Features, + token_per_repo_id: Dict[str, Union[str, bool, None]], + ): self.ex_iterable = ex_iterable self.features = features + self.token_per_repo_id = token_per_repo_id def __iter__(self): + # Then for each example, `TypedExamplesIterable` automatically fills missing columns with None. + # This is done with `_apply_feature_types`. for key, example in self.ex_iterable: - # we encode the example for ClassLabel feature types for example - encoded_example = self.features.encode_example(example) - # Decode example for Audio feature, e.g. - decoded_example = self.features.decode_example(encoded_example) - yield key, decoded_example + yield key, _apply_feature_types(example, self.features, token_per_repo_id=self.token_per_repo_id) def shuffle_data_sources(self, generator: np.random.Generator) -> "TypedExamplesIterable": """Shuffle the wrapped examples iterable.""" return TypedExamplesIterable( self.ex_iterable.shuffle_data_sources(generator), features=self.features, + token_per_repo_id=self.token_per_repo_id, ) def shard_data_sources(self, shard_idx: int) -> "TypedExamplesIterable": @@ -509,6 +637,7 @@ def shard_data_sources(self, shard_idx: int) -> "TypedExamplesIterable": return TypedExamplesIterable( self.ex_iterable.shard_data_sources(shard_idx), features=self.features, + token_per_repo_id=self.token_per_repo_id, ) @property @@ -551,7 +680,7 @@ def __init__( self._format_type = format_type self._shuffling = shuffling self._epoch = 0 - self._token_per_repo_id = token_per_repo_id or {} + self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {} def _head(self, n=5): return _examples_to_batch([x for key, x in islice(self._iter(), n)]) @@ -585,19 +714,14 @@ def _iter_shard(self, shard_idx: int): ex_iterable = self._ex_iterable yield from ex_iterable.shard_data_sources(shard_idx) - def _apply_feature_types(self, example): - if self.features: - # we encode the example for ClassLabel feature types for example - encoded_example = self.features.encode_example(example) - # Decode example for Audio feature, e.g. - decoded_example = self.features.decode_example(encoded_example, token_per_repo_id=self._token_per_repo_id) - return decoded_example - else: - return example - def __iter__(self): for key, example in self._iter(): - yield self._apply_feature_types(example) + if self.features: + # `IterableDataset` automatically fills missing columns with None. + # This is done with `_apply_feature_types`. + yield _apply_feature_types(example, self.features, token_per_repo_id=self._token_per_repo_id) + else: + yield example def with_format( self, @@ -699,7 +823,7 @@ def map( info = self._info.copy() info.features = None ex_iterable = MappedExamplesIterable( - TypedExamplesIterable(self._ex_iterable, self._info.features) + TypedExamplesIterable(self._ex_iterable, self._info.features, token_per_repo_id=self._token_per_repo_id) if self._info.features is not None else self._ex_iterable, function=function, @@ -768,7 +892,7 @@ def filter( # We need the examples to be decoded for certain feature types like Image or Audio, so we use TypedExamplesIterable here ex_iterable = FilteredExamplesIterable( - TypedExamplesIterable(self._ex_iterable, self._info.features) + TypedExamplesIterable(self._ex_iterable, self._info.features, token_per_repo_id=self._token_per_repo_id) if self._info.features is not None else self._ex_iterable, function=function, @@ -1128,6 +1252,24 @@ def cast( token_per_repo_id=self._token_per_repo_id, ) + def _resolve_features(self): + if self.features is not None: + return self + elif isinstance(self._ex_iterable, TypedExamplesIterable): + features = self._ex_iterable.features + else: + features = _infer_features_from_batch(self._head()) + info = self.info.copy() + info.features = features + return iterable_dataset( + ex_iterable=self._ex_iterable, + info=info, + split=self._split, + format_type=self._format_type, + shuffling=copy.deepcopy(self._shuffling), + token_per_repo_id=self._token_per_repo_id, + ) + def iterable_dataset( ex_iterable: Iterable, @@ -1151,3 +1293,128 @@ def iterable_dataset( shuffling=shuffling, token_per_repo_id=token_per_repo_id, ) + + +def _check_if_features_can_be_aligned(features_list: List[Features]): + """Check if the dictionaries of features can be aligned. + + Two dictonaries of features can be aligned if the keys they share have the same type or some of them is of type `Value("null")`. + """ + name2feature = {} + for features in features_list: + for k, v in features.items(): + if k not in name2feature or (isinstance(name2feature[k], Value) and name2feature[k].dtype == "null"): + name2feature[k] = v + + for features in features_list: + for k, v in features.items(): + if not (isinstance(v, Value) and v.dtype == "null") and name2feature[k] != v: + raise ValueError( + f'The features can\'t be aligned because the key {k} of features {features} has unexpected type - {v} (expected either {name2feature[k]} or Value("null").' + ) + + +def _concatenate_iterable_datasets( + dsets: List[IterableDataset], + info: Optional[DatasetInfo] = None, + split: Optional[NamedSplit] = None, + axis: int = 0, +) -> IterableDataset: + """ + Converts a list of :class:`IterableDataset` with the same schema into a single :class:`IterableDataset`. + Missing data are filled with None values. + + Args: + dsets (:obj:`List[datasets.IterableDataset]`): List of Datasets to concatenate. + info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc. + split (:class:`NamedSplit`, optional): Name of the dataset split. + axis (``{0, 1}``, default ``0``, meaning over rows): + Axis to concatenate over, where ``0`` means over rows (vertically) and ``1`` means over columns + (horizontally). + + *New in version 1.6.0* + + Example: + + ```py + >>> ds3 = _concatenate_iterable_datasets([ds1, ds2]) + ``` + """ + dsets = [d._resolve_features() for d in dsets] + + # Perform checks (and a potentional cast if axis=0) + if axis == 0: + _check_if_features_can_be_aligned([dset.features for dset in dsets]) + else: + _check_column_names([col_name for dset in dsets for col_name in dset.features]) + + features = Features() + for dset in dsets: + features.update(dset.features) + + ex_iterables = [d._ex_iterable for d in dsets] + if axis == 0: + ex_iterable = VerticallyConcatenatedMultiSourcesExamplesIterable(ex_iterables) + else: + ex_iterable = HorizontallyConcatenatedMultiSourcesExamplesIterable(ex_iterables) + # Set new info - we update the features + # setting the features also ensures to fill missing columns with None + if info is None: + info = DatasetInfo.from_merge([d.info for d in dsets]) + else: + info = info.copy() + info.features = features + # Get all the auth tokens per repository - in case the datasets come from different private repositories + token_per_repo_id = {repo_id: token for dataset in dsets for repo_id, token in dataset._token_per_repo_id.items()} + # Return new daset + return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id) + + +def _interleave_iterable_datasets( + datasets: List[IterableDataset], + probabilities: Optional[List[float]] = None, + seed: Optional[int] = None, + info: Optional[DatasetInfo] = None, + split: Optional[NamedSplit] = None, +) -> IterableDataset: + """ + Interleave several iterable datasets (sources) into a single iterable dataset. + The new iterable dataset alternates between the sources to yield examples. + If `probabilities = None` (default) the iterable dataset will cycles through the sources in order for each next example in the iteration. + If `probabilities` is not `None, the iterable dataset will sample a random source according to the provided probabilities for each next examples in the iteration. + + Args: + datasets (:obj:`List[IterableDataset]`): list of datasets to interleave + probabilities (:obj:`List[float]`, optional, default None): If specified, the new iterable dataset samples + examples from one source at a time according to these probabilities. + seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example. + + Output: + :class:`datasets.IterableDataset` + """ + # TODO(QL): merge the features as in _concatenate_iterable_datasets() and don't use TypedExamplesIterable + ex_iterables = [ + TypedExamplesIterable(d._ex_iterable, d.features, token_per_repo_id=d._token_per_repo_id) + if not isinstance(d._ex_iterable, TypedExamplesIterable) and d.features is not None + else d._ex_iterable + for d in datasets + ] + # Use cycling or random cycling or sources + if probabilities is None: + ex_iterable = CyclingMultiSourcesExamplesIterable(ex_iterables) + else: + generator = np.random.default_rng(seed) + ex_iterable = RandomlyCyclingMultiSourcesExamplesIterable( + ex_iterables, generator=generator, probabilities=probabilities + ) + # Set new info - we reset the features + # TODO(QL): merge the features as in _concatenate_iterable_datasets() and use them here + if info is None: + info = DatasetInfo.from_merge([d.info for d in datasets]) + info.features = None + # Get all the auth tokens per repository - in case the datasets come from different private repositories + token_per_repo_id = { + repo_id: token for dataset in datasets for repo_id, token in dataset._token_per_repo_id.items() + } + # Return new daset + return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id) diff --git a/src/datasets/table.py b/src/datasets/table.py index 439e03a784d..b5ef48dddec 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -1232,6 +1232,9 @@ class ConcatenationTable(Table): The first axis concatenates the tables along the axis 0 (it appends rows), while the second axis concatenates tables along the axis 1 (it appends columns). + If some columns are missing when concatenating on axis 0, they are filled with null values. + This is done using `pyarrow.concat_tables(tables, promote=True)`. + You can access the fully combined table by accessing the ConcatenationTable.table attribute, and the blocks by accessing the ConcatenationTable.blocks attribute. """ @@ -1261,6 +1264,7 @@ def __setstate__(self, state): def _concat_blocks(blocks: List[Union[TableBlock, pa.Table]], axis: int = 0) -> pa.Table: pa_tables = [table.table if hasattr(table, "table") else table for table in blocks] if axis == 0: + # we set promote=True to fill missing columns with null values return pa.concat_tables(pa_tables, promote=True) elif axis == 1: for i, table in enumerate(pa_tables): diff --git a/tests/features/test_image.py b/tests/features/test_image.py index fa4f4a4cb17..469b6b82ae3 100644 --- a/tests/features/test_image.py +++ b/tests/features/test_image.py @@ -6,8 +6,7 @@ import pyarrow as pa import pytest -from datasets import Dataset, Features, Image, Sequence, Value, load_dataset -from datasets.arrow_dataset import concatenate_datasets +from datasets import Dataset, Features, Image, Sequence, Value, concatenate_datasets, load_dataset from datasets.features.image import image_to_bytes from ..utils import require_pil diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 5eb680670d4..6debc4fa750 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -6,7 +6,7 @@ import pytest from datasets import load_dataset -from datasets.combine import interleave_datasets +from datasets.combine import concatenate_datasets, interleave_datasets from datasets.features import ClassLabel, Features, Value from datasets.info import DatasetInfo from datasets.iterable_dataset import ( @@ -14,6 +14,7 @@ CyclingMultiSourcesExamplesIterable, ExamplesIterable, FilteredExamplesIterable, + HorizontallyConcatenatedMultiSourcesExamplesIterable, IterableDataset, MappedExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable, @@ -21,6 +22,7 @@ SkipExamplesIterable, TakeExamplesIterable, TypedExamplesIterable, + VerticallyConcatenatedMultiSourcesExamplesIterable, _batch_to_examples, _examples_to_batch, iterable_dataset, @@ -468,6 +470,52 @@ def test_take_examples_iterable(): ), "skip examples makes the shards order fixed" +def test_vertically_concatenated_examples_iterable(): + ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label": 10}) + ex_iterable2 = ExamplesIterable(generate_examples_fn, {"label": 5}) + concatenated_ex_iterable = VerticallyConcatenatedMultiSourcesExamplesIterable([ex_iterable1, ex_iterable2]) + expected = list(x for _, x in ex_iterable1) + list(x for _, x in ex_iterable2) + assert list(x for _, x in concatenated_ex_iterable) == expected + + +def test_vertically_concatenated_examples_iterable_with_different_columns(): + # having different columns is supported + # Though iterable datasets fill the missing data with nulls + ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label": 10}) + ex_iterable2 = ExamplesIterable(generate_examples_fn, {}) + concatenated_ex_iterable = VerticallyConcatenatedMultiSourcesExamplesIterable([ex_iterable1, ex_iterable2]) + expected = list(x for _, x in ex_iterable1) + list(x for _, x in ex_iterable2) + assert list(x for _, x in concatenated_ex_iterable) == expected + + +def test_vertically_concatenated_examples_iterable_shuffle_data_sources(): + ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label": 10}) + ex_iterable2 = ExamplesIterable(generate_examples_fn, {"label": 5}) + concatenated_ex_iterable = VerticallyConcatenatedMultiSourcesExamplesIterable([ex_iterable1, ex_iterable2]) + rng = np.random.default_rng(42) + shuffled_ex_iterable = concatenated_ex_iterable.shuffle_data_sources(rng) + # make sure the list of examples iterables is shuffled, and each examples iterable is shuffled + expected = list(x for _, x in ex_iterable2.shuffle_data_sources(rng)) + list( + x for _, x in ex_iterable1.shuffle_data_sources(rng) + ) + assert list(x for _, x in shuffled_ex_iterable) == expected + + +def test_horizontally_concatenated_examples_iterable(): + ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label1": 10}) + ex_iterable2 = ExamplesIterable(generate_examples_fn, {"label2": 5}) + concatenated_ex_iterable = HorizontallyConcatenatedMultiSourcesExamplesIterable([ex_iterable1, ex_iterable2]) + with pytest.raises(ValueError): # column "id" is duplicated -> raise an error + list(concatenated_ex_iterable) + ex_iterable2 = MappedExamplesIterable(ex_iterable2, lambda x: x, remove_columns=["id"]) + concatenated_ex_iterable = HorizontallyConcatenatedMultiSourcesExamplesIterable([ex_iterable1, ex_iterable2]) + expected = list({**x, **y} for (_, x), (_, y) in zip(ex_iterable1, ex_iterable2)) + assert list(x for _, x in concatenated_ex_iterable) == expected + assert ( + concatenated_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is concatenated_ex_iterable + ), "horizontally concatenated examples makes the shards order fixed" + + ############################ # # IterableDataset tests @@ -811,6 +859,78 @@ def test_iterable_dataset_cast(): assert list(casted_dataset) == [new_features.encode_example(ex) for _, ex in ex_iterable] +def test_concatenate_datasets(): + ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label": 10}) + dataset1 = IterableDataset(ex_iterable1) + ex_iterable2 = ExamplesIterable(generate_examples_fn, {"label": 5}) + dataset2 = IterableDataset(ex_iterable2) + concatenated_dataset = concatenate_datasets([dataset1, dataset2]) + assert list(concatenated_dataset) == list(dataset1) + list(dataset2) + + +def test_concatenate_datasets_resolves_features(): + ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label": 10}) + dataset1 = IterableDataset(ex_iterable1) + ex_iterable2 = ExamplesIterable(generate_examples_fn, {"label": 5}) + dataset2 = IterableDataset(ex_iterable2) + concatenated_dataset = concatenate_datasets([dataset1, dataset2]) + assert concatenated_dataset.features is not None + assert sorted(concatenated_dataset.features) == ["id", "label"] + + +def test_concatenate_datasets_with_different_columns(): + ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label": 10}) + dataset1 = IterableDataset(ex_iterable1) + ex_iterable2 = ExamplesIterable(generate_examples_fn, {}) + dataset2 = IterableDataset(ex_iterable2) + # missing column "label" -> it should be replaced with nulls + extended_dataset2_list = [{"label": None, **x} for x in dataset2] + + concatenated_dataset = concatenate_datasets([dataset1, dataset2]) + assert list(concatenated_dataset) == list(dataset1) + extended_dataset2_list + # change order + concatenated_dataset = concatenate_datasets([dataset2, dataset1]) + assert list(concatenated_dataset) == extended_dataset2_list + list(dataset1) + + +def test_concatenate_datasets_axis_1(): + ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label1": 10}) + dataset1 = IterableDataset(ex_iterable1) + ex_iterable2 = ExamplesIterable(generate_examples_fn, {"label2": 5}) + dataset2 = IterableDataset(ex_iterable2) + with pytest.raises(ValueError): # column "id" is duplicated -> raise an error + concatenate_datasets([dataset1, dataset2], axis=1) + concatenated_dataset = concatenate_datasets([dataset1, dataset2.remove_columns("id")], axis=1) + assert list(concatenated_dataset) == [{**x, **y} for x, y in zip(dataset1, dataset2)] + + +def test_concatenate_datasets_axis_1_resolves_features(): + ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label1": 10}) + dataset1 = IterableDataset(ex_iterable1) + ex_iterable2 = ExamplesIterable(generate_examples_fn, {"label2": 5}) + dataset2 = IterableDataset(ex_iterable2).remove_columns("id") + concatenated_dataset = concatenate_datasets([dataset1, dataset2], axis=1) + assert concatenated_dataset.features is not None + assert sorted(concatenated_dataset.features) == ["id", "label1", "label2"] + + +def test_concatenate_datasets_axis_1_with_different_lengths(): + n1 = 10 + ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label1": 10, "n": n1}) + dataset1 = IterableDataset(ex_iterable1) + n2 = 5 + ex_iterable2 = ExamplesIterable(generate_examples_fn, {"label2": 5, "n": n2}) + dataset2 = IterableDataset(ex_iterable2).remove_columns("id") + # missing rows -> they should be replaced with nulls + extended_dataset2_list = list(dataset2) + [{"label2": None}] * (n1 - n2) + + concatenated_dataset = concatenate_datasets([dataset1, dataset2], axis=1) + assert list(concatenated_dataset) == [{**x, **y} for x, y in zip(dataset1, extended_dataset2_list)] + # change order + concatenated_dataset = concatenate_datasets([dataset2, dataset1], axis=1) + assert list(concatenated_dataset) == [{**x, **y} for x, y in zip(extended_dataset2_list, dataset1)] + + @pytest.mark.parametrize( "probas, seed, expected_length", [