Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/tts/magpietts-longform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,23 +169,23 @@ The ``do_tts`` method automatically detects whether longform inference is needed
sf.write("output.wav", long_audio[0].cpu().numpy(), 22050)


Method 2: Using CLI (``tts_infer.py``)
Method 2: Using CLI (``magpietts_inference.py``)
------------------------------------------------

For batch inference from manifests:

.. code-block:: bash

# Auto-detect longform based on text length (default)
python examples/tts/tts_infer.py \
python examples/tts/magpietts_inference.py \
--nemo_files /path/to/magpietts.nemo \
--datasets_json_path /path/to/evalset_config.json \
--out_dir /path/to/output \
--codecmodel_path /path/to/codec.nemo \
--longform_mode auto

# Force longform inference for all inputs
python examples/tts/tts_infer.py \
python examples/tts/magpietts_inference.py \
--nemo_files /path/to/magpietts.nemo \
--datasets_json_path /path/to/evalset_config.json \
--out_dir /path/to/output \
Expand Down
2 changes: 1 addition & 1 deletion docs/source/tts/magpietts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ Several parameters control the generation behavior. The temperature setting affe

.. code-block:: bash

python examples/tts/tts_infer.py \
python examples/tts/magpietts_inference.py \
--nemo_files /path/to/magpietts_model.nemo \
--codecmodel_path /path/to/audio_codec.nemo \
--datasets your_evaluation_set \
Expand Down
60 changes: 26 additions & 34 deletions examples/tts/tts_infer.py → examples/tts/magpietts_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@

Example usage:
# MagpieTTS inference (encoder-decoder, default)
python examples/tts/tts_infer.py \\
python examples/tts/magpietts_inference.py \\
--model_type magpie \\
--nemo_files /path/to/model.nemo \\
--datasets_json_path /path/to/evalset_config.json \\
--out_dir /path/to/output \\
--codecmodel_path /path/to/codec.nemo

# EasyMagpieTTS inference (decoder-only)
python examples/tts/tts_infer.py \\
python examples/tts/magpietts_inference.py \\
--model_type easy_magpie \\
--nemo_files /path/to/model.nemo \\
--datasets_json_path /path/to/evalset_config.json \\
--out_dir /path/to/output \\
--codecmodel_path /path/to/codec.nemo

# With evaluation
python examples/tts/tts_infer.py \\
python examples/tts/magpietts_inference.py \\
--model_type magpie \\
--hparams_files /path/to/hparams.yaml \\
--checkpoint_files /path/to/model.ckpt \\
Expand All @@ -66,7 +66,7 @@
import numpy as np

from nemo.collections.asr.parts.utils.manifest_utils import read_manifest
from nemo.collections.tts.models.easy_magpietts import EasyModelInferenceParameters
from nemo.collections.tts.models.easy_magpietts_inference import EasyModelInferenceParameters
from nemo.collections.tts.models.magpietts import ModelInferenceParameters
from nemo.collections.tts.modules.magpietts_inference.evaluate_generated_audio import load_evalset_config
from nemo.collections.tts.modules.magpietts_inference.evaluation import (
Expand Down Expand Up @@ -161,11 +161,6 @@ def filter_datasets(dataset_meta_info: dict, datasets: Optional[List[str]]) -> L
return datasets


# ---------------------------------------------------------------------------
# Core inference + evaluation orchestration (model-type agnostic)
# ---------------------------------------------------------------------------


def run_inference_and_evaluation(
runner: BaseInferenceRunner,
checkpoint_name: str,
Expand Down Expand Up @@ -355,28 +350,34 @@ def run_inference_and_evaluation(
return None, None


# ---------------------------------------------------------------------------
# CLI argument parser
# ---------------------------------------------------------------------------
def _get_shared_inference_param_names() -> set:
"""Return the field names shared by ModelInferenceParameters and EasyModelInferenceParameters."""
magpie_fields = {f.name for f in fields(ModelInferenceParameters)}
easy_fields = {f.name for f in fields(EasyModelInferenceParameters)}
return magpie_fields & easy_fields


def _add_inference_param_fields(
group: argparse._ArgumentGroup,
param_cls: type,
skip_fields: Optional[set] = None,
only_fields: Optional[set] = None,
) -> None:
"""Auto-generate argparse arguments from fields of a dataclass.

Args:
group: The argparse argument group to add arguments to.
param_cls: The dataclass whose fields to add.
skip_fields: Field names to skip (already added by another group).
only_fields: If provided, only add fields whose names are in this set.
"""
if skip_fields is None:
skip_fields = set()
for f in fields(param_cls):
if f.name in skip_fields:
continue
if only_fields is not None and f.name not in only_fields:
continue
extra_args: dict = {"type": f.type}
if f.type == bool:
extra_args = {"action": "store_true"}
Expand All @@ -399,7 +400,7 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None:
default='magpie',
choices=['magpie', 'easy_magpie'],
help='Model type: "magpie" for encoder-decoder MagpieTTSModel, '
'"easy_magpie" for decoder-only EasyMagpieTTSModel',
'"easy_magpie" for decoder-only EasyMagpieTTSInferenceModel',
)

# Model loading
Expand Down Expand Up @@ -469,8 +470,9 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None:
infer_group.add_argument('--use_cfg', action='store_true', help='Enable classifier-free guidance')
infer_group.add_argument('--use_local_transformer', action='store_true')

# Shared model inference parameters (max_decoder_steps, temperature, topk, cfg_scale)
_add_inference_param_fields(infer_group, EasyModelInferenceParameters)
# Model inference parameters shared by both MagpieTTS and EasyMagpieTTS
shared_param_names = _get_shared_inference_param_names()
_add_inference_param_fields(infer_group, ModelInferenceParameters, only_fields=shared_param_names)

# Evaluation
eval_group = parser.add_argument_group('Evaluation')
Expand Down Expand Up @@ -499,9 +501,8 @@ def _add_magpie_args(parser: argparse.ArgumentParser) -> None:
group = parser.add_argument_group('MagpieTTS-specific Parameters')

# MagpieTTS-specific model inference parameters (attention prior, EOS, etc.)
# Skip fields already added by the common inference group.
shared_field_names = {f.name for f in fields(EasyModelInferenceParameters)}
_add_inference_param_fields(group, ModelInferenceParameters, skip_fields=shared_field_names)
shared_param_names = _get_shared_inference_param_names()
_add_inference_param_fields(group, ModelInferenceParameters, skip_fields=shared_param_names)

group.add_argument('--maskgit_n_steps', type=int, default=3)
group.add_argument('--maskgit_noise_scale', type=float, default=0.0)
Expand All @@ -514,7 +515,7 @@ def _add_magpie_args(parser: argparse.ArgumentParser) -> None:


def _add_easy_magpie_args(parser: argparse.ArgumentParser) -> None:
"""Add arguments specific to decoder-only EasyMagpieTTSModel."""
"""Add arguments specific to decoder-only EasyMagpieTTSInferenceModel."""
group = parser.add_argument_group('EasyMagpieTTS-specific Parameters')
group.add_argument(
'--phoneme_input_type',
Expand All @@ -532,9 +533,10 @@ def _add_easy_magpie_args(parser: argparse.ArgumentParser) -> None:
)
group.add_argument('--dropout_text_input', action='store_true', help='Force dropout on text input')
group.add_argument(
'--legacy_context_stacking',
action='store_true',
help='Use audio_bos_id/audio_eos_id for context stacking',
'--phoneme_tokenizer_path',
type=str,
default=None,
help='Override path to the phoneme tokenizer file (overrides the path stored in the checkpoint config)',
)


Expand All @@ -551,11 +553,6 @@ def create_argument_parser() -> argparse.ArgumentParser:
return parser


# ---------------------------------------------------------------------------
# Config builders (one per model type)
# ---------------------------------------------------------------------------


def _build_inference_params_from_args(param_cls: type, args):
"""Extract inference parameters from parsed CLI args for the given dataclass."""
params = {}
Expand Down Expand Up @@ -592,15 +589,8 @@ def _build_easy_magpie_config(args) -> EasyMagpieInferenceConfig:
phoneme_input_type=args.phoneme_input_type,
phoneme_sampling_method=args.phoneme_sampling_method,
dropout_text_input=args.dropout_text_input,
legacy_context_stacking=args.legacy_context_stacking,
)


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------


def main(argv=None):
"""Entry point for TTS inference and evaluation."""
parser = create_argument_parser()
Expand Down Expand Up @@ -656,6 +646,7 @@ def main(argv=None):
legacy_codebooks=args.legacy_codebooks,
legacy_text_conditioning=args.legacy_text_conditioning,
hparams_from_wandb=args.hparams_file_from_wandb,
phoneme_tokenizer_path=getattr(args, 'phoneme_tokenizer_path', None),
)

model, checkpoint_name = load_fn(model_config)
Expand Down Expand Up @@ -693,6 +684,7 @@ def main(argv=None):
codecmodel_path=args.codecmodel_path,
legacy_codebooks=args.legacy_codebooks,
legacy_text_conditioning=args.legacy_text_conditioning,
phoneme_tokenizer_path=getattr(args, 'phoneme_tokenizer_path', None),
)

model, checkpoint_name = load_fn(model_config)
Expand Down
25 changes: 1 addition & 24 deletions nemo/collections/tts/models/easy_magpietts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json
import os
import random
from dataclasses import dataclass, fields
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -98,29 +98,6 @@ class ProcessBatchOutput:
selected_training_mode: Optional[str]


@dataclass
class EasyModelInferenceParameters:
"""Inference parameters for the decoder-only EasyMagpieTTS model.

Attributes:
max_decoder_steps: Maximum number of decoder steps.
temperature: Sampling temperature.
topk: Number of top-probability tokens to consider in sampling.
cfg_scale: Scale factor for classifier-free guidance.
"""

max_decoder_steps: int = 500
temperature: float = 0.7
topk: int = 80
cfg_scale: float = 2.5

@classmethod
def from_dict(cls, data: dict) -> 'EasyModelInferenceParameters':
field_names = {field.name for field in fields(cls)}
filtered_data = {k: v for k, v in data.items() if k in field_names}
return cls(**filtered_data)


class EasyMagpieTTSModel(EasyMagpieTTSInferenceModel):
"""
Magpie-TTS Model Decoder Only Model with training support.
Expand Down
44 changes: 33 additions & 11 deletions nemo/collections/tts/models/easy_magpietts_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from dataclasses import dataclass
from dataclasses import dataclass, fields
from functools import partial
from typing import Any, Dict, List, Optional, Sequence, Tuple

Expand Down Expand Up @@ -184,6 +184,29 @@ class InferBatchOutput:
phoneme_prediction_start_idx: Optional[torch.Tensor] = None # (B,) start index into predicted_phoneme_tokens


@dataclass
class EasyModelInferenceParameters:
"""Inference parameters for the decoder-only EasyMagpieTTS model.

Attributes:
max_decoder_steps: Maximum number of decoder steps.
temperature: Sampling temperature.
topk: Number of top-probability tokens to consider in sampling.
cfg_scale: Scale factor for classifier-free guidance.
"""

max_decoder_steps: int = 300
temperature: float = 0.7
topk: int = 80
cfg_scale: float = 2.5

@classmethod
def from_dict(cls, data: dict) -> 'EasyModelInferenceParameters':
field_names = {field.name for field in fields(cls)}
filtered_data = {k: v for k, v in data.items() if k in field_names}
return cls(**filtered_data)


class EasyMagpieTTSInferenceModel(BaseMagpieTTSModel):
"""
Inference-only base class for EasyMagpieTTS decoder-only model.
Expand Down Expand Up @@ -319,6 +342,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):

self.pad_context_text_to_max_duration = False
self.add_language_to_context_text = cfg.get('add_language_to_context_text', False)
self.ignore_phoneme_languages = cfg.get('ignore_phoneme_languages', [])

super().__init__(cfg=cfg, trainer=trainer)

Expand Down Expand Up @@ -465,6 +489,12 @@ def _get_state_dict_keys_to_exclude(self):
'_codec_model',
]

def setup_training_data(self, train_data_config=None):
pass

def setup_validation_data(self, val_data_config=None):
pass

def codes_to_audio(self, codes, codes_len):
# codes: (B, C, T')
self._codec_model.eval()
Expand Down Expand Up @@ -734,19 +764,11 @@ def prepare_context_tensors(
eos_id=self.context_audio_eos_id,
)

# Use legacy audio_bos_id/audio_eos_id if flag is set
stack_bos_id = (
self.audio_bos_id if getattr(self, 'legacy_context_stacking', False) else self.context_audio_bos_id
)
stack_eos_id = (
self.audio_eos_id if getattr(self, 'legacy_context_stacking', False) else self.context_audio_eos_id
)

context_audio_codes, context_audio_codes_lens = self.stack_codes(
context_audio_codes,
context_audio_codes_lens,
stack_bos_id,
stack_eos_id,
self.context_audio_bos_id,
self.context_audio_eos_id,
self.frame_stacking_factor,
self.num_audio_codebooks,
)
Expand Down
Loading
Loading