Skip to content

Commit 37bb701

Browse files
committed
add concatenate_datasets for iterable datasets
1 parent 5eac250 commit 37bb701

File tree

6 files changed

+359
-125
lines changed

6 files changed

+359
-125
lines changed

src/datasets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
del pyarrow
3535
del version
3636

37-
from .arrow_dataset import Dataset, concatenate_datasets
37+
from .arrow_dataset import Dataset
3838
from .arrow_reader import ReadInstruction
3939
from .builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
40-
from .combine import interleave_datasets
40+
from .combine import concatenate_datasets, interleave_datasets
4141
from .dataset_dict import DatasetDict, IterableDatasetDict
4242
from .download import *
4343
from .features import *

src/datasets/arrow_dataset.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -826,8 +826,8 @@ def from_dict(
826826
cls,
827827
mapping: dict,
828828
features: Optional[Features] = None,
829-
info: Optional[Any] = None,
830-
split: Optional[Any] = None,
829+
info: Optional[DatasetInfo] = None,
830+
split: Optional[NamedSplit] = None,
831831
) -> "Dataset":
832832
"""
833833
Convert :obj:`dict` to a :obj:`pyarrow.Table` to create a :class:`Dataset`.
@@ -4578,10 +4578,10 @@ def process_label_ids(batch):
45784578
return self.map(process_label_ids, features=features, batched=True, desc="Aligning the labels")
45794579

45804580

4581-
def concatenate_datasets(
4581+
def _concatenate_map_style_datasets(
45824582
dsets: List[Dataset],
4583-
info: Optional[Any] = None,
4584-
split: Optional[Any] = None,
4583+
info: Optional[DatasetInfo] = None,
4584+
split: Optional[NamedSplit] = None,
45854585
axis: int = 0,
45864586
):
45874587
"""
@@ -4600,7 +4600,7 @@ def concatenate_datasets(
46004600
Example:
46014601
46024602
```py
4603-
>>> ds3 = concatenate_datasets([ds1, ds2])
4603+
>>> ds3 = _concatenate_map_style_datasets([ds1, ds2])
46044604
```
46054605
"""
46064606
# Ignore datasets with no rows
@@ -4676,7 +4676,7 @@ def apply_offset_to_indices_table(table, offset):
46764676
if info is None:
46774677
info = DatasetInfo.from_merge([dset.info for dset in dsets])
46784678
fingerprint = update_fingerprint(
4679-
"".join(dset._fingerprint for dset in dsets), concatenate_datasets, {"info": info, "split": split}
4679+
"".join(dset._fingerprint for dset in dsets), _concatenate_map_style_datasets, {"info": info, "split": split}
46804680
)
46814681

46824682
# Make final concatenated dataset
@@ -4691,6 +4691,64 @@ def apply_offset_to_indices_table(table, offset):
46914691
return concatenated_dataset
46924692

46934693

4694+
def _interleave_map_style_datasets(
4695+
datasets: List["Dataset"],
4696+
probabilities: Optional[List[float]] = None,
4697+
seed: Optional[int] = None,
4698+
info: Optional[DatasetInfo] = None,
4699+
split: Optional[NamedSplit] = None,
4700+
**kwargs,
4701+
) -> "Dataset":
4702+
"""
4703+
Interleave several map-style datasets (sources) into a single map-style dataset.
4704+
The new dataset is constructed by alternating between the sources to get the examples.
4705+
If `probabilities = None` (default) the new dataset is constructed by cycling between each source to get the examples.
4706+
If `probabilities` is not `None, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities.
4707+
4708+
Args:
4709+
datasets (:obj:`List[Dataset]`): list of datasets to interleave
4710+
probabilities (:obj:`List[float]`, optional, default None): If specified, the new dataset is constructued by sampling
4711+
examples from one source at a time according to these probabilities.
4712+
seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example.
4713+
info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc.
4714+
split (:class:`NamedSplit`, optional): Name of the dataset split.
4715+
**kwargs (additional keyword arguments): Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets.
4716+
4717+
Output:
4718+
:class:`datasets.Dataset`
4719+
"""
4720+
4721+
# To interleave the datasets, we concatenate them and then we re-order the indices
4722+
concatenated_datasets = _concatenate_map_style_datasets(datasets, info=info, split=split)
4723+
4724+
# Let's now build the indices to pass to .select()
4725+
lengths = [len(dset) for dset in datasets]
4726+
offsets = np.cumsum([0] + lengths[:-1])
4727+
if probabilities is None:
4728+
# Example:: If lengths of the datasets are [3, 4, 5]
4729+
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9]
4730+
# Note that we only have 3 examples per dataset since the first dataset ran out of examples
4731+
indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist()
4732+
else:
4733+
4734+
def iter_random_indices():
4735+
"""Get an infinite iterator that randomly samples the index of the source to pick examples from."""
4736+
rng = np.random.default_rng(seed)
4737+
while True:
4738+
yield from (int(i) for i in rng.choice(len(datasets), size=1000, p=probabilities))
4739+
4740+
current_index = [0] * len(datasets)
4741+
indices = []
4742+
for source_idx in iter_random_indices():
4743+
# we ran out of examples, let's stop
4744+
if current_index[source_idx] >= lengths[source_idx]:
4745+
break
4746+
# let's add the example at the current index of the `source_idx`-th dataset
4747+
indices.append(current_index[source_idx] + offsets[source_idx])
4748+
current_index[source_idx] += 1
4749+
return concatenated_datasets.select(indices, **kwargs)
4750+
4751+
46944752
# This is outside Dataset.filter as it needs to be picklable for multiprocessing
46954753

46964754

src/datasets/combine.py

Lines changed: 46 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
1-
from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
2-
3-
import numpy as np
1+
from typing import List, Optional, TypeVar
42

3+
from .arrow_dataset import Dataset, _concatenate_map_style_datasets, _interleave_map_style_datasets
54
from .info import DatasetInfo
5+
from .iterable_dataset import IterableDataset, _concatenate_iterable_datasets, _interleave_iterable_datasets
6+
from .splits import NamedSplit
67
from .utils import logging
78

89

910
logger = logging.get_logger(__name__)
1011

1112

12-
if TYPE_CHECKING:
13-
from .arrow_dataset import Dataset
14-
from .iterable_dataset import IterableDataset
15-
16-
1713
DatasetType = TypeVar("DatasetType", "Dataset", "IterableDataset")
1814

1915

2016
def interleave_datasets(
21-
datasets: List[DatasetType], probabilities: Optional[List[float]] = None, seed: Optional[int] = None
17+
datasets: List[DatasetType],
18+
probabilities: Optional[List[float]] = None,
19+
seed: Optional[int] = None,
20+
info: Optional[DatasetInfo] = None,
21+
split: Optional[NamedSplit] = None,
2222
) -> DatasetType:
2323
"""
2424
Interleave several datasets (sources) into a single dataset.
@@ -86,118 +86,51 @@ def interleave_datasets(
8686
f"Unable to interleave a {type(datasets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects."
8787
)
8888
if map_style:
89-
return _interleave_map_style_datasets(datasets, probabilities, seed)
89+
return _interleave_map_style_datasets(datasets, probabilities, seed, info=info, split=split)
9090
else:
91-
return _interleave_iterable_datasets(datasets, probabilities, seed)
91+
return _interleave_iterable_datasets(datasets, probabilities, seed, info=info, split=split)
9292

9393

94-
def _interleave_map_style_datasets(
95-
datasets: List["Dataset"],
96-
probabilities: Optional[List[float]] = None,
97-
seed: Optional[int] = None,
98-
info: Optional[Any] = None,
99-
split: Optional[Any] = None,
100-
**kwargs,
101-
) -> "Dataset":
94+
def concatenate_datasets(
95+
dsets: List[Dataset],
96+
info: Optional[DatasetInfo] = None,
97+
split: Optional[NamedSplit] = None,
98+
axis: int = 0,
99+
):
102100
"""
103-
Interleave several map-style datasets (sources) into a single map-style dataset.
104-
The new dataset is constructed by alternating between the sources to get the examples.
105-
If `probabilities = None` (default) the new dataset is constructed by cycling between each source to get the examples.
106-
If `probabilities` is not `None, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities.
101+
Converts a list of :class:`Dataset` with the same schema into a single :class:`Dataset`.
107102
108103
Args:
109-
datasets (:obj:`List[Dataset]`): list of datasets to interleave
110-
probabilities (:obj:`List[float]`, optional, default None): If specified, the new dataset is constructued by sampling
111-
examples from one source at a time according to these probabilities.
112-
seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example.
113-
**kwargs (additional keyword arguments): Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets.
114-
115-
Output:
116-
:class:`datasets.Dataset`
117-
"""
118-
from .arrow_dataset import concatenate_datasets
119-
120-
# To interleave the datasets, we concatenate them and then we re-order the indices
121-
concatenated_datasets = concatenate_datasets(datasets, info=info, split=split)
122-
123-
# Let's now build the indices to pass to .select()
124-
lengths = [len(dset) for dset in datasets]
125-
offsets = np.cumsum([0] + lengths[:-1])
126-
if probabilities is None:
127-
# Example:: If lengths of the datasets are [3, 4, 5]
128-
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9]
129-
# Note that we only have 3 examples per dataset since the first dataset ran out of examples
130-
indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist()
131-
else:
104+
dsets (:obj:`List[datasets.Dataset]`): List of Datasets to concatenate.
105+
info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc.
106+
split (:class:`NamedSplit`, optional): Name of the dataset split.
107+
axis (``{0, 1}``, default ``0``, meaning over rows):
108+
Axis to concatenate over, where ``0`` means over rows (vertically) and ``1`` means over columns
109+
(horizontally).
132110
133-
def iter_random_indices():
134-
"""Get an infinite iterator that randomly samples the index of the source to pick examples from."""
135-
rng = np.random.default_rng(seed)
136-
while True:
137-
yield from (int(i) for i in rng.choice(len(datasets), size=1000, p=probabilities))
138-
139-
current_index = [0] * len(datasets)
140-
indices = []
141-
for source_idx in iter_random_indices():
142-
# we ran out of examples, let's stop
143-
if current_index[source_idx] >= lengths[source_idx]:
144-
break
145-
# let's add the example at the current index of the `source_idx`-th dataset
146-
indices.append(current_index[source_idx] + offsets[source_idx])
147-
current_index[source_idx] += 1
148-
return concatenated_datasets.select(indices, **kwargs)
149-
150-
151-
def _interleave_iterable_datasets(
152-
datasets: List["IterableDataset"],
153-
probabilities: Optional[List[float]] = None,
154-
seed: Optional[int] = None,
155-
info: Optional[Any] = None,
156-
split: Optional[Any] = None,
157-
) -> "IterableDataset":
158-
"""
159-
Interleave several iterable datasets (sources) into a single iterable dataset.
160-
The new iterable dataset alternates between the sources to yield examples.
161-
If `probabilities = None` (default) the iterable dataset will cycles through the sources in order for each next example in the iteration.
162-
If `probabilities` is not `None, the iterable dataset will sample a random source according to the provided probabilities for each next examples in the iteration.
111+
*New in version 1.6.0*
163112
164-
Args:
165-
datasets (:obj:`List[IterableDataset]`): list of datasets to interleave
166-
probabilities (:obj:`List[float]`, optional, default None): If specified, the new iterable dataset samples
167-
examples from one source at a time according to these probabilities.
168-
seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example.
113+
Example:
169114
170-
Output:
171-
:class:`datasets.IterableDataset`
115+
```py
116+
>>> ds3 = concatenate_datasets([ds1, ds2])
117+
```
172118
"""
173-
from .iterable_dataset import (
174-
CyclingMultiSourcesExamplesIterable,
175-
RandomlyCyclingMultiSourcesExamplesIterable,
176-
TypedExamplesIterable,
177-
iterable_dataset,
178-
)
179-
180-
ex_iterables = [
181-
TypedExamplesIterable(d._ex_iterable, d.features)
182-
if not isinstance(d._ex_iterable, TypedExamplesIterable) and d.features is not None
183-
else d._ex_iterable
184-
for d in datasets
185-
]
186-
# Use cycling or random cycling or sources
187-
if probabilities is None:
188-
ex_iterable = CyclingMultiSourcesExamplesIterable(ex_iterables)
189-
else:
190-
generator = np.random.default_rng(seed)
191-
ex_iterable = RandomlyCyclingMultiSourcesExamplesIterable(
192-
ex_iterables, generator=generator, probabilities=probabilities
119+
120+
if not dsets:
121+
raise ValueError("Unable to interleave an empty list of datasets.")
122+
iterable = isinstance(dsets[0], IterableDataset)
123+
map_style = isinstance(dsets[0], Dataset)
124+
if not (iterable ^ map_style):
125+
raise ValueError(
126+
f"Expected a list Dataset objects or a list of IterableDataset objects, but first element is a {type(dsets[0])}"
193127
)
194-
# Set new info - we reset the features
195-
if info is None:
196-
info = DatasetInfo.from_merge([d.info for d in datasets])
197-
info.features = None
198-
# Get all the auth tokens per repository - in case the datasets come from different private repositories
199-
token_per_repo_id = {
200-
repo_id: token for dataset in datasets for repo_id, token in dataset._token_per_repo_id.items()
201-
}
202-
# Return new daset
203-
return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)
128+
for dataset in dsets[1:]:
129+
if (map_style and not isinstance(dataset, Dataset)) or (iterable and not isinstance(dataset, IterableDataset)):
130+
raise ValueError(
131+
f"Unable to interleave a {type(dsets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects."
132+
)
133+
if map_style:
134+
return _concatenate_map_style_datasets(dsets, info=info, split=split, axis=axis)
135+
else:
136+
return _concatenate_iterable_datasets(dsets, info=info, split=split, axis=axis)

0 commit comments

Comments
 (0)