diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 4e84c1b2750..9857a334158 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -105,6 +105,7 @@ struct whisper_params { // Voice Activity Detection (VAD) parameters bool vad = false; + bool stable_timestamps = false; std::string vad_model = ""; float vad_threshold = 0.5f; int vad_min_speech_duration_ms = 250; @@ -210,6 +211,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); } // Voice Activity Detection (VAD) else if ( arg == "--vad") { params.vad = true; } + else if ( arg == "--stable-timestamps") { params.stable_timestamps = true; } else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = ARGV_NEXT; } else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(ARGV_NEXT); } else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); } @@ -293,6 +295,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params // Voice Activity Detection (VAD) parameters fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n"); fprintf(stderr, " --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false"); + fprintf(stderr, " --stable-timestamps [%-7s] enable stable timestamps\n", params.stable_timestamps ? "true" : "false"); fprintf(stderr, " -vm FNAME, --vad-model FNAME [%-7s] VAD model path\n", params.vad_model.c_str()); fprintf(stderr, " -vt N, --vad-threshold N [%-7.2f] VAD threshold for speech recognition\n", params.vad_threshold); fprintf(stderr, " -vspd N, --vad-min-speech-duration-ms N [%-7d] VAD min speech duration (0.0-1.0)\n", params.vad_min_speech_duration_ms); @@ -1211,6 +1214,7 @@ int main(int argc, char ** argv) { wparams.suppress_nst = params.suppress_nst; + wparams.stable_timestamps = params.stable_timestamps; wparams.vad = params.vad; wparams.vad_model_path = params.vad_model.c_str(); diff --git a/include/whisper.h b/include/whisper.h index f4cc6bf7abd..642df9bc7ac 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -583,6 +583,10 @@ extern "C" { size_t i_start_rule; float grammar_penalty; + // Stable timestamps - snap word boundaries to speech edges using VAD + // Requires vad_model_path to be set. Forces vad=true, token_timestamps=true, max_initial_ts=0. + bool stable_timestamps; + // Voice Activity Detection (VAD) params bool vad; // Enable VAD const char * vad_model_path; // Path to VAD model diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 095a2791de5..ac9f055ff3a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -106,6 +106,9 @@ endif() add_library(whisper ../include/whisper.h whisper-arch.h + whisper-state.h + whisper-stable.h + whisper-stable.cpp whisper.cpp ) diff --git a/src/whisper-stable.cpp b/src/whisper-stable.cpp new file mode 100644 index 00000000000..6dbf4db6bb2 --- /dev/null +++ b/src/whisper-stable.cpp @@ -0,0 +1,596 @@ +#include "whisper-stable.h" +#include "whisper.h" + +#include +#include +#include + +// --------------------------------------------------------------------------- +// Silence map from VAD probabilities +// --------------------------------------------------------------------------- + +std::vector> whisper_stable_build_silence_map( + const float * vad_probs, + int n_probs, + int n_window, + int sample_rate, + float threshold, + int64_t min_silence_dur_cs) { + + std::vector> silence; + if (!vad_probs || n_probs <= 0) { + return silence; + } + + const double cs_per_frame = (double)n_window * 100.0 / sample_rate; + int64_t region_start = -1; + + for (int i = 0; i < n_probs; i++) { + const bool is_silent = vad_probs[i] < threshold; + + if (is_silent && region_start < 0) { + region_start = (int64_t)(i * cs_per_frame); + } else if (!is_silent && region_start >= 0) { + const int64_t region_end = (int64_t)(i * cs_per_frame); + if (region_end - region_start >= min_silence_dur_cs) { + silence.push_back({region_start, region_end}); + } + region_start = -1; + } + } + + if (region_start >= 0) { + const int64_t region_end = (int64_t)(n_probs * cs_per_frame); + if (region_end - region_start >= min_silence_dur_cs) { + silence.push_back({region_start, region_end}); + } + } + + return silence; +} + +// --------------------------------------------------------------------------- +// Silence map from raw PCM energy — mirrors stable-ts wav2mask +// --------------------------------------------------------------------------- + +std::vector> whisper_stable_build_silence_map_from_pcm( + const float * pcm, + int n_samples, + int sample_rate, + int64_t min_silence_dur_cs) { + + std::vector> silence; + if (!pcm || n_samples <= 0 || sample_rate <= 0) { + return silence; + } + + // Audio token size matches Whisper's resolution: 320 samples @ 16kHz = 20ms + const int samples_per_token = 320; + const int n_tokens = (int)std::round((double)n_samples / samples_per_token) + 1; + if (n_tokens < 2) { + return silence; + } + + // Step 1+2: abs amplitude, find 99.9th percentile (top 0.1% of samples) + const int k = std::max(1, n_samples / 1000); + std::vector abs_vals(n_samples); + for (int i = 0; i < n_samples; ++i) { + abs_vals[i] = std::fabs(pcm[i]); + } + std::nth_element(abs_vals.begin(), abs_vals.begin() + (n_samples - k), abs_vals.end()); + float threshold = abs_vals[n_samples - k]; + if (threshold < 1e-5f) { + // Entirely silent audio — everything is silence + const double cs_total = (double)n_samples * 100.0 / sample_rate; + if ((int64_t)cs_total >= min_silence_dur_cs) { + silence.push_back({0, (int64_t)cs_total}); + } + return silence; + } + + // Step 3: average abs amplitude per token window + std::vector token_energy(n_tokens, 0.0f); + for (int t = 0; t < n_tokens; ++t) { + const int s0 = t * samples_per_token; + const int s1 = std::min(s0 + samples_per_token, n_samples); + if (s0 >= n_samples) break; + float sum = 0.0f; + for (int s = s0; s < s1; ++s) { + sum += std::fabs(pcm[s]); + } + token_energy[t] = sum / (float)(s1 - s0); + } + + // Normalize: divide by min(1.0, threshold * 1.75), clamp to [0, 1] + const float norm_denom = std::min(1.0f, threshold * 1.75f); + for (auto & v : token_energy) { + v = std::min(1.0f, v / norm_denom); + } + + // Step 4: avg-pool with kernel=5, reflection padding + std::vector smoothed(n_tokens, 0.0f); + const int k_half = 2; // kernel 5 + for (int t = 0; t < n_tokens; ++t) { + float sum = 0.0f; + for (int d = -k_half; d <= k_half; ++d) { + int idx = t + d; + if (idx < 0) idx = -idx; + if (idx >= n_tokens) idx = 2 * n_tokens - 2 - idx; + idx = std::max(0, std::min(n_tokens - 1, idx)); + sum += token_energy[idx]; + } + smoothed[t] = sum / 5.0f; + } + + // Step 5: quantize to 20 levels — anything rounding to 0 is silent + // Step 6: merge adjacent silent tokens into regions, filter by min duration + const double cs_per_token = (double)samples_per_token * 100.0 / sample_rate; + int64_t region_start = -1; + + for (int t = 0; t < n_tokens; ++t) { + const bool is_silent = std::roundf(smoothed[t] * 20.0f) == 0.0f; + + if (is_silent && region_start < 0) { + region_start = (int64_t)(t * cs_per_token); + } else if (!is_silent && region_start >= 0) { + const int64_t region_end = (int64_t)(t * cs_per_token); + if (region_end - region_start >= min_silence_dur_cs) { + silence.push_back({region_start, region_end}); + } + region_start = -1; + } + } + + if (region_start >= 0) { + const int64_t region_end = (int64_t)(n_tokens * cs_per_token); + if (region_end - region_start >= min_silence_dur_cs) { + silence.push_back({region_start, region_end}); + } + } + + return silence; +} + +// --------------------------------------------------------------------------- +// Silence overlap utility +// --------------------------------------------------------------------------- + +int64_t whisper_stable_silence_overlap_len( + int64_t t0, + int64_t t1, + const std::vector> & silence_regions_cs) { + if (t1 <= t0 || silence_regions_cs.empty()) { + return 0; + } + + int64_t overlap = 0; + auto it = std::lower_bound( + silence_regions_cs.begin(), silence_regions_cs.end(), t0, + [](const std::pair & r, int64_t t) { + return r.second <= t; + }); + + for (; it != silence_regions_cs.end() && it->first < t1; ++it) { + const int64_t ss = std::max(t0, it->first); + const int64_t se = std::min(t1, it->second); + if (se > ss) { + overlap += (se - ss); + } + } + + return overlap; +} + +// --------------------------------------------------------------------------- +// Timestamp snapping — stable-ts boundary-moving algorithm +// --------------------------------------------------------------------------- + +void whisper_stable_snap_timestamps( + int64_t * words_t0, + int64_t * words_t1, + int n_words, + const int * seg_first_word, + const int * seg_word_count, + int n_segments, + const std::vector> & silence, + int64_t min_word_dur_cs, + int64_t min_snap_silence_dur_cs) { + + if (n_words <= 0 || silence.empty()) { + return; + } + + std::vector is_first(n_words, false); + std::vector is_last(n_words, false); + for (int s = 0; s < n_segments; ++s) { + const int first = seg_first_word[s]; + const int count = seg_word_count[s]; + if (count > 0) { + is_first[first] = true; + is_last[first + count - 1] = true; + } + } + + for (int w = 0; w < n_words; ++w) { + const int64_t t0 = words_t0[w]; + const int64_t t1 = words_t1[w]; + if (t0 >= t1) { + continue; + } + + int64_t new_t0 = t0; + int64_t new_t1 = t1; + bool moved_start = false; + bool moved_end = false; + + auto it = std::lower_bound( + silence.begin(), silence.end(), new_t0, + [](const std::pair & r, int64_t t) { + return r.second <= t; + }); + + for (; it != silence.end() && it->first < new_t1; ++it) { + if (it->second - it->first < min_snap_silence_dur_cs) { + continue; + } + + const int64_t si_s = it->first; + const int64_t si_e = it->second; + + if (si_s <= new_t0 && si_e > new_t0 && si_e <= new_t1) { + // Start is inside silence → move start forward + new_t0 = si_e; + moved_start = true; + } else if (si_s >= new_t0 && si_s < new_t1 && si_e >= new_t1) { + // End is inside silence → move end backward + new_t1 = si_s; + moved_end = true; + break; + } else if (si_s > new_t0 && si_e < new_t1) { + // Silence fully inside word → snap boundary with less overshoot + const int64_t sil_len = si_e - si_s; + const double left_ratio = (double)(si_s - new_t0) / sil_len; + const double right_ratio = (double)(new_t1 - si_e) / sil_len; + + bool snap_start; + if (is_first[w]) { + snap_start = true; + } else if (is_last[w]) { + snap_start = false; + } else { + snap_start = (left_ratio >= right_ratio); + } + + if (snap_start) { + new_t0 = si_e; + moved_start = true; + } else { + new_t1 = si_s; + moved_end = true; + break; + } + } + } + + // Enforce minimum word duration + if (new_t1 - new_t0 < min_word_dur_cs) { + if (moved_start && !moved_end) { + new_t1 = std::min(t1, new_t0 + min_word_dur_cs); + if (new_t1 - new_t0 < min_word_dur_cs) { + new_t0 = std::max(0, new_t1 - min_word_dur_cs); + } + } else if (moved_end && !moved_start) { + new_t0 = std::max(t0, new_t1 - min_word_dur_cs); + if (new_t1 - new_t0 < min_word_dur_cs) { + new_t1 = new_t0 + min_word_dur_cs; + } + } else { + const int64_t span = t1 - t0; + if (span >= min_word_dur_cs) { + const int64_t mid = (t0 + t1) / 2; + new_t0 = mid - min_word_dur_cs / 2; + new_t1 = new_t0 + min_word_dur_cs; + } else { + new_t0 = t0; + new_t1 = t1; + } + } + } + + if (new_t1 <= new_t0) { + continue; + } + + words_t0[w] = new_t0; + words_t1[w] = new_t1; + } +} + +// --------------------------------------------------------------------------- +// Segment-level snapping +// --------------------------------------------------------------------------- + +void whisper_stable_snap_segments( + struct whisper_context * ctx, + std::vector & result_all, + const std::vector> & silence_regions_cs, + int64_t min_word_dur_cs, + int64_t min_snap_silence_dur_cs) { + + if (!ctx || result_all.empty() || silence_regions_cs.empty()) { + return; + } + + struct word_ref { + int64_t * t0 = nullptr; + int64_t * t1 = nullptr; + }; + + std::vector words; + std::vector seg_first_word; + std::vector seg_word_count; + words.reserve(result_all.size() * 8); + seg_first_word.reserve(result_all.size()); + seg_word_count.reserve(result_all.size()); + + const int token_eot = whisper_token_eot(ctx); + int word_idx = 0; + + for (auto & seg : result_all) { + seg_first_word.push_back(word_idx); + int count = 0; + for (auto & tok : seg.tokens) { + if (tok.id >= token_eot) { + continue; + } + words.push_back({&tok.t0, &tok.t1}); + ++count; + ++word_idx; + } + seg_word_count.push_back(count); + } + + const int n_words = (int)words.size(); + const int n_segs = (int)result_all.size(); + if (n_words <= 0) { + return; + } + + // Token timestamps are already in original timeline (offset applied by per-segment VAD decode) + std::vector t0_arr(n_words); + std::vector t1_arr(n_words); + for (int i = 0; i < n_words; ++i) { + t0_arr[i] = words[i].t0 ? *words[i].t0 : 0; + t1_arr[i] = words[i].t1 ? *words[i].t1 : 0; + } + + whisper_stable_snap_timestamps( + t0_arr.data(), t1_arr.data(), n_words, + seg_first_word.data(), seg_word_count.data(), n_segs, + silence_regions_cs, min_word_dur_cs, min_snap_silence_dur_cs); + + // Write back — values are now on original timeline + for (int i = 0; i < n_words; ++i) { + if (words[i].t0) *words[i].t0 = t0_arr[i]; + if (words[i].t1) *words[i].t1 = t1_arr[i]; + } + + // Update segment t0/t1 from first/last valid word + for (int s = 0; s < n_segs; ++s) { + const int first = seg_first_word[s]; + const int count = seg_word_count[s]; + if (count <= 0) { + continue; + } + + int64_t seg_t0 = std::numeric_limits::max(); + int64_t seg_t1 = std::numeric_limits::min(); + + for (int j = 0; j < count; ++j) { + const int wi = first + j; + if (t1_arr[wi] <= t0_arr[wi]) { + continue; + } + seg_t0 = std::min(seg_t0, t0_arr[wi]); + seg_t1 = std::max(seg_t1, t1_arr[wi]); + } + + if (seg_t0 != std::numeric_limits::max() && + seg_t1 != std::numeric_limits::min()) { + result_all[s].t0 = seg_t0; + result_all[s].t1 = seg_t1; + } + } +} + +// --------------------------------------------------------------------------- +// DTW gap padding tokens +// --------------------------------------------------------------------------- + +std::vector whisper_stable_get_gap_tokens(struct whisper_context * ctx) { + static const char * k_gap_text = " ..."; + + std::vector result; + if (!ctx) { + return result; + } + + std::vector gap_tokens(16); + int n_written = whisper_tokenize(ctx, k_gap_text, gap_tokens.data(), (int)gap_tokens.size()); + if (n_written < 0) { + gap_tokens.resize(-n_written); + n_written = whisper_tokenize(ctx, k_gap_text, gap_tokens.data(), (int)gap_tokens.size()); + } + if (n_written <= 0) { + return result; + } + + result.reserve(n_written); + for (int i = 0; i < n_written; ++i) { + result.push_back(gap_tokens[i]); + } + + return result; +} + +// --------------------------------------------------------------------------- +// Dynamic head selection — score heads by monotonicity, keep top-k +// --------------------------------------------------------------------------- + +void whisper_stable_select_heads( + float * data, + int n_tokens, + int n_audio, + int n_heads, + int top_k) { + + if (!data || n_tokens <= 1 || n_audio <= 0 || n_heads <= 0) { + return; + } + + top_k = std::max(1, std::min(top_k, n_heads)); + if (top_k >= n_heads) { + return; + } + + struct head_score { int head; float score; }; + + std::vector scores; + scores.reserve(n_heads); + + const double mean_x = 0.5 * (n_tokens - 1); + double var_x = 0.0; + for (int t = 0; t < n_tokens; ++t) { + const double dx = t - mean_x; + var_x += dx * dx; + } + if (var_x <= 0.0) { + return; + } + + const int head_stride = n_audio * n_tokens; + const int audio_stride = n_tokens; + + for (int h = 0; h < n_heads; ++h) { + const float * head_data = data + h * head_stride; + + std::vector peaks(n_tokens, 0); + for (int t = 0; t < n_tokens; ++t) { + float best = -std::numeric_limits::infinity(); + int best_a = 0; + for (int a = 0; a < n_audio; ++a) { + const float v = head_data[a * audio_stride + t]; + if (v > best) { best = v; best_a = a; } + } + peaks[t] = best_a; + } + + double mean_y = 0.0; + for (int t = 0; t < n_tokens; ++t) mean_y += peaks[t]; + mean_y /= n_tokens; + + double cov = 0.0, var_y = 0.0; + for (int t = 0; t < n_tokens; ++t) { + const double dx = t - mean_x; + const double dy = peaks[t] - mean_y; + cov += dx * dy; + var_y += dy * dy; + } + + float corr = -1.0f; + if (var_y > 0.0) { + corr = (float)(cov / std::sqrt(var_x * var_y)); + } + scores.push_back({h, corr}); + } + + std::sort(scores.begin(), scores.end(), [](const head_score & a, const head_score & b) { + return a.score != b.score ? a.score > b.score : a.head < b.head; + }); + + std::vector keep(n_heads, 0); + for (int i = 0; i < top_k; ++i) keep[scores[i].head] = 1; + + for (int h = 0; h < n_heads; ++h) { + if (keep[h]) continue; + float * head_data = data + h * head_stride; + std::fill(head_data, head_data + head_stride, 0.0f); + } +} + +// --------------------------------------------------------------------------- +// Constrained decoding filter +// --------------------------------------------------------------------------- + +bool whisper_stable_setup_filter( + struct whisper_full_params & params, + const std::vector> & silence_regions_cs, + int64_t total_audio_cs, + struct whisper_stable_ts_filter_ctx * filter_ctx) { + + if (!filter_ctx) return false; + + filter_ctx->timestamp_silence_mask.clear(); + filter_ctx->seek_cs = 0; + filter_ctx->token_step_cs = 2; + filter_ctx->wrapped_callback = params.logits_filter_callback; + filter_ctx->wrapped_user_data = params.logits_filter_callback_user_data; + + if (silence_regions_cs.empty() || total_audio_cs <= 0) return false; + + const int64_t step = filter_ctx->token_step_cs; + const int64_t n_bins = total_audio_cs / step + 2; + filter_ctx->timestamp_silence_mask.assign((size_t)n_bins, 0); + + for (const auto & r : silence_regions_cs) { + const int64_t bin_start = r.first / step; + const int64_t bin_end = (r.second + step - 1) / step; + for (int64_t b = bin_start; b < bin_end && b < n_bins; ++b) { + filter_ctx->timestamp_silence_mask[(size_t)b] = 1; + } + } + + params.logits_filter_callback = whisper_stable_logits_filter_callback; + params.logits_filter_callback_user_data = filter_ctx; + return true; +} + +void whisper_stable_set_filter_seek(void * user_data, int64_t seek_cs) { + if (!user_data) return; + reinterpret_cast(user_data)->seek_cs = seek_cs; +} + +void whisper_stable_logits_filter_callback( + struct whisper_context * ctx, + struct whisper_state * /*state*/, + const whisper_token_data * /*tokens*/, + int /*n_tokens*/, + float * logits, + void * user_data) { + + auto * stable = reinterpret_cast(user_data); + if (!stable || !ctx || !logits) return; + + if (stable->wrapped_callback) { + stable->wrapped_callback(ctx, nullptr, nullptr, 0, logits, stable->wrapped_user_data); + } + + if (stable->timestamp_silence_mask.empty() || stable->token_step_cs <= 0) return; + + const int token_beg = whisper_token_beg(ctx); + const int n_vocab = whisper_n_vocab(ctx); + if (token_beg < 0 || token_beg >= n_vocab) return; + + for (int id = token_beg; id < n_vocab; ++id) { + const int64_t rel_idx = id - token_beg; + const int64_t abs_cs = stable->seek_cs + rel_idx * stable->token_step_cs; + if (abs_cs < 0) continue; + + const int64_t bin = abs_cs / stable->token_step_cs; + if (bin < 0 || (size_t)bin >= stable->timestamp_silence_mask.size()) continue; + + if (stable->timestamp_silence_mask[(size_t)bin]) { + logits[id] = -INFINITY; + } + } +} diff --git a/src/whisper-stable.h b/src/whisper-stable.h new file mode 100644 index 00000000000..0558e7cc1c5 --- /dev/null +++ b/src/whisper-stable.h @@ -0,0 +1,103 @@ +#pragma once + +#include "whisper.h" +#include "whisper-state.h" + +#include +#include +#include + +// Build silence regions from VAD probabilities. +// Returns vector of (start_cs, end_cs) pairs in centiseconds. +std::vector> whisper_stable_build_silence_map( + const float * vad_probs, + int n_probs, + int n_window, + int sample_rate, + float threshold, + int64_t min_silence_dur_cs); + +// Build silence regions from raw PCM energy — mirrors stable-ts wav2mask. +// Uses 320-sample (20ms) token resolution, avg-pool smoothing, and quantization. +// No VAD model required. +std::vector> whisper_stable_build_silence_map_from_pcm( + const float * pcm, + int n_samples, + int sample_rate, + int64_t min_silence_dur_cs); + +// Snap word timestamps away from silence regions (in-place). +// Implements the stable-ts boundary-moving algorithm: +// - start in silence → move start to silence_end +// - end in silence → move end to silence_start +// - silence in word → snap the boundary with less overshoot +// min_word_dur_cs: minimum word duration in centiseconds after snapping +// min_snap_silence_dur_cs: ignore silence regions shorter than this +void whisper_stable_snap_timestamps( + int64_t * words_t0, + int64_t * words_t1, + int n_words, + const int * seg_first_word, + const int * seg_word_count, + int n_segments, + const std::vector> & silence, + int64_t min_word_dur_cs, + int64_t min_snap_silence_dur_cs); + +// Apply word-level snapping to all segments and update segment boundaries. +// Token timestamps must already be in the original audio timeline (offset applied +// by the per-segment VAD decode loop before calling this). +void whisper_stable_snap_segments( + struct whisper_context * ctx, + std::vector & result_all, + const std::vector> & silence_regions_cs, + int64_t min_word_dur_cs, + int64_t min_snap_silence_dur_cs); + +// Tokenize the DTW gap prefix (" ...") used for gap padding. +std::vector whisper_stable_get_gap_tokens(struct whisper_context * ctx); + +// Select top-k monotonic heads in-place on attention weight data, zero out the rest. +// data layout: [n_heads][n_audio][n_tokens] (token-fast contiguous) +void whisper_stable_select_heads( + float * data, + int n_tokens, + int n_audio, + int n_heads, + int top_k); + +// Compute overlap in centiseconds between [t0, t1) and silence regions. +int64_t whisper_stable_silence_overlap_len( + int64_t t0, + int64_t t1, + const std::vector> & silence_regions_cs); + +// Context for constrained timestamp decoding filter. +struct whisper_stable_ts_filter_ctx { + std::vector timestamp_silence_mask; + int64_t seek_cs = 0; + int64_t token_step_cs = 2; // 20ms per timestamp token + whisper_logits_filter_callback wrapped_callback = nullptr; + void * wrapped_user_data = nullptr; +}; + +// Install constrained decoding filter that suppresses timestamp tokens in silence. +// Only effective when processed timeline == original timeline (no VAD stripping). +// Returns true if filter was installed. +bool whisper_stable_setup_filter( + struct whisper_full_params & params, + const std::vector> & silence_regions_cs, + int64_t total_audio_cs, + struct whisper_stable_ts_filter_ctx * filter_ctx); + +// Update the current decode window seek position (centiseconds) for the filter. +void whisper_stable_set_filter_seek(void * user_data, int64_t seek_cs); + +// Logits filter callback that suppresses timestamp tokens mapped to silence bins. +void whisper_stable_logits_filter_callback( + struct whisper_context * ctx, + struct whisper_state * state, + const whisper_token_data * tokens, + int n_tokens, + float * logits, + void * user_data); diff --git a/src/whisper-state.h b/src/whisper-state.h new file mode 100644 index 00000000000..c0f47848662 --- /dev/null +++ b/src/whisper-state.h @@ -0,0 +1,23 @@ +#pragma once + +#include "whisper.h" + +#include +#include +#include + +// Internal segment representation used inside whisper.cpp implementation. +// Kept in a private header so helper units (e.g. whisper-stable.cpp) can work +// with result segments without exposing this type in the public API. +struct whisper_segment { + int64_t t0; + int64_t t1; + + std::string text; + float no_speech_prob; + + std::vector tokens; + + bool speaker_turn_next; +}; + diff --git a/src/whisper.cpp b/src/whisper.cpp index 796bccfb45d..7abdeb82eb4 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -1,5 +1,7 @@ #include "whisper.h" #include "whisper-arch.h" +#include "whisper-state.h" +#include "whisper-stable.h" #include "ggml.h" #include "ggml-cpp.h" @@ -457,18 +459,6 @@ struct whisper_vocab { } }; -struct whisper_segment { - int64_t t0; - int64_t t1; - - std::string text; - float no_speech_prob; - - std::vector tokens; - - bool speaker_turn_next; -}; - struct whisper_batch { int32_t n_tokens; @@ -826,11 +816,6 @@ struct whisper_aheads_masks { ggml_backend_buffer_t buffer = nullptr; }; -struct vad_time_mapping { - int64_t processed_time; // Time in processed (VAD) audio - int64_t original_time; // Corresponding time in original audio -}; - struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; @@ -921,17 +906,6 @@ struct whisper_state { int32_t exp_n_audio_ctx = 0; // 0 - use default whisper_vad_context * vad_context = nullptr; - - struct vad_segment_info { - int64_t orig_start; - int64_t orig_end; - int64_t vad_start; - int64_t vad_end; - }; - std::vector vad_segments; - bool has_vad_segments = false; - - std::vector vad_mapping_table; }; struct whisper_context { @@ -5985,6 +5959,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.i_start_rule =*/ 0, /*.grammar_penalty =*/ 100.0f, + /*.stable_timestamps =*/ false, + /*.vad =*/ false, /*.vad_model_path =*/ nullptr, @@ -6618,188 +6594,92 @@ static void whisper_sequence_score( } } -static bool whisper_vad( +// Decode each VAD speech segment independently and accumulate results into +// state->result_all with timestamps already offset to the original audio timeline. +// This avoids cross-silence segments and removes the need for a mapping table. +static int whisper_full_vad_segments( struct whisper_context * ctx, struct whisper_state * state, struct whisper_full_params params, const float * samples, - int n_samples, - std::vector & filtered_samples) { - WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__); - int filtered_n_samples = 0; - - // Clear any existing mapping table - state->vad_mapping_table.clear(); - state->has_vad_segments = false; - + int n_samples) { + // Initialize VAD context if needed if (state->vad_context == nullptr) { struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params(); struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params); if (vctx == nullptr) { WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__); - return false; + return -1; } state->vad_context = vctx; } - auto vctx = state->vad_context; - const whisper_vad_params & vad_params = params.vad_params; + whisper_vad_segments * vad_segs = whisper_vad_segments_from_samples( + state->vad_context, params.vad_params, samples, n_samples); - whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples); + if (!vad_segs) { + WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__); + return -1; + } - if (!vad_segments) { - return false; + if (vad_segs->data.empty()) { + whisper_vad_free_segments(vad_segs); + state->result_all.clear(); + return 0; } - if (vad_segments->data.size() > 0) { - state->has_vad_segments = true; - ctx->state->vad_segments.clear(); - ctx->state->vad_segments.reserve(vad_segments->data.size()); + WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segs->data.size()); - // Initialize the time mapping table - state->vad_mapping_table.clear(); - state->vad_mapping_table.reserve(vad_segments->data.size() * 4); + const int overlap_samples = (int)(params.vad_params.samples_overlap * WHISPER_SAMPLE_RATE); - WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size()); - float overlap_seconds = vad_params.samples_overlap; - int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE; + // Per-segment params: always start at offset 0, no duration limit + auto seg_params = params; + seg_params.offset_ms = 0; + seg_params.duration_ms = 0; - for (int i = 0; i < (int)vad_segments->data.size(); i++) { - int segment_start_samples = cs_to_samples(vad_segments->data[i].start); - int segment_end_samples = cs_to_samples(vad_segments->data[i].end); + std::vector all_results; - if (i < (int)vad_segments->data.size() - 1) { - segment_end_samples += overlap_samples; - } - segment_end_samples = std::min(segment_end_samples, n_samples - 1); - filtered_n_samples += (segment_end_samples - segment_start_samples); + for (int i = 0; i < (int)vad_segs->data.size(); i++) { + int seg_start = cs_to_samples(vad_segs->data[i].start); + int seg_end = cs_to_samples(vad_segs->data[i].end); - WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n", - __func__, i, vad_segments->data[i].start/100.0, - (vad_segments->data[i].end/100.0 + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0)), - (vad_segments->data[i].end - vad_segments->data[i].start)/100.0 + - (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0)); + // Add overlap to all but last segment (gives context for boundary words) + if (i < (int)vad_segs->data.size() - 1) { + seg_end += overlap_samples; } + seg_start = std::max(0, std::min(seg_start, n_samples - 1)); + seg_end = std::min(seg_end, n_samples); + const int seg_len = seg_end - seg_start; + if (seg_len <= 0) continue; - int silence_samples = 0.1 * WHISPER_SAMPLE_RATE; - int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0; - int total_samples_needed = filtered_n_samples + total_silence_samples; + WHISPER_LOG_INFO("%s: decoding segment %d: %.2f - %.2f s\n", + __func__, i, vad_segs->data[i].start / 100.0, vad_segs->data[i].end / 100.0); - WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n", - __func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE); - - try { - filtered_samples.resize(total_samples_needed); - } catch (const std::bad_alloc & /* e */) { - WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__); - whisper_vad_free_segments(vad_segments); - return false; + const int ret = whisper_full_with_state(ctx, state, seg_params, samples + seg_start, seg_len); + if (ret != 0) { + whisper_vad_free_segments(vad_segs); + return ret; } - int offset = 0; - for (int i = 0; i < (int)vad_segments->data.size(); i++) { - int segment_start_samples = cs_to_samples(vad_segments->data[i].start); - int segment_end_samples = cs_to_samples(vad_segments->data[i].end); - - if (i < (int)vad_segments->data.size() - 1) { - segment_end_samples += overlap_samples; - } - - segment_start_samples = std::min(segment_start_samples, n_samples - 1); - segment_end_samples = std::min(segment_end_samples, n_samples - 1); - int segment_length = segment_end_samples - segment_start_samples; - if (segment_length > 0) { - whisper_state::vad_segment_info segment; - - segment.orig_start = vad_segments->data[i].start; - segment.orig_end = vad_segments->data[i].end; - - segment.vad_start = samples_to_cs(offset); - segment.vad_end = samples_to_cs(offset + segment_length); - - // Add segment boundaries to mapping table - vad_time_mapping start_mapping = {segment.vad_start, segment.orig_start}; - vad_time_mapping end_mapping = {segment.vad_end, segment.orig_end}; - - state->vad_mapping_table.push_back(start_mapping); - state->vad_mapping_table.push_back(end_mapping); - - // Add intermediate points for longer segments to improve interpolation accuracy - const int64_t min_segment_length = 100; // 1 second - const int64_t point_interval = 20; // Add a point every 200ms - - if (segment.vad_end - segment.vad_start > min_segment_length) { - int64_t segment_duration = segment.vad_end - segment.vad_start; - int num_points = (int)(segment_duration / point_interval) - 1; - - for (int j = 1; j <= num_points; j++) { - int64_t vad_time = segment.vad_start + j * point_interval; - - if (vad_time >= segment.vad_end) continue; - - int64_t vad_elapsed = vad_time - segment.vad_start; - int64_t vad_total = segment.vad_end - segment.vad_start; - int64_t orig_total = segment.orig_end - segment.orig_start; - int64_t orig_time = segment.orig_start + (vad_elapsed * orig_total) / vad_total; - - vad_time_mapping intermediate_mapping = {vad_time, orig_time}; - state->vad_mapping_table.push_back(intermediate_mapping); - } - } - - WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n", - __func__, segment.orig_start/100.0, segment.orig_end/100.0, segment.vad_start/100.0, segment.vad_end/100.0); - ctx->state->vad_segments.push_back(segment); - - // Copy this speech segment - memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float)); - offset += segment_length; - - // Add silence after this segment (except after the last segment) - if (i < (int)vad_segments->data.size() - 1) { - // Calculate the start and end time of the silence gap in processed audio - int64_t silence_start_vad = samples_to_cs(offset); - int64_t silence_end_vad = samples_to_cs(offset + silence_samples); - // Calculate the corresponding original times - int64_t orig_silence_start = segment.orig_end; - int64_t orig_silence_end = vad_segments->data[i+1].start; - - // Add mapping points for silence boundaries - state->vad_mapping_table.push_back({silence_start_vad, orig_silence_start}); - state->vad_mapping_table.push_back({silence_end_vad, orig_silence_end}); - - // Fill with zeros (silence) - memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float)); - offset += silence_samples; - } + // Shift all timestamps to original timeline (segment starts at orig_start in original audio) + const int64_t offset_cs = vad_segs->data[i].start; + for (auto & seg : state->result_all) { + seg.t0 += offset_cs; + seg.t1 += offset_cs; + for (auto & tok : seg.tokens) { + if (tok.t0 >= 0) tok.t0 += offset_cs; + if (tok.t1 >= 0) tok.t1 += offset_cs; } } - // Sort the mapping table by processed time - std::sort(state->vad_mapping_table.begin(), state->vad_mapping_table.end(), - [](const vad_time_mapping& a, const vad_time_mapping& b) { - return a.processed_time < b.processed_time; - }); - - // Remove any duplicate processed times to ensure monotonicity which is - // needed for binary search and interpolation later. - if (!state->vad_mapping_table.empty()) { - auto last = std::unique(state->vad_mapping_table.begin(), state->vad_mapping_table.end(), - [](const vad_time_mapping& a, const vad_time_mapping& b) { - return a.processed_time == b.processed_time; - }); - state->vad_mapping_table.erase(last, state->vad_mapping_table.end()); - } - - WHISPER_LOG_INFO("%s: Created time mapping table with %d points\n", __func__, (int)state->vad_mapping_table.size()); - - filtered_n_samples = offset; - WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n", - __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples)); + all_results.insert(all_results.end(), + std::make_move_iterator(state->result_all.begin()), + std::make_move_iterator(state->result_all.end())); } - whisper_vad_free_segments(vad_segments); - return true; + whisper_vad_free_segments(vad_segs); + state->result_all = std::move(all_results); + return 0; } int whisper_full_with_state( @@ -7022,6 +6902,10 @@ int whisper_full_with_state( break; } + if (params.stable_timestamps && params.logits_filter_callback == whisper_stable_logits_filter_callback) { + whisper_stable_set_filter_seek(params.logits_filter_callback_user_data, seek); + } + if (params.encoder_begin_callback) { if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) { WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__); @@ -7759,21 +7643,49 @@ int whisper_full( const float * samples, int n_samples) { - std::vector vad_samples; - if (params.vad) { - WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__); - if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) { - WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__); - return -1; - } - if (vad_samples.empty()) { - ctx->state->result_all.clear(); - return 0; + // stable_timestamps forces word-level timestamps and removes the 1s initial constraint + if (params.stable_timestamps) { + params.token_timestamps = true; + params.max_initial_ts = 0.0f; + } + + // Build PCM-energy silence map from original audio. Install constrained decoding + // filter only when not using VAD (with per-segment VAD, cross-silence decoding can't happen). + std::vector> stable_silence; + whisper_stable_ts_filter_ctx stable_filter_ctx; + if (params.stable_timestamps) { + stable_silence = whisper_stable_build_silence_map_from_pcm( + samples, n_samples, WHISPER_SAMPLE_RATE, /*min_silence_dur_cs=*/10); + + if (!stable_silence.empty() && !params.vad) { + const int64_t total_cs = (int64_t)((double)n_samples * 100.0 / WHISPER_SAMPLE_RATE + 0.5); + whisper_stable_setup_filter(params, stable_silence, total_cs, &stable_filter_ctx); } - samples = vad_samples.data(); - n_samples = vad_samples.size(); } - return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples); + + auto * state = ctx->state; + int ret; + if (params.vad) { + ret = whisper_full_vad_segments(ctx, state, params, samples, n_samples); + } else { + ret = whisper_full_with_state(ctx, state, params, samples, n_samples); + } + if (ret != 0) { + return ret; + } + + // Post-process: snap word timestamps away from silence regions. + // Timestamps are already in original timeline (VAD offset applied per-segment above). + if (params.stable_timestamps && !stable_silence.empty()) { + whisper_stable_snap_segments( + ctx, + state->result_all, + stable_silence, + /*min_word_dur_cs=*/5, + /*min_snap_silence_dur_cs=*/10); + } + + return 0; } int whisper_full_parallel( @@ -7787,19 +7699,11 @@ int whisper_full_parallel( return whisper_full(ctx, params, samples, n_samples); } - std::vector vad_samples; + // VAD uses per-segment sequential decoding; delegate to whisper_full if (params.vad) { - WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__); - if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) { - WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__); - return -1; - } - if (vad_samples.empty()) { - return 0; - } - samples = vad_samples.data(); - n_samples = vad_samples.size(); + return whisper_full(ctx, params, samples, n_samples); } + int ret = 0; // prepare separate states for each thread @@ -7922,84 +7826,13 @@ int whisper_full_lang_id(struct whisper_context * ctx) { return ctx->state->lang_id; } -static int64_t map_processed_to_original_time(int64_t processed_time, const std::vector & mapping_table) { - if (mapping_table.empty()) { - return processed_time; - } - - if (processed_time <= mapping_table.front().processed_time) { - return mapping_table.front().original_time; // Before first mapping point - } - - if (processed_time >= mapping_table.back().processed_time) { - return mapping_table.back().original_time; // After last mapping point - } - - // Binary search over the time map that finds the first entry that has a - // processed time greater than or equal to the current processed time. - auto upper = std::lower_bound(mapping_table.begin(), mapping_table.end(), processed_time, - [](const vad_time_mapping & entry, int64_t time) { - return entry.processed_time < time; - } - ); - - // If exact match found - if (upper->processed_time == processed_time) { - return upper->original_time; - } - - // Need to interpolate between two points - auto lower = upper - 1; - - int64_t processed_diff = upper->processed_time - lower->processed_time; - int64_t original_diff = upper->original_time - lower->original_time; - int64_t offset = processed_time - lower->processed_time; - - if (processed_diff == 0) { - return lower->original_time; - } - - // Perform linear interpolation - return lower->original_time + (offset * original_diff) / processed_diff; -} - -// Function to get the starting timestamp of a segment +// Timestamps are stored in original audio timeline (VAD offset applied at decode time). int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) { - // If VAD wasn't used, return the original timestamp - if (!state->has_vad_segments || state->vad_mapping_table.empty()) { - return state->result_all[i_segment].t0; - } - - // Get the processed timestamp - int64_t t0 = state->result_all[i_segment].t0; - - // Map to original time using the mapping table - return map_processed_to_original_time(t0, state->vad_mapping_table); + return state->result_all[i_segment].t0; } -// Function to get the ending timestamp of a segment int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) { - // If VAD wasn't used, return the original timestamp - if (!state->has_vad_segments || state->vad_mapping_table.empty()) { - return state->result_all[i_segment].t1; - } - - // Get the processed timestamp - int64_t t1 = state->result_all[i_segment].t1; - - // Map to original time using the mapping table - int64_t orig_t1 = map_processed_to_original_time(t1, state->vad_mapping_table); - - // Get the corresponding t0 for this segment - int64_t orig_t0 = whisper_full_get_segment_t0_from_state(state, i_segment); - - // Ensure minimum duration to prevent zero-length segments - const int64_t min_duration = 10; // 10ms minimum - if (orig_t1 - orig_t0 < min_duration) { - orig_t1 = orig_t0 + min_duration; - } - - return orig_t1; + return state->result_all[i_segment].t1; } @@ -8848,6 +8681,13 @@ static void whisper_exp_compute_token_level_timestamps_dtw( } const size_t sot_sequence_length = tokens.size(); tokens.push_back(whisper_token_not(ctx)); + + size_t gap_token_count = 0; + if (params.stable_timestamps) { + const auto gap_tokens = whisper_stable_get_gap_tokens(ctx); + gap_token_count = gap_tokens.size(); + tokens.insert(tokens.end(), gap_tokens.begin(), gap_tokens.end()); + } for (size_t i = i_segment; i < i_segment + n_segments; ++i) { auto & segment = state->result_all[i]; for (auto &t: segment.tokens) { @@ -8898,6 +8738,10 @@ static void whisper_exp_compute_token_level_timestamps_dtw( } } + if (params.stable_timestamps) { + whisper_stable_select_heads((float *) w->data, n_tokens, n_audio_tokens, n_heads, /*top_k=*/6); + } + // Normalize - in original OpenAI code, this is done over dim=-2. In this case, // we already permuted N_TOKENS dimension to columns on last loop, becase ggml_norm // operates over columns. Afterwards, permute to a shape that facilitates mean @@ -8935,6 +8779,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw( // Place timestamps on segments int32_t last_v = 0; + const int32_t first_text_alignment_row = 1 + (int32_t) gap_token_count; auto seg_i = state->result_all.begin() + i_segment; auto tok_i = seg_i->tokens.begin(); for (int i = 0; i < alignment->ne[1]; ++i) { @@ -8944,6 +8789,11 @@ static void whisper_exp_compute_token_level_timestamps_dtw( int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio last_v = v; + // Rows before first_text_alignment_row are [no_timestamps] and optional gap padding. + if (v < first_text_alignment_row) { + continue; + } + // Skip non-text tokens while (!(tok_i->id < whisper_token_eot(ctx))) { ++tok_i;