diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 4d34471af5a1..8f8281416c2b 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -4228,7 +4228,6 @@ def do_tts( # Workaround for bug in Ja normalizer, Ja normalizer does not work well with spaces. if language == "ja": transcript = re.sub(r'\s+', '', transcript) - # Apply text normalization if requested normalized_text = ( self._get_normalized_text(transcript=transcript, language=language) if apply_TN else transcript @@ -4507,6 +4506,7 @@ def _check_eos_and_update_state( audio_codes_next: torch.Tensor, all_codes_next_argmax: torch.Tensor, chunk_end_dict: Dict[int, int], + chunk_end_frame_lens: Dict[int, int], finished_texts_counter: Dict[int, int], end_of_text: List[bool], eos_detection_method: 'EOSDetectionMethod', @@ -4518,9 +4518,11 @@ def _check_eos_and_update_state( Args: chunk_state: Mutable state object tracking history across chunks. - audio_codes_next: Sampled audio codes. Shape: (B, num_codebooks). + audio_codes_next: Sampled audio codes. Shape: (B, num_codebooks, frame_stacking_factor). + Always 3D; when frame stacking is disabled (frame_stacking_factor=1) the last dim is 1. all_codes_next_argmax: Argmax sampled codes for EOS detection. chunk_end_dict: Maps batch indices to chunk end timesteps. + chunk_end_frame_lens: Maps batch indices to frame-level length (for codes_to_audio); aligned with infer(). finished_texts_counter: Counter for near-end timesteps. end_of_text: Whether text has ended for each batch item. eos_detection_method: Method for detecting end-of-sequence. @@ -4537,6 +4539,8 @@ def _check_eos_and_update_state( # End of speech detected. Update the state. if end_frame_index != float('inf'): + frame_len = current_step * self.frame_stacking_factor + end_frame_index + chunk_end_frame_lens[item_idx] = frame_len if end_of_text[item_idx]: # Speech for entire multi-chunk text has ended. Update the state. chunk_state.end_indices[item_idx] = chunk_state.overall_idx @@ -4555,6 +4559,7 @@ def _check_eos_and_update_state( >= self.chunked_inference_config.forceful_chunk_end_threshold ): chunk_end_dict[item_idx] = current_step + chunk_end_frame_lens[item_idx] = (current_step + 1) * self.frame_stacking_factor logging.info(f"Forceful chunk end detected for item {item_idx} at local timestep {current_step}") def _should_terminate_loop( @@ -4951,8 +4956,13 @@ def generate_speech( finished_texts_counter={}, attn_prior=initial_attn_prior, ) + # Frame-level lengths for this chunk only: batch_idx -> number of codec frames to keep + # per item (used for predicted_codes_lens and trimming). Filled when EOS or chunk end + # is detected. + chunk_end_frame_lens: Dict[int, int] = {} - for idx in range(self.inference_parameters.max_decoder_steps): + max_steps = self.inference_parameters.max_decoder_steps // self.frame_stacking_factor + for idx in range(max_steps): if idx % 30 == 0: logging.info(f"Decoding timestep {idx}") @@ -5143,6 +5153,7 @@ def generate_speech( audio_codes_next, all_codes_next_argmax, state.chunk_end_dict, + chunk_end_frame_lens, state.finished_texts_counter, end_of_text, eos_detection_method, @@ -5161,15 +5172,16 @@ def generate_speech( chunk_state.overall_idx += 1 + # Concatenate the list of predictions along the time dimension. + # Note that when frame stacking is on, this also undoes the stacking. predicted_codes = torch.cat(state.all_predictions, dim=-1) # (B, C, F*T_steps) num_steps = len(state.all_predictions) + default_frame_len = num_steps * self.frame_stacking_factor predicted_codes_lens = torch.tensor( - [ - state.chunk_end_dict.get(item_idx, num_steps) * self.frame_stacking_factor - for item_idx in range(batch_size) - ], + [chunk_end_frame_lens.get(item_idx, default_frame_len) for item_idx in range(batch_size)], device=device, ) + predicted_codes = predicted_codes[:, :, : predicted_codes_lens.max()] return InferBatchOutput( predicted_audio=torch.empty(0, device=device),