Skip to content

Commit 919a713

Browse files
authored
attempt to fix the repetition/hallucination issue identified in #1046 (#1052)
* attempt to fix the repetition/hallucination issue identified in #1046 * zero-pad the audio instead of spectrogram * formatting fix * delete debug print
1 parent 38e990d commit 919a713

File tree

2 files changed

+38
-27
lines changed

2 files changed

+38
-27
lines changed

whisper/audio.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from functools import lru_cache
3-
from typing import Union
3+
from typing import Optional, Union
44

55
import ffmpeg
66
import numpy as np
@@ -15,10 +15,8 @@
1515
N_MELS = 80
1616
HOP_LENGTH = 160
1717
CHUNK_LENGTH = 30
18-
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
19-
N_FRAMES = exact_div(
20-
N_SAMPLES, HOP_LENGTH
21-
) # 3000: number of frames in a mel spectrogram input
18+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
19+
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
2220

2321
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
2422
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
@@ -100,7 +98,10 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
10098

10199

102100
def log_mel_spectrogram(
103-
audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
101+
audio: Union[str, np.ndarray, torch.Tensor],
102+
n_mels: int = N_MELS,
103+
padding: int = 0,
104+
device: Optional[Union[str, torch.device]] = None,
104105
):
105106
"""
106107
Compute the log-Mel spectrogram of
@@ -113,6 +114,12 @@ def log_mel_spectrogram(
113114
n_mels: int
114115
The number of Mel-frequency filters, only 80 is supported
115116
117+
padding: int
118+
Number of zero samples to pad to the right
119+
120+
device: Optional[Union[str, torch.device]]
121+
If given, the audio tensor is moved to this device before STFT
122+
116123
Returns
117124
-------
118125
torch.Tensor, shape = (80, n_frames)
@@ -123,6 +130,10 @@ def log_mel_spectrogram(
123130
audio = load_audio(audio)
124131
audio = torch.from_numpy(audio)
125132

133+
if device is not None:
134+
audio = audio.to(device)
135+
if padding > 0:
136+
audio = F.pad(audio, (0, padding))
126137
window = torch.hann_window(N_FFT).to(audio.device)
127138
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
128139
magnitudes = stft[..., :-1].abs() ** 2

whisper/transcribe.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
FRAMES_PER_SECOND,
1212
HOP_LENGTH,
1313
N_FRAMES,
14+
N_SAMPLES,
1415
SAMPLE_RATE,
1516
log_mel_spectrogram,
1617
pad_or_trim,
@@ -116,7 +117,9 @@ def transcribe(
116117
if dtype == torch.float32:
117118
decode_options["fp16"] = False
118119

119-
mel = log_mel_spectrogram(audio)
120+
# Pad 30-seconds of silence to the input audio, for slicing
121+
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
122+
content_frames = mel.shape[-1] - N_FRAMES
120123

121124
if decode_options.get("language", None) is None:
122125
if not model.is_multilingual:
@@ -212,14 +215,13 @@ def new_segment(
212215
}
213216

214217
# show the progress bar when verbose is False (if True, transcribed text will be printed)
215-
num_frames = mel.shape[-1]
216218
with tqdm.tqdm(
217-
total=num_frames, unit="frames", disable=verbose is not False
219+
total=content_frames, unit="frames", disable=verbose is not False
218220
) as pbar:
219-
while seek < num_frames:
221+
while seek < content_frames:
220222
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
221-
mel_segment = mel[:, seek:]
222-
segment_size = min(mel_segment.shape[-1], N_FRAMES)
223+
mel_segment = mel[:, seek : seek + N_FRAMES]
224+
segment_size = min(N_FRAMES, content_frames - seek)
223225
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
224226
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
225227

@@ -246,20 +248,18 @@ def new_segment(
246248
current_tokens = []
247249

248250
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
249-
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
250-
0
251-
].add_(1)
252-
if (
253-
len(consecutive) > 0
254-
): # if the output contains two consecutive timestamp tokens
255-
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
256-
False,
257-
True,
258-
]:
259-
consecutive = consecutive.tolist() + [len(tokens)]
251+
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
252+
253+
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
254+
consecutive.add_(1)
255+
if len(consecutive) > 0:
256+
# if the output contains two consecutive timestamp tokens
257+
slices = consecutive.tolist()
258+
if single_timestamp_ending:
259+
slices.append(len(tokens))
260260

261261
last_slice = 0
262-
for current_slice in consecutive:
262+
for current_slice in slices:
263263
sliced_tokens = tokens[last_slice:current_slice]
264264
start_timestamp_pos = (
265265
sliced_tokens[0].item() - tokenizer.timestamp_begin
@@ -278,7 +278,7 @@ def new_segment(
278278
current_tokens.append(sliced_tokens.tolist())
279279
last_slice = current_slice
280280

281-
if ended_with_single_timestamp:
281+
if single_timestamp_ending:
282282
# single timestamp at the end means no speech after the last timestamp.
283283
seek += segment_size
284284
else:
@@ -329,7 +329,7 @@ def new_segment(
329329
word_end_timestamps = [
330330
w["end"] for s in current_segments for w in s["words"]
331331
]
332-
if len(consecutive) > 0 and len(word_end_timestamps) > 0:
332+
if not single_timestamp_ending and len(word_end_timestamps) > 0:
333333
seek_shift = round(
334334
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
335335
)
@@ -356,7 +356,7 @@ def new_segment(
356356
)
357357

358358
# update progress bar
359-
pbar.update(min(num_frames, seek) - previous_seek)
359+
pbar.update(min(content_frames, seek) - previous_seek)
360360

361361
return dict(
362362
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),

0 commit comments

Comments
 (0)