1+ import copy
12import importlib
23import os
4+ from dataclasses import dataclass
35from functools import partial
46from typing import Any , Callable , Dict , Iterable , List , Optional , Union
57
68import fsspec
9+ import numpy as np
710import pyarrow as pa
811
912from .arrow_dataset import DatasetInfoMixin
@@ -44,6 +47,50 @@ def __iter__(self):
4447 yield example
4548
4649
50+ @dataclass
51+ class ShuffingConfig :
52+ buffer_size : int
53+ seed : Optional [int ] = None
54+
55+
56+ @dataclass
57+ class DatasetFormat :
58+ type : Optional [str ] = None
59+ transform : Optional [Callable ] = None
60+
61+
62+ class ShufflingBuffer :
63+ def __init__ (self , iterable : Iterable , buffer_size : int , seed : Optional [int ] = None ):
64+ self ._iterable = iterable
65+ self ._buffer_size = buffer_size
66+ self ._seed = seed
67+ self ._mem_buffer = []
68+
69+ def _iter_random_indices (self , rng : np .random .Generator , buffer_size : int , random_batch_size = 1000 ):
70+ while True :
71+ yield from rng .integers (0 , buffer_size , size = random_batch_size )
72+
73+ def __iter__ (self ):
74+ buffer_size = self ._buffer_size
75+ rng = np .random .default_rng (self ._seed )
76+ indices_iterator = self ._iter_random_indices (rng , buffer_size )
77+ for x in self ._iterable :
78+ if len (self ._mem_buffer ) == buffer_size :
79+ i = next (indices_iterator )
80+ yield self ._mem_buffer [i ]
81+ self ._mem_buffer [i ] = x
82+ else :
83+ self ._mem_buffer .append (x )
84+ if len (self ._mem_buffer ) != buffer_size :
85+ raise ValueError (
86+ "Buffer size is too small. "
87+ "It should be at least bigger than the number of examples. "
88+ f"Got { buffer_size } but expected at least { len (self ._mem_buffer )} ."
89+ )
90+ for i in rng .shuffle (range (buffer_size )):
91+ yield self ._mem_buffer [i ]
92+
93+
4794class IterableDataset (DatasetInfoMixin ):
4895 """A Dataset backed by an iterable."""
4996
@@ -52,15 +99,17 @@ def __init__(
5299 iterable : Iterable ,
53100 info : Optional [DatasetInfo ] = None ,
54101 split : Optional [NamedSplit ] = None ,
55- format : Optional [dict ] = None ,
102+ format : Optional [DatasetFormat ] = None ,
103+ shuffling : Optional [ShuffingConfig ] = None ,
56104 ):
57105 info = info .copy () if info is not None else DatasetInfo ()
58- format = format if format is not None else {}
106+ format = format if format is not None else DatasetFormat ()
59107 DatasetInfoMixin .__init__ (self , info = info , split = split )
60108
61109 self ._iterable = iterable
62- self ._format_type = format .get ("type" )
63- self ._transform = format .get ("transform" )
110+ self ._format = format
111+ self ._shuffling = shuffling
112+ self ._epoch = 0
64113
65114 # Infer features if None
66115
@@ -79,20 +128,28 @@ def __init__(
79128 )
80129
81130 def _head (self , n = 5 ):
82- return _examples_to_batch ([x for x , _ in zip (self , range (n ))])
131+ return _examples_to_batch ([x for x , _ in zip (self . _iter () , range (n ))])
83132
84- def __iter__ (self ):
85- for example in self ._iterable :
86- if self ._transform is not None :
87- yield self ._transform (example )
133+ def _iter (self , epoch = 0 , transform : Optional [Callable ] = None , shuffling : Optional [ShuffingConfig ] = None ):
134+ if shuffling :
135+ effective_seed = shuffling .seed + epoch if shuffling .seed is not None else None
136+ iterable = ShufflingBuffer (self ._iterable , buffer_size = shuffling .buffer_size , seed = effective_seed )
137+ else :
138+ iterable = self ._iterable
139+ for example in iterable :
140+ if transform is not None :
141+ yield transform (example )
88142 else :
89143 yield example
90144
145+ def __iter__ (self ):
146+ yield from self ._iter (epoch = self ._epoch , transform = self ._format .transform , shuffling = self ._shuffling )
147+
91148 def with_format (
92149 self ,
93150 type : Optional [str ] = None ,
94151 transform : Optional [Callable ] = None ,
95- ):
152+ ) -> "IterableDataset" :
96153 if type == "torch" :
97154 import torch
98155
@@ -104,12 +161,28 @@ class TorchIterableDataset(IterableDataset, torch.utils.data.IterableDataset):
104161 cls = IterableDataset
105162 dataset = cls (
106163 iterable = self ._iterable ,
107- info = self ._info ,
164+ info = copy .deepcopy (self ._info ),
165+ split = self ._split ,
166+ format = DatasetFormat (type = type , transform = transform ),
167+ shuffling = copy .deepcopy (self ._shuffling ),
168+ )
169+ return dataset
170+
171+ def shuffle (self , buffer_size , seed = None ) -> "IterableDataset" :
172+ shuffling = ShuffingConfig (buffer_size = buffer_size , seed = seed )
173+ cls = self .__class__
174+ dataset = cls (
175+ iterable = self ._iterable ,
176+ info = copy .deepcopy (self ._info ),
108177 split = self ._split ,
109- format = {"type" : type , "transform" : transform },
178+ format = copy .deepcopy (self ._format ),
179+ shuffling = shuffling ,
110180 )
111181 return dataset
112182
183+ def set_epoch (self , epoch : int ):
184+ self ._epoch = epoch
185+
113186
114187class IterableDatasetDict (dict ):
115188 pass
0 commit comments