Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c6cb96e
qwen tts 1/n
knyazer Mar 17, 2026
86245d2
neater config loading
knyazer Mar 17, 2026
401b317
linters
knyazer Mar 18, 2026
8d5bdc5
latent tts + formatting fixes
knyazer Mar 18, 2026
316c4ec
more strict generation config
knyazer Mar 19, 2026
c11e071
update
knyazer Mar 21, 2026
9272bce
Merge branch 'main' into qwen-tts-new
knyazer Mar 21, 2026
be8212d
finalizing qwen-tts and latent tts
knyazer Mar 23, 2026
34d201e
Merge branch 'main' into qwen-tts-new
knyazer Mar 23, 2026
22709a7
fix up little issues
knyazer Mar 23, 2026
e915f6e
finishing up
knyazer Mar 24, 2026
bd5662c
fixing minor bugs
knyazer Mar 24, 2026
03dad34
update trans deps
knyazer Mar 24, 2026
ab49ecb
correct dep
knyazer Mar 24, 2026
e22ddd3
fix fish
knyazer Mar 25, 2026
42aaa66
Merge branch 'main' into qwen-tts-new
knyazer Mar 25, 2026
b364957
transformers version
knyazer Mar 25, 2026
bf02930
Adressing review
knyazer Mar 27, 2026
68e5469
fix
knyazer Mar 27, 2026
f750bc1
compat
knyazer Mar 27, 2026
745cd2b
Address PR review comments: remove slop and clean up TTS code
claude Mar 27, 2026
17de7f4
iterating
knyazer Mar 28, 2026
c8725ff
fix import
knyazer Mar 29, 2026
89e6e85
fix tests :( i wish it could be automatable
knyazer Mar 29, 2026
97b25e1
unslop pass
knyazer Mar 29, 2026
9c47a39
fix up broky imports
knyazer Mar 29, 2026
6535fe9
fix up tests
knyazer Mar 29, 2026
94e55dc
tweaks
knyazer Mar 30, 2026
a2cc5aa
fix nemo
knyazer Mar 30, 2026
d2c62e3
merge
knyazer Apr 1, 2026
3d18387
iterating
knyazer Apr 1, 2026
bb34a12
tweaks
knyazer Apr 2, 2026
28dd128
Merge branch 'main' into qwen-tts-new
knyazer Apr 3, 2026
ecbb8af
fix precommit
knyazer Apr 3, 2026
e66d401
add new models
knyazer Apr 5, 2026
a51e9a9
final clean up
knyazer Apr 7, 2026
00435ac
microfixes
knyazer Apr 7, 2026
48acaac
Merge branch 'main' into qwen-tts-new
knyazer Apr 7, 2026
d90c3a1
few fixes for the tests precision stuff
knyazer Apr 7, 2026
b05bb06
fill in defaults
knyazer Apr 7, 2026
e9ceb22
new style of origins
knyazer Apr 9, 2026
e188cb5
better origin structuring
knyazer Apr 9, 2026
2df469c
wip: origins cli
knyazer Apr 10, 2026
e3e5cc1
new origin system
knyazer Apr 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions lalamo/audio/tts_message_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,24 @@
from functools import cached_property
from typing import TypedDict

from jaxtyping import Array, Float
from jinja2 import Template
from tokenizers import Tokenizer


@dataclass(frozen=True)
class VoicePrompt:
"""
Current class is reserved for future usage of audio prompts
to condition style of generated audio
"""
waveform: Float[Array, "*"]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

waveform: Float[Array, " audio_samples"]

sampling_rate: int


@dataclass(frozen=True)
class TTSMessage:
content: str
speaker_id: str
style: str
speaker_id: str | None = None
style: str | None = None
language: str | None = None
voice_prompt: VoicePrompt | None = None


class TTSRequest(TypedDict):
Expand Down
65 changes: 60 additions & 5 deletions lalamo/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import random
import re
import shutil
Expand All @@ -10,6 +11,7 @@
from pathlib import Path
from typing import Annotated

import jax.numpy as jnp
import jax.profiler
import requests
import soundfile as sf
Expand All @@ -33,6 +35,7 @@
from rich.table import Table
from typer import Argument, Context, Exit, Option, Typer

from lalamo.audio.tts_message_processor import VoicePrompt
from lalamo.audio.utils import play_mono_audio
from lalamo.commands import (
CollectTracesCallbacks,
Expand Down Expand Up @@ -62,9 +65,14 @@
from lalamo.model_import.common import FileSpec
from lalamo.model_import.remote_registry import RegistryModel, RegistryModelFile, fetch_available_models
from lalamo.model_registry import ModelRegistry
from lalamo.models import ClassifierModelConfig, LanguageModelConfig
from lalamo.models import (
ClassifierModelConfig,
LanguageModelConfig,
LatentTTSGenerator,
TTSGenerator,
)
from lalamo.models.common import BatchSizesComputedEvent
from lalamo.models.tts_model import TTSGenerator, TTSMessage
from lalamo.models.tts_model import TTSMessage
from lalamo.speculator.ngram import NGramSpeculator
from lalamo.speculator.utils import test_speculator

Expand Down Expand Up @@ -115,6 +123,15 @@ def convert(self, value: str, param: ClickParameter | None, ctx: ClickContext |
return model_spec


def _is_latent_tts_model(model_path: Path) -> bool:
config_path = model_path / "config.json"
if not config_path.exists():
return False
with open(config_path) as f:
config_json = json.load(f)
return config_json.get("model_type") == "latent_tts_model"


def _error(message: str) -> None:
panel = Panel(message, box=box.ROUNDED, title="Error", title_align="left", border_style="red")
err_console.print(panel)
Expand Down Expand Up @@ -345,6 +362,26 @@ def tts(
help="Render synthesized speech into default audio interface.",
),
] = False,
speaker_id: Annotated[
str | None,
Option(
help="Speaker ID for speech synthesis.",
show_default="First available speaker from the model",
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

show_default should be "A pre-selected speaker available for the specified model" or something like that

),
] = None,
style: Annotated[
str | None,
Option(
help="Style instruction for speech synthesis (e.g. voice description or intonation hint).",
show_default="Default style from the model",
),
] = None,
reference: Annotated[
Path | None,
Option(
help="Path to reference audio file for voice cloning (WAV format).",
),
] = None,
) -> None:
if output_file is None:
output_file = Path.cwd() / "generated_speech.wav"
Expand All @@ -355,9 +392,27 @@ def tts(
raise Exit(1)

console.print(f"🤖 Loading model from specified path: {model_path}.")
model = TTSGenerator.load_model(model_path)

assert model is not None
voice_prompt: VoicePrompt | None = None
if reference is not None:
ref_audio, ref_sr = sf.read(str(reference), dtype="float32")
if ref_audio.ndim > 1:
ref_audio = ref_audio.mean(axis=1)
voice_prompt = VoicePrompt(waveform=jnp.array(ref_audio), sampling_rate=ref_sr)
console.print(f"🎤 Loaded reference audio from {reference} ({ref_sr}Hz, {len(ref_audio) / ref_sr:.1f}s)")

model: TTSGenerator | LatentTTSGenerator
if _is_latent_tts_model(model_path):
model = LatentTTSGenerator.load_model(model_path)
else:
model = TTSGenerator.load_model(model_path)

if isinstance(model, TTSGenerator):
if speaker_id is None:
speaker_id = model.default_speaker_id
if style is None:
style = model.default_style

_stop_word = "/stop"
while True:
user_text = console.input(f"[cyan]input text to generate speech({_stop_word} to exit)> [/cyan]")
Expand All @@ -367,7 +422,7 @@ def tts(
if user_text == "":
continue

user_message = TTSMessage(content=user_text, speaker_id="speaker:0", style="interleave")
user_message = TTSMessage(content=user_text, speaker_id=speaker_id, style=style, voice_prompt=voice_prompt)

tts_result = model.generate_speech([user_message])

Expand Down
145 changes: 125 additions & 20 deletions lalamo/model_import/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
GenerationConfig,
LanguageModel,
LanguageModelConfig,
LatentTTSGenerator,
LatentTTSGeneratorConfig,
TTSGenerator,
TTSGeneratorConfig,
)
from lalamo.modules import Classifier, Decoder, LalamoModule, TTSModel
from lalamo.modules import Classifier, Decoder, LalamoModule, LatentTTSModel, TTSModel
from lalamo.modules.common import ShardingConfig, use_sharding
from lalamo.quantization import QuantizationMode
from lalamo.utils import process_chat_template
Expand Down Expand Up @@ -90,7 +92,7 @@ class ModelMetadata:
repo: str
use_cases: tuple[UseCase, ...]
model_type: ModelType
model_config: LanguageModelConfig | ClassifierModelConfig | TTSGeneratorConfig
model_config: LanguageModelConfig | ClassifierModelConfig | TTSGeneratorConfig | LatentTTSGeneratorConfig
grammar_start_tokens: tuple[str, ...]


Expand Down Expand Up @@ -143,7 +145,7 @@ def download_config_file(


class ImportResults(NamedTuple):
model: LanguageModel | ClassifierModel | TTSGenerator
model: LanguageModel | ClassifierModel | TTSGenerator | LatentTTSGenerator
metadata: ModelMetadata


Expand All @@ -162,15 +164,20 @@ def _instantiate_tokenizer_from_model_spec(
model_spec: ModelSpec,
output_dir: Path | str | None = None,
progress_callback: Callable[[StatusEvent], None] | None = None,
local_dir: Path | None = None,
) -> Tokenizer:
if model_spec.vendor == "NVIDIA" and model_spec.family == "nanocodec":
# NOTE: once text decoder for Nanocodec is implemented - proper Tokenizer will hopefully become available
tokenizer = Tokenizer.from_str(dummy_char_level_tokenizer_config())
else:
assert isinstance(model_spec.configs.tokenizer, FileSpec)
tokenizer_file = download_file(model_spec.configs.tokenizer, model_spec.repo, output_dir, progress_callback)
tokenizer = Tokenizer.from_file(str(tokenizer_file))
return tokenizer
effective_local_dir = local_dir or model_spec.local_dir
match model_spec.configs.tokenizer:
case None:
return Tokenizer.from_str(dummy_char_level_tokenizer_config())
case FileSpec() as file_spec:
if effective_local_dir is not None:
tokenizer_file = effective_local_dir / file_spec.filename
else:
tokenizer_file = download_file(file_spec, model_spec.repo, output_dir, progress_callback)
return Tokenizer.from_file(str(tokenizer_file))
Comment on lines +99 to +102
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make sure to check we don't have a tokenizer only for nanocodec - so throw up here if the model is not nanocodec and has no tokenizer; also fail if its not a file spec

case str() as tokenizer_string:
return Tokenizer.from_str(tokenizer_string)


def import_message_processor(
Expand Down Expand Up @@ -284,21 +291,41 @@ def _is_safe_to_extract(tar_item_info: TarInfo) -> bool:
yield (weights_paths, config_path)


_WEIGHTS_EXTENSIONS: dict[WeightsType, str] = {
WeightsType.SAFETENSORS: ".safetensors",
WeightsType.TORCH: ".pth",
WeightsType.NEMO: ".nemo",
}


@contextmanager
def _download_weights_and_config_files(
model_spec: ModelSpec,
progress_callback: Callable[[StatusEvent], None] | None = None,
) -> Iterator[tuple[list[Path], Path]]:
local_dir: Path | None = None,
) -> Iterator[tuple[list[Path], Path, list[Path]]]:
effective_local_dir = local_dir or model_spec.local_dir
if effective_local_dir is not None:
config_path = effective_local_dir / model_spec.configs.model_config.filename
ext = _WEIGHTS_EXTENSIONS[model_spec.weights_type]
weights_paths = sorted(effective_local_dir.glob(f"*{ext}"))
yield (weights_paths, config_path, [])
return

if model_spec.weights_type == WeightsType.NEMO:
(nemo_model_file,) = download_weights(model_spec, progress_callback=progress_callback)
with _unpack_nemo_model(nemo_model_file) as nemo_file_contents:
weights_paths, foreign_config_file_path = nemo_file_contents
yield (weights_paths, foreign_config_file_path)
yield (weights_paths, foreign_config_file_path, [])
else:
weights_paths = download_weights(model_spec, progress_callback=progress_callback)
foreign_config_file_path = download_config_file(model_spec)

yield (weights_paths, foreign_config_file_path)
extra_config_paths = [
download_file(extra_config, model_spec.repo) for extra_config in model_spec.configs.extra_configs
]

yield (weights_paths, foreign_config_file_path, extra_config_paths)


def _load_main_processing_module(
Expand Down Expand Up @@ -345,8 +372,8 @@ def _import_language_model(
with _download_weights_and_config_files(
model_spec,
progress_callback=progress_callback,
) as (model_weights_paths, config_path):
foreign_decoder_config = model_spec.config_type.from_json(config_path)
) as (model_weights_paths, config_path, extra_config_paths):
foreign_decoder_config = model_spec.config_type.from_json(config_path, extra_config_paths)
assert isinstance(foreign_decoder_config, ForeignLMConfig)

if precision is None:
Expand Down Expand Up @@ -402,8 +429,8 @@ def _import_classifier(
with _download_weights_and_config_files(
model_spec,
progress_callback=progress_callback,
) as (model_weights_paths, config_path):
foreign_classifier_config = model_spec.config_type.from_json(config_path)
) as (model_weights_paths, config_path, extra_config_paths):
foreign_classifier_config = model_spec.config_type.from_json(config_path, extra_config_paths)
assert isinstance(foreign_classifier_config, ForeignClassifierConfig)

if precision is None:
Expand Down Expand Up @@ -444,8 +471,8 @@ def _import_tts_model(
with _download_weights_and_config_files(
model_spec,
progress_callback=progress_callback,
) as (model_weights_paths, config_path):
foreign_tts_config = model_spec.config_type.from_json(config_path)
) as (model_weights_paths, config_path, extra_config_paths):
foreign_tts_config = model_spec.config_type.from_json(config_path, extra_config_paths)
if precision is None:
precision = foreign_tts_config.default_precision
if model_spec.vendor == "FishAudio" and model_spec.family == "openaudio":
Expand Down Expand Up @@ -508,6 +535,71 @@ def _import_tts_model(
return (tts_generator, tts_generator_config)


def _import_latent_tts_model(
model_spec: ModelSpec,
*,
context_length: int | None = None,
precision: DTypeLike | None = None,
accumulation_precision: DTypeLike = jnp.float32,
progress_callback: Callable[[StatusEvent], None] | None = None,
local_dir: Path | None = None,
) -> tuple[LatentTTSGenerator, LatentTTSGeneratorConfig]:
with _download_weights_and_config_files(
model_spec,
progress_callback=progress_callback,
local_dir=local_dir,
) as (model_weights_paths, config_path, extra_config_paths):
foreign_config = model_spec.config_type.from_json(config_path, extra_config_paths)
if precision is None:
precision = foreign_config.default_precision

tokenizer = _instantiate_tokenizer_from_model_spec(model_spec, None, progress_callback, local_dir=local_dir)

latent_tts_model = _load_main_processing_module(
model_spec,
model_weights_paths,
precision,
foreign_config,
progress_callback,
context_length,
accumulation_precision,
)

assert isinstance(latent_tts_model, LatentTTSModel)
if progress_callback is not None:
progress_callback(FinishedInitializingModelEvent())

assert isinstance(model_spec.configs.chat_template, str)
tts_request_factory_config = TTSMessageProcessorConfig(
prompt_template=model_spec.configs.chat_template,
)
message_processor = TTSMessageProcessor(tts_request_factory_config, tokenizer)

latent_tts_config = foreign_config.to_lalamo_config(
context_length=context_length,
activation_precision=precision,
accumulation_precision=precision,
metadata_dict={},
)

if not hasattr(latent_tts_config, "default_generation_config"):
raise ValueError(f"{type(latent_tts_config).__name__} must implement default_generation_config()")
generation_config = latent_tts_config.default_generation_config()
generator_config = LatentTTSGeneratorConfig(
latent_tts_config=latent_tts_config,
message_processor_config=message_processor.config,
generation_config=generation_config,
)

generator = LatentTTSGenerator(
config=generator_config,
latent_tts_model=latent_tts_model,
message_processor=message_processor,
)

return (generator, generator_config)


def import_model(
model_spec: ModelSpec | str,
*,
Expand All @@ -516,7 +608,11 @@ def import_model(
accumulation_precision: DTypeLike = jnp.float32,
progress_callback: Callable[[StatusEvent], None] | None = None,
sharding_config: ShardingConfig | None = None,
local_dir: Path | str | None = None,
) -> ImportResults:
if isinstance(local_dir, str):
local_dir = Path(local_dir)

if isinstance(model_spec, str):
try:
model_spec = ModelRegistry.build().repo_to_model[model_spec]
Expand Down Expand Up @@ -549,6 +645,15 @@ def import_model(
accumulation_precision=accumulation_precision,
progress_callback=progress_callback,
)
case ModelType.LATENT_TTS_MODEL:
model, config = _import_latent_tts_model(
model_spec,
context_length=context_length,
precision=precision,
accumulation_precision=accumulation_precision,
progress_callback=progress_callback,
local_dir=local_dir,
)

metadata = ModelMetadata(
toolchain_version=LALAMO_VERSION,
Expand Down
Loading
Loading