Skip to content

Commit f5826ef

Browse files
authored
Add concatenate_datasets for iterable datasets (#4500)
* add concatenate_datasets for iterable datasets * fix * infer features * fill missing rowzs and columns * comments * only check for duplicate keys once * comments * keep concatenate_datasets in arrow_dataset (to be deprecated) * style * comments, typing, fix missing token_per_repo_id * style
1 parent 6d2a970 commit f5826ef

File tree

10 files changed

+548
-156
lines changed

10 files changed

+548
-156
lines changed

src/datasets/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
del pyarrow
3535
del version
3636

37-
from .arrow_dataset import Dataset, concatenate_datasets
37+
from .arrow_dataset import Dataset
3838
from .arrow_reader import ReadInstruction
3939
from .builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
40-
from .combine import interleave_datasets
40+
from .combine import concatenate_datasets, interleave_datasets
4141
from .dataset_dict import DatasetDict, IterableDatasetDict
4242
from .download import *
4343
from .features import *
@@ -73,15 +73,16 @@
7373

7474

7575
# deprecated modules
76+
from datasets import arrow_dataset as _arrow_dataset # isort:skip
7677
from datasets import utils as _utils # isort:skip
7778
from datasets.utils import download_manager as _deprecated_download_manager # isort:skip
7879

79-
80+
_arrow_dataset.concatenate_datasets = concatenate_datasets
8081
_utils.DownloadConfig = DownloadConfig
8182
_utils.DownloadManager = DownloadManager
8283
_utils.DownloadMode = DownloadMode
8384
_deprecated_download_manager.DownloadConfig = DownloadConfig
8485
_deprecated_download_manager.DownloadMode = DownloadMode
8586
_deprecated_download_manager.DownloadManager = DownloadManager
8687

87-
del _utils, _deprecated_download_manager
88+
del _arrow_dataset, _utils, _deprecated_download_manager

src/datasets/arrow_dataset.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -827,8 +827,8 @@ def from_dict(
827827
cls,
828828
mapping: dict,
829829
features: Optional[Features] = None,
830-
info: Optional[Any] = None,
831-
split: Optional[Any] = None,
830+
info: Optional[DatasetInfo] = None,
831+
split: Optional[NamedSplit] = None,
832832
) -> "Dataset":
833833
"""
834834
Convert :obj:`dict` to a :obj:`pyarrow.Table` to create a :class:`Dataset`.
@@ -2493,7 +2493,7 @@ def catch_non_existent_error(func, kwargs):
24932493
), "All shards have to be defined Datasets, none should still be missing."
24942494

24952495
logger.info(f"Concatenating {num_proc} shards")
2496-
result = concatenate_datasets(transformed_shards)
2496+
result = _concatenate_map_style_datasets(transformed_shards)
24972497
if new_fingerprint is not None:
24982498
result._fingerprint = new_fingerprint
24992499
return result
@@ -4725,14 +4725,15 @@ def process_label_ids(batch):
47254725
return self.map(process_label_ids, features=features, batched=True, desc="Aligning the labels")
47264726

47274727

4728-
def concatenate_datasets(
4728+
def _concatenate_map_style_datasets(
47294729
dsets: List[Dataset],
4730-
info: Optional[Any] = None,
4731-
split: Optional[Any] = None,
4730+
info: Optional[DatasetInfo] = None,
4731+
split: Optional[NamedSplit] = None,
47324732
axis: int = 0,
47334733
):
47344734
"""
47354735
Converts a list of :class:`Dataset` with the same schema into a single :class:`Dataset`.
4736+
When you concatenate on axis 0, missing data are filled with None values.
47364737
47374738
Args:
47384739
dsets (:obj:`List[datasets.Dataset]`): List of Datasets to concatenate.
@@ -4747,7 +4748,7 @@ def concatenate_datasets(
47474748
Example:
47484749
47494750
```py
4750-
>>> ds3 = concatenate_datasets([ds1, ds2])
4751+
>>> ds3 = _concatenate_map_style_datasets([ds1, ds2])
47514752
```
47524753
"""
47534754
# Ignore datasets with no rows
@@ -4823,7 +4824,7 @@ def apply_offset_to_indices_table(table, offset):
48234824
if info is None:
48244825
info = DatasetInfo.from_merge([dset.info for dset in dsets])
48254826
fingerprint = update_fingerprint(
4826-
"".join(dset._fingerprint for dset in dsets), concatenate_datasets, {"info": info, "split": split}
4827+
"".join(dset._fingerprint for dset in dsets), _concatenate_map_style_datasets, {"info": info, "split": split}
48274828
)
48284829

48294830
# Make final concatenated dataset
@@ -4838,6 +4839,64 @@ def apply_offset_to_indices_table(table, offset):
48384839
return concatenated_dataset
48394840

48404841

4842+
def _interleave_map_style_datasets(
4843+
datasets: List["Dataset"],
4844+
probabilities: Optional[List[float]] = None,
4845+
seed: Optional[int] = None,
4846+
info: Optional[DatasetInfo] = None,
4847+
split: Optional[NamedSplit] = None,
4848+
**kwargs,
4849+
) -> "Dataset":
4850+
"""
4851+
Interleave several map-style datasets (sources) into a single map-style dataset.
4852+
The new dataset is constructed by alternating between the sources to get the examples.
4853+
If `probabilities = None` (default) the new dataset is constructed by cycling between each source to get the examples.
4854+
If `probabilities` is not `None, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities.
4855+
4856+
Args:
4857+
datasets (:obj:`List[Dataset]`): list of datasets to interleave
4858+
probabilities (:obj:`List[float]`, optional, default None): If specified, the new dataset is constructued by sampling
4859+
examples from one source at a time according to these probabilities.
4860+
seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example.
4861+
info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc.
4862+
split (:class:`NamedSplit`, optional): Name of the dataset split.
4863+
**kwargs (additional keyword arguments): Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets.
4864+
4865+
Output:
4866+
:class:`datasets.Dataset`
4867+
"""
4868+
4869+
# To interleave the datasets, we concatenate them and then we re-order the indices
4870+
concatenated_datasets = _concatenate_map_style_datasets(datasets, info=info, split=split)
4871+
4872+
# Let's now build the indices to pass to .select()
4873+
lengths = [len(dset) for dset in datasets]
4874+
offsets = np.cumsum([0] + lengths[:-1])
4875+
if probabilities is None:
4876+
# Example:: If lengths of the datasets are [3, 4, 5]
4877+
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9]
4878+
# Note that we only have 3 examples per dataset since the first dataset ran out of examples
4879+
indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist()
4880+
else:
4881+
4882+
def iter_random_indices():
4883+
"""Get an infinite iterator that randomly samples the index of the source to pick examples from."""
4884+
rng = np.random.default_rng(seed)
4885+
while True:
4886+
yield from (int(i) for i in rng.choice(len(datasets), size=1000, p=probabilities))
4887+
4888+
current_index = [0] * len(datasets)
4889+
indices = []
4890+
for source_idx in iter_random_indices():
4891+
# we ran out of examples, let's stop
4892+
if current_index[source_idx] >= lengths[source_idx]:
4893+
break
4894+
# let's add the example at the current index of the `source_idx`-th dataset
4895+
indices.append(current_index[source_idx] + offsets[source_idx])
4896+
current_index[source_idx] += 1
4897+
return concatenated_datasets.select(indices, **kwargs)
4898+
4899+
48414900
# This is outside Dataset.filter as it needs to be picklable for multiprocessing
48424901

48434902

src/datasets/combine.py

Lines changed: 47 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
1-
from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
2-
3-
import numpy as np
1+
from typing import List, Optional, TypeVar
42

3+
from .arrow_dataset import Dataset, _concatenate_map_style_datasets, _interleave_map_style_datasets
54
from .info import DatasetInfo
5+
from .iterable_dataset import IterableDataset, _concatenate_iterable_datasets, _interleave_iterable_datasets
6+
from .splits import NamedSplit
67
from .utils import logging
78

89

910
logger = logging.get_logger(__name__)
1011

1112

12-
if TYPE_CHECKING:
13-
from .arrow_dataset import Dataset
14-
from .iterable_dataset import IterableDataset
15-
16-
1713
DatasetType = TypeVar("DatasetType", "Dataset", "IterableDataset")
1814

1915

2016
def interleave_datasets(
21-
datasets: List[DatasetType], probabilities: Optional[List[float]] = None, seed: Optional[int] = None
17+
datasets: List[DatasetType],
18+
probabilities: Optional[List[float]] = None,
19+
seed: Optional[int] = None,
20+
info: Optional[DatasetInfo] = None,
21+
split: Optional[NamedSplit] = None,
2222
) -> DatasetType:
2323
"""
2424
Interleave several datasets (sources) into a single dataset.
@@ -78,126 +78,59 @@ def interleave_datasets(
7878
map_style = isinstance(datasets[0], Dataset)
7979
if not (iterable ^ map_style):
8080
raise ValueError(
81-
f"Expected a list Dataset objects or a list of IterableDataset objects, but first element is a {type(datasets[0])}"
81+
f"Expected a list of Dataset objects or a list of IterableDataset objects, but first element is a {type(datasets[0])}"
8282
)
8383
for dataset in datasets[1:]:
8484
if (map_style and not isinstance(dataset, Dataset)) or (iterable and not isinstance(dataset, IterableDataset)):
8585
raise ValueError(
8686
f"Unable to interleave a {type(datasets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects."
8787
)
8888
if map_style:
89-
return _interleave_map_style_datasets(datasets, probabilities, seed)
89+
return _interleave_map_style_datasets(datasets, probabilities, seed, info=info, split=split)
9090
else:
91-
return _interleave_iterable_datasets(datasets, probabilities, seed)
91+
return _interleave_iterable_datasets(datasets, probabilities, seed, info=info, split=split)
9292

9393

94-
def _interleave_map_style_datasets(
95-
datasets: List["Dataset"],
96-
probabilities: Optional[List[float]] = None,
97-
seed: Optional[int] = None,
98-
info: Optional[Any] = None,
99-
split: Optional[Any] = None,
100-
**kwargs,
101-
) -> "Dataset":
94+
def concatenate_datasets(
95+
dsets: List[Dataset],
96+
info: Optional[DatasetInfo] = None,
97+
split: Optional[NamedSplit] = None,
98+
axis: int = 0,
99+
):
102100
"""
103-
Interleave several map-style datasets (sources) into a single map-style dataset.
104-
The new dataset is constructed by alternating between the sources to get the examples.
105-
If `probabilities = None` (default) the new dataset is constructed by cycling between each source to get the examples.
106-
If `probabilities` is not `None, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities.
101+
Converts a list of :class:`Dataset` with the same schema into a single :class:`Dataset`.
107102
108103
Args:
109-
datasets (:obj:`List[Dataset]`): list of datasets to interleave
110-
probabilities (:obj:`List[float]`, optional, default None): If specified, the new dataset is constructued by sampling
111-
examples from one source at a time according to these probabilities.
112-
seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example.
113-
**kwargs (additional keyword arguments): Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets.
104+
dsets (:obj:`List[datasets.Dataset]`): List of Datasets to concatenate.
105+
info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc.
106+
split (:class:`NamedSplit`, optional): Name of the dataset split.
107+
axis (``{0, 1}``, default ``0``, meaning over rows):
108+
Axis to concatenate over, where ``0`` means over rows (vertically) and ``1`` means over columns
109+
(horizontally).
114110
115-
Output:
116-
:class:`datasets.Dataset`
117-
"""
118-
from .arrow_dataset import concatenate_datasets
119-
120-
# To interleave the datasets, we concatenate them and then we re-order the indices
121-
concatenated_datasets = concatenate_datasets(datasets, info=info, split=split)
122-
123-
# Let's now build the indices to pass to .select()
124-
lengths = [len(dset) for dset in datasets]
125-
offsets = np.cumsum([0] + lengths[:-1])
126-
if probabilities is None:
127-
# Example:: If lengths of the datasets are [3, 4, 5]
128-
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9]
129-
# Note that we only have 3 examples per dataset since the first dataset ran out of examples
130-
indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist()
131-
else:
111+
*New in version 1.6.0*
132112
133-
def iter_random_indices():
134-
"""Get an infinite iterator that randomly samples the index of the source to pick examples from."""
135-
rng = np.random.default_rng(seed)
136-
while True:
137-
yield from (int(i) for i in rng.choice(len(datasets), size=1000, p=probabilities))
138-
139-
current_index = [0] * len(datasets)
140-
indices = []
141-
for source_idx in iter_random_indices():
142-
# we ran out of examples, let's stop
143-
if current_index[source_idx] >= lengths[source_idx]:
144-
break
145-
# let's add the example at the current index of the `source_idx`-th dataset
146-
indices.append(current_index[source_idx] + offsets[source_idx])
147-
current_index[source_idx] += 1
148-
return concatenated_datasets.select(indices, **kwargs)
149-
150-
151-
def _interleave_iterable_datasets(
152-
datasets: List["IterableDataset"],
153-
probabilities: Optional[List[float]] = None,
154-
seed: Optional[int] = None,
155-
info: Optional[Any] = None,
156-
split: Optional[Any] = None,
157-
) -> "IterableDataset":
158-
"""
159-
Interleave several iterable datasets (sources) into a single iterable dataset.
160-
The new iterable dataset alternates between the sources to yield examples.
161-
If `probabilities = None` (default) the iterable dataset will cycles through the sources in order for each next example in the iteration.
162-
If `probabilities` is not `None, the iterable dataset will sample a random source according to the provided probabilities for each next examples in the iteration.
163-
164-
Args:
165-
datasets (:obj:`List[IterableDataset]`): list of datasets to interleave
166-
probabilities (:obj:`List[float]`, optional, default None): If specified, the new iterable dataset samples
167-
examples from one source at a time according to these probabilities.
168-
seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example.
113+
Example:
169114
170-
Output:
171-
:class:`datasets.IterableDataset`
115+
```py
116+
>>> ds3 = concatenate_datasets([ds1, ds2])
117+
```
172118
"""
173-
from .iterable_dataset import (
174-
CyclingMultiSourcesExamplesIterable,
175-
RandomlyCyclingMultiSourcesExamplesIterable,
176-
TypedExamplesIterable,
177-
iterable_dataset,
178-
)
179-
180-
ex_iterables = [
181-
TypedExamplesIterable(d._ex_iterable, d.features)
182-
if not isinstance(d._ex_iterable, TypedExamplesIterable) and d.features is not None
183-
else d._ex_iterable
184-
for d in datasets
185-
]
186-
# Use cycling or random cycling or sources
187-
if probabilities is None:
188-
ex_iterable = CyclingMultiSourcesExamplesIterable(ex_iterables)
189-
else:
190-
generator = np.random.default_rng(seed)
191-
ex_iterable = RandomlyCyclingMultiSourcesExamplesIterable(
192-
ex_iterables, generator=generator, probabilities=probabilities
119+
120+
if not dsets:
121+
raise ValueError("Unable to concatenate an empty list of datasets.")
122+
iterable = isinstance(dsets[0], IterableDataset)
123+
map_style = isinstance(dsets[0], Dataset)
124+
if not (iterable ^ map_style):
125+
raise ValueError(
126+
f"Expected a list of Dataset objects or a list of IterableDataset objects, but first element is a {type(dsets[0])}"
193127
)
194-
# Set new info - we reset the features
195-
if info is None:
196-
info = DatasetInfo.from_merge([d.info for d in datasets])
197-
info.features = None
198-
# Get all the auth tokens per repository - in case the datasets come from different private repositories
199-
token_per_repo_id = {
200-
repo_id: token for dataset in datasets for repo_id, token in dataset._token_per_repo_id.items()
201-
}
202-
# Return new daset
203-
return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)
128+
for dataset in dsets[1:]:
129+
if (map_style and not isinstance(dataset, Dataset)) or (iterable and not isinstance(dataset, IterableDataset)):
130+
raise ValueError(
131+
f"Unable to concatenate a {type(dsets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects."
132+
)
133+
if map_style:
134+
return _concatenate_map_style_datasets(dsets, info=info, split=split, axis=axis)
135+
else:
136+
return _concatenate_iterable_datasets(dsets, info=info, split=split, axis=axis)

src/datasets/features/audio.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def encode_example(self, value: Union[str, dict]) -> dict:
101101
f"An audio sample should have one of 'path' or 'bytes' but they are missing or None in {value}."
102102
)
103103

104-
def decode_example(self, value: dict, token_per_repo_id=None) -> dict:
104+
def decode_example(
105+
self, value: dict, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None
106+
) -> dict:
105107
"""Decode example audio file into audio data.
106108
107109
Args:
@@ -211,7 +213,9 @@ def path_to_bytes(path):
211213
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null())
212214
return array_cast(storage, self.pa_type)
213215

214-
def _decode_non_mp3_path_like(self, path, format=None, token_per_repo_id=None):
216+
def _decode_non_mp3_path_like(
217+
self, path, format=None, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None
218+
):
215219
try:
216220
import librosa
217221
except ImportError as err:

src/datasets/features/features.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,7 @@ def encode_nested_example(schema, obj, level=0):
12221222
return obj
12231223

12241224

1225-
def decode_nested_example(schema, obj, token_per_repo_id=None):
1225+
def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None):
12261226
"""Decode a nested example.
12271227
This is used since some features (in particular Audio and Image) have some logic during decoding.
12281228
@@ -1613,7 +1613,7 @@ def encode_batch(self, batch):
16131613
encoded_batch[key] = [encode_nested_example(self[key], obj) for obj in column]
16141614
return encoded_batch
16151615

1616-
def decode_example(self, example: dict, token_per_repo_id=None):
1616+
def decode_example(self, example: dict, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None):
16171617
"""Decode example with custom feature decoding.
16181618
16191619
Args:

0 commit comments

Comments
 (0)