Skip to content

Commit 0a307ac

Browse files
authored
Merge pull request #101 from xinghai-sun/ds2_improve
Add shuffle type of instance_shuffle and batch_shuffle_clipped.
2 parents 8e80e20 + f545367 commit 0a307ac

File tree

21 files changed

+88
-33
lines changed

21 files changed

+88
-33
lines changed

deep_speech_2/compute_mean_std.py

100755100644
File mode changed.

deep_speech_2/data_utils/__init__.py

100755100644
File mode changed.

deep_speech_2/data_utils/audio.py

100755100644
File mode changed.

deep_speech_2/data_utils/augmentor/__init__.py

100755100644
File mode changed.

deep_speech_2/data_utils/augmentor/augmentation.py

100755100644
File mode changed.

deep_speech_2/data_utils/augmentor/base.py

100755100644
File mode changed.

deep_speech_2/data_utils/augmentor/volume_perturb.py

100755100644
File mode changed.

deep_speech_2/data_utils/data.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def batch_reader_creator(self,
8080
padding_to=-1,
8181
flatten=False,
8282
sortagrad=False,
83-
batch_shuffle=False):
83+
shuffle_method="batch_shuffle"):
8484
"""
8585
Batch data reader creator for audio data. Return a callable generator
8686
function to produce batches of data.
@@ -104,12 +104,22 @@ def batch_reader_creator(self,
104104
:param sortagrad: If set True, sort the instances by audio duration
105105
in the first epoch for speed up training.
106106
:type sortagrad: bool
107-
:param batch_shuffle: If set True, instances are batch-wise shuffled.
108-
For more details, please see
109-
``_batch_shuffle.__doc__``.
110-
If sortagrad is True, batch_shuffle is disabled
107+
:param shuffle_method: Shuffle method. Options:
108+
'' or None: no shuffle.
109+
'instance_shuffle': instance-wise shuffle.
110+
'batch_shuffle': similarly-sized instances are
111+
put into batches, and then
112+
batch-wise shuffle the batches.
113+
For more details, please see
114+
``_batch_shuffle.__doc__``.
115+
'batch_shuffle_clipped': 'batch_shuffle' with
116+
head shift and tail
117+
clipping. For more
118+
details, please see
119+
``_batch_shuffle``.
120+
If sortagrad is True, shuffle is disabled
111121
for the first epoch.
112-
:type batch_shuffle: bool
122+
:type shuffle_method: None|str
113123
:return: Batch reader function, producing batches of data when called.
114124
:rtype: callable
115125
"""
@@ -123,8 +133,20 @@ def batch_reader():
123133
# sort (by duration) or batch-wise shuffle the manifest
124134
if self._epoch == 0 and sortagrad:
125135
manifest.sort(key=lambda x: x["duration"])
126-
elif batch_shuffle:
127-
manifest = self._batch_shuffle(manifest, batch_size)
136+
else:
137+
if shuffle_method == "batch_shuffle":
138+
manifest = self._batch_shuffle(
139+
manifest, batch_size, clipped=False)
140+
elif shuffle_method == "batch_shuffle_clipped":
141+
manifest = self._batch_shuffle(
142+
manifest, batch_size, clipped=True)
143+
elif shuffle_method == "instance_shuffle":
144+
self._rng.shuffle(manifest)
145+
elif not shuffle_method:
146+
pass
147+
else:
148+
raise ValueError("Unknown shuffle method %s." %
149+
shuffle_method)
128150
# prepare batches
129151
instance_reader = self._instance_reader_creator(manifest)
130152
batch = []
@@ -218,7 +240,7 @@ def _padding_batch(self, batch, padding_to=-1, flatten=False):
218240
new_batch.append((padded_audio, text))
219241
return new_batch
220242

221-
def _batch_shuffle(self, manifest, batch_size):
243+
def _batch_shuffle(self, manifest, batch_size, clipped=False):
222244
"""Put similarly-sized instances into minibatches for better efficiency
223245
and make a batch-wise shuffle.
224246
@@ -233,6 +255,9 @@ def _batch_shuffle(self, manifest, batch_size):
233255
:param batch_size: Batch size. This size is also used for generate
234256
a random number for batch shuffle.
235257
:type batch_size: int
258+
:param clipped: Whether to clip the heading (small shift) and trailing
259+
(incomplete batch) instances.
260+
:type clipped: bool
236261
:return: Batch shuffled mainifest.
237262
:rtype: list
238263
"""
@@ -241,7 +266,8 @@ def _batch_shuffle(self, manifest, batch_size):
241266
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
242267
self._rng.shuffle(batch_manifest)
243268
batch_manifest = list(sum(batch_manifest, ()))
244-
res_len = len(manifest) - shift_len - len(batch_manifest)
245-
batch_manifest.extend(manifest[-res_len:])
246-
batch_manifest.extend(manifest[0:shift_len])
269+
if not clipped:
270+
res_len = len(manifest) - shift_len - len(batch_manifest)
271+
batch_manifest.extend(manifest[-res_len:])
272+
batch_manifest.extend(manifest[0:shift_len])
247273
return batch_manifest

deep_speech_2/data_utils/featurizer/__init__.py

100755100644
File mode changed.

deep_speech_2/data_utils/featurizer/audio_featurizer.py

100755100644
File mode changed.

0 commit comments

Comments
 (0)