|
1 | | -from typing import TYPE_CHECKING, Any, List, Optional, TypeVar |
2 | | - |
3 | | -import numpy as np |
| 1 | +from typing import List, Optional, TypeVar |
4 | 2 |
|
| 3 | +from .arrow_dataset import Dataset, _concatenate_map_style_datasets, _interleave_map_style_datasets |
5 | 4 | from .info import DatasetInfo |
| 5 | +from .iterable_dataset import IterableDataset, _concatenate_iterable_datasets, _interleave_iterable_datasets |
| 6 | +from .splits import NamedSplit |
6 | 7 | from .utils import logging |
7 | 8 |
|
8 | 9 |
|
9 | 10 | logger = logging.get_logger(__name__) |
10 | 11 |
|
11 | 12 |
|
12 | | -if TYPE_CHECKING: |
13 | | - from .arrow_dataset import Dataset |
14 | | - from .iterable_dataset import IterableDataset |
15 | | - |
16 | | - |
17 | 13 | DatasetType = TypeVar("DatasetType", "Dataset", "IterableDataset") |
18 | 14 |
|
19 | 15 |
|
20 | 16 | def interleave_datasets( |
21 | | - datasets: List[DatasetType], probabilities: Optional[List[float]] = None, seed: Optional[int] = None |
| 17 | + datasets: List[DatasetType], |
| 18 | + probabilities: Optional[List[float]] = None, |
| 19 | + seed: Optional[int] = None, |
| 20 | + info: Optional[DatasetInfo] = None, |
| 21 | + split: Optional[NamedSplit] = None, |
22 | 22 | ) -> DatasetType: |
23 | 23 | """ |
24 | 24 | Interleave several datasets (sources) into a single dataset. |
@@ -86,118 +86,51 @@ def interleave_datasets( |
86 | 86 | f"Unable to interleave a {type(datasets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects." |
87 | 87 | ) |
88 | 88 | if map_style: |
89 | | - return _interleave_map_style_datasets(datasets, probabilities, seed) |
| 89 | + return _interleave_map_style_datasets(datasets, probabilities, seed, info=info, split=split) |
90 | 90 | else: |
91 | | - return _interleave_iterable_datasets(datasets, probabilities, seed) |
| 91 | + return _interleave_iterable_datasets(datasets, probabilities, seed, info=info, split=split) |
92 | 92 |
|
93 | 93 |
|
94 | | -def _interleave_map_style_datasets( |
95 | | - datasets: List["Dataset"], |
96 | | - probabilities: Optional[List[float]] = None, |
97 | | - seed: Optional[int] = None, |
98 | | - info: Optional[Any] = None, |
99 | | - split: Optional[Any] = None, |
100 | | - **kwargs, |
101 | | -) -> "Dataset": |
| 94 | +def concatenate_datasets( |
| 95 | + dsets: List[Dataset], |
| 96 | + info: Optional[DatasetInfo] = None, |
| 97 | + split: Optional[NamedSplit] = None, |
| 98 | + axis: int = 0, |
| 99 | +): |
102 | 100 | """ |
103 | | - Interleave several map-style datasets (sources) into a single map-style dataset. |
104 | | - The new dataset is constructed by alternating between the sources to get the examples. |
105 | | - If `probabilities = None` (default) the new dataset is constructed by cycling between each source to get the examples. |
106 | | - If `probabilities` is not `None, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities. |
| 101 | + Converts a list of :class:`Dataset` with the same schema into a single :class:`Dataset`. |
107 | 102 |
|
108 | 103 | Args: |
109 | | - datasets (:obj:`List[Dataset]`): list of datasets to interleave |
110 | | - probabilities (:obj:`List[float]`, optional, default None): If specified, the new dataset is constructued by sampling |
111 | | - examples from one source at a time according to these probabilities. |
112 | | - seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example. |
113 | | - **kwargs (additional keyword arguments): Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets. |
114 | | -
|
115 | | - Output: |
116 | | - :class:`datasets.Dataset` |
117 | | - """ |
118 | | - from .arrow_dataset import concatenate_datasets |
119 | | - |
120 | | - # To interleave the datasets, we concatenate them and then we re-order the indices |
121 | | - concatenated_datasets = concatenate_datasets(datasets, info=info, split=split) |
122 | | - |
123 | | - # Let's now build the indices to pass to .select() |
124 | | - lengths = [len(dset) for dset in datasets] |
125 | | - offsets = np.cumsum([0] + lengths[:-1]) |
126 | | - if probabilities is None: |
127 | | - # Example:: If lengths of the datasets are [3, 4, 5] |
128 | | - # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9] |
129 | | - # Note that we only have 3 examples per dataset since the first dataset ran out of examples |
130 | | - indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist() |
131 | | - else: |
| 104 | + dsets (:obj:`List[datasets.Dataset]`): List of Datasets to concatenate. |
| 105 | + info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc. |
| 106 | + split (:class:`NamedSplit`, optional): Name of the dataset split. |
| 107 | + axis (``{0, 1}``, default ``0``, meaning over rows): |
| 108 | + Axis to concatenate over, where ``0`` means over rows (vertically) and ``1`` means over columns |
| 109 | + (horizontally). |
132 | 110 |
|
133 | | - def iter_random_indices(): |
134 | | - """Get an infinite iterator that randomly samples the index of the source to pick examples from.""" |
135 | | - rng = np.random.default_rng(seed) |
136 | | - while True: |
137 | | - yield from (int(i) for i in rng.choice(len(datasets), size=1000, p=probabilities)) |
138 | | - |
139 | | - current_index = [0] * len(datasets) |
140 | | - indices = [] |
141 | | - for source_idx in iter_random_indices(): |
142 | | - # we ran out of examples, let's stop |
143 | | - if current_index[source_idx] >= lengths[source_idx]: |
144 | | - break |
145 | | - # let's add the example at the current index of the `source_idx`-th dataset |
146 | | - indices.append(current_index[source_idx] + offsets[source_idx]) |
147 | | - current_index[source_idx] += 1 |
148 | | - return concatenated_datasets.select(indices, **kwargs) |
149 | | - |
150 | | - |
151 | | -def _interleave_iterable_datasets( |
152 | | - datasets: List["IterableDataset"], |
153 | | - probabilities: Optional[List[float]] = None, |
154 | | - seed: Optional[int] = None, |
155 | | - info: Optional[Any] = None, |
156 | | - split: Optional[Any] = None, |
157 | | -) -> "IterableDataset": |
158 | | - """ |
159 | | - Interleave several iterable datasets (sources) into a single iterable dataset. |
160 | | - The new iterable dataset alternates between the sources to yield examples. |
161 | | - If `probabilities = None` (default) the iterable dataset will cycles through the sources in order for each next example in the iteration. |
162 | | - If `probabilities` is not `None, the iterable dataset will sample a random source according to the provided probabilities for each next examples in the iteration. |
| 111 | + *New in version 1.6.0* |
163 | 112 |
|
164 | | - Args: |
165 | | - datasets (:obj:`List[IterableDataset]`): list of datasets to interleave |
166 | | - probabilities (:obj:`List[float]`, optional, default None): If specified, the new iterable dataset samples |
167 | | - examples from one source at a time according to these probabilities. |
168 | | - seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example. |
| 113 | + Example: |
169 | 114 |
|
170 | | - Output: |
171 | | - :class:`datasets.IterableDataset` |
| 115 | + ```py |
| 116 | + >>> ds3 = concatenate_datasets([ds1, ds2]) |
| 117 | + ``` |
172 | 118 | """ |
173 | | - from .iterable_dataset import ( |
174 | | - CyclingMultiSourcesExamplesIterable, |
175 | | - RandomlyCyclingMultiSourcesExamplesIterable, |
176 | | - TypedExamplesIterable, |
177 | | - iterable_dataset, |
178 | | - ) |
179 | | - |
180 | | - ex_iterables = [ |
181 | | - TypedExamplesIterable(d._ex_iterable, d.features) |
182 | | - if not isinstance(d._ex_iterable, TypedExamplesIterable) and d.features is not None |
183 | | - else d._ex_iterable |
184 | | - for d in datasets |
185 | | - ] |
186 | | - # Use cycling or random cycling or sources |
187 | | - if probabilities is None: |
188 | | - ex_iterable = CyclingMultiSourcesExamplesIterable(ex_iterables) |
189 | | - else: |
190 | | - generator = np.random.default_rng(seed) |
191 | | - ex_iterable = RandomlyCyclingMultiSourcesExamplesIterable( |
192 | | - ex_iterables, generator=generator, probabilities=probabilities |
| 119 | + |
| 120 | + if not dsets: |
| 121 | + raise ValueError("Unable to interleave an empty list of datasets.") |
| 122 | + iterable = isinstance(dsets[0], IterableDataset) |
| 123 | + map_style = isinstance(dsets[0], Dataset) |
| 124 | + if not (iterable ^ map_style): |
| 125 | + raise ValueError( |
| 126 | + f"Expected a list Dataset objects or a list of IterableDataset objects, but first element is a {type(dsets[0])}" |
193 | 127 | ) |
194 | | - # Set new info - we reset the features |
195 | | - if info is None: |
196 | | - info = DatasetInfo.from_merge([d.info for d in datasets]) |
197 | | - info.features = None |
198 | | - # Get all the auth tokens per repository - in case the datasets come from different private repositories |
199 | | - token_per_repo_id = { |
200 | | - repo_id: token for dataset in datasets for repo_id, token in dataset._token_per_repo_id.items() |
201 | | - } |
202 | | - # Return new daset |
203 | | - return iterable_dataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id) |
| 128 | + for dataset in dsets[1:]: |
| 129 | + if (map_style and not isinstance(dataset, Dataset)) or (iterable and not isinstance(dataset, IterableDataset)): |
| 130 | + raise ValueError( |
| 131 | + f"Unable to interleave a {type(dsets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects." |
| 132 | + ) |
| 133 | + if map_style: |
| 134 | + return _concatenate_map_style_datasets(dsets, info=info, split=split, axis=axis) |
| 135 | + else: |
| 136 | + return _concatenate_iterable_datasets(dsets, info=info, split=split, axis=axis) |
0 commit comments