@@ -80,7 +80,7 @@ def batch_reader_creator(self,
8080 padding_to = - 1 ,
8181 flatten = False ,
8282 sortagrad = False ,
83- batch_shuffle = False ):
83+ shuffle_method = "batch_shuffle" ):
8484 """
8585 Batch data reader creator for audio data. Return a callable generator
8686 function to produce batches of data.
@@ -104,12 +104,22 @@ def batch_reader_creator(self,
104104 :param sortagrad: If set True, sort the instances by audio duration
105105 in the first epoch for speed up training.
106106 :type sortagrad: bool
107- :param batch_shuffle: If set True, instances are batch-wise shuffled.
108- For more details, please see
109- ``_batch_shuffle.__doc__``.
110- If sortagrad is True, batch_shuffle is disabled
107+ :param shuffle_method: Shuffle method. Options:
108+ '' or None: no shuffle.
109+ 'instance_shuffle': instance-wise shuffle.
110+ 'batch_shuffle': similarly-sized instances are
111+ put into batches, and then
112+ batch-wise shuffle the batches.
113+ For more details, please see
114+ ``_batch_shuffle.__doc__``.
115+ 'batch_shuffle_clipped': 'batch_shuffle' with
116+ head shift and tail
117+ clipping. For more
118+ details, please see
119+ ``_batch_shuffle``.
120+ If sortagrad is True, shuffle is disabled
111121 for the first epoch.
112- :type batch_shuffle: bool
122+ :type shuffle_method: None|str
113123 :return: Batch reader function, producing batches of data when called.
114124 :rtype: callable
115125 """
@@ -123,8 +133,20 @@ def batch_reader():
123133 # sort (by duration) or batch-wise shuffle the manifest
124134 if self ._epoch == 0 and sortagrad :
125135 manifest .sort (key = lambda x : x ["duration" ])
126- elif batch_shuffle :
127- manifest = self ._batch_shuffle (manifest , batch_size )
136+ else :
137+ if shuffle_method == "batch_shuffle" :
138+ manifest = self ._batch_shuffle (
139+ manifest , batch_size , clipped = False )
140+ elif shuffle_method == "batch_shuffle_clipped" :
141+ manifest = self ._batch_shuffle (
142+ manifest , batch_size , clipped = True )
143+ elif shuffle_method == "instance_shuffle" :
144+ self ._rng .shuffle (manifest )
145+ elif not shuffle_method :
146+ pass
147+ else :
148+ raise ValueError ("Unknown shuffle method %s." %
149+ shuffle_method )
128150 # prepare batches
129151 instance_reader = self ._instance_reader_creator (manifest )
130152 batch = []
@@ -218,7 +240,7 @@ def _padding_batch(self, batch, padding_to=-1, flatten=False):
218240 new_batch .append ((padded_audio , text ))
219241 return new_batch
220242
221- def _batch_shuffle (self , manifest , batch_size ):
243+ def _batch_shuffle (self , manifest , batch_size , clipped = False ):
222244 """Put similarly-sized instances into minibatches for better efficiency
223245 and make a batch-wise shuffle.
224246
@@ -233,6 +255,9 @@ def _batch_shuffle(self, manifest, batch_size):
233255 :param batch_size: Batch size. This size is also used for generate
234256 a random number for batch shuffle.
235257 :type batch_size: int
258+ :param clipped: Whether to clip the heading (small shift) and trailing
259+ (incomplete batch) instances.
260+ :type clipped: bool
236261 :return: Batch shuffled mainifest.
237262 :rtype: list
238263 """
@@ -241,7 +266,8 @@ def _batch_shuffle(self, manifest, batch_size):
241266 batch_manifest = zip (* [iter (manifest [shift_len :])] * batch_size )
242267 self ._rng .shuffle (batch_manifest )
243268 batch_manifest = list (sum (batch_manifest , ()))
244- res_len = len (manifest ) - shift_len - len (batch_manifest )
245- batch_manifest .extend (manifest [- res_len :])
246- batch_manifest .extend (manifest [0 :shift_len ])
269+ if not clipped :
270+ res_len = len (manifest ) - shift_len - len (batch_manifest )
271+ batch_manifest .extend (manifest [- res_len :])
272+ batch_manifest .extend (manifest [0 :shift_len ])
247273 return batch_manifest
0 commit comments