Skip to content

Commit 45bd319

Browse files
committed
add shuffle
1 parent 00a4a57 commit 45bd319

1 file changed

Lines changed: 85 additions & 12 deletions

File tree

src/datasets/streaming.py

Lines changed: 85 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import copy
12
import importlib
23
import os
4+
from dataclasses import dataclass
35
from functools import partial
46
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
57

68
import fsspec
9+
import numpy as np
710
import pyarrow as pa
811

912
from .arrow_dataset import DatasetInfoMixin
@@ -44,6 +47,50 @@ def __iter__(self):
4447
yield example
4548

4649

50+
@dataclass
51+
class ShuffingConfig:
52+
buffer_size: int
53+
seed: Optional[int] = None
54+
55+
56+
@dataclass
57+
class DatasetFormat:
58+
type: Optional[str] = None
59+
transform: Optional[Callable] = None
60+
61+
62+
class ShufflingBuffer:
63+
def __init__(self, iterable: Iterable, buffer_size: int, seed: Optional[int] = None):
64+
self._iterable = iterable
65+
self._buffer_size = buffer_size
66+
self._seed = seed
67+
self._mem_buffer = []
68+
69+
def _iter_random_indices(self, rng: np.random.Generator, buffer_size: int, random_batch_size=1000):
70+
while True:
71+
yield from rng.integers(0, buffer_size, size=random_batch_size)
72+
73+
def __iter__(self):
74+
buffer_size = self._buffer_size
75+
rng = np.random.default_rng(self._seed)
76+
indices_iterator = self._iter_random_indices(rng, buffer_size)
77+
for x in self._iterable:
78+
if len(self._mem_buffer) == buffer_size:
79+
i = next(indices_iterator)
80+
yield self._mem_buffer[i]
81+
self._mem_buffer[i] = x
82+
else:
83+
self._mem_buffer.append(x)
84+
if len(self._mem_buffer) != buffer_size:
85+
raise ValueError(
86+
"Buffer size is too small. "
87+
"It should be at least bigger than the number of examples. "
88+
f"Got {buffer_size} but expected at least {len(self._mem_buffer)}."
89+
)
90+
for i in rng.shuffle(range(buffer_size)):
91+
yield self._mem_buffer[i]
92+
93+
4794
class IterableDataset(DatasetInfoMixin):
4895
"""A Dataset backed by an iterable."""
4996

@@ -52,15 +99,17 @@ def __init__(
5299
iterable: Iterable,
53100
info: Optional[DatasetInfo] = None,
54101
split: Optional[NamedSplit] = None,
55-
format: Optional[dict] = None,
102+
format: Optional[DatasetFormat] = None,
103+
shuffling: Optional[ShuffingConfig] = None,
56104
):
57105
info = info.copy() if info is not None else DatasetInfo()
58-
format = format if format is not None else {}
106+
format = format if format is not None else DatasetFormat()
59107
DatasetInfoMixin.__init__(self, info=info, split=split)
60108

61109
self._iterable = iterable
62-
self._format_type = format.get("type")
63-
self._transform = format.get("transform")
110+
self._format = format
111+
self._shuffling = shuffling
112+
self._epoch = 0
64113

65114
# Infer features if None
66115

@@ -79,20 +128,28 @@ def __init__(
79128
)
80129

81130
def _head(self, n=5):
82-
return _examples_to_batch([x for x, _ in zip(self, range(n))])
131+
return _examples_to_batch([x for x, _ in zip(self._iter(), range(n))])
83132

84-
def __iter__(self):
85-
for example in self._iterable:
86-
if self._transform is not None:
87-
yield self._transform(example)
133+
def _iter(self, epoch=0, transform: Optional[Callable] = None, shuffling: Optional[ShuffingConfig] = None):
134+
if shuffling:
135+
effective_seed = shuffling.seed + epoch if shuffling.seed is not None else None
136+
iterable = ShufflingBuffer(self._iterable, buffer_size=shuffling.buffer_size, seed=effective_seed)
137+
else:
138+
iterable = self._iterable
139+
for example in iterable:
140+
if transform is not None:
141+
yield transform(example)
88142
else:
89143
yield example
90144

145+
def __iter__(self):
146+
yield from self._iter(epoch=self._epoch, transform=self._format.transform, shuffling=self._shuffling)
147+
91148
def with_format(
92149
self,
93150
type: Optional[str] = None,
94151
transform: Optional[Callable] = None,
95-
):
152+
) -> "IterableDataset":
96153
if type == "torch":
97154
import torch
98155

@@ -104,12 +161,28 @@ class TorchIterableDataset(IterableDataset, torch.utils.data.IterableDataset):
104161
cls = IterableDataset
105162
dataset = cls(
106163
iterable=self._iterable,
107-
info=self._info,
164+
info=copy.deepcopy(self._info),
165+
split=self._split,
166+
format=DatasetFormat(type=type, transform=transform),
167+
shuffling=copy.deepcopy(self._shuffling),
168+
)
169+
return dataset
170+
171+
def shuffle(self, buffer_size, seed=None) -> "IterableDataset":
172+
shuffling = ShuffingConfig(buffer_size=buffer_size, seed=seed)
173+
cls = self.__class__
174+
dataset = cls(
175+
iterable=self._iterable,
176+
info=copy.deepcopy(self._info),
108177
split=self._split,
109-
format={"type": type, "transform": transform},
178+
format=copy.deepcopy(self._format),
179+
shuffling=shuffling,
110180
)
111181
return dataset
112182

183+
def set_epoch(self, epoch: int):
184+
self._epoch = epoch
185+
113186

114187
class IterableDatasetDict(dict):
115188
pass

0 commit comments

Comments
 (0)