Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions examples/tts/magpietts_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,12 +555,8 @@ def main(argv=None):
model_inference_parameters[field_name] = arg_from_cmdline

if "max_decoder_steps" not in model_inference_parameters:
if args.longform_mode in {'always', 'auto'}:
model_inference_parameters["max_decoder_steps"] = args.longform_max_decoder_steps
elif args.is_decoder_only_model:
if args.is_decoder_only_model:
model_inference_parameters["max_decoder_steps"] = 300
else:
model_inference_parameters["max_decoder_steps"] = 440

inference_config = InferenceConfig(
model_inference_parameters=ModelInferenceParameters.from_dict(model_inference_parameters),
Expand All @@ -577,8 +573,6 @@ def main(argv=None):
phoneme_sampling_method=args.phoneme_sampling_method,
dropout_text_input=args.dropout_text_input,
legacy_context_stacking=args.legacy_context_stacking,
longform_mode=args.longform_mode,
longform_word_threshold=args.longform_word_threshold,
)

eval_config = EvaluationConfig(
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/tts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

from nemo.collections.tts.models.aligner import AlignerModel
from nemo.collections.tts.models.audio_codec import AudioCodecModel
from nemo.collections.tts.models.base_magpietts import BaseMagpieTTSModel
from nemo.collections.tts.models.easy_magpietts import EasyMagpieTTSModel
from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel
from nemo.collections.tts.models.easy_magpietts_preference_optimization import EasyMagpieTTSModelOnlinePO
from nemo.collections.tts.models.fastpitch import FastPitchModel
from nemo.collections.tts.models.fastpitch_ssl import FastPitchModel_SSL
Expand All @@ -30,13 +32,15 @@
__all__ = [
"AlignerModel",
"AudioCodecModel",
"BaseMagpieTTSModel",
"FastPitchModel",
"FastPitchModel_SSL",
"SSLDisentangler",
"HifiGanModel",
"InferBatchOutput",
"MagpieTTSModel",
"EasyMagpieTTSModel",
"EasyMagpieTTSInferenceModel",
"EasyMagpieTTSModelOnlinePO",
"MagpieTTSModelOfflinePODataGen",
"MagpieTTSModelOfflinePO",
Expand Down
Loading
Loading