Skip to content
Merged
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
74e0204
Unify longform with standard inference
subhankar-ghosh Feb 2, 2026
7d8e746
Unify longform with standard inference - small fixes
subhankar-ghosh Feb 4, 2026
e920a7f
Hi and Ja
subhankar-ghosh Feb 4, 2026
6c0549c
Rename longform with chunk
subhankar-ghosh Feb 9, 2026
03a487f
Apply isort and black reformatting
subhankar-ghosh Feb 9, 2026
9c40f43
Added Long and short test cases, review comments
subhankar-ghosh Feb 12, 2026
f8a76a9
merge conflict
subhankar-ghosh Feb 12, 2026
e7966fc
merge conflicts
subhankar-ghosh Feb 12, 2026
de64eae
Apply isort and black reformatting
subhankar-ghosh Feb 12, 2026
1768f6b
Potential fix for code scanning alert no. 16979: Unused import
subhankar-ghosh Feb 12, 2026
02572fb
Potential fix for code scanning alert no. 16980: Unused import
subhankar-ghosh Feb 12, 2026
5f13ea2
Fix unit tests
subhankar-ghosh Feb 12, 2026
7aa2c36
Merge branch 'magpietts_longform_unify' of github.com:NVIDIA-NeMo/NeM…
subhankar-ghosh Feb 12, 2026
6106c18
Change ssim_target
subhankar-ghosh Feb 17, 2026
50ff865
Merge branch 'main' into magpietts_longform_unify
blisc Feb 18, 2026
f51c961
Reset kv cache after a batch.
subhankar-ghosh Feb 20, 2026
648083d
Merge branch 'magpietts_longform_unify' of github.com:NVIDIA-NeMo/NeM…
subhankar-ghosh Feb 20, 2026
23da4ab
Fix tests with latest checkpoint, torch load weight_only false
subhankar-ghosh Feb 28, 2026
aa6ac16
Apply isort and black reformatting
subhankar-ghosh Feb 28, 2026
4a42fad
review comments
subhankar-ghosh Feb 28, 2026
f872a4a
Merge branch 'magpietts_longform_unify' of github.com:NVIDIA-NeMo/NeM…
subhankar-ghosh Feb 28, 2026
4d919f0
Merge branch 'main' into magpietts_longform_unify
subhankar-ghosh Mar 2, 2026
e433d76
Fix Framestacking test command.
subhankar-ghosh Mar 2, 2026
55b5b5e
Change checkpoint in magpie tests, review comments
subhankar-ghosh Mar 3, 2026
b3a234f
Frame stacking MagpieTTS generate_speech method
subhankar-ghosh Mar 6, 2026
b673bca
Typo fix in tests
subhankar-ghosh Mar 6, 2026
88da860
Framestacking fix
subhankar-ghosh Mar 9, 2026
ebea092
Refactor audio processing to include frame lengths for Framestacking
subhankar-ghosh Mar 9, 2026
98acf28
Adding back Fix Japanese transcript normalization issue
subhankar-ghosh Mar 9, 2026
4385e71
Apply isort and black reformatting
subhankar-ghosh Mar 9, 2026
2660a10
Fix review comments
subhankar-ghosh Mar 11, 2026
398900b
Merge conflicts
subhankar-ghosh Mar 11, 2026
2bc657f
Apply isort and black reformatting
subhankar-ghosh Mar 11, 2026
88de65b
Merge branch 'main' into magpietts_longform_unify
subhankar-ghosh Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions nemo/collections/tts/models/magpietts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4228,7 +4228,6 @@ def do_tts(
# Workaround for bug in Ja normalizer, Ja normalizer does not work well with spaces.
if language == "ja":
transcript = re.sub(r'\s+', '', transcript)

# Apply text normalization if requested
normalized_text = (
self._get_normalized_text(transcript=transcript, language=language) if apply_TN else transcript
Expand Down Expand Up @@ -4507,6 +4506,7 @@ def _check_eos_and_update_state(
audio_codes_next: torch.Tensor,
all_codes_next_argmax: torch.Tensor,
chunk_end_dict: Dict[int, int],
chunk_end_frame_lens: Dict[int, int],
finished_texts_counter: Dict[int, int],
end_of_text: List[bool],
eos_detection_method: 'EOSDetectionMethod',
Expand All @@ -4518,9 +4518,11 @@ def _check_eos_and_update_state(

Args:
chunk_state: Mutable state object tracking history across chunks.
audio_codes_next: Sampled audio codes. Shape: (B, num_codebooks).
audio_codes_next: Sampled audio codes. Shape: (B, num_codebooks, frame_stacking_factor).
Always 3D; when frame stacking is disabled (frame_stacking_factor=1) the last dim is 1.
all_codes_next_argmax: Argmax sampled codes for EOS detection.
chunk_end_dict: Maps batch indices to chunk end timesteps.
chunk_end_frame_lens: Maps batch indices to frame-level length (for codes_to_audio); aligned with infer().
finished_texts_counter: Counter for near-end timesteps.
end_of_text: Whether text has ended for each batch item.
eos_detection_method: Method for detecting end-of-sequence.
Expand All @@ -4537,6 +4539,8 @@ def _check_eos_and_update_state(

# End of speech detected. Update the state.
if end_frame_index != float('inf'):
frame_len = current_step * self.frame_stacking_factor + end_frame_index
chunk_end_frame_lens[item_idx] = frame_len
if end_of_text[item_idx]:
# Speech for entire multi-chunk text has ended. Update the state.
chunk_state.end_indices[item_idx] = chunk_state.overall_idx
Expand All @@ -4555,6 +4559,7 @@ def _check_eos_and_update_state(
>= self.chunked_inference_config.forceful_chunk_end_threshold
):
chunk_end_dict[item_idx] = current_step
chunk_end_frame_lens[item_idx] = (current_step + 1) * self.frame_stacking_factor
logging.info(f"Forceful chunk end detected for item {item_idx} at local timestep {current_step}")

def _should_terminate_loop(
Expand Down Expand Up @@ -4951,8 +4956,13 @@ def generate_speech(
finished_texts_counter={},
attn_prior=initial_attn_prior,
)
# Frame-level lengths for this chunk only: batch_idx -> number of codec frames to keep
# per item (used for predicted_codes_lens and trimming). Filled when EOS or chunk end
# is detected.
chunk_end_frame_lens: Dict[int, int] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment saying what this tracks and if it persists across chunked calls to generate_speech()? Because it appears that this one keeps state locally, unlike chunk_state which is persistent between calls but not super clear from the naming.

Side note, maybe we could find a better name than chunk_state since that structure appears not to be associated with a particular chunk but rather tracks overall inference state (I think). E.g. could call it inference_state or chunked_inference_state (the latter is admittedly kind of verbose).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. It does not maintain state across generate_speech calls. It is local only.


for idx in range(self.inference_parameters.max_decoder_steps):
max_steps = self.inference_parameters.max_decoder_steps // self.frame_stacking_factor
for idx in range(max_steps):
if idx % 30 == 0:
logging.info(f"Decoding timestep {idx}")

Expand Down Expand Up @@ -5143,6 +5153,7 @@ def generate_speech(
audio_codes_next,
all_codes_next_argmax,
state.chunk_end_dict,
chunk_end_frame_lens,
state.finished_texts_counter,
end_of_text,
eos_detection_method,
Expand All @@ -5161,15 +5172,16 @@ def generate_speech(

chunk_state.overall_idx += 1

# Concatenate the list of predictions along the time dimension.
# Note that when frame stacking is on, this also undoes the stacking.
predicted_codes = torch.cat(state.all_predictions, dim=-1) # (B, C, F*T_steps)
num_steps = len(state.all_predictions)
default_frame_len = num_steps * self.frame_stacking_factor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add back the comment that was here originally, I think it got lost:
# Concatenate the list of predictions along the time dimension. Note that when frame stacking is on, this also undoes the stacking.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

predicted_codes_lens = torch.tensor(
[
state.chunk_end_dict.get(item_idx, num_steps) * self.frame_stacking_factor
for item_idx in range(batch_size)
],
[chunk_end_frame_lens.get(item_idx, default_frame_len) for item_idx in range(batch_size)],
device=device,
)
predicted_codes = predicted_codes[:, :, : predicted_codes_lens.max()]

return InferBatchOutput(
predicted_audio=torch.empty(0, device=device),
Expand Down
Loading