Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/process.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ In this case, the new dataset is constructed by getting examples one by one from
You can also specify the `stopping_strategy`. The default strategy, `first_exhausted`, is a subsampling strategy, i.e the dataset construction is stopped as soon one of the dataset runs out of samples.
You can specify `stopping_strategy=all_exhausted` to execute an oversampling strategy. In this case, the dataset construction is stopped as soon as every samples in every dataset has been added at least once. In practice, it means that if a dataset is exhausted, it will return to the beginning of this dataset until the stop criterion has been reached.
Note that if no sampling probabilities are specified, the new dataset will have `max_length_datasets*nb_dataset samples`.
There is also `stopping_strategy=all_exhausted_without_replacement` to ensure that every sample is seen exactly once.

```py
>>> d1 = Dataset.from_dict({"a": [0, 1, 2]})
Expand Down
1 change: 1 addition & 0 deletions docs/source/stream.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ Around 80% of the final dataset is made of the `es_dataset`, and 20% of the `fr_
You can also specify the `stopping_strategy`. The default strategy, `first_exhausted`, is a subsampling strategy, i.e the dataset construction is stopped as soon one of the dataset runs out of samples.
You can specify `stopping_strategy=all_exhausted` to execute an oversampling strategy. In this case, the dataset construction is stopped as soon as every samples in every dataset has been added at least once. In practice, it means that if a dataset is exhausted, it will return to the beginning of this dataset until the stop criterion has been reached.
Note that if no sampling probabilities are specified, the new dataset will have `max_length_datasets*nb_dataset samples`.
There is also `stopping_strategy=all_exhausted_without_replacement` to ensure that every sample is seen exactly once.

## Rename, remove, and cast

Expand Down
21 changes: 15 additions & 6 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6566,7 +6566,9 @@ def _interleave_map_style_datasets(
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
stopping_strategy: Literal[
"first_exhausted", "all_exhausted", "all_exhausted_without_replacement"
] = "first_exhausted",
**kwargs,
) -> "Dataset":
"""
Expand All @@ -6586,6 +6588,7 @@ def _interleave_map_style_datasets(
Two strategies are proposed right now.
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
When strategy is `all_exhausted_without_replacement` we make sure that each sample in each dataset is sampled only once.
Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous:
- with no probabilities, the resulting dataset will have max_length_datasets*nb_dataset samples.
- with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting.
Expand All @@ -6594,7 +6597,7 @@ def _interleave_map_style_datasets(
Output:
:class:`datasets.Dataset`
"""
if stopping_strategy not in ["first_exhausted", "all_exhausted"]:
if stopping_strategy not in ["first_exhausted", "all_exhausted", "all_exhausted_without_replacement"]:
raise ValueError(
f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}"
)
Expand Down Expand Up @@ -6637,7 +6640,9 @@ def _interleave_map_style_datasets(

# if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
# if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
bool_strategy_func = np.all if oversampling else np.any
bool_strategy_func = (
np.all if (oversampling or stopping_strategy == "all_exhausted_without_replacement") else np.any
)

def iter_random_indices():
"""Get an infinite iterator that randomly samples the index of the source to pick examples from."""
Expand All @@ -6655,13 +6660,17 @@ def iter_random_indices():
break

# let's add the example at the current index of the `source_idx`-th dataset
indices.append(current_index[source_idx] + offsets[source_idx])
current_index[source_idx] += 1
# For without replacement sampling we additionally need to make sure the current source is not exhausted to not oversample.
if stopping_strategy != "all_exhausted_without_replacement" or not is_exhausted[source_idx]:
indices.append(current_index[source_idx] + offsets[source_idx])
current_index[source_idx] += 1

# we've ran out of examples for the current dataset, let's update our boolean array and bring the current_index back to 0
if current_index[source_idx] >= lengths[source_idx]:
is_exhausted[source_idx] = True
current_index[source_idx] = 0
# We don't want to reset the iterator when stopping strategy is without replacement.
if stopping_strategy != "all_exhausted_without_replacement":
current_index[source_idx] = 0

return concatenated_datasets.select(indices, **kwargs)

Expand Down
16 changes: 12 additions & 4 deletions src/datasets/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def interleave_datasets(
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
stopping_strategy: Literal[
"first_exhausted", "all_exhausted", "all_exhausted_without_replacement"
] = "first_exhausted",
) -> DatasetType:
"""
Interleave several datasets (sources) into a single dataset.
Expand Down Expand Up @@ -55,9 +57,10 @@ def interleave_datasets(
Name of the dataset split.
<Added version="2.4.0"/>
stopping_strategy (`str`, defaults to `first_exhausted`):
Two strategies are proposed right now, `first_exhausted` and `all_exhausted`.
Three strategies are proposed right now, `first_exhausted`, `all_exhausted` and `all_exhausted_without_replacement`.
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
When strategy is `all_exhausted_without_replacement` we make sure that each sample in each dataset is sampled only once.
Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous:
- with no probabilities, the resulting dataset will have `max_length_datasets*nb_dataset` samples.
- with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting.
Expand Down Expand Up @@ -143,15 +146,20 @@ def interleave_datasets(
raise ValueError(
f"Unable to interleave a {dataset_type.__name__} (at position 0) with a {other_type.__name__} (at position {i}). Expected a list of Dataset objects or a list of IterableDataset objects."
)
if stopping_strategy not in ["first_exhausted", "all_exhausted"]:
if stopping_strategy not in ["first_exhausted", "all_exhausted", "all_exhausted_without_replacement"]:
raise ValueError(f"{stopping_strategy} is not supported. Please enter a valid stopping_strategy.")
if dataset_type is Dataset:
return _interleave_map_style_datasets(
datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy
)
else:
return _interleave_iterable_datasets(
datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy
datasets,
probabilities,
seed,
info=info,
split=split,
stopping_strategy=stopping_strategy,
)


Expand Down
113 changes: 84 additions & 29 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,15 +673,20 @@ class CyclingMultiSourcesExamplesIterable(_BaseExamplesIterable):
def __init__(
self,
ex_iterables: list[_BaseExamplesIterable],
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
stopping_strategy: Literal[
"first_exhausted", "all_exhausted", "all_exhausted_without_replacement"
] = "first_exhausted",
):
super().__init__()
self.ex_iterables = ex_iterables
self.stopping_strategy = stopping_strategy

# if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
# if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
self.bool_strategy_func = np.all if (stopping_strategy == "all_exhausted") else np.any
# if sampling without replacement ("all_exhausted_without_replacement"), we stop once all samples of every dataset has been visited exactly once.
self.bool_strategy_func = (
np.all if (stopping_strategy in ("all_exhausted", "all_exhausted_without_replacement")) else np.any
)

@property
def is_typed(self):
Expand Down Expand Up @@ -734,6 +739,9 @@ def _iter_arrow(self):
# if the stopping criteria is met, break the main for loop
if self.bool_strategy_func(is_exhausted):
break
# Skip exhausted iterators if we sample without replacement
if is_exhausted[i] and self.stopping_strategy in ["all_exhausted_without_replacement"]:
continue
# let's pick one example from the iterator at index i
if nexts[i] is None:
nexts[i] = next(iterators[i], False)
Expand All @@ -747,12 +755,13 @@ def _iter_arrow(self):
is_exhausted[i] = True
if self._state_dict:
self._state_dict["is_exhausted"][i] = True
# we reset it in case the stopping crtieria isn't met yet
nexts[i] = None
if self._state_dict:
self._state_dict["ex_iterables"][i] = self.ex_iterables[i]._init_state_dict()
self._state_dict["previous_states"][i] = None
iterators[i] = self.ex_iterables[i].iter_arrow()
# we reset it in case the stopping crtieria isn't met yet and we sample with replacement
if self.stopping_strategy not in ["all_exhausted_without_replacement"]:
nexts[i] = None
if self._state_dict:
self._state_dict["ex_iterables"][i] = self.ex_iterables[i]._init_state_dict()
self._state_dict["previous_states"][i] = None
iterators[i] = self.ex_iterables[i]._iter_arrow()

if result is not False:
yield result
Expand All @@ -777,6 +786,8 @@ def __iter__(self):
if self.bool_strategy_func(is_exhausted):
break
# let's pick one example from the iterator at index i
if is_exhausted[i] and self.stopping_strategy in ["all_exhausted_without_replacement"]:
continue
if nexts[i] is None:
nexts[i] = next(iterators[i], False)
result = nexts[i]
Expand All @@ -790,12 +801,12 @@ def __iter__(self):
if self._state_dict:
self._state_dict["is_exhausted"][i] = True
# we reset it in case the stopping crtieria isn't met yet
nexts[i] = None
if self._state_dict:
self._state_dict["ex_iterables"][i] = self.ex_iterables[i]._init_state_dict()
self._state_dict["previous_states"][i] = None
iterators[i] = iter(self.ex_iterables[i])

if self.stopping_strategy not in ["all_exhausted_without_replacement"]:
nexts[i] = None
if self._state_dict:
self._state_dict["ex_iterables"][i] = self.ex_iterables[i]._init_state_dict()
self._state_dict["previous_states"][i] = None
iterators[i] = iter(self.ex_iterables[i])
if result is not False:
yield result

Expand All @@ -806,16 +817,33 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "CyclingMultiS

@property
def num_shards(self) -> int:
return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables)
return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables) if self.ex_iterables else 0

def shard_data_sources(
self, num_shards: int, index: int, contiguous=True
) -> "CyclingMultiSourcesExamplesIterable":
"""Either keep only the requested shard, or propagate the request to the underlying iterable."""
return CyclingMultiSourcesExamplesIterable(
[iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables],
stopping_strategy=self.stopping_strategy,
)
if num_shards < self.num_shards:
return CyclingMultiSourcesExamplesIterable(
[
iterable.shard_data_sources(num_shards, index, contiguous=contiguous)
for iterable in self.ex_iterables
],
stopping_strategy=self.stopping_strategy,
)
elif index < self.num_shards:
return CyclingMultiSourcesExamplesIterable(
[
iterable.shard_data_sources(self.num_shards, index, contiguous=contiguous)
for iterable in self.ex_iterables
],
stopping_strategy=self.stopping_strategy,
)
else:
return CyclingMultiSourcesExamplesIterable(
[],
stopping_strategy=self.stopping_strategy,
)


class VerticallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable):
Expand Down Expand Up @@ -987,12 +1015,13 @@ def __init__(
ex_iterables: list[_BaseExamplesIterable],
generator: np.random.Generator,
probabilities: Optional[list[float]] = None,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
stopping_strategy: Literal[
"first_exhausted", "all_exhausted", "all_exhausted_without_replacement"
] = "first_exhausted",
):
super().__init__(ex_iterables, stopping_strategy)
self.generator = deepcopy(generator)
self.probabilities = probabilities
# TODO(QL): implement iter_arrow

@property
def is_typed(self):
Expand Down Expand Up @@ -1056,12 +1085,33 @@ def shard_data_sources(
self, num_shards: int, index: int, contiguous=True
) -> "RandomlyCyclingMultiSourcesExamplesIterable":
"""Either keep only the requested shard, or propagate the request to the underlying iterable."""
return RandomlyCyclingMultiSourcesExamplesIterable(
[iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables],
self.generator,
self.probabilities,
self.stopping_strategy,
)
if num_shards < self.num_shards:
return RandomlyCyclingMultiSourcesExamplesIterable(
[
iterable.shard_data_sources(num_shards, index, contiguous=contiguous)
for iterable in self.ex_iterables
],
self.generator,
self.probabilities,
self.stopping_strategy,
)
elif index < self.num_shards:
return RandomlyCyclingMultiSourcesExamplesIterable(
[
iterable.shard_data_sources(self.num_shards, index, contiguous=contiguous)
for iterable in self.ex_iterables
],
self.generator,
self.probabilities,
self.stopping_strategy,
)
else:
return RandomlyCyclingMultiSourcesExamplesIterable(
[],
self.generator,
self.probabilities,
self.stopping_strategy,
)


def _table_output_to_arrow(output) -> pa.Table:
Expand Down Expand Up @@ -4489,7 +4539,9 @@ def _interleave_iterable_datasets(
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
stopping_strategy: Literal[
"first_exhausted", "all_exhausted", "all_exhausted_without_replacement"
] = "first_exhausted",
) -> IterableDataset:
"""
Interleave several iterable datasets (sources) into a single iterable dataset.
Expand Down Expand Up @@ -4535,7 +4587,10 @@ def _interleave_iterable_datasets(
else:
generator = np.random.default_rng(seed)
ex_iterable = RandomlyCyclingMultiSourcesExamplesIterable(
ex_iterables, generator=generator, probabilities=probabilities, stopping_strategy=stopping_strategy
ex_iterables,
generator=generator,
probabilities=probabilities,
stopping_strategy=stopping_strategy,
)
# Set new info - we update the features
# setting the features also ensures to fill missing columns with None
Expand Down