diff --git a/docs/source/tts/magpietts-longform.rst b/docs/source/tts/magpietts-longform.rst index 33aef42a5abe..fb3eeb659d33 100644 --- a/docs/source/tts/magpietts-longform.rst +++ b/docs/source/tts/magpietts-longform.rst @@ -169,7 +169,7 @@ The ``do_tts`` method automatically detects whether longform inference is needed sf.write("output.wav", long_audio[0].cpu().numpy(), 22050) -Method 2: Using CLI (``magpietts_inference.py``) +Method 2: Using CLI (``tts_infer.py``) ------------------------------------------------ For batch inference from manifests: @@ -177,7 +177,7 @@ For batch inference from manifests: .. code-block:: bash # Auto-detect longform based on text length (default) - python examples/tts/magpietts_inference.py \ + python examples/tts/tts_infer.py \ --nemo_files /path/to/magpietts.nemo \ --datasets_json_path /path/to/evalset_config.json \ --out_dir /path/to/output \ @@ -185,7 +185,7 @@ For batch inference from manifests: --longform_mode auto # Force longform inference for all inputs - python examples/tts/magpietts_inference.py \ + python examples/tts/tts_infer.py \ --nemo_files /path/to/magpietts.nemo \ --datasets_json_path /path/to/evalset_config.json \ --out_dir /path/to/output \ diff --git a/docs/source/tts/magpietts.rst b/docs/source/tts/magpietts.rst index b79c11ea88ff..6d297a694596 100644 --- a/docs/source/tts/magpietts.rst +++ b/docs/source/tts/magpietts.rst @@ -130,7 +130,7 @@ Several parameters control the generation behavior. The temperature setting affe .. code-block:: bash - python examples/tts/magpietts_inference.py \ + python examples/tts/tts_infer.py \ --nemo_files /path/to/magpietts_model.nemo \ --codecmodel_path /path/to/audio_codec.nemo \ --datasets your_evaluation_set \ diff --git a/examples/tts/evalset_config.json b/examples/tts/evalset_config.json index 4be3056020ce..2d61a601f880 100644 --- a/examples/tts/evalset_config.json +++ b/examples/tts/evalset_config.json @@ -15,3 +15,4 @@ "feature_dir": null } } + diff --git a/examples/tts/magpietts_inference.py b/examples/tts/tts_infer.py similarity index 68% rename from examples/tts/magpietts_inference.py rename to examples/tts/tts_infer.py index f1ed60c27428..2c3bec0aa7f7 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/tts_infer.py @@ -12,25 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -MagpieTTS Inference and Evaluation Script. +TTS Inference and Evaluation Script. -Supports both standard and Mixture of Experts (MoE) models with: +Supports both encoder-decoder MagpieTTS and decoder-only EasyMagpieTTS models +with: - Automatic MoE detection and FLOPs calculation - Comprehensive evaluation metrics (RTF, FLOPs, CER, SSIM, etc.) -This script provides a clean CLI for running MagpieTTS inference with optional evaluation. -It decouples inference and evaluation into separate modules for better maintainability. +This script provides a clean CLI for running TTS inference with optional +evaluation. Model-specific behaviour (dataset creation, inference loop, CLI +arguments) is handled by separate runner classes so there is no scattered +if/else branching. Example usage: - # Inference only (from .nemo file) - default behavior - python examples/tts/magpietts_inference.py \\ + # MagpieTTS inference (encoder-decoder, default) + python examples/tts/tts_infer.py \\ + --model_type magpie \\ --nemo_files /path/to/model.nemo \\ --datasets_json_path /path/to/evalset_config.json \\ --out_dir /path/to/output \\ --codecmodel_path /path/to/codec.nemo - # Inference with evaluation (from checkpoint) - python examples/tts/magpietts_inference.py \\ + # EasyMagpieTTS inference (decoder-only) + python examples/tts/tts_infer.py \\ + --model_type easy_magpie \\ + --nemo_files /path/to/model.nemo \\ + --datasets_json_path /path/to/evalset_config.json \\ + --out_dir /path/to/output \\ + --codecmodel_path /path/to/codec.nemo + + # With evaluation + python examples/tts/tts_infer.py \\ + --model_type magpie \\ --hparams_files /path/to/hparams.yaml \\ --checkpoint_files /path/to/model.ckpt \\ --datasets_json_path /path/to/evalset_config.json \\ @@ -53,20 +66,27 @@ import numpy as np from nemo.collections.asr.parts.utils.manifest_utils import read_manifest +from nemo.collections.tts.models.easy_magpietts import EasyModelInferenceParameters from nemo.collections.tts.models.magpietts import ModelInferenceParameters from nemo.collections.tts.modules.magpietts_inference.evaluate_generated_audio import load_evalset_config - -# Import the modular components from nemo.collections.tts.modules.magpietts_inference.evaluation import ( DEFAULT_VIOLIN_METRICS, EvaluationConfig, compute_mean_with_confidence_interval, evaluate_generated_audio_dir, ) -from nemo.collections.tts.modules.magpietts_inference.inference import InferenceConfig, MagpieInferenceRunner +from nemo.collections.tts.modules.magpietts_inference.inference import ( + BaseInferenceConfig, + BaseInferenceRunner, + EasyMagpieInferenceConfig, + EasyMagpieInferenceRunner, + MagpieInferenceConfig, + MagpieInferenceRunner, +) from nemo.collections.tts.modules.magpietts_inference.utils import ( ModelLoadConfig, get_experiment_name_from_checkpoint_path, + load_easy_magpie_model, load_magpie_model, log_model_architecture_summary, ) @@ -132,50 +152,54 @@ def create_formatted_metrics_mean_ci(metrics_mean_ci: dict) -> dict: def filter_datasets(dataset_meta_info: dict, datasets: Optional[List[str]]) -> List[str]: """Select datasets from the dataset meta info.""" if datasets is None: - # Dataset filtering not specified, return all datasets return list(dataset_meta_info.keys()) else: datasets = datasets.split(",") - # Check if datasets are valid for dataset in datasets: if dataset not in dataset_meta_info: raise ValueError(f"Dataset {dataset} not found in dataset meta info") - # Return all requsted datasets return datasets +# --------------------------------------------------------------------------- +# Core inference + evaluation orchestration (model-type agnostic) +# --------------------------------------------------------------------------- + + def run_inference_and_evaluation( - model_config: ModelLoadConfig, - inference_config: InferenceConfig, + runner: BaseInferenceRunner, + checkpoint_name: str, + inference_config: BaseInferenceConfig, eval_config: EvaluationConfig, dataset_meta_info: dict, - datasets: Optional[List[str]], + datasets: List[str], out_dir: str, + flops_per_component: dict, + moe_info: str, num_repeats: int = 1, confidence_level: float = 0.95, violin_plot_metrics: Optional[List[str]] = None, - log_exp_name: bool = False, clean_up_disk: bool = False, skip_evaluation: bool = False, ) -> Tuple[Optional[float], Optional[float]]: """Run inference and optional evaluation on specified datasets. - Uses unified inference path with automatic text chunking based on - per-sample language thresholds. Short texts are processed as single chunks, - long texts are automatically split into sentences. + This function is model-type agnostic -- it delegates dataset creation + and batch inference to the provided ``runner``. Args: - model_config: Configuration for loading the model. + runner: Concrete inference runner (MagpieInferenceRunner or EasyMagpieInferenceRunner). + checkpoint_name: Human-readable checkpoint identifier for output naming. inference_config: Configuration for inference. eval_config: Configuration for evaluation. dataset_meta_info: Dictionary containing dataset metadata. - datasets: List of dataset names to run inference and evaluation on. If None, all datasets in the - dataset meta info will be processed. + datasets: List of dataset names to process. out_dir: Output directory for results. + flops_per_component: FLOPs info dict from log_model_architecture_summary. + moe_info: MoE identifier string from log_model_architecture_summary. num_repeats: Number of times to repeat inference (for CI estimation). confidence_level: Confidence level for CI calculation. violin_plot_metrics: Metrics to include in violin plots. - log_exp_name: Whether to include experiment name in output paths. clean_up_disk: Whether to clean up output directory after completion. skip_evaluation: Whether to skip evaluation (inference only mode). @@ -185,40 +209,17 @@ def run_inference_and_evaluation( if violin_plot_metrics is None: violin_plot_metrics = list(DEFAULT_VIOLIN_METRICS) - # Remove UTMOSv2 from plots if disabled if not eval_config.with_utmosv2 and 'utmosv2' in violin_plot_metrics: violin_plot_metrics.remove('utmosv2') - # Load model - model, checkpoint_name = load_magpie_model( - model_config, is_decoder_only_model=inference_config.is_decoder_only_model - ) - # change model to fp32 for inference - model = model.float() - - # Log architecture summary and get MoE info + FLOPs metrics - moe_info, flops_per_component = log_model_architecture_summary(model) - - # Add experiment name prefix if requested - if log_exp_name and model_config.checkpoint_file: - exp_name = get_experiment_name_from_checkpoint_path(model_config.checkpoint_file) - checkpoint_name = f"{exp_name}__{checkpoint_name}" - - # Build full checkpoint identifier (include MoE info if present) full_checkpoint_name = ( f"{checkpoint_name}_{moe_info}{inference_config.build_identifier()}_SV_{eval_config.sv_model}" ) - # Create inference runner (uses unified path with automatic text chunking) - logging.info("Using unified inference with automatic text chunking based on language thresholds") - runner = MagpieInferenceRunner(model, inference_config) - - # Tracking metrics across datasets ssim_per_dataset = [] cer_per_dataset = [] all_datasets_filewise_metrics = {} - # CSV headers csv_header = ( "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative," "wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg," @@ -234,17 +235,14 @@ def run_inference_and_evaluation( manifest_records = read_manifest(meta['manifest_path']) language = meta.get('whisper_language', 'en') - # Prepare dataset metadata (remove evaluation-specific keys) dataset_meta_for_dl = copy.deepcopy(meta) for key in ["whisper_language", "load_cached_codes_if_available"]: dataset_meta_for_dl.pop(key, None) - # Setup output directories eval_dir = os.path.join(out_dir, f"{full_checkpoint_name}_{dataset}") audio_dir = os.path.join(eval_dir, "audio") os.makedirs(eval_dir, exist_ok=True) - # Setup CSV files per_run_csv = os.path.join(eval_dir, "all_experiment_metrics.csv") write_csv_header_if_needed(per_run_csv, csv_header) @@ -257,7 +255,6 @@ def run_inference_and_evaluation( repeat_audio_dir = os.path.join(audio_dir, f"repeat_{repeat_idx}") os.makedirs(repeat_audio_dir, exist_ok=True) - # Create dataset and run inference test_dataset = runner.create_dataset({dataset: dataset_meta_for_dl}) if len(test_dataset) != len(manifest_records): @@ -271,14 +268,12 @@ def run_inference_and_evaluation( manifest_records=manifest_records, audio_base_dir=meta['audio_dir'], save_cross_attention_maps=True, - save_context_audio=(repeat_idx == 0), # Only save context audio once - save_predicted_codes=eval_config.with_fcd, # Code files are only needed for FCD computation + save_context_audio=(repeat_idx == 0), + save_predicted_codes=eval_config.with_fcd, ) - # Compute mean RTF metrics mean_rtf = runner.compute_mean_rtf_metrics(rtf_metrics_list) - # Add FLOPs metrics per component for component_name, component_flops in flops_per_component.items(): for key, value in component_flops.items(): mean_rtf[f"{component_name}_{key}"] = value @@ -291,7 +286,6 @@ def run_inference_and_evaluation( logging.info("Skipping evaluation as requested.") continue - # Run evaluation eval_config_for_dataset = EvaluationConfig( sv_model=eval_config.sv_model, asr_model_name=eval_config.asr_model_name, @@ -312,7 +306,6 @@ def run_inference_and_evaluation( metrics_all_repeats.append(metrics) filewise_metrics_all_repeats.extend(filewise_metrics) - # Save metrics with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w") as f: json.dump(metrics, f, indent=4) @@ -320,24 +313,19 @@ def run_inference_and_evaluation( with open(os.path.join(eval_dir, f"{dataset}_filewise_metrics_{repeat_idx}.json"), "w") as f: json.dump(sorted_filewise, f, indent=4) - # Append to per-run CSV append_metrics_to_csv(per_run_csv, full_checkpoint_name, dataset, metrics) - # Create violin plot for this repeat violin_path = Path(eval_dir) / f"{dataset}_violin_{repeat_idx}.png" create_violin_plot(filewise_metrics, violin_plot_metrics, violin_path) - # Delete temporary predicted codes files for codec_file_path in codec_file_paths: os.remove(codec_file_path) if skip_evaluation or not metrics_all_repeats: continue - # Store for combined plot all_datasets_filewise_metrics[dataset] = filewise_metrics_all_repeats - # Compute mean with confidence interval across repeats metrics_mean_ci = compute_mean_with_confidence_interval( metrics_all_repeats, confidence=confidence_level, @@ -345,42 +333,76 @@ def run_inference_and_evaluation( formatted_metrics_mean_ci = create_formatted_metrics_mean_ci(metrics_mean_ci) - # Write to aggregated CSV ci_csv = os.path.join(out_dir, "all_experiment_metrics_with_ci.csv") write_csv_header_if_needed(ci_csv, csv_header) append_metrics_to_csv(ci_csv, full_checkpoint_name, dataset, formatted_metrics_mean_ci) - # Track per-dataset means ssim_values = [m['ssim_pred_context_avg'] for m in metrics_all_repeats] cer_values = [m['cer_cumulative'] for m in metrics_all_repeats] ssim_per_dataset.append(np.mean(ssim_values)) cer_per_dataset.append(np.mean(cer_values)) - # Create combined plot if we have multiple datasets if len(all_datasets_filewise_metrics) > 1: combined_plot_path = os.path.join(out_dir, f"{full_checkpoint_name}_combined_violin_plot.png") create_combined_box_plot(all_datasets_filewise_metrics, violin_plot_metrics, combined_plot_path) - # Clean up if requested if clean_up_disk: logging.info(f"Cleaning up output directory: {out_dir}") shutil.rmtree(out_dir) - # Return averaged metrics if ssim_per_dataset and cer_per_dataset: return np.mean(cer_per_dataset), np.mean(ssim_per_dataset) return None, None -def create_argument_parser() -> argparse.ArgumentParser: - """Create the CLI argument parser.""" - parser = argparse.ArgumentParser( - description='MagpieTTS Inference and Evaluation', - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__, +# --------------------------------------------------------------------------- +# CLI argument parser +# --------------------------------------------------------------------------- + + +def _add_inference_param_fields( + group: argparse._ArgumentGroup, + param_cls: type, + skip_fields: Optional[set] = None, +) -> None: + """Auto-generate argparse arguments from fields of a dataclass. + + Args: + group: The argparse argument group to add arguments to. + param_cls: The dataclass whose fields to add. + skip_fields: Field names to skip (already added by another group). + """ + if skip_fields is None: + skip_fields = set() + for f in fields(param_cls): + if f.name in skip_fields: + continue + extra_args: dict = {"type": f.type} + if f.type == bool: + extra_args = {"action": "store_true"} + if f.name in ("estimate_alignment_from_layers", "apply_prior_to_layers"): + extra_args = { + "help": "Must be a comma separate string. Not enclosed in brackets", + "type": str, + } + elif f.name == "eos_detection_method": + extra_args["choices"] = [m.value for m in EOSDetectionMethod] + group.add_argument(f"--{f.name}", **extra_args) + + +def _add_common_args(parser: argparse.ArgumentParser) -> None: + """Add arguments shared by all model types.""" + + parser.add_argument( + '--model_type', + type=str, + default='magpie', + choices=['magpie', 'easy_magpie'], + help='Model type: "magpie" for encoder-decoder MagpieTTSModel, ' + '"easy_magpie" for decoder-only EasyMagpieTTSModel', ) - # Model loading arguments + # Model loading model_group = parser.add_argument_group('Model Loading') model_group.add_argument( '--hparams_files', @@ -422,73 +444,37 @@ def create_argument_parser() -> argparse.ArgumentParser: help='Use legacy text conditioning (for old checkpoints)', ) - # Dataset and output arguments + # Dataset and output data_group = parser.add_argument_group('Dataset and Output') data_group.add_argument( '--datasets_json_path', type=str, required=True, default=None, - help='Path to dataset configuration JSON file (will process all datasets in the file if --datasets is not specified)', + help='Path to dataset configuration JSON file', ) data_group.add_argument( '--datasets', type=str, default=None, - help='Comma-separated list of dataset names to process using names from the datasets_json_path file. If not specified, all datasets in the datasets_json_path will be processed.', - ) - data_group.add_argument( - '--out_dir', - type=str, - required=True, - help='Output directory for generated audio and metrics', - ) - data_group.add_argument( - '--log_exp_name', - action='store_true', - help='Include experiment name in output folder name', - ) - data_group.add_argument( - '--clean_up_disk', - action='store_true', - help='Delete output directory after completion', + help='Comma-separated list of dataset names to process', ) + data_group.add_argument('--out_dir', type=str, required=True, help='Output directory') + data_group.add_argument('--log_exp_name', action='store_true') + data_group.add_argument('--clean_up_disk', action='store_true') - # Inference arguments - infer_group = parser.add_argument_group('Inference Parameters') - # Add model specific parameters - for field in fields(ModelInferenceParameters): - extra_args = {"type": field.type} - if field.type == bool: - extra_args["action"] = "store_true" - del extra_args["type"] - if field.name == "estimate_alignment_from_layers" or field.name == "apply_prior_to_layers": - extra_args["help"] = "Must be a comma separate string. Not enclosed in brackets" - extra_args["type"] = str - elif field.name == "eos_detection_method": - extra_args["choices"] = [m.value for m in EOSDetectionMethod] - infer_group.add_argument(f"--{field.name}", **extra_args) + # Common inference parameters + infer_group = parser.add_argument_group('Common Inference Parameters') infer_group.add_argument('--batch_size', type=int, default=32) infer_group.add_argument('--use_cfg', action='store_true', help='Enable classifier-free guidance') - - # Local transformer / MaskGit arguments infer_group.add_argument('--use_local_transformer', action='store_true') - infer_group.add_argument('--maskgit_n_steps', type=int, default=3) - infer_group.add_argument('--maskgit_noise_scale', type=float, default=0.0) - infer_group.add_argument('--maskgit_fixed_schedule', type=int, nargs='+', default=None) - infer_group.add_argument( - '--maskgit_sampling_type', - default=None, - choices=["default", "causal", "purity_causal", "purity_default"], - ) - # Evaluation arguments + # Shared model inference parameters (max_decoder_steps, temperature, topk, cfg_scale) + _add_inference_param_fields(infer_group, EasyModelInferenceParameters) + + # Evaluation eval_group = parser.add_argument_group('Evaluation') - eval_group.add_argument( - '--run_evaluation', - action='store_true', - help='Run evaluation after inference (default: False, inference only)', - ) + eval_group.add_argument('--run_evaluation', action='store_true', help='Run evaluation after inference') eval_group.add_argument('--sv_model', type=str, default="titanet", choices=["titanet", "wavlm"]) eval_group.add_argument('--asr_model_name', type=str, default="nvidia/parakeet-tdt-1.1b") eval_group.add_argument('--num_repeats', type=int, default=1) @@ -500,70 +486,92 @@ def create_argument_parser() -> argparse.ArgumentParser: nargs='*', default=['cer', 'pred_context_ssim', 'utmosv2'], ) - eval_group.add_argument('--disable_fcd', action='store_true', help="Disable Frechet Codec Distance computation") + eval_group.add_argument('--disable_fcd', action='store_true') - # Quality targets (for CI/CD) + # Quality targets target_group = parser.add_argument_group('Quality Targets') target_group.add_argument('--cer_target', type=float, default=None) target_group.add_argument('--ssim_target', type=float, default=None) - target_group.add_argument('--is_decoder_only_model', action='store_true') - target_group.add_argument( - '--legacy_context_stacking', - action='store_true', - help='Use audio_bos_id/audio_eos_id instead of context_audio_bos_id/context_audio_eos_id for context stacking', - ) - target_group.add_argument('--phoneme_input_type', type=str, default='gt', choices=['predicted', 'gt']) - target_group.add_argument( - '--phoneme_sampling_method', type=str, default='argmax', choices=['argmax', 'multinomial'] - ) - target_group.add_argument('--dropout_text_input', action='store_true') - return parser +def _add_magpie_args(parser: argparse.ArgumentParser) -> None: + """Add arguments specific to encoder-decoder MagpieTTSModel.""" + group = parser.add_argument_group('MagpieTTS-specific Parameters') -def main(argv=None): - """Entry point for MagpieTTS inference and evaluation. + # MagpieTTS-specific model inference parameters (attention prior, EOS, etc.) + # Skip fields already added by the common inference group. + shared_field_names = {f.name for f in fields(EasyModelInferenceParameters)} + _add_inference_param_fields(group, ModelInferenceParameters, skip_fields=shared_field_names) - Args: - argv: Command-line arguments. If None, uses sys.argv. - """ - parser = create_argument_parser() - args = parser.parse_args(argv) + group.add_argument('--maskgit_n_steps', type=int, default=3) + group.add_argument('--maskgit_noise_scale', type=float, default=0.0) + group.add_argument('--maskgit_fixed_schedule', type=int, nargs='+', default=None) + group.add_argument( + '--maskgit_sampling_type', + default=None, + choices=["default", "causal", "purity_causal", "purity_default"], + ) - dataset_meta_info = load_evalset_config(args.datasets_json_path) - datasets = filter_datasets(dataset_meta_info, args.datasets) - logging.info(f"Loaded {len(datasets)} datasets: {', '.join(datasets)}") +def _add_easy_magpie_args(parser: argparse.ArgumentParser) -> None: + """Add arguments specific to decoder-only EasyMagpieTTSModel.""" + group = parser.add_argument_group('EasyMagpieTTS-specific Parameters') + group.add_argument( + '--phoneme_input_type', + type=str, + default='gt', + choices=['gt', 'predicted'], + help='Source of phoneme input for decoder-only model', + ) + group.add_argument( + '--phoneme_sampling_method', + type=str, + default='argmax', + choices=['argmax', 'multinomial'], + help='Sampling method for phoneme prediction', + ) + group.add_argument('--dropout_text_input', action='store_true', help='Force dropout on text input') + group.add_argument( + '--legacy_context_stacking', + action='store_true', + help='Use audio_bos_id/audio_eos_id for context stacking', + ) - # Determine mode and validate - has_checkpoint_mode = ( - args.hparams_files is not None - and args.checkpoint_files is not None - and args.hparams_files != "null" - and args.checkpoint_files != "null" + +def create_argument_parser() -> argparse.ArgumentParser: + """Create the CLI argument parser with all argument groups.""" + parser = argparse.ArgumentParser( + description='TTS Inference and Evaluation (MagpieTTS & EasyMagpieTTS)', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, ) - has_nemo_mode = args.nemo_files is not None and args.nemo_files != "null" + _add_common_args(parser) + _add_magpie_args(parser) + _add_easy_magpie_args(parser) + return parser - if not has_checkpoint_mode and not has_nemo_mode: - parser.error("You must provide either:\n 1. --hparams_files and --checkpoint_files\n 2. --nemo_files") - # Build configurations - model_inference_parameters = {} - for field in fields(ModelInferenceParameters): - field_name = field.name - arg_from_cmdline = vars(args)[field_name] - if arg_from_cmdline is not None: - if field_name in ["estimate_alignment_from_layers", "apply_prior_to_layers"]: - model_inference_parameters[field_name] = parse_layer_list(arg_from_cmdline) +# --------------------------------------------------------------------------- +# Config builders (one per model type) +# --------------------------------------------------------------------------- + + +def _build_inference_params_from_args(param_cls: type, args): + """Extract inference parameters from parsed CLI args for the given dataclass.""" + params = {} + for f in fields(param_cls): + arg_val = vars(args).get(f.name) + if arg_val is not None: + if f.name in ("estimate_alignment_from_layers", "apply_prior_to_layers"): + params[f.name] = parse_layer_list(arg_val) else: - model_inference_parameters[field_name] = arg_from_cmdline + params[f.name] = arg_val + return param_cls.from_dict(params) - if "max_decoder_steps" not in model_inference_parameters: - if args.is_decoder_only_model: - model_inference_parameters["max_decoder_steps"] = 300 - inference_config = InferenceConfig( - model_inference_parameters=ModelInferenceParameters.from_dict(model_inference_parameters), +def _build_magpie_config(args) -> MagpieInferenceConfig: + return MagpieInferenceConfig( + model_inference_parameters=_build_inference_params_from_args(ModelInferenceParameters, args), batch_size=args.batch_size, use_cfg=args.use_cfg, apply_attention_prior=args.apply_attention_prior, @@ -572,13 +580,54 @@ def main(argv=None): maskgit_noise_scale=args.maskgit_noise_scale, maskgit_fixed_schedule=args.maskgit_fixed_schedule, maskgit_sampling_type=args.maskgit_sampling_type, - is_decoder_only_model=args.is_decoder_only_model, + ) + + +def _build_easy_magpie_config(args) -> EasyMagpieInferenceConfig: + return EasyMagpieInferenceConfig( + model_inference_parameters=_build_inference_params_from_args(EasyModelInferenceParameters, args), + batch_size=args.batch_size, + use_cfg=args.use_cfg, + use_local_transformer=args.use_local_transformer, phoneme_input_type=args.phoneme_input_type, phoneme_sampling_method=args.phoneme_sampling_method, dropout_text_input=args.dropout_text_input, legacy_context_stacking=args.legacy_context_stacking, ) + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(argv=None): + """Entry point for TTS inference and evaluation.""" + parser = create_argument_parser() + args = parser.parse_args(argv) + + dataset_meta_info = load_evalset_config(args.datasets_json_path) + datasets = filter_datasets(dataset_meta_info, args.datasets) + logging.info(f"Loaded {len(datasets)} datasets: {', '.join(datasets)}") + + # Validate model loading args + has_checkpoint_mode = ( + args.hparams_files is not None + and args.checkpoint_files is not None + and args.hparams_files != "null" + and args.checkpoint_files != "null" + ) + has_nemo_mode = args.nemo_files is not None and args.nemo_files != "null" + + if not has_checkpoint_mode and not has_nemo_mode: + parser.error("You must provide either:\n 1. --hparams_files and --checkpoint_files\n 2. --nemo_files") + + # Select model loader and config builder based on --model_type + is_easy_magpie = args.model_type == 'easy_magpie' + load_fn = load_easy_magpie_model if is_easy_magpie else load_magpie_model + inference_config = _build_easy_magpie_config(args) if is_easy_magpie else _build_magpie_config(args) + runner_cls = EasyMagpieInferenceRunner if is_easy_magpie else MagpieInferenceRunner + eval_config = EvaluationConfig( sv_model=args.sv_model, asr_model_name=args.asr_model_name, @@ -589,7 +638,7 @@ def main(argv=None): cer, ssim = None, None - # Run for each model (checkpoint or nemo) + # Iterate over model files (checkpoint or nemo) if has_checkpoint_mode: hparam_files = args.hparams_files.split(",") checkpoint_files = args.checkpoint_files.split(",") @@ -609,17 +658,28 @@ def main(argv=None): hparams_from_wandb=args.hparams_file_from_wandb, ) + model, checkpoint_name = load_fn(model_config) + moe_info, flops_per_component = log_model_architecture_summary(model) + + if args.log_exp_name and model_config.checkpoint_file: + exp_name = get_experiment_name_from_checkpoint_path(model_config.checkpoint_file) + checkpoint_name = f"{exp_name}__{checkpoint_name}" + + runner = runner_cls(model, inference_config) + cer, ssim = run_inference_and_evaluation( - model_config=model_config, + runner=runner, + checkpoint_name=checkpoint_name, inference_config=inference_config, eval_config=eval_config, dataset_meta_info=dataset_meta_info, datasets=datasets, out_dir=args.out_dir, + flops_per_component=flops_per_component, + moe_info=moe_info, num_repeats=args.num_repeats, confidence_level=args.confidence_level, violin_plot_metrics=args.violin_plot_metrics, - log_exp_name=args.log_exp_name, clean_up_disk=args.clean_up_disk, skip_evaluation=not args.run_evaluation, ) @@ -635,17 +695,24 @@ def main(argv=None): legacy_text_conditioning=args.legacy_text_conditioning, ) + model, checkpoint_name = load_fn(model_config) + moe_info, flops_per_component = log_model_architecture_summary(model) + + runner = runner_cls(model, inference_config) + cer, ssim = run_inference_and_evaluation( - model_config=model_config, + runner=runner, + checkpoint_name=checkpoint_name, inference_config=inference_config, eval_config=eval_config, dataset_meta_info=dataset_meta_info, datasets=datasets, out_dir=args.out_dir, + flops_per_component=flops_per_component, + moe_info=moe_info, num_repeats=args.num_repeats, confidence_level=args.confidence_level, violin_plot_metrics=args.violin_plot_metrics, - log_exp_name=args.log_exp_name, clean_up_disk=args.clean_up_disk, skip_evaluation=not args.run_evaluation, ) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 5a117432b986..19705eed1ad3 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -14,7 +14,7 @@ import json import os import random -from dataclasses import dataclass +from dataclasses import dataclass, fields from typing import Dict, List, Optional, Tuple import numpy as np @@ -98,6 +98,29 @@ class ProcessBatchOutput: selected_training_mode: Optional[str] +@dataclass +class EasyModelInferenceParameters: + """Inference parameters for the decoder-only EasyMagpieTTS model. + + Attributes: + max_decoder_steps: Maximum number of decoder steps. + temperature: Sampling temperature. + topk: Number of top-probability tokens to consider in sampling. + cfg_scale: Scale factor for classifier-free guidance. + """ + + max_decoder_steps: int = 500 + temperature: float = 0.7 + topk: int = 80 + cfg_scale: float = 2.5 + + @classmethod + def from_dict(cls, data: dict) -> 'EasyModelInferenceParameters': + field_names = {field.name for field in fields(cls)} + filtered_data = {k: v for k, v in data.items() if k in field_names} + return cls(**filtered_data) + + class EasyMagpieTTSModel(EasyMagpieTTSInferenceModel): """ Magpie-TTS Model Decoder Only Model with training support. diff --git a/nemo/collections/tts/modules/magpietts_inference/__init__.py b/nemo/collections/tts/modules/magpietts_inference/__init__.py index fd99780f21b2..b1ff0aefe91e 100644 --- a/nemo/collections/tts/modules/magpietts_inference/__init__.py +++ b/nemo/collections/tts/modules/magpietts_inference/__init__.py @@ -12,35 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -MagpieTTS inference and evaluation subpackage. +TTS inference and evaluation subpackage. This package provides modular components for: - Model loading and configuration (utils.py) -- Batch inference (inference.py) +- Batch inference (inference.py) for both MagpieTTS and EasyMagpieTTS - Audio quality evaluation (evaluation.py) - Metrics visualization (visualization.py) -Example Usage: - from examples.tts.magpietts import ( - InferenceConfig, +Example Usage (MagpieTTS - encoder-decoder): + from nemo.collections.tts.modules.magpietts_inference import ( + MagpieInferenceConfig, MagpieInferenceRunner, load_magpie_model, ModelLoadConfig, ) - # Load model - model_config = ModelLoadConfig( - nemo_file="/path/to/model.nemo", - codecmodel_path="/path/to/codec.nemo", - ) - model, checkpoint_name = load_magpie_model(model_config) + model_config = ModelLoadConfig(nemo_file="/path/to/model.nemo", codecmodel_path="/path/to/codec.nemo") + model, name = load_magpie_model(model_config) + runner = MagpieInferenceRunner(model, MagpieInferenceConfig()) - # Log architecture summary and retrieve MoE info + FLOPs metrics - moe_info, flops_per_component = log_model_architecture_summary(model) +Example Usage (EasyMagpieTTS - decoder-only): + from nemo.collections.tts.modules.magpietts_inference import ( + EasyMagpieInferenceConfig, + EasyMagpieInferenceRunner, + load_easy_magpie_model, + ModelLoadConfig, + ) - # Create runner and run inference - inference_config = InferenceConfig() - runner = MagpieInferenceRunner(model, inference_config) + model_config = ModelLoadConfig(nemo_file="/path/to/model.nemo", codecmodel_path="/path/to/codec.nemo") + model, name = load_easy_magpie_model(model_config) + runner = EasyMagpieInferenceRunner(model, EasyMagpieInferenceConfig()) """ from nemo.collections.tts.modules.magpietts_inference.evaluation import ( @@ -49,11 +51,20 @@ compute_mean_with_confidence_interval, evaluate_generated_audio_dir, ) -from nemo.collections.tts.modules.magpietts_inference.inference import InferenceConfig, MagpieInferenceRunner +from nemo.collections.tts.modules.magpietts_inference.inference import ( + BaseInferenceConfig, + BaseInferenceRunner, + EasyMagpieInferenceConfig, + EasyMagpieInferenceRunner, + InferenceConfig, + MagpieInferenceConfig, + MagpieInferenceRunner, +) from nemo.collections.tts.modules.magpietts_inference.utils import ( ModelLoadConfig, compute_ffn_flops_per_token, get_experiment_name_from_checkpoint_path, + load_easy_magpie_model, load_magpie_model, log_model_architecture_summary, ) @@ -63,12 +74,19 @@ # Utils "ModelLoadConfig", "load_magpie_model", + "load_easy_magpie_model", "compute_ffn_flops_per_token", "get_experiment_name_from_checkpoint_path", "log_model_architecture_summary", - # Inference + # Inference configs + "BaseInferenceConfig", + "MagpieInferenceConfig", + "EasyMagpieInferenceConfig", "InferenceConfig", + # Inference runners + "BaseInferenceRunner", "MagpieInferenceRunner", + "EasyMagpieInferenceRunner", # Evaluation "EvaluationConfig", "evaluate_generated_audio_dir", diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index cf325b91d71c..d5d34537e088 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -12,15 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Core inference logic for MagpieTTS. +Core inference logic for MagpieTTS models. -This module provides: -- InferenceConfig: Dataclass for inference hyperparameters -- MagpieInferenceRunner: Class for running batch inference with a loaded model - (uses unified inference path with automatic text chunking based on language thresholds) +This module provides a strategy-pattern based inference framework with: +- BaseInferenceConfig / MagpieInferenceConfig / EasyMagpieInferenceConfig +- BaseInferenceRunner / MagpieInferenceRunner / EasyMagpieInferenceRunner + +MagpieInferenceRunner handles the encoder-decoder MagpieTTSModel +(chunked text, generate_speech + codes_to_audio). + +EasyMagpieInferenceRunner handles the decoder-only EasyMagpieTTSModel +(infer_batch, returns audio directly). """ from __future__ import annotations +import abc import glob import os import shutil @@ -34,65 +40,56 @@ from nemo.collections.asr.parts.utils.manifest_utils import read_manifest from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPATokenizer from nemo.collections.tts.data.text_to_speech_dataset import ChunkedTTSInferenceDataset, MagpieTTSDataset -from nemo.collections.tts.models import EasyMagpieTTSModel, MagpieTTSModel +from nemo.collections.tts.models.easy_magpietts import EasyModelInferenceParameters from nemo.collections.tts.models.magpietts import ModelInferenceParameters from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors from nemo.utils import logging +# --------------------------------------------------------------------------- +# Inference config hierarchy +# --------------------------------------------------------------------------- + + @dataclass -class InferenceConfig: - """Configuration for MagpieTTS inference. - - Attributes: - batch_size: Batch size for inference. - use_cfg: Whether to use classifier-free guidance. - apply_attention_prior: Whether to apply attention prior during decoding. - - # Model specific inference parameters - model_inference_parameters: See ModelInferenceParameters dataclass - - # Local transformer / MaskGit parameters - use_local_transformer: Whether to use local transformer for inference. - maskgit_n_steps: Number of MaskGit refinement steps. - maskgit_noise_scale: Noise scale for MaskGit sampling. - maskgit_fixed_schedule: Fixed schedule for MaskGit (optional). - maskgit_sampling_type: Type of MaskGit sampling. +class BaseInferenceConfig(abc.ABC): + """Shared inference configuration fields. + + Subclasses must declare their own ``model_inference_parameters`` field + with the appropriate type (ModelInferenceParameters or + EasyModelInferenceParameters). """ - # Core sampling parameters batch_size: int = 32 use_cfg: bool = False - apply_attention_prior: bool = False + use_local_transformer: bool = False + + @abc.abstractmethod + def build_identifier(self) -> str: + """Build a unique identifier string for naming output directories.""" + ... + + @staticmethod + def _format_layer_list(layers: Optional[List[int]]) -> str: + if layers is None: + return "None" + return "".join(str(_layer) for _layer in layers) + + +@dataclass +class MagpieInferenceConfig(BaseInferenceConfig): + """Configuration for encoder-decoder MagpieTTSModel inference.""" + model_inference_parameters: ModelInferenceParameters = field(default_factory=ModelInferenceParameters) + apply_attention_prior: bool = False - # Local transformer / MaskGit parameters - use_local_transformer: bool = False + # MaskGit parameters maskgit_n_steps: int = 3 maskgit_noise_scale: float = 0.0 maskgit_fixed_schedule: Optional[List[int]] = None maskgit_sampling_type: Optional[str] = None - # Decoder-only inference options - phoneme_input_type: str = "gt" # gt or predicted - phoneme_sampling_method: str = "argmax" # argmax or multinomial - dropout_text_input: bool = False - legacy_context_stacking: bool = False # Use audio_bos_id/audio_eos_id for context stacking - - # Longform inference mode - longform_mode: str = "auto" # "auto" | "always" | "never" - longform_word_threshold: int = 40 # Word threshold for auto-detection - - is_decoder_only_model: bool = False - def build_identifier(self) -> str: - """Build a unique identifier string for this configuration. - - Used for naming output directories and files. - - Returns: - String identifier incorporating key config values. - """ parts = [ f"Temp{self.model_inference_parameters.temperature}", f"Topk{self.model_inference_parameters.topk}", @@ -123,134 +120,69 @@ def build_identifier(self) -> str: return "_".join(parts) - @staticmethod - def _format_layer_list(layers: Optional[List[int]]) -> str: - """Format a list of layer indices as a compact string.""" - if layers is None: - return "None" - return "".join(str(_layer) for _layer in layers) + +@dataclass +class EasyMagpieInferenceConfig(BaseInferenceConfig): + """Configuration for decoder-only EasyMagpieTTSModel inference.""" + + model_inference_parameters: EasyModelInferenceParameters = field( + default_factory=EasyModelInferenceParameters + ) + phoneme_input_type: str = "gt" + phoneme_sampling_method: str = "argmax" + dropout_text_input: bool = False + legacy_context_stacking: bool = False + + def build_identifier(self) -> str: + parts = [ + f"Temp{self.model_inference_parameters.temperature}", + f"Topk{self.model_inference_parameters.topk}", + f"Cfg_{self.use_cfg}_{self.model_inference_parameters.cfg_scale}", + f"LT_{self.use_local_transformer}", + f"Phoneme_{self.phoneme_input_type}_{self.phoneme_sampling_method}", + ] + return "_".join(parts) + + +# Backwards-compatible aliases +InferenceConfig = MagpieInferenceConfig + + +# --------------------------------------------------------------------------- +# Inference runner hierarchy +# --------------------------------------------------------------------------- -class MagpieInferenceRunner: - """Runner class for MagpieTTS batch inference. +class BaseInferenceRunner(abc.ABC): + """Abstract base for TTS inference runners. - Encapsulates the logic for running inference on a dataset, saving outputs, - and collecting metrics. + Provides shared utilities (batch-to-cuda, file cleanup, reference audio + copying, RTF metrics) and declares the interface that concrete runners + must implement. """ - def __init__( - self, # model can be MagpieTTSModel or DecoderOnlyMagpieTTSModel - model: Union[MagpieTTSModel, EasyMagpieTTSModel], - config: InferenceConfig, - ): - """Initialize the inference runner. - - Args: - model: Loaded MagpieTTS model (should be on GPU and in eval mode). - config: Inference configuration. - """ + def __init__(self, model, config: BaseInferenceConfig): self.model = model self.config = config - - # Set legacy context stacking flag on model - self.model.legacy_context_stacking = config.legacy_context_stacking - - # Set phoneme probability to 1 for inference self._configure_tokenizer() - - # Cached state from create_dataset (set when create_dataset is called) self._manifest_records: Optional[List[dict]] = None self._audio_base_dir: Optional[str] = None - def _configure_tokenizer(self) -> None: - """Configure the tokenizer for inference (phoneme prob = 1.0).""" - g2p = None - if isinstance(self.model.tokenizer, AggregatedTTSTokenizer): - if "english_phoneme" in self.model.tokenizer.tokenizers and hasattr( - self.model.tokenizer.tokenizers["english_phoneme"], "g2p" - ): - g2p = self.model.tokenizer.tokenizers["english_phoneme"].g2p - elif isinstance(self.model.tokenizer, IPATokenizer): - g2p = self.model.tokenizer.g2p - - if g2p is not None: - g2p.phoneme_probability = 1.0 + # -- interface ----------------------------------------------------------- + @abc.abstractmethod def create_dataset( self, dataset_meta: dict, context_duration_min: Optional[float] = None, context_duration_max: Optional[float] = None, ) -> Union[ChunkedTTSInferenceDataset, MagpieTTSDataset]: - """Create an inference dataset. - - Standard MagpieTTS uses the chunked inference dataset from `main`. - Decoder-only MagpieTTS uses the regular dataset and its dedicated - `infer_batch()` inference path. - - Args: - dataset_meta: Dataset metadata dictionary with 'manifest_path' and 'audio_dir'. - context_duration_min: Minimum context duration (uses model default if None). - context_duration_max: Maximum context duration (uses model default if None). - - Returns: - Configured ChunkedTTSInferenceDataset instance. - """ - # Use model defaults if not specified - if context_duration_min is None: - context_duration_min = self.model.cfg.get('context_duration_min', 5.0) - if context_duration_max is None: - context_duration_max = self.model.cfg.get('context_duration_max', 5.0) - - # For multi-encoder models, use fixed 5s context for fair evaluation - if context_duration_min < 5.0 and context_duration_max > 5.0: - context_duration_min = 5.0 - context_duration_max = 5.0 - - # Read manifest and cache for later use - dataset_name = list(dataset_meta.keys())[0] - dataset_info = dataset_meta[dataset_name] - manifest_path = dataset_info.get('manifest_path') - audio_dir = dataset_info.get('audio_dir', '') - logging.info(f"Dataset name: {dataset_name}, manifest_path: {manifest_path}, audio_dir: {audio_dir}") - - self._manifest_records = read_manifest(manifest_path) - self._audio_base_dir = audio_dir - if self.config.is_decoder_only_model: - logging.info("Creating standard inference dataset for decoder-only model") - dataset = MagpieTTSDataset( - dataset_meta=dataset_meta, - sample_rate=self.model.sample_rate, - min_duration=0.5, - max_duration=20, - codec_model_samples_per_frame=self.model.codec_model_samples_per_frame, - bos_id=getattr(self.model, "bos_id", None), - eos_id=self.model.eos_id, - num_audio_codebooks=self.model.num_audio_codebooks, - prior_scaling_factor=None, - load_cached_codes_if_available=False, - dataset_type='test', - tokenizer_config=None, - load_16khz_audio=False, - use_text_conditioning_tokenizer=True, - text_conditioning_tokenizer_name=self.model.text_conditioning_tokenizer_name, - pad_context_text_to_max_duration=False, - context_duration_min=context_duration_min, - context_duration_max=context_duration_max, - ) - dataset.text_tokenizer = self.model.tokenizer - else: - logging.info("Creating unified inference dataset") - dataset = self._create_chunked_inference_dataset(dataset_meta, context_duration_min, context_duration_max) - - if hasattr(self.model, 'phoneme_tokenizer'): - dataset.phoneme_tokenizer = self.model.phoneme_tokenizer - - return dataset + ... + @abc.abstractmethod def run_inference_on_dataset( self, - dataset: ChunkedTTSInferenceDataset, + dataset, output_dir: str, manifest_records: Optional[List[dict]] = None, audio_base_dir: Optional[str] = None, @@ -258,127 +190,66 @@ def run_inference_on_dataset( save_context_audio: bool = True, save_predicted_codes: bool = True, ) -> Tuple[List[dict], List[str], List[str]]: - """Run inference on a dataset. - - Args: - dataset: The inference dataset (created by create_dataset()). - output_dir: Directory to save generated audio and artifacts. - manifest_records: Original manifest records (uses cached if None). - audio_base_dir: Base directory for audio paths (uses cached if None). - save_cross_attention_maps: Whether to save attention map images (not used in unified path). - save_context_audio: Whether to copy context audio files. - save_predicted_codes: Whether to save predicted code files. - - Returns: - Tuple of: - - rtf_metrics: List of real-time factor metrics per batch. - - generated_audio_paths: List of paths to generated audio files. - - codec_file_paths: List of paths to predicted codes files. - """ - # Use cached values if not provided + ... + + # -- shared helpers ------------------------------------------------------ + + def _configure_tokenizer(self) -> None: + """Configure the tokenizer for inference (phoneme prob = 1.0).""" + g2p = None + if isinstance(self.model.tokenizer, AggregatedTTSTokenizer): + if "english_phoneme" in self.model.tokenizer.tokenizers and hasattr( + self.model.tokenizer.tokenizers["english_phoneme"], "g2p" + ): + g2p = self.model.tokenizer.tokenizers["english_phoneme"].g2p + elif isinstance(self.model.tokenizer, IPATokenizer): + g2p = self.model.tokenizer.g2p + + if g2p is not None: + g2p.phoneme_probability = 1.0 + + def _resolve_manifest_and_audio_dir( + self, + manifest_records: Optional[List[dict]], + audio_base_dir: Optional[str], + ) -> Tuple[List[dict], str]: if manifest_records is None: if self._manifest_records is None: raise ValueError("manifest_records not provided and not cached from create_dataset()") manifest_records = self._manifest_records - if audio_base_dir is None: if self._audio_base_dir is None: raise ValueError("audio_base_dir not provided and not cached from create_dataset()") audio_base_dir = self._audio_base_dir + return manifest_records, audio_base_dir - if self.config.is_decoder_only_model: - logging.info("Using decoder-only inference path") - return self._run_decoder_only_inference( - dataset, output_dir, manifest_records, audio_base_dir, save_context_audio, save_predicted_codes - ) - - logging.info("Using unified inference path") - return self._run_unified_inference( - dataset, output_dir, manifest_records, audio_base_dir, save_context_audio, save_predicted_codes - ) + def _read_and_cache_manifest(self, dataset_meta: dict) -> Tuple[str, str]: + """Read manifest from dataset_meta, cache records, return (manifest_path, audio_dir).""" + dataset_name = list(dataset_meta.keys())[0] + dataset_info = dataset_meta[dataset_name] + manifest_path = dataset_info.get('manifest_path') + audio_dir = dataset_info.get('audio_dir', '') + logging.info(f"Dataset name: {dataset_name}, manifest_path: {manifest_path}, audio_dir: {audio_dir}") + self._manifest_records = read_manifest(manifest_path) + self._audio_base_dir = audio_dir + return manifest_path, audio_dir - def _run_decoder_only_inference( + def _get_context_durations( self, - dataset: MagpieTTSDataset, - output_dir: str, - manifest_records: List[dict], - audio_base_dir: str, - save_context_audio: bool = True, - save_predicted_codes: bool = True, - ) -> Tuple[List[dict], List[str], List[str]]: - """Run inference for decoder-only models via `infer_batch()`.""" - os.makedirs(output_dir, exist_ok=True) - self._delete_old_generated_files(output_dir) - - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=self.config.batch_size, - collate_fn=dataset.collate_fn, - num_workers=0, - shuffle=False, - ) - - all_rtf_metrics = [] - generated_audio_paths = [] - codec_file_paths = [] - item_idx = 0 - phoneme_sampling_method = ( - "argmax" if self.config.phoneme_sampling_method == "greedy" else self.config.phoneme_sampling_method - ) - - for batch_idx, batch in enumerate(dataloader): - logging.info(f"Processing batch {batch_idx + 1}/{len(dataloader)}") - batch = self._batch_to_cuda(batch) - output = self.model.infer_batch( - batch, - max_decoder_steps=self.config.model_inference_parameters.max_decoder_steps, - temperature=self.config.model_inference_parameters.temperature, - topk=self.config.model_inference_parameters.topk, - use_cfg=self.config.use_cfg, - cfg_scale=self.config.model_inference_parameters.cfg_scale, - use_local_transformer_for_inference=self.config.use_local_transformer, - phoneme_input_type=self.config.phoneme_input_type, - phoneme_sampling_method=phoneme_sampling_method, - force_dropout_text=self.config.dropout_text_input, - ) - predicted_audio = output.predicted_audio - predicted_audio_lens = output.predicted_audio_lens - predicted_codes = output.predicted_codes - predicted_codes_lens = output.predicted_codes_lens - rtf_metrics = output.rtf_metrics - - all_rtf_metrics.append(rtf_metrics) - logging.info(f"Output shape: {predicted_audio.size()}") - - for idx in range(predicted_audio.size(0)): - audio_len = predicted_audio_lens[idx].item() - audio_np = predicted_audio[idx].float().detach().cpu().numpy()[:audio_len] - audio_path = os.path.join(output_dir, f"predicted_audio_{item_idx}.wav") - sample_rate = getattr(self.model, "output_sample_rate", self.model.sample_rate) - sf.write(audio_path, audio_np, sample_rate) - generated_audio_paths.append(audio_path) - - if save_context_audio and item_idx < len(manifest_records): - self._copy_reference_audio( - manifest_records[item_idx], - audio_base_dir, - output_dir, - item_idx, - ) - - if save_predicted_codes: - code_len = predicted_codes_lens[idx].item() - codes_path = os.path.join(output_dir, f"predicted_codes_{item_idx}.pt") - torch.save(predicted_codes[idx, :, :code_len].detach().cpu(), codes_path) - codec_file_paths.append(codes_path) - - item_idx += 1 - - return all_rtf_metrics, generated_audio_paths, codec_file_paths + context_duration_min: Optional[float], + context_duration_max: Optional[float], + ) -> Tuple[float, float]: + if context_duration_min is None: + context_duration_min = self.model.cfg.get('context_duration_min', 5.0) + if context_duration_max is None: + context_duration_max = self.model.cfg.get('context_duration_max', 5.0) + if context_duration_min < 5.0 and context_duration_max > 5.0: + context_duration_min = 5.0 + context_duration_max = 5.0 + return context_duration_min, context_duration_max @staticmethod def _batch_to_cuda(batch: dict) -> dict: - """Move batch tensors to CUDA device.""" batch_cuda = {} for key, value in batch.items(): if isinstance(value, torch.Tensor): @@ -389,7 +260,6 @@ def _batch_to_cuda(batch: dict) -> dict: @staticmethod def _delete_old_generated_files(output_dir: str) -> None: - """Delete leftover generated files from previous runs.""" logging.info(f"Cleaning up old generated files in: {output_dir}") patterns = [ "predicted_codes*.pt", @@ -407,7 +277,6 @@ def _copy_reference_audio( output_dir: str, item_idx: int, ) -> None: - """Copy context and target audio files to output directory.""" context_path = record.get('context_audio_filepath') target_path = record.get('audio_filepath') @@ -425,48 +294,69 @@ def _copy_reference_audio( @staticmethod def compute_mean_rtf_metrics(rtf_metrics_list: List[dict]) -> Dict[str, float]: - """Compute mean RTF metrics across batches.""" if not rtf_metrics_list or not rtf_metrics_list[0]: return {} - mean_metrics = {} for key in rtf_metrics_list[0]: values = [m[key] for m in rtf_metrics_list if key in m] mean_metrics[key] = float(sum(values) / len(values)) if values else 0.0 - return mean_metrics - def _create_chunked_inference_dataset( + +# --------------------------------------------------------------------------- +# MagpieInferenceRunner (encoder-decoder MagpieTTSModel) +# --------------------------------------------------------------------------- + + +class MagpieInferenceRunner(BaseInferenceRunner): + """Runner for encoder-decoder MagpieTTSModel. + + Uses ChunkedTTSInferenceDataset and model.generate_speech() per chunk, + then model.codes_to_audio() to produce waveforms. + """ + + def __init__(self, model, config: MagpieInferenceConfig): + super().__init__(model, config) + + def create_dataset( self, dataset_meta: dict, context_duration_min: Optional[float] = None, context_duration_max: Optional[float] = None, ) -> ChunkedTTSInferenceDataset: - """Create a unified inference dataset. - - Creates ChunkedTTSInferenceDataset which uses language-aware chunking - to automatically handle both short and long texts. + context_duration_min, context_duration_max = self._get_context_durations( + context_duration_min, context_duration_max + ) + self._read_and_cache_manifest(dataset_meta) - Args: - dataset_meta: Dataset metadata dictionary (same format as MagpieTTSDataset). - context_duration_min: Minimum context duration (uses model default if None). - context_duration_max: Maximum context duration (uses model default if None). + logging.info("Creating unified inference dataset") + dataset = self._create_chunked_inference_dataset(dataset_meta, context_duration_min, context_duration_max) + return dataset - Returns: - Configured ChunkedTTSInferenceDataset instance. - """ - # Use model defaults if not specified - if context_duration_min is None: - context_duration_min = self.model.cfg.get('context_duration_min', 5.0) - if context_duration_max is None: - context_duration_max = self.model.cfg.get('context_duration_max', 5.0) + def run_inference_on_dataset( + self, + dataset: ChunkedTTSInferenceDataset, + output_dir: str, + manifest_records: Optional[List[dict]] = None, + audio_base_dir: Optional[str] = None, + save_cross_attention_maps: bool = True, + save_context_audio: bool = True, + save_predicted_codes: bool = True, + ) -> Tuple[List[dict], List[str], List[str]]: + manifest_records, audio_base_dir = self._resolve_manifest_and_audio_dir(manifest_records, audio_base_dir) + logging.info("Using unified inference path") + return self._run_unified_inference( + dataset, output_dir, manifest_records, audio_base_dir, save_context_audio, save_predicted_codes + ) - # For multi-encoder models, use fixed 5s context for fair evaluation - if context_duration_min < 5.0 and context_duration_max > 5.0: - context_duration_min = 5.0 - context_duration_max = 5.0 + # -- private ------------------------------------------------------------- - # Create unified dataset - language and tokenizer are determined per-sample from manifest + def _create_chunked_inference_dataset( + self, + dataset_meta: dict, + context_duration_min: float, + context_duration_max: float, + ) -> ChunkedTTSInferenceDataset: dataset = ChunkedTTSInferenceDataset( dataset_meta=dataset_meta, sample_rate=self.model.output_sample_rate, @@ -480,10 +370,7 @@ def _create_chunked_inference_dataset( pad_context_text_to_max_duration=self.model.pad_context_text_to_max_duration, load_16khz_audio=self.model.model_type == 'single_encoder_sv_tts', ) - - # Attach model's tokenizer dataset.text_tokenizer = self.model.tokenizer - return dataset def _run_unified_inference( @@ -495,26 +382,6 @@ def _run_unified_inference( save_context_audio: bool = True, save_predicted_codes: bool = True, ) -> Tuple[List[dict], List[str], List[str]]: - """Run unified inference with automatic single/multi-chunk handling. - - Processes all samples through generate_speech, passing - beginning_of_text and end_of_text so the model can handle both - single-chunk (short text) and multi-chunk (long text) cases correctly. - - Args: - dataset: ChunkedTTSInferenceDataset created by create_dataset(). - output_dir: Directory to save generated audio and artifacts. - manifest_records: List of manifest record dictionaries. - audio_base_dir: Base directory for resolving audio paths. - save_context_audio: Whether to copy context audio files. - save_predicted_codes: Whether to save predicted code files. - - Returns: - Tuple of: - - rtf_metrics: List of real-time factor metrics per batch. - - generated_audio_paths: List of paths to generated audio files. - - codec_file_paths: List of paths to predicted codes files. - """ os.makedirs(output_dir, exist_ok=True) self._delete_old_generated_files(output_dir) @@ -522,7 +389,7 @@ def _run_unified_inference( dataset, batch_size=self.config.batch_size, collate_fn=dataset.collate_fn, - num_workers=0, # Avoid multiprocessing issues with CUDA + num_workers=0, shuffle=False, ) @@ -534,54 +401,42 @@ def _run_unified_inference( for batch_idx, batch in enumerate(dataloader): logging.info(f"Processing batch {batch_idx + 1}/{len(dataloader)}") - # Move batch tensors to CUDA batch = self._batch_to_cuda(batch) - batch['sample_rate'] = self.model.output_sample_rate batch['context_sample_rate'] = self.model.output_sample_rate batch_size = len(batch['chunked_tokens']) max_num_chunks = max(len(tokens) for tokens in batch['chunked_tokens']) - # Clear stale KV cache from prior inference calls (e.g., the previous batch or dataset - # may have left with populated tensors). logging.info(f"Resetting KV cache for decoder: {self.model.use_kv_cache_for_inference}") use_kv_cache_for_this_batch = self.model.use_kv_cache_for_inference if max_num_chunks == 1 else False self.model.decoder.reset_cache(use_cache=use_kv_cache_for_this_batch) - # Create chunk state for this batch chunk_state = self.model.create_chunk_state(batch_size=batch_size) - # Accumulators for predicted codes predicted_codes_per_sample = [[] for _ in range(batch_size)] predicted_codes_lens = [0 for _ in range(batch_size)] - # Overwrite the model's parameters since we want to use the arguments from the commandline self.model.inference_parameters = self.config.model_inference_parameters start_time = time.time() - # Iterate over text chunks (1 for short text, N for long text) for chunk_idx in range(max_num_chunks): - # Extract current chunk tokens for each sample current_tokens = [] current_tokens_lens = [] for b_idx in range(batch_size): current_tokens.append(batch['chunked_tokens'][b_idx][chunk_idx]) current_tokens_lens.append(batch['chunked_tokens_lens'][b_idx][chunk_idx]) - # Pad tokens to max length in this chunk max_len = max(current_tokens_lens) batch['text'] = stack_tensors(current_tokens, max_lens=[max_len]).cuda() batch['text_lens'] = torch.tensor(current_tokens_lens, dtype=torch.int32).cuda() - # Compute is_end_of_text flags (per-sample) is_end_of_text = self._compute_end_of_text_flags( batch, chunk_idx, max_num_chunks, current_tokens_lens, batch_size ) beginning_of_text = chunk_idx == 0 - # Call generate_speech (unified entry point) output = self.model.generate_speech( batch, chunk_state=chunk_state, @@ -595,16 +450,12 @@ def _run_unified_inference( maskgit_sampling_type=self.config.maskgit_sampling_type, ) - # Unpack output chunk_codes = output.predicted_codes chunk_codes_lens = output.predicted_codes_lens - # Accumulate codes for each sample for b_idx in range(batch_size): - # Skip if this sample's text has ended (padding chunks) if is_end_of_text[b_idx] and current_tokens_lens[b_idx] == 1: continue - code_len = chunk_codes_lens[b_idx] if code_len > 0: codes_slice = chunk_codes[b_idx][:, :code_len] @@ -614,17 +465,14 @@ def _run_unified_inference( elapsed = time.time() - start_time logging.info(f"Batch inference time: {elapsed:.2f}s") - # Concatenate codes and convert to audio predicted_codes_list = [] for b_idx in range(batch_size): if predicted_codes_per_sample[b_idx]: concatenated = torch.cat(predicted_codes_per_sample[b_idx], dim=1).cuda() else: - # Empty placeholder concatenated = torch.zeros((self.model.num_audio_codebooks, 1), dtype=torch.long, device='cuda') predicted_codes_list.append(concatenated) - # Stack and convert to audio max_code_len = max(predicted_codes_lens) if any(predicted_codes_lens) else 1 predicted_codes = stack_tensors(predicted_codes_list, max_lens=[max_code_len]).cuda() predicted_codes_lens_tensor = torch.tensor(predicted_codes_lens, dtype=torch.long, device='cuda') @@ -633,7 +481,6 @@ def _run_unified_inference( predicted_codes, predicted_codes_lens_tensor ) - # Compute RTF metrics total_audio_samples = sum(predicted_audio_lens.cpu().tolist()) total_audio_seconds = total_audio_samples / self.model.output_sample_rate rtf = elapsed / total_audio_seconds if total_audio_seconds > 0 else 0.0 @@ -644,7 +491,6 @@ def _run_unified_inference( } all_rtf_metrics.append(rtf_metrics) - # Save outputs predicted_audio_np = predicted_audio.float().detach().cpu().numpy() for b_idx in range(batch_size): @@ -656,7 +502,6 @@ def _run_unified_inference( sf.write(audio_path, audio_np, self.model.output_sample_rate) generated_audio_paths.append(audio_path) - # Copy reference audio if requested if save_context_audio and sample_idx < len(manifest_records): self._copy_reference_audio( manifest_records[sample_idx], @@ -667,7 +512,7 @@ def _run_unified_inference( if save_predicted_codes: codes_path = os.path.join(output_dir, f"predicted_codes_{sample_idx}.pt") - predicted_codes_current = predicted_codes[b_idx, :, : predicted_codes_lens[b_idx]] # C, T + predicted_codes_current = predicted_codes[b_idx, :, : predicted_codes_lens[b_idx]] torch.save(predicted_codes_current, codes_path) codec_file_paths.append(codes_path) @@ -675,38 +520,173 @@ def _run_unified_inference( return all_rtf_metrics, generated_audio_paths, codec_file_paths + @staticmethod def _compute_end_of_text_flags( - self, batch: Dict[str, Any], chunk_idx: int, max_num_chunks: int, current_tokens_lens: List[int], batch_size: int, ) -> List[bool]: - """Compute end-of-text flags for each sample in batch. - - Args: - batch: Current batch dictionary. - chunk_idx: Current chunk index. - max_num_chunks: Maximum number of chunks in this batch. - current_tokens_lens: Token lengths for current chunk per sample. - batch_size: Number of samples in batch. - - Returns: - List of booleans indicating if each sample has reached end of text. - """ is_end_of_text = [] for b_idx in range(batch_size): if chunk_idx == max_num_chunks - 1: - # Last chunk is_end_of_text.append(True) elif current_tokens_lens[b_idx] == 1: - # Current chunk is padding is_end_of_text.append(True) elif batch['chunked_tokens_lens'][b_idx][chunk_idx + 1] == 1: - # Next chunk is padding is_end_of_text.append(True) else: is_end_of_text.append(False) - return is_end_of_text + + +# --------------------------------------------------------------------------- +# EasyMagpieInferenceRunner (decoder-only EasyMagpieTTSModel) +# --------------------------------------------------------------------------- + + +class EasyMagpieInferenceRunner(BaseInferenceRunner): + """Runner for decoder-only EasyMagpieTTSModel. + + Uses MagpieTTSDataset and model.infer_batch() which returns audio directly. + """ + + def __init__(self, model, config: EasyMagpieInferenceConfig): + super().__init__(model, config) + self.model.legacy_context_stacking = config.legacy_context_stacking + + def create_dataset( + self, + dataset_meta: dict, + context_duration_min: Optional[float] = None, + context_duration_max: Optional[float] = None, + ) -> MagpieTTSDataset: + context_duration_min, context_duration_max = self._get_context_durations( + context_duration_min, context_duration_max + ) + self._read_and_cache_manifest(dataset_meta) + + logging.info("Creating inference dataset for decoder-only model") + dataset = MagpieTTSDataset( + dataset_meta=dataset_meta, + sample_rate=self.model.sample_rate, + min_duration=0.5, + max_duration=20, + codec_model_samples_per_frame=self.model.codec_model_samples_per_frame, + bos_id=getattr(self.model, "bos_id", None), + eos_id=self.model.eos_id, + num_audio_codebooks=self.model.num_audio_codebooks, + prior_scaling_factor=None, + load_cached_codes_if_available=False, + dataset_type='test', + tokenizer_config=None, + load_16khz_audio=False, + use_text_conditioning_tokenizer=True, + text_conditioning_tokenizer_name=self.model.text_conditioning_tokenizer_name, + pad_context_text_to_max_duration=False, + context_duration_min=context_duration_min, + context_duration_max=context_duration_max, + ) + dataset.text_tokenizer = self.model.tokenizer + + if hasattr(self.model, 'phoneme_tokenizer'): + dataset.phoneme_tokenizer = self.model.phoneme_tokenizer + + return dataset + + def run_inference_on_dataset( + self, + dataset: MagpieTTSDataset, + output_dir: str, + manifest_records: Optional[List[dict]] = None, + audio_base_dir: Optional[str] = None, + save_cross_attention_maps: bool = True, + save_context_audio: bool = True, + save_predicted_codes: bool = True, + ) -> Tuple[List[dict], List[str], List[str]]: + manifest_records, audio_base_dir = self._resolve_manifest_and_audio_dir(manifest_records, audio_base_dir) + logging.info("Using decoder-only inference path") + return self._run_decoder_only_inference( + dataset, output_dir, manifest_records, audio_base_dir, save_context_audio, save_predicted_codes + ) + + # -- private ------------------------------------------------------------- + + def _run_decoder_only_inference( + self, + dataset: MagpieTTSDataset, + output_dir: str, + manifest_records: List[dict], + audio_base_dir: str, + save_context_audio: bool = True, + save_predicted_codes: bool = True, + ) -> Tuple[List[dict], List[str], List[str]]: + os.makedirs(output_dir, exist_ok=True) + self._delete_old_generated_files(output_dir) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=self.config.batch_size, + collate_fn=dataset.collate_fn, + num_workers=0, + shuffle=False, + ) + + all_rtf_metrics = [] + generated_audio_paths = [] + codec_file_paths = [] + item_idx = 0 + phoneme_sampling_method = ( + "argmax" if self.config.phoneme_sampling_method == "greedy" else self.config.phoneme_sampling_method + ) + + for batch_idx, batch in enumerate(dataloader): + logging.info(f"Processing batch {batch_idx + 1}/{len(dataloader)}") + batch = self._batch_to_cuda(batch) + output = self.model.infer_batch( + batch, + max_decoder_steps=self.config.model_inference_parameters.max_decoder_steps, + temperature=self.config.model_inference_parameters.temperature, + topk=self.config.model_inference_parameters.topk, + use_cfg=self.config.use_cfg, + cfg_scale=self.config.model_inference_parameters.cfg_scale, + use_local_transformer_for_inference=self.config.use_local_transformer, + phoneme_input_type=self.config.phoneme_input_type, + phoneme_sampling_method=phoneme_sampling_method, + force_dropout_text=self.config.dropout_text_input, + ) + predicted_audio = output.predicted_audio + predicted_audio_lens = output.predicted_audio_lens + predicted_codes = output.predicted_codes + predicted_codes_lens = output.predicted_codes_lens + rtf_metrics = output.rtf_metrics + + all_rtf_metrics.append(rtf_metrics) + logging.info(f"Output shape: {predicted_audio.size()}") + + for idx in range(predicted_audio.size(0)): + audio_len = predicted_audio_lens[idx].item() + audio_np = predicted_audio[idx].float().detach().cpu().numpy()[:audio_len] + audio_path = os.path.join(output_dir, f"predicted_audio_{item_idx}.wav") + sample_rate = getattr(self.model, "output_sample_rate", self.model.sample_rate) + sf.write(audio_path, audio_np, sample_rate) + generated_audio_paths.append(audio_path) + + if save_context_audio and item_idx < len(manifest_records): + self._copy_reference_audio( + manifest_records[item_idx], + audio_base_dir, + output_dir, + item_idx, + ) + + if save_predicted_codes: + code_len = predicted_codes_lens[idx].item() + codes_path = os.path.join(output_dir, f"predicted_codes_{item_idx}.pt") + torch.save(predicted_codes[idx, :, :code_len].detach().cpu(), codes_path) + codec_file_paths.append(codes_path) + + item_idx += 1 + + return all_rtf_metrics, generated_audio_paths, codec_file_paths diff --git a/nemo/collections/tts/modules/magpietts_inference/utils.py b/nemo/collections/tts/modules/magpietts_inference/utils.py index 580a6e32ebc7..ca89356494fa 100644 --- a/nemo/collections/tts/modules/magpietts_inference/utils.py +++ b/nemo/collections/tts/modules/magpietts_inference/utils.py @@ -23,7 +23,7 @@ import os from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple import torch from omegaconf import DictConfig, OmegaConf, open_dict @@ -253,9 +253,7 @@ def update_checkpoint_state_dict(state_dict: dict) -> dict: return new_state_dict -def load_magpie_model( - config: ModelLoadConfig, device: str = "cuda", is_decoder_only_model: bool = False -) -> Tuple[Union[MagpieTTSModel, EasyMagpieTTSModel], str]: +def load_magpie_model(config: ModelLoadConfig, device: str = "cuda") -> Tuple[MagpieTTSModel, str]: """Load a MagpieTTS model from checkpoint or NeMo archive. Supports two loading modes: @@ -273,7 +271,7 @@ def load_magpie_model( ValueError: If configuration is invalid or sample rates don't match. """ config.validate() - model_cls = EasyMagpieTTSModel if is_decoder_only_model else MagpieTTSModel + if config.hparams_file is not None and config.checkpoint_file is not None: # Mode 1: Load from hparams + checkpoint model_cfg = OmegaConf.load(config.hparams_file) @@ -292,7 +290,7 @@ def load_magpie_model( config.legacy_text_conditioning, ) - model = model_cls(cfg=model_cfg) + model = MagpieTTSModel(cfg=model_cfg) model.use_kv_cache_for_inference = True # Load weights @@ -304,15 +302,15 @@ def load_magpie_model( checkpoint_name = os.path.basename(config.checkpoint_file).replace(".ckpt", "") else: - if config.nemo_file.startswith("nvidia/"): - model = model_cls.from_pretrained(config.nemo_file) + if config.nemo_file.startswith("nvidia/"): # TODO @xueyang: why ignore `update_config_for_inference`? + model = MagpieTTSModel.from_pretrained(config.nemo_file) model.use_kv_cache_for_inference = True checkpoint_name = config.nemo_file.split("/")[-1] cfg_sample_rate = None else: # Mode 2: Load from .nemo archive logging.info(f"Loading model from NeMo archive: {config.nemo_file}") - model_cfg = model_cls.restore_from(config.nemo_file, return_config=True) + model_cfg = MagpieTTSModel.restore_from(config.nemo_file, return_config=True) with open_dict(model_cfg): model_cfg, cfg_sample_rate = update_config_for_inference( @@ -322,7 +320,7 @@ def load_magpie_model( config.legacy_text_conditioning, ) - model = model_cls.restore_from(config.nemo_file, override_config_path=model_cfg) + model = MagpieTTSModel.restore_from(config.nemo_file, override_config_path=model_cfg) model.use_kv_cache_for_inference = True checkpoint_name = os.path.basename(config.nemo_file).replace(".nemo", "") @@ -338,6 +336,69 @@ def load_magpie_model( return model, checkpoint_name +def load_easy_magpie_model(config: ModelLoadConfig, device: str = "cuda") -> Tuple[EasyMagpieTTSModel, str]: + """Load an EasyMagpieTTSModel (decoder-only) from checkpoint or NeMo archive. + + Supports two loading modes: + 1. Checkpoint mode: hparams.yaml + .ckpt file + 2. NeMo mode: .nemo archive file + + Args: + config: Model loading configuration. + device: Device to load the model onto ("cuda" or "cpu"). + + Returns: + Tuple of (loaded model, checkpoint name for output labeling). + + Raises: + ValueError: If configuration is invalid. + """ + config.validate() + + if config.hparams_file is not None and config.checkpoint_file is not None: + model_cfg = OmegaConf.load(config.hparams_file) + + if "cfg" in model_cfg: + model_cfg = model_cfg.cfg + if config.hparams_from_wandb: + model_cfg = model_cfg.value + + with open_dict(model_cfg): + model_cfg.codecmodel_path = config.codecmodel_path + model_cfg.train_ds = None + model_cfg.validation_ds = None + + model = EasyMagpieTTSModel(cfg=model_cfg) + + logging.info(f"Loading weights from checkpoint: {config.checkpoint_file}") + ckpt = torch.load(config.checkpoint_file) + state_dict = ckpt['state_dict'] + model.load_state_dict(state_dict) + + checkpoint_name = os.path.basename(config.checkpoint_file).replace(".ckpt", "") + else: + if config.nemo_file.startswith("nvidia/"): + model = EasyMagpieTTSModel.from_pretrained(config.nemo_file) + checkpoint_name = config.nemo_file.split("/")[-1] + else: + logging.info(f"Loading model from NeMo archive: {config.nemo_file}") + model_cfg = EasyMagpieTTSModel.restore_from(config.nemo_file, return_config=True) + + with open_dict(model_cfg): + model_cfg.codecmodel_path = config.codecmodel_path + model_cfg.train_ds = None + model_cfg.validation_ds = None + + model = EasyMagpieTTSModel.restore_from(config.nemo_file, override_config_path=model_cfg) + checkpoint_name = os.path.basename(config.nemo_file).replace(".nemo", "") + + model.to(device) + model.eval() + logging.info("EasyMagpieTTS model loaded and ready for inference.") + + return model, checkpoint_name + + def _log_transformer_component(name: str, cfg: DictConfig, use_moe: bool = False) -> dict: """Log architecture info for a single transformer component and return its FLOPs metrics. @@ -414,23 +475,22 @@ def _log_transformer_component(name: str, cfg: DictConfig, use_moe: bool = False return flops_info -def log_model_architecture_summary(model: MagpieTTSModel) -> Tuple[str, Dict[str, dict]]: +def log_model_architecture_summary(model) -> Tuple[str, Dict[str, dict]]: """Log model architecture summary including MoE configuration. Detects and logs MoE configuration for each transformer component, - computing FLOPs metrics and parameter counts. + computing FLOPs metrics and parameter counts. Gracefully handles + decoder-only models (EasyMagpieTTSModel) that use HuggingFace/Nemotron + decoders without the d_model/d_ffn config structure. Args: - model: Loaded MagpieTTS model. + model: Loaded MagpieTTS or EasyMagpieTTS model. Returns: Tuple of: - moe_info: String for checkpoint naming (e.g., "MoE_8x2_d2048_softmax_"), empty for dense models - flops_per_component: Dict mapping component name (e.g., "decoder") to its FLOPs metrics dict """ - if isinstance(model, EasyMagpieTTSModel): - return "", {} - logging.info("=" * 60) logging.info("MODEL ARCHITECTURE SUMMARY") logging.info("=" * 60) @@ -438,23 +498,28 @@ def log_model_architecture_summary(model: MagpieTTSModel) -> Tuple[str, Dict[str flops_per_component: Dict[str, dict] = {} use_moe = getattr(model.cfg, 'use_moe', False) - # Log optional encoder if present - if hasattr(model.cfg, 'encoder'): + # Log optional encoder if present (encoder-decoder models) + if hasattr(model.cfg, 'encoder') and hasattr(model.cfg.encoder, 'd_model'): flops_per_component['encoder'] = _log_transformer_component('encoder', model.cfg.encoder) # Log optional context_encoder if present - if hasattr(model.cfg, 'context_encoder'): + if hasattr(model.cfg, 'context_encoder') and hasattr(model.cfg.context_encoder, 'd_model'): flops_per_component['context_encoder'] = _log_transformer_component( 'context_encoder', model.cfg.context_encoder ) - # Decoder is required - always present in MagpieTTS. MoE only applies to decoder. - flops_per_component['decoder'] = _log_transformer_component('decoder', model.cfg.decoder, use_moe=use_moe) + # Decoder -- only log detailed FLOPs for encoder-decoder models whose + # decoder config exposes d_model/d_ffn. Decoder-only models (EasyMagpieTTS) + # use HuggingFace or Nemotron decoders with a different config shape. + decoder_cfg = getattr(model.cfg, 'decoder', None) + if decoder_cfg is not None and hasattr(decoder_cfg, 'd_model'): + flops_per_component['decoder'] = _log_transformer_component('decoder', decoder_cfg, use_moe=use_moe) + else: + logging.info("DECODER: detailed FLOPs logging not available for this model type") # Build MoE info string for checkpoint naming moe_info = "" - if use_moe: - decoder_cfg = model.cfg.decoder + if use_moe and decoder_cfg is not None and hasattr(decoder_cfg, 'num_experts'): moe_info = ( f"decoder-MoE_{decoder_cfg.num_experts}x{decoder_cfg.top_k_experts}" f"_d{decoder_cfg.d_ffn}_{decoder_cfg.routing_strategy}_" @@ -488,4 +553,4 @@ def get_experiment_name_from_checkpoint_path(checkpoint_path: str) -> str: Returns: The experiment name (parent directory of checkpoints folder). """ - return os.path.basename(os.path.dirname(os.path.dirname(checkpoint_path))) + return os.path.basename(os.path.dirname(os.path.dirname(checkpoint_path))) \ No newline at end of file diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 6d91ad25f976..027ca47a4e82 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -1411,7 +1411,7 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st if isinstance(cfg.init_from_ptl_ckpt, str): # Restore checkpoint ckpt_path = cfg.pop('init_from_ptl_ckpt') - ckpt = torch.load(ckpt_path, map_location=map_location, weights_only=False) + ckpt = torch.load(ckpt_path, map_location=map_location) # Restore checkpoint into current model self.load_state_dict(ckpt['state_dict'], strict=False) diff --git a/tests/collections/tts/test_infer_vs_process_batch.py b/tests/collections/tts/test_infer_vs_process_batch.py deleted file mode 100644 index 0ea66e2870ef..000000000000 --- a/tests/collections/tts/test_infer_vs_process_batch.py +++ /dev/null @@ -1,491 +0,0 @@ -""" -Test script to verify that infer_batch (teacher-forced) produces the same audio code -and phoneme predictions as process_batch (single forward pass). - -Usage: - python tests/collections/tts/test_infer_vs_process_batch.py --codecmodel_path /path/to/codec.nemo - -The script: -1. Builds a tiny NemotronH-backed EasyMagpieTTSModel with a real codec model. -2. Creates synthetic random inputs (with variable lengths per batch item). -3. Runs process_batch (full-sequence forward) and infer_batch (streaming, teacher-forced). -4. Compares the argmax audio code predictions and phoneme predictions from both paths. -5. Repeats for multiple configurations. -""" - -import argparse -import sys -import torch -from omegaconf import OmegaConf - -from nemo.collections.tts.models.easy_magpietts import EasyMagpieTTSModel - - -def build_minimal_config(codecmodel_path: str) -> OmegaConf: - """Build a minimal OmegaConf config for a tiny NemotronH model.""" - hidden_size = 256 - - cfg_dict = { - # Decoder backend - 'decoder_type': 'nemotron_h', - 'nemotron_h_config': { - 'hidden_size': hidden_size, - 'num_hidden_layers': 2, - 'vocab_size': 131072, - 'num_attention_heads': 4, - 'num_key_value_heads': 2, - 'attention_dropout': 0.0, - 'attention_bias': False, - 'max_position_embeddings': 4096, - 'mamba_num_heads': 16, - 'mamba_head_dim': 16, - 'ssm_state_size': 128, - 'conv_kernel': 4, - 'n_groups': 8, - 'chunk_size': 256, - 'mamba_hidden_act': 'silu', - 'use_conv_bias': True, - 'use_bias': False, - 'intermediate_size': 512, - 'mlp_hidden_act': 'silu', - 'mlp_bias': False, - 'hybrid_override_pattern': 'M*', # All Mamba layers - 'layer_norm_epsilon': 1e-5, - 'residual_in_fp32': True, - }, - 'embedding_dim': hidden_size, - 'hidden_dim': hidden_size, - 'audio_embedding_dim': hidden_size, - 'codecmodel_path': codecmodel_path, - # Text tokenizer - use a simple AutoTokenizer - 'text_tokenizers': { - 'test_tokenizer': { - '_target_': 'AutoTokenizer', - 'pretrained_model': 'gpt2', - }, - }, - # Phoneme tokenizer - 'phoneme_tokenizer': { - '_target_': 'nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPABPETokenizer', - 'tokenizer_path': 'scripts/tts_dataset_files/bpe_ipa_tokenizer_2048_en_de_es_fr_hi_it_vi_zh.json', - }, - 'phoneme_stacking_factor': 1, - # Training modes (single streaming mode) - 'training_modes': [ - { - 'text_input_mode': 'streaming', - 'streaming_phonemes_delay': 4, - 'streaming_speech_delay': 8, - }, - ], - 'frame_stacking_factor': 2, - 'cfg_unconditional_prob': 0.0, - 'dropout_text_input_prob': 0.0, - 'local_transformer_type': 'none', - 'run_val_inference': False, - # Optim placeholder (required by ModelPT but not used) - 'optim': { - '_target_': 'torch.optim.AdamW', - 'lr': 1e-4, - }, - # No dataloaders - } - return OmegaConf.create(cfg_dict) - - -def create_synthetic_batch( - model, - batch_size=2, - text_lens_list=None, - audio_frames_list=None, - context_text_lens_list=None, - context_audio_frames_list=None, - phoneme_lens_list=None, - device='cpu', -): - """Create a synthetic batch with random valid token IDs and variable lengths per item. - - If *_list args are None, defaults to uniform lengths for all items. - """ - num_codebooks = model.num_audio_codebooks - codebook_size = model.codebook_size - text_vocab_size = model.bos_id # valid text tokens are [0, bos_id) - phoneme_vocab_size = model.phoneme_tokenizer.vocab_size - 2 # exclude BOS/EOS - - # Defaults - if text_lens_list is None: - text_lens_list = [20] * batch_size - if audio_frames_list is None: - audio_frames_list = [30] * batch_size - if context_text_lens_list is None: - context_text_lens_list = [10] * batch_size - if context_audio_frames_list is None: - context_audio_frames_list = [15] * batch_size - if phoneme_lens_list is None: - phoneme_lens_list = [25] * batch_size - - assert len(text_lens_list) == batch_size - assert len(audio_frames_list) == batch_size - assert len(context_text_lens_list) == batch_size - assert len(context_audio_frames_list) == batch_size - assert len(phoneme_lens_list) == batch_size - - # Max lengths for padding - max_text_len = max(text_lens_list) - max_audio_frames = max(audio_frames_list) - max_context_text_len = max(context_text_lens_list) - max_context_audio_frames = max(context_audio_frames_list) - max_phoneme_len = max(phoneme_lens_list) - - # Text tokens: random tokens + EOS at the end (matching dataset behavior) - text = torch.zeros(batch_size, max_text_len, dtype=torch.long, device=device) - for b in range(batch_size): - tl = text_lens_list[b] - text[b, : tl - 1] = torch.randint(0, text_vocab_size, (tl - 1,), device=device) - text[b, tl - 1] = model.eos_id # EOS as last valid token - text_lens = torch.tensor(text_lens_list, dtype=torch.long, device=device) - - # Context text tokens - context_text_tokens = torch.zeros(batch_size, max_context_text_len, dtype=torch.long, device=device) - for b in range(batch_size): - cl = context_text_lens_list[b] - context_text_tokens[b, :cl] = torch.randint(0, text_vocab_size, (cl,), device=device) - context_text_tokens_lens = torch.tensor(context_text_lens_list, dtype=torch.long, device=device) - - # Audio codes (raw, without BOS/EOS) - audio_codes = torch.zeros(batch_size, num_codebooks, max_audio_frames, dtype=torch.long, device=device) - for b in range(batch_size): - af = audio_frames_list[b] - audio_codes[b, :, :af] = torch.randint(0, codebook_size, (num_codebooks, af), device=device) - audio_codes_lens = torch.tensor(audio_frames_list, dtype=torch.long, device=device) - - # Context audio codes (raw, without BOS/EOS) - context_audio_codes = torch.zeros( - batch_size, num_codebooks, max_context_audio_frames, dtype=torch.long, device=device - ) - for b in range(batch_size): - caf = context_audio_frames_list[b] - context_audio_codes[b, :, :caf] = torch.randint(0, codebook_size, (num_codebooks, caf), device=device) - context_audio_codes_lens = torch.tensor(context_audio_frames_list, dtype=torch.long, device=device) - - # Phoneme tokens (raw IDs, BOS/EOS will be added by the model) - phoneme_tokens = torch.zeros(batch_size, max_phoneme_len, dtype=torch.long, device=device) - for b in range(batch_size): - pl = phoneme_lens_list[b] - phoneme_tokens[b, :pl] = torch.randint(0, phoneme_vocab_size, (pl,), device=device) - phoneme_tokens_lens = torch.tensor(phoneme_lens_list, dtype=torch.long, device=device) - - batch = { - 'text': text, - 'text_lens': text_lens, - 'context_text_tokens': context_text_tokens, - 'context_text_tokens_lens': context_text_tokens_lens, - 'audio_codes': audio_codes, - 'audio_codes_lens': audio_codes_lens, - 'context_audio_codes': context_audio_codes, - 'context_audio_codes_lens': context_audio_codes_lens, - 'phoneme_tokens': phoneme_tokens, - 'phoneme_tokens_lens': phoneme_tokens_lens, - } - return batch - - -def compare_audio_codes(model, pb_output, ib_output, batch): - """Compare audio codes from process_batch and infer_batch. Returns True if all match.""" - C = model.num_audio_codebooks - S = model.frame_stacking_factor - C_stacked = C * S - V = model.num_all_tokens_per_codebook - pb_logits = pb_output.logits # (B, T_stacked, C_stacked * V) - T_stacked = pb_logits.size(1) - batch_size = batch['text'].size(0) - - # Extract per-codebook argmax at stacked resolution - pb_stacked_codes_list = [] - for cb_idx in range(C_stacked): - si = cb_idx * V - ei = si + V - cb_logits = pb_logits[:, :, si:ei] # (B, T_stacked, V) - cb_preds = cb_logits.argmax(dim=-1) # (B, T_stacked) - pb_stacked_codes_list.append(cb_preds) - pb_stacked_codes = torch.stack(pb_stacked_codes_list, dim=1) # (B, C_stacked, T_stacked) - - # Unstack: (B, C*S, T_stacked) -> (B, C, S, T_stacked) -> (B, C, T_stacked, S) -> (B, C, T_stacked*S) - pb_unstacked = pb_stacked_codes.view(batch_size, C, S, T_stacked) - pb_unstacked = pb_unstacked.permute(0, 1, 3, 2).contiguous() - pb_unstacked = pb_unstacked.reshape(batch_size, C, T_stacked * S) - pb_unstacked_lens = pb_output.audio_codes_lens_target * S - - ib_codes = ib_output.predicted_codes - ib_codes_lens = ib_output.predicted_codes_lens - - print(f" process_batch argmax codes (unstacked): {pb_unstacked.shape}, lens: {pb_unstacked_lens.tolist()}") - print(f" infer_batch predicted codes: {ib_codes.shape}, lens: {ib_codes_lens.tolist()}") - - all_match = True - for b in range(batch_size): - pb_len = pb_unstacked_lens[b].item() - ib_len = ib_codes_lens[b].item() - compare_len = min(pb_len, ib_len) - - if compare_len == 0: - print(f" Batch item {b}: No codes to compare (pb_len={pb_len}, ib_len={ib_len})") - continue - - pb_codes_b = pb_unstacked[b, :, :compare_len] - ib_codes_b = ib_codes[b, :, :compare_len] - - matches = (pb_codes_b == ib_codes_b).all() - num_matching = (pb_codes_b == ib_codes_b).sum().item() - total = pb_codes_b.numel() - match_pct = 100.0 * num_matching / total if total > 0 else 0.0 - - print(f" Batch item {b}: pb_len={pb_len}, ib_len={ib_len}, compare_len={compare_len}") - print(f" Audio match: {matches.item()}, {num_matching}/{total} ({match_pct:.1f}%)") - - if not matches: - all_match = False - mismatch_mask = pb_codes_b != ib_codes_b - mismatch_positions = mismatch_mask.nonzero(as_tuple=False) - num_show = min(10, mismatch_positions.size(0)) - for i in range(num_show): - cb, t = mismatch_positions[i].tolist() - print( - f" Mismatch at codebook={cb}, time={t}: " - f"pb={pb_codes_b[cb, t].item()}, ib={ib_codes_b[cb, t].item()}" - ) - - return all_match - - -def compare_phoneme_predictions(model, pb_output, ib_output, batch): - """Compare phoneme predictions from process_batch and infer_batch. Returns True if all match.""" - if pb_output.phoneme_logits is None: - print(" No phoneme logits from process_batch (no phoneme tokenizer?). Skipping.") - return True - if ib_output.predicted_phoneme_tokens is None: - print(" No phoneme predictions from infer_batch. Skipping.") - return True - - batch_size = batch['text'].size(0) - phoneme_stacking_factor = model.phoneme_stacking_factor - phoneme_vocab_size = model.phoneme_vocab_size - - # Extract argmax phoneme predictions from process_batch logits - # phoneme_logits: (B, T_phoneme, phoneme_stacking_factor * phoneme_vocab_size) - pb_phoneme_logits = pb_output.phoneme_logits - T_phoneme = pb_phoneme_logits.size(1) - - pb_phoneme_preds_list = [] - for sf_idx in range(phoneme_stacking_factor): - si = sf_idx * phoneme_vocab_size - ei = si + phoneme_vocab_size - sf_logits = pb_phoneme_logits[:, :, si:ei] # (B, T_phoneme, V_phoneme) - sf_preds = sf_logits.argmax(dim=-1) # (B, T_phoneme) - pb_phoneme_preds_list.append(sf_preds) - pb_phoneme_preds = torch.stack(pb_phoneme_preds_list, dim=1) # (B, phoneme_stacking_factor, T_phoneme) - pb_phoneme_lens = pb_output.phoneme_tokens_lens_target # (B,) number of phoneme prediction steps - - # infer_batch phoneme predictions: (B, phoneme_stacking_factor, T_all_steps) - ib_phoneme_preds = ib_output.predicted_phoneme_tokens - ib_phoneme_lens = ib_output.predicted_phoneme_tokens_lens - - print(f" process_batch phoneme preds: {pb_phoneme_preds.shape}, lens: {pb_phoneme_lens.tolist()}") - print(f" infer_batch phoneme preds: {ib_phoneme_preds.shape}, lens: {ib_phoneme_lens.tolist()}") - - # Get start indices for infer_batch phoneme predictions - ib_start_idx = ib_output.phoneme_prediction_start_idx # (B,) - - all_match = True - for b in range(batch_size): - pb_len = pb_phoneme_lens[b].item() - ib_len = ib_phoneme_lens[b].item() - compare_len = min(pb_len, ib_len) - - if compare_len == 0: - print(f" Batch item {b}: No phonemes to compare (pb_len={pb_len}, ib_len={ib_len})") - continue - - # process_batch phoneme preds start from 0 (already sliced to prediction region) - pb_ph_b = pb_phoneme_preds[b, :, :compare_len] - - # infer_batch phoneme preds: slice from start_idx for this batch item - start = max(0, ib_start_idx[b].item()) - ib_ph_b = ib_phoneme_preds[b, :, start : start + compare_len] - - matches = (pb_ph_b == ib_ph_b).all() - num_matching = (pb_ph_b == ib_ph_b).sum().item() - total = pb_ph_b.numel() - match_pct = 100.0 * num_matching / total if total > 0 else 0.0 - - print(f" Batch item {b}: pb_len={pb_len}, ib_len={ib_len}, compare_len={compare_len}") - print(f" Phoneme match: {matches.item()}, {num_matching}/{total} ({match_pct:.1f}%)") - - if not matches: - all_match = False - mismatch_mask = pb_ph_b != ib_ph_b - mismatch_positions = mismatch_mask.nonzero(as_tuple=False) - num_show = min(10, mismatch_positions.size(0)) - for i in range(num_show): - sf, t = mismatch_positions[i].tolist() - print( - f" Mismatch at stacking_factor={sf}, time={t}: " - f"pb={pb_ph_b[sf, t].item()}, ib={ib_ph_b[sf, t].item()}" - ) - - return all_match - - -def run_single_test(model, batch, test_name, device): - """Run a single test comparing process_batch and infer_batch outputs.""" - print(f"\n{'='*60}") - print(f"TEST: {test_name}") - print(f"{'='*60}") - - for k, v in batch.items(): - if isinstance(v, torch.Tensor): - print(f" {k}: shape={v.shape}, dtype={v.dtype}") - - # Run process_batch - print("\n Running process_batch...") - training_mode = model.training_modes[0] - with torch.inference_mode(): - pb_output = model.process_batch( - text=batch['text'], - text_lens=batch['text_lens'], - context_text_tokens=batch['context_text_tokens'], - context_text_tokens_lens=batch['context_text_tokens_lens'], - audio_codes=batch['audio_codes'], - audio_codes_lens=batch['audio_codes_lens'], - context_audio_codes=batch['context_audio_codes'], - context_audio_codes_lens=batch['context_audio_codes_lens'], - phoneme_tokens=batch['phoneme_tokens'], - phoneme_tokens_lens=batch['phoneme_tokens_lens'], - mode='val', - training_mode=training_mode, - ) - - # Run infer_batch (teacher-forced) - print(" Running infer_batch (teacher-forced)...") - ib_output = model.infer_batch( - batch=batch, - max_decoder_steps=1000, - temperature=0.0, - topk=80, - use_cfg=False, - use_local_transformer_for_inference=False, - phoneme_input_type='gt', - phoneme_sampling_method='argmax', - use_teacher_forced=True, - ) - - # Compare audio codes - print("\n --- Audio Codes Comparison ---") - audio_match = compare_audio_codes(model, pb_output, ib_output, batch) - - # Compare phoneme predictions - print("\n --- Phoneme Predictions Comparison ---") - phoneme_match = compare_phoneme_predictions(model, pb_output, ib_output, batch) - - success = audio_match and phoneme_match - if success: - print(f"\n ✓ {test_name}: PASSED (audio + phoneme match)") - else: - parts = [] - if not audio_match: - parts.append("audio") - if not phoneme_match: - parts.append("phoneme") - print(f"\n ✗ {test_name}: FAILED ({' and '.join(parts)} mismatch)") - - return success - - -def main(): - parser = argparse.ArgumentParser(description='Test infer_batch vs process_batch') - parser.add_argument('--codecmodel_path', type=str, required=True, help='Path to codec model .nemo file') - parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') - args = parser.parse_args() - - device = args.device - print(f"Using device: {device}") - - # 1. Build config and model - print("Building minimal config...") - cfg = build_minimal_config(args.codecmodel_path) - - print("Instantiating EasyMagpieTTSModel (tiny NemotronH + real codec)...") - model = EasyMagpieTTSModel(cfg=cfg, trainer=None) - model = model.to(device) - model.eval() - print(f" num_audio_codebooks={model.num_audio_codebooks}, codebook_size={model.codebook_size}") - print(f" frame_stacking_factor={model.frame_stacking_factor}") - print(f" phoneme_vocab_size={model.phoneme_tokenizer.vocab_size}") - - # Define test configurations: (test_name, kwargs_for_create_synthetic_batch) - test_configs = [ - ( - "Uniform lengths (B=2, text=20, audio=30, ctx_text=10, ctx_audio=15, phoneme=25)", - dict( - batch_size=2, - text_lens_list=[20, 20], - audio_frames_list=[30, 30], - context_text_lens_list=[10, 10], - context_audio_frames_list=[15, 15], - phoneme_lens_list=[25, 25], - ), - ), - ( - "Variable text & context lens (B=2, text=[15,25], ctx_text=[8,12], ctx_audio=[10,20])", - dict( - batch_size=2, - text_lens_list=[15, 25], - audio_frames_list=[30, 30], - context_text_lens_list=[8, 12], - context_audio_frames_list=[10, 20], - phoneme_lens_list=[20, 30], - ), - ), - ( - "Variable audio & phoneme lens (B=2, audio=[20,40], phoneme=[15,35])", - dict( - batch_size=2, - text_lens_list=[20, 20], - audio_frames_list=[20, 40], - context_text_lens_list=[10, 10], - context_audio_frames_list=[15, 15], - phoneme_lens_list=[15, 35], - ), - ), - ( - "All different (B=3)", - dict( - batch_size=3, - text_lens_list=[12, 20, 28], - audio_frames_list=[20, 30, 40], - context_text_lens_list=[6, 10, 14], - context_audio_frames_list=[8, 15, 22], - phoneme_lens_list=[15, 25, 35], - ), - ), - ] - - all_passed = True - for test_name, kwargs in test_configs: - batch = create_synthetic_batch(model, device=device, **kwargs) - passed = run_single_test(model, batch, test_name, device) - if not passed: - all_passed = False - - # Final summary - print(f"\n{'='*60}") - if all_passed: - print("✓ ALL TESTS PASSED") - else: - print("✗ SOME TESTS FAILED") - sys.exit(1) - print(f"{'='*60}") - - -if __name__ == '__main__': - main() diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_FrameStacking.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_FrameStacking.sh index 368b5c83bba5..b6d87e91a254 100644 --- a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_FrameStacking.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_FrameStacking.sh @@ -14,7 +14,7 @@ # Tests a 4x-stacked model with local transformer inference. -TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts_inference.py \ +TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/tts_infer.py \ --codecmodel_path /home/TestData/tts/21fps_causal_codecmodel.nemo \ --datasets_json_path examples/tts/evalset_config.json \ --datasets an4_val_ci \ diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_MoE_ZeroShot.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_MoE_ZeroShot.sh index a591497f22e0..4e917733f59a 100755 --- a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_MoE_ZeroShot.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_MoE_ZeroShot.sh @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts_inference.py \ +TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/tts_infer.py \ --nemo_files "/home/TestData/tts/2602_MoE/moe16_sinkhorn_top1_valLoss5.0469_step2625132_epoch524.nemo" \ --codecmodel_path "/home/TestData/tts/21fps_causal_codecmodel.nemo" \ --datasets_json_path "examples/tts/evalset_config.json" \ diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh index 5ed8d48f5aff..8eb30eb40c36 100644 --- a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_SeenSpeakers.sh @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts_inference.py \ +TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/tts_infer.py \ --codecmodel_path /home/TestData/tts/21fps_causal_codecmodel.nemo \ --datasets_json_path examples/tts/evalset_config.json \ --datasets an4_val_ci \ diff --git a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh index 3a9415bbc2b3..eed95fc5a64e 100644 --- a/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluate_Magpietts_ZeroShot.sh @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts_inference.py \ +TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/tts_infer.py \ --codecmodel_path /home/TestData/tts/21fps_causal_codecmodel.nemo \ --datasets_json_path examples/tts/evalset_config.json \ --datasets an4_val_ci \ diff --git a/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_MoE_ZeroShot.sh b/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_MoE_ZeroShot.sh index ec8b6b885212..c21454d39cb1 100755 --- a/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_MoE_ZeroShot.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_MoE_ZeroShot.sh @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts_inference.py \ +TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/tts_infer.py \ --nemo_files "/home/TestData/tts/2602_MoE/moe16_sinkhorn_top1_valLoss5.0469_step2625132_epoch524.nemo" \ --codecmodel_path "/home/TestData/tts/21fps_causal_codecmodel.nemo" \ --datasets_json_path "examples/tts/evalset_config.json" \ diff --git a/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_ZeroShot.sh b/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_ZeroShot.sh index a0694c16b9ba..96e20304197a 100644 --- a/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_ZeroShot.sh +++ b/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_ZeroShot.sh @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts_inference.py \ +TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/tts_infer.py \ --codecmodel_path /home/TestData/tts/21fps_causal_codecmodel.nemo \ --datasets_json_path examples/tts/evalset_config.json \ --datasets an4_val_ci_longform_tiny \