1111 FRAMES_PER_SECOND ,
1212 HOP_LENGTH ,
1313 N_FRAMES ,
14+ N_SAMPLES ,
1415 SAMPLE_RATE ,
1516 log_mel_spectrogram ,
1617 pad_or_trim ,
@@ -116,7 +117,9 @@ def transcribe(
116117 if dtype == torch .float32 :
117118 decode_options ["fp16" ] = False
118119
119- mel = log_mel_spectrogram (audio )
120+ # Pad 30-seconds of silence to the input audio, for slicing
121+ mel = log_mel_spectrogram (audio , padding = N_SAMPLES )
122+ content_frames = mel .shape [- 1 ] - N_FRAMES
120123
121124 if decode_options .get ("language" , None ) is None :
122125 if not model .is_multilingual :
@@ -212,14 +215,13 @@ def new_segment(
212215 }
213216
214217 # show the progress bar when verbose is False (if True, transcribed text will be printed)
215- num_frames = mel .shape [- 1 ]
216218 with tqdm .tqdm (
217- total = num_frames , unit = "frames" , disable = verbose is not False
219+ total = content_frames , unit = "frames" , disable = verbose is not False
218220 ) as pbar :
219- while seek < num_frames :
221+ while seek < content_frames :
220222 time_offset = float (seek * HOP_LENGTH / SAMPLE_RATE )
221- mel_segment = mel [:, seek : ]
222- segment_size = min (mel_segment . shape [ - 1 ], N_FRAMES )
223+ mel_segment = mel [:, seek : seek + N_FRAMES ]
224+ segment_size = min (N_FRAMES , content_frames - seek )
223225 segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
224226 mel_segment = pad_or_trim (mel_segment , N_FRAMES ).to (model .device ).to (dtype )
225227
@@ -246,20 +248,18 @@ def new_segment(
246248 current_tokens = []
247249
248250 timestamp_tokens : torch .Tensor = tokens .ge (tokenizer .timestamp_begin )
249- consecutive = torch .where (timestamp_tokens [:- 1 ] & timestamp_tokens [1 :])[
250- 0
251- ].add_ (1 )
252- if (
253- len (consecutive ) > 0
254- ): # if the output contains two consecutive timestamp tokens
255- if ended_with_single_timestamp := timestamp_tokens [- 2 :].tolist () == [
256- False ,
257- True ,
258- ]:
259- consecutive = consecutive .tolist () + [len (tokens )]
251+ single_timestamp_ending = timestamp_tokens [- 2 :].tolist () == [False , True ]
252+
253+ consecutive = torch .where (timestamp_tokens [:- 1 ] & timestamp_tokens [1 :])[0 ]
254+ consecutive .add_ (1 )
255+ if len (consecutive ) > 0 :
256+ # if the output contains two consecutive timestamp tokens
257+ slices = consecutive .tolist ()
258+ if single_timestamp_ending :
259+ slices .append (len (tokens ))
260260
261261 last_slice = 0
262- for current_slice in consecutive :
262+ for current_slice in slices :
263263 sliced_tokens = tokens [last_slice :current_slice ]
264264 start_timestamp_pos = (
265265 sliced_tokens [0 ].item () - tokenizer .timestamp_begin
@@ -278,7 +278,7 @@ def new_segment(
278278 current_tokens .append (sliced_tokens .tolist ())
279279 last_slice = current_slice
280280
281- if ended_with_single_timestamp :
281+ if single_timestamp_ending :
282282 # single timestamp at the end means no speech after the last timestamp.
283283 seek += segment_size
284284 else :
@@ -329,7 +329,7 @@ def new_segment(
329329 word_end_timestamps = [
330330 w ["end" ] for s in current_segments for w in s ["words" ]
331331 ]
332- if len ( consecutive ) > 0 and len (word_end_timestamps ) > 0 :
332+ if not single_timestamp_ending and len (word_end_timestamps ) > 0 :
333333 seek_shift = round (
334334 (word_end_timestamps [- 1 ] - time_offset ) * FRAMES_PER_SECOND
335335 )
@@ -356,7 +356,7 @@ def new_segment(
356356 )
357357
358358 # update progress bar
359- pbar .update (min (num_frames , seek ) - previous_seek )
359+ pbar .update (min (content_frames , seek ) - previous_seek )
360360
361361 return dict (
362362 text = tokenizer .decode (all_tokens [len (initial_prompt_tokens ) :]),
0 commit comments