Skip to content
Open
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c6cb96e
qwen tts 1/n
knyazer Mar 17, 2026
86245d2
neater config loading
knyazer Mar 17, 2026
401b317
linters
knyazer Mar 18, 2026
8d5bdc5
latent tts + formatting fixes
knyazer Mar 18, 2026
316c4ec
more strict generation config
knyazer Mar 19, 2026
c11e071
update
knyazer Mar 21, 2026
9272bce
Merge branch 'main' into qwen-tts-new
knyazer Mar 21, 2026
be8212d
finalizing qwen-tts and latent tts
knyazer Mar 23, 2026
34d201e
Merge branch 'main' into qwen-tts-new
knyazer Mar 23, 2026
22709a7
fix up little issues
knyazer Mar 23, 2026
e915f6e
finishing up
knyazer Mar 24, 2026
bd5662c
fixing minor bugs
knyazer Mar 24, 2026
03dad34
update trans deps
knyazer Mar 24, 2026
ab49ecb
correct dep
knyazer Mar 24, 2026
e22ddd3
fix fish
knyazer Mar 25, 2026
42aaa66
Merge branch 'main' into qwen-tts-new
knyazer Mar 25, 2026
b364957
transformers version
knyazer Mar 25, 2026
bf02930
Adressing review
knyazer Mar 27, 2026
68e5469
fix
knyazer Mar 27, 2026
f750bc1
compat
knyazer Mar 27, 2026
745cd2b
Address PR review comments: remove slop and clean up TTS code
claude Mar 27, 2026
17de7f4
iterating
knyazer Mar 28, 2026
c8725ff
fix import
knyazer Mar 29, 2026
89e6e85
fix tests :( i wish it could be automatable
knyazer Mar 29, 2026
97b25e1
unslop pass
knyazer Mar 29, 2026
9c47a39
fix up broky imports
knyazer Mar 29, 2026
6535fe9
fix up tests
knyazer Mar 29, 2026
94e55dc
tweaks
knyazer Mar 30, 2026
a2cc5aa
fix nemo
knyazer Mar 30, 2026
d2c62e3
merge
knyazer Apr 1, 2026
3d18387
iterating
knyazer Apr 1, 2026
bb34a12
tweaks
knyazer Apr 2, 2026
28dd128
Merge branch 'main' into qwen-tts-new
knyazer Apr 3, 2026
ecbb8af
fix precommit
knyazer Apr 3, 2026
e66d401
add new models
knyazer Apr 5, 2026
a51e9a9
final clean up
knyazer Apr 7, 2026
00435ac
microfixes
knyazer Apr 7, 2026
48acaac
Merge branch 'main' into qwen-tts-new
knyazer Apr 7, 2026
d90c3a1
few fixes for the tests precision stuff
knyazer Apr 7, 2026
b05bb06
fill in defaults
knyazer Apr 7, 2026
e9ceb22
new style of origins
knyazer Apr 9, 2026
e188cb5
better origin structuring
knyazer Apr 9, 2026
2df469c
wip: origins cli
knyazer Apr 10, 2026
e3e5cc1
new origin system
knyazer Apr 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions lalamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
UserMessage,
)
from lalamo.model_import import ModelSpec, import_model
from lalamo.model_import.model_specs.common import ConfigMap, FileSpec, JSONFieldSpec, ModelType, UseCase, WeightsType
from lalamo.model_import.model_specs.common import ConfigMap, FileSpec, JSONFieldSpec, ModelType, UseCase
from lalamo.models import ClassifierModel, LanguageModel
from lalamo.modules.common import ShardingConfig, pad_and_apply_data_sharding
from lalamo.quantization import QuantizationMode
Expand Down Expand Up @@ -60,7 +60,6 @@
"TrainCallbacks",
"UseCase",
"UserMessage",
"WeightsType",
"collect_traces",
"convert",
"estimate_batchsize",
Expand Down
16 changes: 10 additions & 6 deletions lalamo/audio/tts_message_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,24 @@
from functools import cached_property
from typing import TypedDict

from jaxtyping import Array, Float
from jinja2 import Template
from tokenizers import Tokenizer


@dataclass(frozen=True)
class VoicePrompt:
"""
Current class is reserved for future usage of audio prompts
to condition style of generated audio
"""
waveform: Float[Array, " audio_samples"]
sampling_rate: int


@dataclass(frozen=True)
class TTSMessage:
content: str
speaker_id: str
style: str
speaker_id: str | None = None
style: str | None = None
language: str | None = None
voice_prompt: VoicePrompt | None = None


class TTSRequest(TypedDict):
Expand Down Expand Up @@ -62,6 +63,9 @@ def render_request(self, messages: Iterable[TTSMessage]) -> str:
prompt_text = prompt_text[1:]
return prompt_text

def preprocess(self, text: str, language: str) -> str: # noqa: ARG002
return text

def tokenize_text(self, text: str) -> list[int]:
return self.tokenizer.encode(text, add_special_tokens=False).ids

Expand Down
2 changes: 0 additions & 2 deletions lalamo/audio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ def play_mono_audio(audio: np.ndarray, samplerate: int, audio_chunk_size: int =
audio = np.clip(audio, -1.0, 1.0)
# very dumb conversion to PCM16
pcm_audio = (audio * np.iinfo(np.int16).max).astype(np.int16)

audio_chunk_size = 1024
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh? why did we remove this line

num_chunks = int(np.ceil(n_samples / audio_chunk_size))

# actual size of each chunk might not be exactly 'audio_chunk_size' but not critical here
Expand Down
2 changes: 1 addition & 1 deletion lalamo/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def progress_callback(event: StatusEvent) -> None:
weights = flatten_parameters(model.export_weights())
del model

with Path(output_dir / "model.safetensors").open("wb") as fd:
with (output_dir / "model.safetensors").open("wb") as fd:
safe_write(fd, weights)

config_json = config_converter.unstructure(metadata, ModelMetadata)
Expand Down
3 changes: 3 additions & 0 deletions lalamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@

from lalamo.utils import MapDictValues, MapSequence

type WeightShard = tuple[Mapping[str, Array], Mapping[str, str]]

__all__ = [
"DEFAULT_PRECISION",
"ArrayLike",
"LalamoWarning",
"ParameterPath",
"ParameterTree",
"WeightShard",
"dummy_array",
"flatten_parameters",
"require_array",
Expand Down
95 changes: 86 additions & 9 deletions lalamo/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
import random
import re
import shutil
import sys
from contextlib import ExitStack
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from functools import partial
from importlib.util import find_spec
from itertools import islice
from pathlib import Path
from typing import Annotated

import jax.numpy as jnp
import jax.profiler
import requests
import soundfile as sf
Expand All @@ -33,6 +35,7 @@
from rich.table import Table
from typer import Argument, Context, Exit, Option, Typer

from lalamo.audio.tts_message_processor import VoicePrompt
from lalamo.audio.utils import play_mono_audio
from lalamo.commands import (
CollectTracesCallbacks,
Expand All @@ -58,13 +61,19 @@
)
from lalamo.data.lalamo_completions import iter_completions
from lalamo.message_processor import UserMessage
from lalamo.model_import import ModelSpec
from lalamo.model_import import ModelSpec, ModelType
from lalamo.model_import.common import FileSpec
from lalamo.model_import.model_specs.common import structure_origin
from lalamo.model_import.remote_registry import RegistryModel, RegistryModelFile, fetch_available_models
from lalamo.model_registry import ModelRegistry
from lalamo.models import ClassifierModelConfig, LanguageModelConfig
from lalamo.models import (
ClassifierModelConfig,
LanguageModelConfig,
LatentTTSGenerator,
TTSGenerator,
)
from lalamo.models.common import BatchSizesComputedEvent
from lalamo.models.tts_model import TTSGenerator, TTSMessage
from lalamo.models.tts_model import TTSMessage
from lalamo.speculator.ngram import NGramSpeculator
from lalamo.speculator.utils import test_speculator

Expand Down Expand Up @@ -339,12 +348,39 @@ def tts(
),
],
output_file: Annotated[Path | None, Argument(help="Path to output WAV file with synthesized speech")] = None,
message: Annotated[
str | None,
Option(
help="Text to synthesize in non-interactive mode. Generates speech and exits.",
show_default="None, run interactively",
),
] = None,
replay: Annotated[
bool,
Option(
help="Render synthesized speech into default audio interface.",
),
] = False,
speaker_id: Annotated[
str | None,
Option(
help="Speaker ID for speech synthesis.",
show_default="A pre-selected speaker available for the specified model",
),
] = None,
style: Annotated[
str | None,
Option(
help="Style instruction for speech synthesis (e.g. voice description or intonation hint).",
show_default="Default style from the model",
),
] = None,
reference: Annotated[
Path | None,
Option(
help="Path to reference audio file for voice cloning (WAV format).",
),
] = None,
) -> None:
if output_file is None:
output_file = Path.cwd() / "generated_speech.wav"
Expand All @@ -355,9 +391,35 @@ def tts(
raise Exit(1)

console.print(f"🤖 Loading model from specified path: {model_path}.")
model = TTSGenerator.load_model(model_path)

assert model is not None
voice_prompt: VoicePrompt | None = None
if reference is not None:
ref_audio, ref_sr = sf.read(str(reference), dtype="float32")
if ref_audio.ndim > 1:
ref_audio = ref_audio.mean(axis=1)
voice_prompt = VoicePrompt(waveform=jnp.array(ref_audio), sampling_rate=ref_sr)
console.print(f"🎤 Loaded reference audio from {reference} ({ref_sr}Hz, {len(ref_audio) / ref_sr:.1f}s)")

config_json = json.loads((model_path / "config.json").read_text())
model_type = ModelType(config_json["model_type"])
model: TTSGenerator | LatentTTSGenerator
match model_type:
case ModelType.TTS_MODEL:
model = TTSGenerator.load_model(model_path)
case ModelType.LATENT_TTS_MODEL:
model = LatentTTSGenerator.load_model(model_path)
case _:
raise ValueError(f"Expected a TTS model, got: {model_type}")

if message is not None:
user_message = TTSMessage(content=message, speaker_id=speaker_id, style=style, voice_prompt=voice_prompt)
tts_result = model.generate_speech([user_message])
if replay:
play_mono_audio(tts_result.audio, tts_result.audio_params.samplerate)
sf.write(str(output_file), tts_result.audio, tts_result.audio_params.samplerate)
console.print(f"[green] ... saved generated audio to {output_file}[/green]")
return

_stop_word = "/stop"
while True:
user_text = console.input(f"[cyan]input text to generate speech({_stop_word} to exit)> [/cyan]")
Expand All @@ -367,7 +429,7 @@ def tts(
if user_text == "":
continue

user_message = TTSMessage(content=user_text, speaker_id="speaker:0", style="interleave")
user_message = TTSMessage(content=user_text, speaker_id=speaker_id, style=style, voice_prompt=voice_prompt)

tts_result = model.generate_speech([user_message])

Expand Down Expand Up @@ -430,13 +492,28 @@ def convert(
show_default="Model's native maximum context length.",
),
] = None,
custom_origin: Annotated[
str | None,
Option(
"--custom-origin",
help=(
"Origin JSON to override the model's default origin."
' Example: \'{"type": "LocalOrigin", "root": "/path/to/weights"}\''
),
show_default="Use the model's default origin",
),
] = None,
overwrite: Annotated[
bool,
Option(
help="Overwrite existing model files.",
),
] = False,
) -> None:
if custom_origin is not None:
origin = structure_origin(json.loads(custom_origin))
model_repo = replace(model_repo, origin=origin)

if output_dir is None:
output_dir = DEFAULT_OUTPUT_DIR / model_repo.name

Expand Down Expand Up @@ -634,7 +711,7 @@ def list_models(

if plain:
for spec in sorted_specs:
console.print(spec.repo)
console.print(spec.origin.description)
return

table = Table(
Expand All @@ -654,7 +731,7 @@ def list_models(
spec.family,
spec.size,
str(spec.quantization),
spec.repo,
spec.origin.description,
)
console.print(table)

Expand Down
3 changes: 2 additions & 1 deletion lalamo/model_import/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .common import ModelMetadata, ModelSpec, import_model
from .common import ModelMetadata, ModelSpec, ModelType, import_model

__all__ = [
"ModelMetadata",
"ModelSpec",
"ModelType",
"import_model",
]
Loading
Loading