From b6bdee3487e496e17227d28e3534942a38f37d82 Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Wed, 11 Mar 2026 20:49:41 -0700 Subject: [PATCH 1/2] clean up code, rename back to magpietts_inference.py Signed-off-by: Shehzeen Hussain --- docs/source/tts/magpietts-longform.rst | 6 +-- docs/source/tts/magpietts.rst | 2 +- .../{tts_infer.py => magpietts_inference.py} | 50 +++++++------------ nemo/collections/tts/models/easy_magpietts.py | 2 +- .../tts/models/easy_magpietts_inference.py | 12 +---- .../modules/magpietts_inference/inference.py | 2 - ...S_InferEvaluate_Magpietts_FrameStacking.sh | 2 +- ...TS_InferEvaluate_Magpietts_MoE_ZeroShot.sh | 2 +- ...TS_InferEvaluate_Magpietts_SeenSpeakers.sh | 2 +- ...L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh | 2 +- ...Evaluatelongform_Magpietts_MoE_ZeroShot.sh | 2 +- ...nferEvaluatelongform_Magpietts_ZeroShot.sh | 2 +- 12 files changed, 30 insertions(+), 56 deletions(-) rename examples/tts/{tts_infer.py => magpietts_inference.py} (95%) diff --git a/docs/source/tts/magpietts-longform.rst b/docs/source/tts/magpietts-longform.rst index fb3eeb659d33..33aef42a5abe 100644 --- a/docs/source/tts/magpietts-longform.rst +++ b/docs/source/tts/magpietts-longform.rst @@ -169,7 +169,7 @@ The ``do_tts`` method automatically detects whether longform inference is needed sf.write("output.wav", long_audio[0].cpu().numpy(), 22050) -Method 2: Using CLI (``tts_infer.py``) +Method 2: Using CLI (``magpietts_inference.py``) ------------------------------------------------ For batch inference from manifests: @@ -177,7 +177,7 @@ For batch inference from manifests: .. code-block:: bash # Auto-detect longform based on text length (default) - python examples/tts/tts_infer.py \ + python examples/tts/magpietts_inference.py \ --nemo_files /path/to/magpietts.nemo \ --datasets_json_path /path/to/evalset_config.json \ --out_dir /path/to/output \ @@ -185,7 +185,7 @@ For batch inference from manifests: --longform_mode auto # Force longform inference for all inputs - python examples/tts/tts_infer.py \ + python examples/tts/magpietts_inference.py \ --nemo_files /path/to/magpietts.nemo \ --datasets_json_path /path/to/evalset_config.json \ --out_dir /path/to/output \ diff --git a/docs/source/tts/magpietts.rst b/docs/source/tts/magpietts.rst index 6d297a694596..b79c11ea88ff 100644 --- a/docs/source/tts/magpietts.rst +++ b/docs/source/tts/magpietts.rst @@ -130,7 +130,7 @@ Several parameters control the generation behavior. The temperature setting affe .. code-block:: bash - python examples/tts/tts_infer.py \ + python examples/tts/magpietts_inference.py \ --nemo_files /path/to/magpietts_model.nemo \ --codecmodel_path /path/to/audio_codec.nemo \ --datasets your_evaluation_set \ diff --git a/examples/tts/tts_infer.py b/examples/tts/magpietts_inference.py similarity index 95% rename from examples/tts/tts_infer.py rename to examples/tts/magpietts_inference.py index 2c3bec0aa7f7..50333f2ab7e9 100644 --- a/examples/tts/tts_infer.py +++ b/examples/tts/magpietts_inference.py @@ -26,7 +26,7 @@ Example usage: # MagpieTTS inference (encoder-decoder, default) - python examples/tts/tts_infer.py \\ + python examples/tts/magpietts_inference.py \\ --model_type magpie \\ --nemo_files /path/to/model.nemo \\ --datasets_json_path /path/to/evalset_config.json \\ @@ -34,7 +34,7 @@ --codecmodel_path /path/to/codec.nemo # EasyMagpieTTS inference (decoder-only) - python examples/tts/tts_infer.py \\ + python examples/tts/magpietts_inference.py \\ --model_type easy_magpie \\ --nemo_files /path/to/model.nemo \\ --datasets_json_path /path/to/evalset_config.json \\ @@ -42,7 +42,7 @@ --codecmodel_path /path/to/codec.nemo # With evaluation - python examples/tts/tts_infer.py \\ + python examples/tts/magpietts_inference.py \\ --model_type magpie \\ --hparams_files /path/to/hparams.yaml \\ --checkpoint_files /path/to/model.ckpt \\ @@ -161,11 +161,6 @@ def filter_datasets(dataset_meta_info: dict, datasets: Optional[List[str]]) -> L return datasets -# --------------------------------------------------------------------------- -# Core inference + evaluation orchestration (model-type agnostic) -# --------------------------------------------------------------------------- - - def run_inference_and_evaluation( runner: BaseInferenceRunner, checkpoint_name: str, @@ -355,15 +350,18 @@ def run_inference_and_evaluation( return None, None -# --------------------------------------------------------------------------- -# CLI argument parser -# --------------------------------------------------------------------------- +def _get_shared_inference_param_names() -> set: + """Return the field names shared by ModelInferenceParameters and EasyModelInferenceParameters.""" + magpie_fields = {f.name for f in fields(ModelInferenceParameters)} + easy_fields = {f.name for f in fields(EasyModelInferenceParameters)} + return magpie_fields & easy_fields def _add_inference_param_fields( group: argparse._ArgumentGroup, param_cls: type, skip_fields: Optional[set] = None, + only_fields: Optional[set] = None, ) -> None: """Auto-generate argparse arguments from fields of a dataclass. @@ -371,12 +369,15 @@ def _add_inference_param_fields( group: The argparse argument group to add arguments to. param_cls: The dataclass whose fields to add. skip_fields: Field names to skip (already added by another group). + only_fields: If provided, only add fields whose names are in this set. """ if skip_fields is None: skip_fields = set() for f in fields(param_cls): if f.name in skip_fields: continue + if only_fields is not None and f.name not in only_fields: + continue extra_args: dict = {"type": f.type} if f.type == bool: extra_args = {"action": "store_true"} @@ -469,8 +470,9 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None: infer_group.add_argument('--use_cfg', action='store_true', help='Enable classifier-free guidance') infer_group.add_argument('--use_local_transformer', action='store_true') - # Shared model inference parameters (max_decoder_steps, temperature, topk, cfg_scale) - _add_inference_param_fields(infer_group, EasyModelInferenceParameters) + # Model inference parameters shared by both MagpieTTS and EasyMagpieTTS + shared_param_names = _get_shared_inference_param_names() + _add_inference_param_fields(infer_group, ModelInferenceParameters, only_fields=shared_param_names) # Evaluation eval_group = parser.add_argument_group('Evaluation') @@ -499,9 +501,8 @@ def _add_magpie_args(parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('MagpieTTS-specific Parameters') # MagpieTTS-specific model inference parameters (attention prior, EOS, etc.) - # Skip fields already added by the common inference group. - shared_field_names = {f.name for f in fields(EasyModelInferenceParameters)} - _add_inference_param_fields(group, ModelInferenceParameters, skip_fields=shared_field_names) + shared_param_names = _get_shared_inference_param_names() + _add_inference_param_fields(group, ModelInferenceParameters, skip_fields=shared_param_names) group.add_argument('--maskgit_n_steps', type=int, default=3) group.add_argument('--maskgit_noise_scale', type=float, default=0.0) @@ -531,11 +532,6 @@ def _add_easy_magpie_args(parser: argparse.ArgumentParser) -> None: help='Sampling method for phoneme prediction', ) group.add_argument('--dropout_text_input', action='store_true', help='Force dropout on text input') - group.add_argument( - '--legacy_context_stacking', - action='store_true', - help='Use audio_bos_id/audio_eos_id for context stacking', - ) def create_argument_parser() -> argparse.ArgumentParser: @@ -551,11 +547,6 @@ def create_argument_parser() -> argparse.ArgumentParser: return parser -# --------------------------------------------------------------------------- -# Config builders (one per model type) -# --------------------------------------------------------------------------- - - def _build_inference_params_from_args(param_cls: type, args): """Extract inference parameters from parsed CLI args for the given dataclass.""" params = {} @@ -592,15 +583,8 @@ def _build_easy_magpie_config(args) -> EasyMagpieInferenceConfig: phoneme_input_type=args.phoneme_input_type, phoneme_sampling_method=args.phoneme_sampling_method, dropout_text_input=args.dropout_text_input, - legacy_context_stacking=args.legacy_context_stacking, ) - -# --------------------------------------------------------------------------- -# Entry point -# --------------------------------------------------------------------------- - - def main(argv=None): """Entry point for TTS inference and evaluation.""" parser = create_argument_parser() diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 19705eed1ad3..5c20a09292ff 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -109,7 +109,7 @@ class EasyModelInferenceParameters: cfg_scale: Scale factor for classifier-free guidance. """ - max_decoder_steps: int = 500 + max_decoder_steps: int = 300 temperature: float = 0.7 topk: int = 80 cfg_scale: float = 2.5 diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 765c234e2683..5f9f47c22d04 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -734,19 +734,11 @@ def prepare_context_tensors( eos_id=self.context_audio_eos_id, ) - # Use legacy audio_bos_id/audio_eos_id if flag is set - stack_bos_id = ( - self.audio_bos_id if getattr(self, 'legacy_context_stacking', False) else self.context_audio_bos_id - ) - stack_eos_id = ( - self.audio_eos_id if getattr(self, 'legacy_context_stacking', False) else self.context_audio_eos_id - ) - context_audio_codes, context_audio_codes_lens = self.stack_codes( context_audio_codes, context_audio_codes_lens, - stack_bos_id, - stack_eos_id, + self.context_audio_bos_id, + self.context_audio_eos_id, self.frame_stacking_factor, self.num_audio_codebooks, ) diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index c343a9d31f9a..2f283b6c181d 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -129,7 +129,6 @@ class EasyMagpieInferenceConfig(BaseInferenceConfig): phoneme_input_type: str = "gt" phoneme_sampling_method: str = "argmax" dropout_text_input: bool = False - legacy_context_stacking: bool = False def build_identifier(self) -> str: parts = [ @@ -550,7 +549,6 @@ class EasyMagpieInferenceRunner(BaseInferenceRunner): def __init__(self, model, config: EasyMagpieInferenceConfig): super().__init__(model, config) - self.model.legacy_context_stacking = config.legacy_context_stacking def create_dataset( self, diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_FrameStacking.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_FrameStacking.sh index b6d87e91a254..368b5c83bba5 100644 --- a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_FrameStacking.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_FrameStacking.sh @@ -14,7 +14,7 @@ # Tests a 4x-stacked model with local transformer inference. -TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/tts_infer.py \ +TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts_inference.py \ --codecmodel_path /home/TestData/tts/21fps_causal_codecmodel.nemo \ --datasets_json_path examples/tts/evalset_config.json \ --datasets an4_val_ci \ diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_MoE_ZeroShot.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_MoE_ZeroShot.sh index 4e917733f59a..a591497f22e0 100755 --- a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_MoE_ZeroShot.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_MoE_ZeroShot.sh @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/tts_infer.py \ +TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts_inference.py \ --nemo_files "/home/TestData/tts/2602_MoE/moe16_sinkhorn_top1_valLoss5.0469_step2625132_epoch524.nemo" \ --codecmodel_path "/home/TestData/tts/21fps_causal_codecmodel.nemo" \ --datasets_json_path "examples/tts/evalset_config.json" \ diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh index 8eb30eb40c36..5ed8d48f5aff 100644 --- a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/tts_infer.py \ +TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts_inference.py \ --codecmodel_path /home/TestData/tts/21fps_causal_codecmodel.nemo \ --datasets_json_path examples/tts/evalset_config.json \ --datasets an4_val_ci \ diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh index eed95fc5a64e..3a9415bbc2b3 100644 --- a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/tts_infer.py \ +TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts_inference.py \ --codecmodel_path /home/TestData/tts/21fps_causal_codecmodel.nemo \ --datasets_json_path examples/tts/evalset_config.json \ --datasets an4_val_ci \ diff --git a/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_MoE_ZeroShot.sh b/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_MoE_ZeroShot.sh index c21454d39cb1..ec8b6b885212 100755 --- a/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_MoE_ZeroShot.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_MoE_ZeroShot.sh @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/tts_infer.py \ +TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts_inference.py \ --nemo_files "/home/TestData/tts/2602_MoE/moe16_sinkhorn_top1_valLoss5.0469_step2625132_epoch524.nemo" \ --codecmodel_path "/home/TestData/tts/21fps_causal_codecmodel.nemo" \ --datasets_json_path "examples/tts/evalset_config.json" \ diff --git a/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_ZeroShot.sh b/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_ZeroShot.sh index 96e20304197a..a0694c16b9ba 100644 --- a/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_ZeroShot.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_ZeroShot.sh @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/tts_infer.py \ +TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts_inference.py \ --codecmodel_path /home/TestData/tts/21fps_causal_codecmodel.nemo \ --datasets_json_path examples/tts/evalset_config.json \ --datasets an4_val_ci_longform_tiny \ From dbcb02c1fd16e295c29d1e5b1990696410c3ae55 Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Wed, 11 Mar 2026 21:22:23 -0700 Subject: [PATCH 2/2] bug fixes, inference runs now Signed-off-by: Shehzeen Hussain --- examples/tts/magpietts_inference.py | 14 ++++++-- nemo/collections/tts/models/easy_magpietts.py | 25 +------------- .../tts/models/easy_magpietts_inference.py | 32 ++++++++++++++++- .../modules/magpietts_inference/inference.py | 12 ++++--- .../tts/modules/magpietts_inference/utils.py | 34 ++++++++++++++----- 5 files changed, 75 insertions(+), 42 deletions(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 50333f2ab7e9..fca92fccddc4 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -66,7 +66,7 @@ import numpy as np from nemo.collections.asr.parts.utils.manifest_utils import read_manifest -from nemo.collections.tts.models.easy_magpietts import EasyModelInferenceParameters +from nemo.collections.tts.models.easy_magpietts_inference import EasyModelInferenceParameters from nemo.collections.tts.models.magpietts import ModelInferenceParameters from nemo.collections.tts.modules.magpietts_inference.evaluate_generated_audio import load_evalset_config from nemo.collections.tts.modules.magpietts_inference.evaluation import ( @@ -400,7 +400,7 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None: default='magpie', choices=['magpie', 'easy_magpie'], help='Model type: "magpie" for encoder-decoder MagpieTTSModel, ' - '"easy_magpie" for decoder-only EasyMagpieTTSModel', + '"easy_magpie" for decoder-only EasyMagpieTTSInferenceModel', ) # Model loading @@ -515,7 +515,7 @@ def _add_magpie_args(parser: argparse.ArgumentParser) -> None: def _add_easy_magpie_args(parser: argparse.ArgumentParser) -> None: - """Add arguments specific to decoder-only EasyMagpieTTSModel.""" + """Add arguments specific to decoder-only EasyMagpieTTSInferenceModel.""" group = parser.add_argument_group('EasyMagpieTTS-specific Parameters') group.add_argument( '--phoneme_input_type', @@ -532,6 +532,12 @@ def _add_easy_magpie_args(parser: argparse.ArgumentParser) -> None: help='Sampling method for phoneme prediction', ) group.add_argument('--dropout_text_input', action='store_true', help='Force dropout on text input') + group.add_argument( + '--phoneme_tokenizer_path', + type=str, + default=None, + help='Override path to the phoneme tokenizer file (overrides the path stored in the checkpoint config)', + ) def create_argument_parser() -> argparse.ArgumentParser: @@ -640,6 +646,7 @@ def main(argv=None): legacy_codebooks=args.legacy_codebooks, legacy_text_conditioning=args.legacy_text_conditioning, hparams_from_wandb=args.hparams_file_from_wandb, + phoneme_tokenizer_path=getattr(args, 'phoneme_tokenizer_path', None), ) model, checkpoint_name = load_fn(model_config) @@ -677,6 +684,7 @@ def main(argv=None): codecmodel_path=args.codecmodel_path, legacy_codebooks=args.legacy_codebooks, legacy_text_conditioning=args.legacy_text_conditioning, + phoneme_tokenizer_path=getattr(args, 'phoneme_tokenizer_path', None), ) model, checkpoint_name = load_fn(model_config) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 5c20a09292ff..5a117432b986 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -14,7 +14,7 @@ import json import os import random -from dataclasses import dataclass, fields +from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import numpy as np @@ -98,29 +98,6 @@ class ProcessBatchOutput: selected_training_mode: Optional[str] -@dataclass -class EasyModelInferenceParameters: - """Inference parameters for the decoder-only EasyMagpieTTS model. - - Attributes: - max_decoder_steps: Maximum number of decoder steps. - temperature: Sampling temperature. - topk: Number of top-probability tokens to consider in sampling. - cfg_scale: Scale factor for classifier-free guidance. - """ - - max_decoder_steps: int = 300 - temperature: float = 0.7 - topk: int = 80 - cfg_scale: float = 2.5 - - @classmethod - def from_dict(cls, data: dict) -> 'EasyModelInferenceParameters': - field_names = {field.name for field in fields(cls)} - filtered_data = {k: v for k, v in data.items() if k in field_names} - return cls(**filtered_data) - - class EasyMagpieTTSModel(EasyMagpieTTSInferenceModel): """ Magpie-TTS Model Decoder Only Model with training support. diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 5f9f47c22d04..59db7decda0e 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -from dataclasses import dataclass +from dataclasses import dataclass, fields from functools import partial from typing import Any, Dict, List, Optional, Sequence, Tuple @@ -184,6 +184,29 @@ class InferBatchOutput: phoneme_prediction_start_idx: Optional[torch.Tensor] = None # (B,) start index into predicted_phoneme_tokens +@dataclass +class EasyModelInferenceParameters: + """Inference parameters for the decoder-only EasyMagpieTTS model. + + Attributes: + max_decoder_steps: Maximum number of decoder steps. + temperature: Sampling temperature. + topk: Number of top-probability tokens to consider in sampling. + cfg_scale: Scale factor for classifier-free guidance. + """ + + max_decoder_steps: int = 300 + temperature: float = 0.7 + topk: int = 80 + cfg_scale: float = 2.5 + + @classmethod + def from_dict(cls, data: dict) -> 'EasyModelInferenceParameters': + field_names = {field.name for field in fields(cls)} + filtered_data = {k: v for k, v in data.items() if k in field_names} + return cls(**filtered_data) + + class EasyMagpieTTSInferenceModel(BaseMagpieTTSModel): """ Inference-only base class for EasyMagpieTTS decoder-only model. @@ -319,6 +342,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.pad_context_text_to_max_duration = False self.add_language_to_context_text = cfg.get('add_language_to_context_text', False) + self.ignore_phoneme_languages = cfg.get('ignore_phoneme_languages', []) super().__init__(cfg=cfg, trainer=trainer) @@ -465,6 +489,12 @@ def _get_state_dict_keys_to_exclude(self): '_codec_model', ] + def setup_training_data(self, train_data_config=None): + pass + + def setup_validation_data(self, val_data_config=None): + pass + def codes_to_audio(self, codes, codes_len): # codes: (B, C, T') self._codec_model.eval() diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index 2f283b6c181d..ab501075c98d 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -21,7 +21,7 @@ MagpieInferenceRunner handles the encoder-decoder MagpieTTSModel (chunked text, generate_speech + codes_to_audio). -EasyMagpieInferenceRunner handles the decoder-only EasyMagpieTTSModel +EasyMagpieInferenceRunner handles the decoder-only EasyMagpieTTSInferenceModel (infer_batch, returns audio directly). """ from __future__ import annotations @@ -40,7 +40,7 @@ from nemo.collections.asr.parts.utils.manifest_utils import read_manifest from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPATokenizer from nemo.collections.tts.data.text_to_speech_dataset import ChunkedTTSInferenceDataset, MagpieTTSDataset -from nemo.collections.tts.models.easy_magpietts import EasyModelInferenceParameters +from nemo.collections.tts.models.easy_magpietts_inference import EasyModelInferenceParameters from nemo.collections.tts.models.magpietts import ModelInferenceParameters from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors from nemo.utils import logging @@ -123,7 +123,7 @@ def build_identifier(self) -> str: @dataclass class EasyMagpieInferenceConfig(BaseInferenceConfig): - """Configuration for decoder-only EasyMagpieTTSModel inference.""" + """Configuration for decoder-only EasyMagpieTTSInferenceModel inference.""" model_inference_parameters: EasyModelInferenceParameters = field(default_factory=EasyModelInferenceParameters) phoneme_input_type: str = "gt" @@ -537,12 +537,12 @@ def _compute_end_of_text_flags( # --------------------------------------------------------------------------- -# EasyMagpieInferenceRunner (decoder-only EasyMagpieTTSModel) +# EasyMagpieInferenceRunner (decoder-only EasyMagpieTTSInferenceModel) # --------------------------------------------------------------------------- class EasyMagpieInferenceRunner(BaseInferenceRunner): - """Runner for decoder-only EasyMagpieTTSModel. + """Runner for decoder-only EasyMagpieTTSInferenceModel. Uses MagpieTTSDataset and model.infer_batch() which returns audio directly. """ @@ -581,6 +581,8 @@ def create_dataset( pad_context_text_to_max_duration=False, context_duration_min=context_duration_min, context_duration_max=context_duration_max, + ignore_phoneme_languages=self.config.get('ignore_phoneme_languages', []), + add_language_to_context_text=self.model.add_language_to_context_text ) dataset.text_tokenizer = self.model.tokenizer diff --git a/nemo/collections/tts/modules/magpietts_inference/utils.py b/nemo/collections/tts/modules/magpietts_inference/utils.py index a14cd0789f7a..9c67125f4343 100644 --- a/nemo/collections/tts/modules/magpietts_inference/utils.py +++ b/nemo/collections/tts/modules/magpietts_inference/utils.py @@ -28,7 +28,7 @@ import torch from omegaconf import DictConfig, OmegaConf, open_dict -from nemo.collections.tts.models import EasyMagpieTTSModel, MagpieTTSModel +from nemo.collections.tts.models import EasyMagpieTTSInferenceModel, MagpieTTSModel from nemo.utils import logging @@ -119,6 +119,7 @@ class ModelLoadConfig: legacy_codebooks: Use legacy codebook indices for old checkpoints. legacy_text_conditioning: Use legacy text conditioning for old checkpoints. hparams_from_wandb: Whether hparams file is from wandb export. + phoneme_tokenizer_path: Override path to the phoneme tokenizer file (EasyMagpieTTS only). """ hparams_file: Optional[str] = None @@ -128,6 +129,7 @@ class ModelLoadConfig: legacy_codebooks: bool = False legacy_text_conditioning: bool = False hparams_from_wandb: bool = False + phoneme_tokenizer_path: Optional[str] = None def validate(self) -> None: """Validate that the configuration is complete and consistent.""" @@ -336,8 +338,13 @@ def load_magpie_model(config: ModelLoadConfig, device: str = "cuda") -> Tuple[Ma return model, checkpoint_name -def load_easy_magpie_model(config: ModelLoadConfig, device: str = "cuda") -> Tuple[EasyMagpieTTSModel, str]: - """Load an EasyMagpieTTSModel (decoder-only) from checkpoint or NeMo archive. +def load_easy_magpie_model( + config: ModelLoadConfig, device: str = "cuda" +) -> Tuple[EasyMagpieTTSInferenceModel, str]: + """Load an EasyMagpieTTSInferenceModel (decoder-only) from checkpoint or NeMo archive. + + Uses the inference-only base class rather than the full training model, + which avoids pulling in training-specific dependencies. Supports two loading modes: 1. Checkpoint mode: hparams.yaml + .ckpt file @@ -367,8 +374,10 @@ def load_easy_magpie_model(config: ModelLoadConfig, device: str = "cuda") -> Tup model_cfg.codecmodel_path = config.codecmodel_path model_cfg.train_ds = None model_cfg.validation_ds = None + if config.phoneme_tokenizer_path and hasattr(model_cfg, 'phoneme_tokenizer'): + model_cfg.phoneme_tokenizer.tokenizer_path = config.phoneme_tokenizer_path - model = EasyMagpieTTSModel(cfg=model_cfg) + model = EasyMagpieTTSInferenceModel(cfg=model_cfg) logging.info(f"Loading weights from checkpoint: {config.checkpoint_file}") ckpt = torch.load(config.checkpoint_file) @@ -378,22 +387,29 @@ def load_easy_magpie_model(config: ModelLoadConfig, device: str = "cuda") -> Tup checkpoint_name = os.path.basename(config.checkpoint_file).replace(".ckpt", "") else: if config.nemo_file.startswith("nvidia/"): - model = EasyMagpieTTSModel.from_pretrained(config.nemo_file) + model = EasyMagpieTTSInferenceModel.from_pretrained(config.nemo_file) checkpoint_name = config.nemo_file.split("/")[-1] else: logging.info(f"Loading model from NeMo archive: {config.nemo_file}") - model_cfg = EasyMagpieTTSModel.restore_from(config.nemo_file, return_config=True) + model_cfg = EasyMagpieTTSInferenceModel.restore_from(config.nemo_file, return_config=True) with open_dict(model_cfg): model_cfg.codecmodel_path = config.codecmodel_path model_cfg.train_ds = None model_cfg.validation_ds = None + if config.phoneme_tokenizer_path and hasattr(model_cfg, 'phoneme_tokenizer'): + model_cfg.phoneme_tokenizer.tokenizer_path = config.phoneme_tokenizer_path + # Override target so restore_from instantiates the inference class, + # not the training subclass stored in the .nemo config. + model_cfg.target = ( + 'nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel' + ) - model = EasyMagpieTTSModel.restore_from(config.nemo_file, override_config_path=model_cfg) + model = EasyMagpieTTSInferenceModel.restore_from(config.nemo_file, override_config_path=model_cfg) checkpoint_name = os.path.basename(config.nemo_file).replace(".nemo", "") model.to(device) - model.eval() + model.eval().float() logging.info("EasyMagpieTTS model loaded and ready for inference.") return model, checkpoint_name @@ -480,7 +496,7 @@ def log_model_architecture_summary(model) -> Tuple[str, Dict[str, dict]]: Detects and logs MoE configuration for each transformer component, computing FLOPs metrics and parameter counts. Gracefully handles - decoder-only models (EasyMagpieTTSModel) that use HuggingFace/Nemotron + decoder-only models (EasyMagpieTTSInferenceModel) that use HuggingFace/Nemotron decoders without the d_model/d_ffn config structure. Args: