Skip to content

Commit 96ef5ca

Browse files
authored
[RNN-T] bucketing sampler fix (pytorch#460)
* [RNN-T] bucketing sampler fix drop random samples, instead of replacing them with longes sequences * remove samples without repetitions
1 parent 6a4c5eb commit 96ef5ca

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

  • rnn_speech_recognition/pytorch/common/data/dali

rnn_speech_recognition/pytorch/common/data/dali/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ def process_output_files(self, output_files):
8181
epochs = np.reshape(shuffled_buckets, [self.num_epochs, -1])
8282
to_drop = epochs.shape[1] - (epochs.shape[1] // gbs * gbs)
8383
for epoch in epochs:
84-
dropped_idxs = self.rng.choice(epochs.shape[1], to_drop)
84+
dropped_idxs = self.rng.choice(epochs.shape[1], to_drop, replace=False)
8585
if dropped_idxs is not None:
86-
epoch[dropped_idxs] = epoch[-to_drop:]
87-
epochs = epochs[:, :epochs.shape[1] // gbs * gbs]
86+
epoch[dropped_idxs] = -1
87+
epochs = epochs[epochs != -1].reshape(self.num_epochs, -1)
8888
self.dataset_size = epochs.shape[1]
8989

9090
epochs_iters_batch = np.reshape(epochs, [self.num_epochs, -1, gbs])

0 commit comments

Comments
 (0)