Skip to content

tomer9080/WhisperRT-Streaming

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

WhisperRT - Causal Whisper Streaming Model

WhisperRT Streaming is a fine tuned version of OpenAI Whisper, which can handle causal data and perform real-time transcription.

arXiv Demo on Hugging Face

📄 Paper

For more details, see our paper.

🔧 Setup

We used Python 3.9.16, PyTorch 2.6.0, and PyTorch-Lightning 2.5.0 to train and test our models. Portions of this code are adapted from OpenAI's Whisper.

To set up the project environment using conda, follow these steps:

  1. Clone the repository
    git clone https://github.com/tomer9080/WhisperRT-Streaming
    cd WhisperRT-Streaming

💡 Make sure you have Miniconda or Anaconda installed before proceeding.

  1. Create the conda environment

    conda env create -f environment.yml
  2. Activate The environment

    conda activate whisper_rt
  3. Install the appropriate PyTorch version
    Depending on your hardware and CUDA version, install PyTorch by following the instructions at https://pytorch.org/get-started/locally.
    This project was tested with CUDA 12.4, but it should also work with compatible earlier or later versions. You can use the next command to install torch as it was used during the process of building this project:

    pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124

After installing all of the dependencies, you can try to run inference.

🤖 Available Models

We fine-tuned three different sizes of Whisper, all support english only transcription. A large-v2 that was fine tuned on multilingual data is available, and supports English, French, Spanish, German and Portuguese with chunk size of 300 miliseconds.

  • RCS Models (Random Chunk Size) RCS denotes checkpoints trained using a mask with random chunk sizes ranging from 0.1 to 1.0 seconds. These models are optimized for transcription tasks using any chunk size within that specific interval.

    Note: These models were initialized with a base training chunk size of 600ms.

Size Chunk Size [msec] Multilingual
base 40, 100, 200, 300, RCS N/A
small 40, 100, 200, 300, 1000, RCS N/A
large-v2 40, 100, 200, 300, 1000, RCS 300

🎤 Running Inference

To run inference, download the repo content, and run from the repository root accroding to following sections.

Note: The models are hosted on the Hugging Face Hub, which requires an access token.
Make sure you are logged in with your token to access the models.

How to Apply Your Hugging Face 🤗 Access Token

  1. Create a Hugging Face account (if you don’t have one) at https://huggingface.co/join.

  2. Generate an access token:

    • Go to your Hugging Face account settings: https://huggingface.co/settings/tokens
    • Click on "New token", give it a name, select the appropriate scopes (usually read is enough), and create it.
  3. Login using the Hugging Face CLI:
    Install the CLI if you don’t have it:

    pip install huggingface_hub

    Then login:

    huggingface-cli login

    Paste your token when prompted.

🖥️ CLI Usage

The transcription model is easily activated using the next command:

# Using a local microphone for streaming transcription, dumping the recording to out.wav
python transcribe.py \
--output_filename out.wav \
--channels 2 \
--model small \ 
--chunk_size 300 \
--device cuda \
--beam_size 5 \
--ca_kv_cache \

A simulation of a stream on a wav file is also available:

# Simulating a stream on a wav file
python transcribe.py \
--model small \
--chunk_size 300 \
--device cuda \
--beam_size 5 \
--ca_kv_cache \
--wav_file /path/to/audio.wav \
--simulate_stream \
--use_latency

🐍 Python Usage

If you prefer using python, a code sinppet utilizing a microphone or a wav file is provided below:

import torch
import whisper_rt

model_size = "small" # model size
chunk_size = 300 # chunk size in milliseconds
multilingual = False # currently on large-v2_300msec supports other languages than english.
device = "cuda" if torch.cuda.is_available() else "cpu"

# Loading a fixed chunk size model
model = whisper_rt.load_streaming_model(name=model_size,
                                        gran=chunk_size,
                                        multilingual=multilingual,
                                        device=device)


# using a local microphone recording 
texts_microphone = model.transcribe(output_filename="/path/to/dump/file.wav",
                                    channels=2,
                                    beam_size=5,
                                    ca_kv_cache=True)

# Simulating on a wav file
texts_wav_simulation = model.transcribe(simulate_stream=True,
                                        wav_file="/path/to/file/you/want/to/transcribe.wav",
                                        beam_size=5,
                                        ca_kv_cache=True)

# loading an RCS model, no need in gran field
model_rcs = whisper_rt.load_streaming_model(name=model_size,
                                            varying_chunk_size=True,
                                            multilingual=multilingual,
                                            device=device)

# Simulating on a wav file using an RCS model.
# Note: ms_gran and extra_initial_blocks field must be specified when using an RCS model!
texts_wav_simulation_rcs = model.transcribe(simulate_stream=True,
                                            wav_file="/path/to/file/you/want/to/transcribe.wav",
                                            beam_size=5,
                                            ca_kv_cache=True,
                                            ms_gran=240,
                                            extra_initial_blocks=2)

🦾 Training

In order to train using LoRA, you can use our existing code. Make sure all the requirements are installed.

📂 Dataset Structure

Before starting model training using the command-line interface provided below, you must first configure your dataset dictionary file located at training_code/ds_dict.py.

This file defines a Python dictionary named ds_paths, where you should specify paths to the train, val, and test partitions of your dataset. Each partition should be a CSV file with the following three columns:

  1. wav_path — Path to the WAV audio file.
  2. tg_path — Path to the corresponding .TextGrid file containing forced alignment.
  3. raw_text — Ground truth transcription.

Note: The dictionary key (i.e., the name of the dataset) will be used by the training script to identify and load the dataset correctly.

You can find an example entry in training_code/ds_dict.py.

Note: We used Montreal Forced Aligner (MFA) to force-align our dataset.

To run the same force-alignment process as described in the paper, use:

mfa align --clean /dataset/root/path english_us_arpa english_us_arpa /aligned_dataset/root/path

For more details on how to run using mfa command, visit MFA site.

🖥️ CLI Interface

Below is an example of training a model of size base, using a fixed chunk size.

python training_code/train.py \
--lora \
--streaming_train \
--simulate_stream \
--dataset LIBRI-960-ALIGNED \
--name example_training_base_model \
--size base \
--batch_size 32 \
--epochs 10 \
--learning_rate 1e-5 \
--rank 32 \
--gran 15 \
--extra_gran_blocks 1 \
--streaming_fraction 0.25 \
--top_k 5 \

An example of a training script of large-v2 model with random chunk size mask:

python training_code/train.py \
--lora \
--streaming_train \
--simulate_stream \
--dataset LIBRI-960-ALIGNED \
--name training-name \
--size base \
--batch_size 4 \
--rank 4 \
--learning_rate 1e-5 \
--epochs 3 \
--random_masking \
--num_slices 30

For more options and training configurations, run:

python training_code/train.py --help

📜 License

This repository uses a dual license:

MIT License
Portions derived from OpenAI Whisper are licensed under the MIT License.

CC BY-NC 4.0 License
All other original code in this repository is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).

See the LICENSE file for full details.

About

Causal streaming adaptation of OpenAI Whisper for real-time transcription on small audio chunks.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors