-
Notifications
You must be signed in to change notification settings - Fork 15
Qwen tts #172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Qwen tts #172
Changes from 35 commits
c6cb96e
86245d2
401b317
8d5bdc5
316c4ec
c11e071
9272bce
be8212d
34d201e
22709a7
e915f6e
bd5662c
03dad34
ab49ecb
e22ddd3
42aaa66
b364957
bf02930
68e5469
f750bc1
745cd2b
17de7f4
c8725ff
89e6e85
97b25e1
9c47a39
6535fe9
94e55dc
a2cc5aa
d2c62e3
3d18387
bb34a12
28dd128
ecbb8af
e66d401
a51e9a9
00435ac
48acaac
d90c3a1
b05bb06
e9ceb22
e188cb5
2df469c
e3e5cc1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,23 +3,24 @@ | |
| from functools import cached_property | ||
| from typing import TypedDict | ||
|
|
||
| from jaxtyping import Array, Float | ||
| from jinja2 import Template | ||
| from tokenizers import Tokenizer | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class VoicePrompt: | ||
| """ | ||
| Current class is reserved for future usage of audio prompts | ||
| to condition style of generated audio | ||
| """ | ||
| waveform: Float[Array, "*"] | ||
| sampling_rate: int | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class TTSMessage: | ||
| content: str | ||
| speaker_id: str | ||
| style: str | ||
| speaker_id: str | None = None | ||
| style: str | None = None | ||
| language: str | None = None | ||
| voice_prompt: VoicePrompt | None = None | ||
|
|
||
|
|
||
| class TTSRequest(TypedDict): | ||
|
|
@@ -62,6 +63,9 @@ def render_request(self, messages: Iterable[TTSMessage]) -> str: | |
| prompt_text = prompt_text[1:] | ||
| return prompt_text | ||
|
|
||
| def preprocess(self, text: str, language: str = "en") -> str: # noqa: ARG002 | ||
|
||
| return text | ||
|
|
||
| def tokenize_text(self, text: str) -> list[int]: | ||
| return self.tokenizer.encode(text, add_special_tokens=False).ids | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,8 +19,6 @@ def play_mono_audio(audio: np.ndarray, samplerate: int, audio_chunk_size: int = | |
| audio = np.clip(audio, -1.0, 1.0) | ||
| # very dumb conversion to PCM16 | ||
| pcm_audio = (audio * np.iinfo(np.int16).max).astype(np.int16) | ||
|
|
||
| audio_chunk_size = 1024 | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. huh? why did we remove this line |
||
| num_chunks = int(np.ceil(n_samples / audio_chunk_size)) | ||
|
|
||
| # actual size of each chunk might not be exactly 'audio_chunk_size' but not critical here | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import json | ||
| import random | ||
| import re | ||
| import shutil | ||
|
|
@@ -10,6 +11,7 @@ | |
| from pathlib import Path | ||
| from typing import Annotated | ||
|
|
||
| import jax.numpy as jnp | ||
| import jax.profiler | ||
| import requests | ||
| import soundfile as sf | ||
|
|
@@ -33,6 +35,7 @@ | |
| from rich.table import Table | ||
| from typer import Argument, Context, Exit, Option, Typer | ||
|
|
||
| from lalamo.audio.tts_message_processor import VoicePrompt | ||
| from lalamo.audio.utils import play_mono_audio | ||
| from lalamo.commands import ( | ||
| CollectTracesCallbacks, | ||
|
|
@@ -58,13 +61,18 @@ | |
| ) | ||
| from lalamo.data.lalamo_completions import LalamoCompletion | ||
| from lalamo.message_processor import UserMessage | ||
| from lalamo.model_import import ModelSpec | ||
| from lalamo.model_import import ModelSpec, ModelType | ||
| from lalamo.model_import.common import FileSpec | ||
| from lalamo.model_import.remote_registry import RegistryModel, RegistryModelFile, fetch_available_models | ||
| from lalamo.model_registry import ModelRegistry | ||
| from lalamo.models import ClassifierModelConfig, LanguageModelConfig | ||
| from lalamo.models import ( | ||
| ClassifierModelConfig, | ||
| LanguageModelConfig, | ||
| LatentTTSGenerator, | ||
| TTSGenerator, | ||
| ) | ||
| from lalamo.models.common import BatchSizesComputedEvent | ||
| from lalamo.models.tts_model import TTSGenerator, TTSMessage | ||
| from lalamo.models.tts_model import TTSMessage | ||
| from lalamo.speculator.ngram import NGramSpeculator | ||
| from lalamo.speculator.utils import test_speculator | ||
|
|
||
|
|
@@ -345,6 +353,26 @@ def tts( | |
| help="Render synthesized speech into default audio interface.", | ||
| ), | ||
| ] = False, | ||
| speaker_id: Annotated[ | ||
| str | None, | ||
| Option( | ||
| help="Speaker ID for speech synthesis.", | ||
| show_default="First available speaker from the model", | ||
|
||
| ), | ||
| ] = None, | ||
| style: Annotated[ | ||
| str | None, | ||
| Option( | ||
| help="Style instruction for speech synthesis (e.g. voice description or intonation hint).", | ||
| show_default="Default style from the model", | ||
| ), | ||
| ] = None, | ||
| reference: Annotated[ | ||
| Path | None, | ||
| Option( | ||
| help="Path to reference audio file for voice cloning (WAV format).", | ||
| ), | ||
| ] = None, | ||
| ) -> None: | ||
| if output_file is None: | ||
| output_file = Path.cwd() / "generated_speech.wav" | ||
|
|
@@ -355,9 +383,26 @@ def tts( | |
| raise Exit(1) | ||
|
|
||
| console.print(f"🤖 Loading model from specified path: {model_path}.") | ||
| model = TTSGenerator.load_model(model_path) | ||
|
|
||
| assert model is not None | ||
| voice_prompt: VoicePrompt | None = None | ||
| if reference is not None: | ||
| ref_audio, ref_sr = sf.read(str(reference), dtype="float32") | ||
| if ref_audio.ndim > 1: | ||
| ref_audio = ref_audio.mean(axis=1) | ||
| voice_prompt = VoicePrompt(waveform=jnp.array(ref_audio), sampling_rate=ref_sr) | ||
| console.print(f"🎤 Loaded reference audio from {reference} ({ref_sr}Hz, {len(ref_audio) / ref_sr:.1f}s)") | ||
|
|
||
| config_json = json.loads((model_path / "config.json").read_text()) | ||
| model_type = ModelType(config_json["model_type"]) | ||
| model: TTSGenerator | LatentTTSGenerator | ||
| match model_type: | ||
| case ModelType.TTS_MODEL: | ||
| model = TTSGenerator.load_model(model_path) | ||
| case ModelType.LATENT_TTS_MODEL: | ||
| model = LatentTTSGenerator.load_model(model_path) | ||
| case _: | ||
| raise ValueError(f"Expected a TTS model, got: {model_type}") | ||
|
|
||
| _stop_word = "/stop" | ||
| while True: | ||
| user_text = console.input(f"[cyan]input text to generate speech({_stop_word} to exit)> [/cyan]") | ||
|
|
@@ -367,7 +412,7 @@ def tts( | |
| if user_text == "": | ||
| continue | ||
|
|
||
| user_message = TTSMessage(content=user_text, speaker_id="speaker:0", style="interleave") | ||
| user_message = TTSMessage(content=user_text, speaker_id=speaker_id, style=style, voice_prompt=voice_prompt) | ||
|
|
||
| tts_result = model.generate_speech([user_message]) | ||
|
|
||
|
|
@@ -634,7 +679,7 @@ def list_models( | |
|
|
||
| if plain: | ||
| for spec in sorted_specs: | ||
| console.print(spec.repo) | ||
| console.print(spec.origin.description) | ||
| return | ||
|
|
||
| table = Table( | ||
|
|
@@ -654,7 +699,7 @@ def list_models( | |
| spec.family, | ||
| spec.size, | ||
| str(spec.quantization), | ||
| spec.repo, | ||
| spec.origin.description, | ||
| ) | ||
| console.print(table) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,8 @@ | ||
| from .common import ModelMetadata, ModelSpec, import_model | ||
| from .common import ModelMetadata, ModelSpec, ModelType, import_model | ||
|
|
||
| __all__ = [ | ||
| "ModelMetadata", | ||
| "ModelSpec", | ||
| "ModelType", | ||
| "import_model", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
waveform: Float[Array, " audio_samples"]