-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[MagpieTTS] Magpietts longform unify #15477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
74e0204
7d8e746
e920a7f
6c0549c
03a487f
9c40f43
f8a76a9
e7966fc
de64eae
1768f6b
02572fb
5f13ea2
7aa2c36
6106c18
50ff865
f51c961
648083d
23da4ab
aa6ac16
4a42fad
f872a4a
4d919f0
e433d76
55b5b5e
b3a234f
b673bca
88da860
ebea092
98acf28
4385e71
2660a10
398900b
2bc657f
88de65b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4228,7 +4228,6 @@ def do_tts( | |
| # Workaround for bug in Ja normalizer, Ja normalizer does not work well with spaces. | ||
| if language == "ja": | ||
| transcript = re.sub(r'\s+', '', transcript) | ||
|
|
||
| # Apply text normalization if requested | ||
| normalized_text = ( | ||
| self._get_normalized_text(transcript=transcript, language=language) if apply_TN else transcript | ||
|
|
@@ -4507,6 +4506,7 @@ def _check_eos_and_update_state( | |
| audio_codes_next: torch.Tensor, | ||
| all_codes_next_argmax: torch.Tensor, | ||
| chunk_end_dict: Dict[int, int], | ||
| chunk_end_frame_lens: Dict[int, int], | ||
| finished_texts_counter: Dict[int, int], | ||
| end_of_text: List[bool], | ||
| eos_detection_method: 'EOSDetectionMethod', | ||
|
|
@@ -4518,9 +4518,11 @@ def _check_eos_and_update_state( | |
|
|
||
| Args: | ||
| chunk_state: Mutable state object tracking history across chunks. | ||
| audio_codes_next: Sampled audio codes. Shape: (B, num_codebooks). | ||
| audio_codes_next: Sampled audio codes. Shape: (B, num_codebooks, frame_stacking_factor). | ||
| Always 3D; when frame stacking is disabled (frame_stacking_factor=1) the last dim is 1. | ||
| all_codes_next_argmax: Argmax sampled codes for EOS detection. | ||
| chunk_end_dict: Maps batch indices to chunk end timesteps. | ||
| chunk_end_frame_lens: Maps batch indices to frame-level length (for codes_to_audio); aligned with infer(). | ||
| finished_texts_counter: Counter for near-end timesteps. | ||
| end_of_text: Whether text has ended for each batch item. | ||
| eos_detection_method: Method for detecting end-of-sequence. | ||
|
|
@@ -4537,6 +4539,8 @@ def _check_eos_and_update_state( | |
|
|
||
| # End of speech detected. Update the state. | ||
| if end_frame_index != float('inf'): | ||
| frame_len = current_step * self.frame_stacking_factor + end_frame_index | ||
| chunk_end_frame_lens[item_idx] = frame_len | ||
| if end_of_text[item_idx]: | ||
| # Speech for entire multi-chunk text has ended. Update the state. | ||
| chunk_state.end_indices[item_idx] = chunk_state.overall_idx | ||
|
|
@@ -4555,6 +4559,7 @@ def _check_eos_and_update_state( | |
| >= self.chunked_inference_config.forceful_chunk_end_threshold | ||
| ): | ||
| chunk_end_dict[item_idx] = current_step | ||
| chunk_end_frame_lens[item_idx] = (current_step + 1) * self.frame_stacking_factor | ||
| logging.info(f"Forceful chunk end detected for item {item_idx} at local timestep {current_step}") | ||
|
|
||
| def _should_terminate_loop( | ||
|
|
@@ -4951,8 +4956,13 @@ def generate_speech( | |
| finished_texts_counter={}, | ||
| attn_prior=initial_attn_prior, | ||
| ) | ||
| # Frame-level lengths for this chunk only: batch_idx -> number of codec frames to keep | ||
| # per item (used for predicted_codes_lens and trimming). Filled when EOS or chunk end | ||
| # is detected. | ||
| chunk_end_frame_lens: Dict[int, int] = {} | ||
|
|
||
| for idx in range(self.inference_parameters.max_decoder_steps): | ||
| max_steps = self.inference_parameters.max_decoder_steps // self.frame_stacking_factor | ||
| for idx in range(max_steps): | ||
| if idx % 30 == 0: | ||
| logging.info(f"Decoding timestep {idx}") | ||
|
|
||
|
|
@@ -5143,6 +5153,7 @@ def generate_speech( | |
| audio_codes_next, | ||
| all_codes_next_argmax, | ||
| state.chunk_end_dict, | ||
| chunk_end_frame_lens, | ||
| state.finished_texts_counter, | ||
| end_of_text, | ||
| eos_detection_method, | ||
|
|
@@ -5161,15 +5172,16 @@ def generate_speech( | |
|
|
||
| chunk_state.overall_idx += 1 | ||
|
|
||
| # Concatenate the list of predictions along the time dimension. | ||
| # Note that when frame stacking is on, this also undoes the stacking. | ||
| predicted_codes = torch.cat(state.all_predictions, dim=-1) # (B, C, F*T_steps) | ||
| num_steps = len(state.all_predictions) | ||
| default_frame_len = num_steps * self.frame_stacking_factor | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add back the comment that was here originally, I think it got lost:
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
| predicted_codes_lens = torch.tensor( | ||
| [ | ||
| state.chunk_end_dict.get(item_idx, num_steps) * self.frame_stacking_factor | ||
| for item_idx in range(batch_size) | ||
| ], | ||
| [chunk_end_frame_lens.get(item_idx, default_frame_len) for item_idx in range(batch_size)], | ||
| device=device, | ||
| ) | ||
| predicted_codes = predicted_codes[:, :, : predicted_codes_lens.max()] | ||
|
|
||
| return InferBatchOutput( | ||
| predicted_audio=torch.empty(0, device=device), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment saying what this tracks and if it persists across chunked calls to generate_speech()? Because it appears that this one keeps state locally, unlike chunk_state which is persistent between calls but not super clear from the naming.
Side note, maybe we could find a better name than
chunk_statesince that structure appears not to be associated with a particular chunk but rather tracks overall inference state (I think). E.g. could call it inference_state or chunked_inference_state (the latter is admittedly kind of verbose).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added. It does not maintain state across generate_speech calls. It is local only.