Skip to content

Commit dbb022c

Browse files
committed
Merge branch 'magpietts_eou_quality' of github.com:rfejgin/NeMo into magpietts_eou_quality
2 parents 0fda552 + 84917c6 commit dbb022c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+4801
-1370
lines changed

docs/source/speechlm2/intro.rst

Lines changed: 105 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,21 @@ SpeechLM2
44
.. note::
55
The SpeechLM2 collection is still in active development and the code is likely to keep changing.
66

7+
8+
79
SpeechLM2 refers to a collection that augments pre-trained Large Language Models (LLMs) with speech understanding and generation capabilities.
810

911
This collection is designed to be compact, efficient, and to support easy swapping of different LLMs backed by HuggingFace AutoModel.
1012
It has a first-class support for using dynamic batch sizes via Lhotse and various model parallelism techniques (e.g., FSDP2, Tensor Parallel, Sequence Parallel) via PyTorch DTensor API.
1113

12-
We currently support four main model types:
13-
* SALM (Speech-Augmented Language Model) - a simple but effective approach to augmenting pre-trained LLMs with speech understanding capabilities.
14-
* DuplexS2SModel - a full-duplex speech-to-speech model with an ASR encoder, directly predicting discrete audio codes.
15-
* DuplexS2SSpeechDecoderModel - a variant of DuplexS2SModel with a separate transformer decoder for speech generation.
16-
* DuplexSTTModel - a decoder model to generate agent text in duplex, in response to both user speech and text inputs.
14+
We currently support six main model types:
15+
16+
* **SALM** (Speech-Augmented Language Model) - a simple but effective approach to augmenting pre-trained LLMs with speech understanding capabilities.
17+
* **DuplexS2SModel** - a full-duplex speech-to-speech model with an ASR encoder, directly predicting discrete audio codes.
18+
* **DuplexS2SSpeechDecoderModel** - a variant of DuplexS2SModel with a separate transformer decoder for speech generation.
19+
* **DuplexEARTTS** - a ready-to-use duplex text-to-speech model that supports user interruption via a special text interruption token.
20+
* **DuplexSTTModel** - a decoder model to generate agent text in duplex, in response to both user speech and text inputs.
21+
* **NemotronVoiceChat** - an *inference-only* pipeline that seamlessly merges `DuplexSTTModel` and `DuplexEARTTS` to deliver an end-to-end, full-duplex conversational agent with high-fidelity speech generation.
1722

1823
Using Pretrained Models
1924
-----------------------
@@ -148,10 +153,100 @@ You can run inference using the loaded pretrained DuplexSTTModel:
148153
transcription = results["text"][0]
149154
print(f"Transcription: {transcription}")
150155
156+
DuplexEARTTS
157+
************
158+
159+
Because `DuplexEARTTS` relies on precise token padding and EOS placement to handle potential user interruptions, inference and evaluation are handled via the `duplex_eartts_eval.py` script following the MagpieTTS dataset format recipe.
160+
161+
The evaluation script processes a `JSONL` file where each line is a dictionary containing the text, the reference audio for the speaker, and the desired output audio filename.
162+
163+
**JSONL Format Examples:**
164+
165+
Single-Turn format (evaluates a continuous string):
166+
167+
.. code-block:: json
168+
169+
{"text": "Like really quickly and then they run off.", "context_audio_filepath": "speaker_1.wav", "audio_filepath": "audio_1.wav"}
170+
171+
Multi-Turn format (evaluates sequential conversational turns, padded incrementally):
172+
173+
.. code-block:: json
174+
175+
{"text": ["Yes.", "Sure.", "Right.", "I get what you’re saying."], "context_audio_filepath": "speaker_2.wav", "audio_filepath": "audio_2.wav"}
176+
177+
**Running the Evaluation/Inference Script:**
178+
179+
.. code-block:: bash
180+
181+
python examples/speechlm2/duplex_eartts_eval.py \
182+
--config-path=conf/ \
183+
--config-name=duplex_eartts.yaml \
184+
++checkpoint_path=/path/to/duplex_eartts/model.ckpt \
185+
++datasets_json_path=/path/to/evalset_config.jsonl \
186+
++out_dir=/path/to/output/audio_samples/ \
187+
++user_custom_speaker_reference=/path/to/optional_override_speaker.wav
188+
189+
The script will decode the text, apply the target speaker conditioning, generate the resulting audio waveforms into `out_dir`, and compute ASR intelligibility metrics (CER/WER) on the generated speech.
190+
191+
NemotronVoiceChat
192+
*****************
193+
194+
You can evaluate and run full-duplex inference using the `NemotronVoiceChat` pipeline. This model natively chains the `DuplexSTTModel` with the `DuplexEARTTS` speech decoder for an end-to-end response:
195+
196+
.. code-block:: python
197+
198+
import torch
199+
import torchaudio
200+
import nemo.collections.speechlm2 as slm
201+
202+
model = slm.models.NemotronVoiceChat.from_pretrained("path/to/pretrained_checkpoint").eval()
203+
204+
# Load user audio prompt
205+
audio_path = "path/to/user_audio.wav"
206+
audio_signal, sample_rate = torchaudio.load(audio_path)
207+
208+
# Resample to the source_sample_rate (usually 16kHz for STT perception)
209+
if sample_rate != 16000:
210+
audio_signal = torchaudio.functional.resample(audio_signal, sample_rate, 16000)
211+
sample_rate = 16000
212+
213+
# Prepare audio for model
214+
audio_signal = audio_signal.to(model.device)
215+
audio_len = torch.tensor([audio_signal.shape[1]], device=model.device)
216+
217+
# (Optional) Load an explicit speaker reference audio to condition the agent's voice
218+
# speaker_audio, _ = torchaudio.load("path/to/speaker_reference.wav")
219+
# speaker_audio = speaker_audio.to(model.device)
220+
# speaker_len = torch.tensor([speaker_audio.shape[1]], device=model.device)
221+
222+
# Note: If an explicit audio reference is not passed into `offline_inference`,
223+
# the model relies on the internal config parameters:
224+
# 1. model.cfg.inference_speaker_name (Highest priority preset, e.g., 'Megan')
225+
# 2. model.cfg.inference_speaker_reference (Fallback audio file path)
226+
227+
# Run full offline inference
228+
results = model.offline_inference(
229+
input_signal=audio_signal,
230+
input_signal_lens=audio_len,
231+
# speaker_audio=speaker_audio, # Pass speaker reference if available
232+
# speaker_audio_lens=speaker_len
233+
)
234+
235+
# Decode the predicted text and generated speech waveform
236+
generated_text = results["text"][0]
237+
generated_speech = results["audio"][0]
238+
239+
print(f"Agent response: {generated_text}")
240+
# generated_speech can now be saved or played (sampled at model.target_sample_rate)
241+
242+
151243
Training a Model
152244
----------------
153245

154-
This example demonstrates how to train a SALM model. The remaining models can be trained in a similar manner.
246+
This example demonstrates how to train a SALM model.
247+
248+
.. note::
249+
**NemotronVoiceChat is an inference-only class.** It does not implement a `training_step` and cannot be trained using the pipeline below. To update its underlying capabilities, you must train the `DuplexSTTModel` and `DuplexEARTTS` models independently.
155250

156251
.. code-block:: python
157252
@@ -207,7 +302,7 @@ Alternatively, you can train a model using the provided training scripts in the
207302
--config-path=examples/speechlm2/conf \
208303
--config-name=salm
209304
210-
# For inference/evaluation
305+
# For SALM inference/evaluation
211306
python examples/speechlm2/salm_eval.py \
212307
pretrained_name=/path/to/checkpoint \
213308
inputs=/path/to/test_manifest \
@@ -222,9 +317,9 @@ Collection Structure
222317

223318
The speechlm2 collection is organized into the following key components:
224319

225-
- **Models**: Contains implementations of DuplexS2SModel, DuplexS2SSpeechDecoderModel, DuplexSTTModel, and SALM
226-
- **Modules**: Contains audio perception and speech generation modules
227-
- **Data**: Includes dataset classes and data loading utilities
320+
- **Models**: Contains implementations of DuplexS2SModel, DuplexS2SSpeechDecoderModel, DuplexSTTModel, SALM, DuplexEARTTS, and the inference-only NemotronVoiceChat.
321+
- **Modules**: Contains audio perception and speech generation modules.
322+
- **Data**: Includes dataset classes and data loading utilities.
228323

229324
SpeechLM2 Documentation
230325
-----------------------

docs/source/speechlm2/models.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,24 @@ This model is particularly useful for:
9999
* Duplex systems where text responses are needed instead of speech
100100
* Applications requiring transcript generation from spoken dialogue
101101

102+
103+
NemotronVoiceChat
104+
^^^^^^^^^^^^^^^^^
105+
106+
NemotronVoiceChat is an **inference-only**, end-to-end Duplex Speech-to-Speech pipeline. It achieves full-duplex conversational capabilities by seamlessly merging the `DuplexSTTModel` with the `DuplexEARTTS` model.
107+
108+
Because it is designed exclusively for evaluation, offline inference, and validation workflows (no training step is implemented), it is highly optimized for executing the full perception-generation-synthesis loop.
109+
110+
Key components:
111+
112+
* **DuplexSTTModel**: Handles the streaming audio perception and text response generation.
113+
* **DuplexEARTTS**: Serves as the autoregressive speech decoder, generating high-fidelity audio from the STT model's text tokens in a streamable fashion.
114+
115+
This model is particularly useful for:
116+
* End-to-end evaluation of the complete speech-to-speech pipeline.
117+
* Offline speech-to-speech inference workflows.
118+
119+
102120
Model Components
103121
----------------
104122

@@ -247,6 +265,9 @@ All models in the speechlm2 collection can be instantiated from pretrained check
247265
# Load DuplexEARTTS
248266
ear_tts_model = slm.models.DuplexEARTTS.from_pretrained("path/to/checkpoint")
249267
268+
# Load NemotronVoiceChat (Inference Only)
269+
voicechat_model = slm.models.NemotronVoiceChat.from_pretrained("path/to/checkpoint")
270+
250271
Model Configuration
251272
-------------------
252273

docs/source/tts/magpietts-longform.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ The input text is split into individual sentences using punctuation markers (``.
6868
Step 2: State Initialization
6969
----------------------------
7070

71-
A ``LongformChunkState`` object is created to track information across sentence chunks:
71+
A ``ChunkState`` object is created to track information across sentence chunks:
7272

7373
- **History text tokens**: Text from previous chunks for context
7474
- **History encoder context**: Encoder outputs that provide continuity
@@ -112,7 +112,7 @@ Key Components
112112

113113
1. **Sentence Splitting** (``split_by_sentence``): Intelligently splits text on sentence boundaries while handling abbreviations (e.g., "Dr.", "Mr.").
114114

115-
2. **Chunk State** (``LongformChunkState``): Maintains context across chunks:
115+
2. **Chunk State** (``ChunkState``): Maintains context across chunks:
116116

117117
- ``history_text``: Text tokens from previous chunks
118118
- ``history_context_tensor``: Encoder outputs for continuity
@@ -211,24 +211,24 @@ Configuration Dataclasses
211211
#########################
212212

213213

214-
``LongformConfig``
215-
------------------
214+
``ChunkedInferenceConfig``
215+
--------------------------
216216

217217
Immutable tuning parameters (set in model):
218218

219219
.. literalinclude:: ../../../nemo/collections/tts/models/magpietts.py
220220
:language: python
221-
:pyobject: LongformConfig
221+
:pyobject: ChunkedInferenceConfig
222222

223223

224-
``LongformChunkState``
225-
----------------------
224+
``ChunkState``
225+
--------------
226226

227227
Mutable state passed between chunk iterations:
228228

229229
.. literalinclude:: ../../../nemo/collections/tts/models/magpietts.py
230230
:language: python
231-
:pyobject: LongformChunkState
231+
:pyobject: ChunkState
232232

233233

234234
Best Practices

docs/source/tts/magpietts-po.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ The final step is fine-tuning the base model on the preference pairs using the D
9696
max_epochs=10 \
9797
exp_manager.exp_dir=/path/to/dpo_experiment \
9898
exp_manager.checkpoint_callback_params.always_save_nemo=false \
99-
model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \
100-
model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \
99+
model.train_ds.datasets._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \
100+
model.validation_ds.datasets._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \
101101
+train_ds_meta.dpopreftrain.manifest_path="/path/to/manifests/" \
102102
+train_ds_meta.dpopreftrain.audio_dir="/" \
103103
+train_ds_meta.dpopreftrain.feature_dir="/" \
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
checkpoint_path: null # Path to the pre-trained NemotronVoiceChat checkpoint for evaluation
2+
model:
3+
scoring_asr: stt_en_fastconformer_transducer_large # ASR model used to transcribe generated audio for ASR-BLEU computation
4+
inference_speaker_reference: null # Path to an audio file used to clone/condition the TTS voice. Set to "null" if using a preset name below.
5+
inference_speaker_name: Megan # Preset speaker identifier. If provided, this overrides `inference_speaker_reference`.
6+
7+
stt:
8+
model:
9+
# evaluation params
10+
eval_text_turn_taking: true # Enables evaluation of turn-taking and text prediction accuracy in the Duplex STT model
11+
12+
speech_generation:
13+
model:
14+
# inference params for the Duplex EAR-TTS module
15+
inference_guidance_scale: 0.2 # Classifier-Free Guidance (CFG) scale for conditioning the audio generation
16+
inference_noise_scale: 0.001 # Sampling temperature/noise for MoG
17+
inference_top_p_or_k: 0.95 # Nucleus sampling (top-p) or top-k threshold for token selection
18+
inference_guidance_enabled: true # Toggle to enable/disable Classifier-Free Guidance
19+
inference_force_speech_silence_on_eos: true # Forces the model to output silence tokens once the End-Of-Sequence (EOS) token is generated
20+
21+
trainer:
22+
devices: -1 # Number of GPUs to use (-1 uses all available)
23+
accelerator: gpu # Hardware accelerator type
24+
num_nodes: 1 # Number of compute nodes
25+
precision: 32 # Mixed precision setting (16-bit) for faster, memory-efficient inference
26+
logger: False # Disabled here because NeMo's `exp_manager` handles logging
27+
limit_val_batches: 1.0 # Fraction of the validation dataset to use (1.0 = use the entire dataset)
28+
log_every_n_steps: 20 # Frequency of logging metrics to the console/wandb
29+
use_distributed_sampler: false # Disable distributed sampler
30+
strategy:
31+
_target_: lightning.pytorch.strategies.DDPStrategy # Distributed Data Parallel strategy for multi-GPU inference
32+
gradient_as_bucket_view: true # Memory optimization for DDP
33+
find_unused_parameters: true # Required if parts of the model (like text-only branches) don't receive gradients/usage
34+
35+
data:
36+
frame_length: 0.08 # Duration of a single audio frame in seconds (80ms)
37+
source_sample_rate: 16000 # Sample rate of the input/user audio prompts (16 kHz)
38+
target_sample_rate: 22050 # Sample rate of the generated output speech (22.05 kHz)
39+
input_roles: ["user", "User"] # Conversation roles mapped to the input prompt
40+
output_roles: ["agent", "Assistant", "assistant","Agent"] # Conversation roles the model is tasked with generating
41+
42+
validation_ds:
43+
datasets:
44+
evaluation_set:
45+
shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/duplex/ultrachat_v2/shar_duplex/manifest_000020 # Path to the Lhotse WebDataset tar shards manifest
46+
47+
sample_rate: ${data.target_sample_rate} # Audio will be resampled to this rate if necessary
48+
batch_size: 4 # Number of samples processed per GPU during evaluation
49+
seed: 42 # Random seed for reproducibility
50+
shard_seed: "randomized" # Ensures distributed workers get different data shards
51+
52+
exp_manager:
53+
explicit_log_dir: nemotron_voicechat_log_dir/ # Root directory where evaluation metrics, JSON logs, and generated audio will be saved
54+
name: nemotron-voicechat-eval # Name of the experiment
55+
create_tensorboard_logger: false # Toggle for TensorBoard logging
56+
create_checkpoint_callback: false # Enables the checkpoint callback module
57+
use_datetime_version: true # Appends a timestamp to the log directory name

0 commit comments

Comments
 (0)