-
Notifications
You must be signed in to change notification settings - Fork 15
Qwen tts #172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Qwen tts #172
Changes from 11 commits
c6cb96e
86245d2
401b317
8d5bdc5
316c4ec
c11e071
9272bce
be8212d
34d201e
22709a7
e915f6e
bd5662c
03dad34
ab49ecb
e22ddd3
42aaa66
b364957
bf02930
68e5469
f750bc1
745cd2b
17de7f4
c8725ff
89e6e85
97b25e1
9c47a39
6535fe9
94e55dc
a2cc5aa
d2c62e3
3d18387
bb34a12
28dd128
ecbb8af
e66d401
a51e9a9
00435ac
48acaac
d90c3a1
b05bb06
e9ceb22
e188cb5
2df469c
e3e5cc1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import json | ||
| import random | ||
| import re | ||
| import shutil | ||
|
|
@@ -10,6 +11,7 @@ | |
| from pathlib import Path | ||
| from typing import Annotated | ||
|
|
||
| import jax.numpy as jnp | ||
| import jax.profiler | ||
| import requests | ||
| import soundfile as sf | ||
|
|
@@ -33,6 +35,7 @@ | |
| from rich.table import Table | ||
| from typer import Argument, Context, Exit, Option, Typer | ||
|
|
||
| from lalamo.audio.tts_message_processor import VoicePrompt | ||
| from lalamo.audio.utils import play_mono_audio | ||
| from lalamo.commands import ( | ||
| CollectTracesCallbacks, | ||
|
|
@@ -62,9 +65,14 @@ | |
| from lalamo.model_import.common import FileSpec | ||
| from lalamo.model_import.remote_registry import RegistryModel, RegistryModelFile, fetch_available_models | ||
| from lalamo.model_registry import ModelRegistry | ||
| from lalamo.models import ClassifierModelConfig, LanguageModelConfig | ||
| from lalamo.models import ( | ||
| ClassifierModelConfig, | ||
| LanguageModelConfig, | ||
| LatentTTSGenerator, | ||
| TTSGenerator, | ||
| ) | ||
| from lalamo.models.common import BatchSizesComputedEvent | ||
| from lalamo.models.tts_model import TTSGenerator, TTSMessage | ||
| from lalamo.models.tts_model import TTSMessage | ||
| from lalamo.speculator.ngram import NGramSpeculator | ||
| from lalamo.speculator.utils import test_speculator | ||
|
|
||
|
|
@@ -115,6 +123,15 @@ def convert(self, value: str, param: ClickParameter | None, ctx: ClickContext | | |
| return model_spec | ||
|
|
||
|
|
||
| def _is_latent_tts_model(model_path: Path) -> bool: | ||
| config_path = model_path / "config.json" | ||
| if not config_path.exists(): | ||
| return False | ||
| with open(config_path) as f: | ||
| config_json = json.load(f) | ||
| return config_json.get("model_type") == "latent_tts_model" | ||
|
|
||
|
|
||
| def _error(message: str) -> None: | ||
| panel = Panel(message, box=box.ROUNDED, title="Error", title_align="left", border_style="red") | ||
| err_console.print(panel) | ||
|
|
@@ -345,6 +362,26 @@ def tts( | |
| help="Render synthesized speech into default audio interface.", | ||
| ), | ||
| ] = False, | ||
| speaker_id: Annotated[ | ||
| str | None, | ||
| Option( | ||
| help="Speaker ID for speech synthesis.", | ||
| show_default="First available speaker from the model", | ||
|
||
| ), | ||
| ] = None, | ||
| style: Annotated[ | ||
| str | None, | ||
| Option( | ||
| help="Style instruction for speech synthesis (e.g. voice description or intonation hint).", | ||
| show_default="Default style from the model", | ||
| ), | ||
| ] = None, | ||
| reference: Annotated[ | ||
| Path | None, | ||
| Option( | ||
| help="Path to reference audio file for voice cloning (WAV format).", | ||
| ), | ||
| ] = None, | ||
| ) -> None: | ||
| if output_file is None: | ||
| output_file = Path.cwd() / "generated_speech.wav" | ||
|
|
@@ -355,9 +392,27 @@ def tts( | |
| raise Exit(1) | ||
|
|
||
| console.print(f"🤖 Loading model from specified path: {model_path}.") | ||
| model = TTSGenerator.load_model(model_path) | ||
|
|
||
| assert model is not None | ||
| voice_prompt: VoicePrompt | None = None | ||
| if reference is not None: | ||
| ref_audio, ref_sr = sf.read(str(reference), dtype="float32") | ||
| if ref_audio.ndim > 1: | ||
| ref_audio = ref_audio.mean(axis=1) | ||
| voice_prompt = VoicePrompt(waveform=jnp.array(ref_audio), sampling_rate=ref_sr) | ||
| console.print(f"🎤 Loaded reference audio from {reference} ({ref_sr}Hz, {len(ref_audio) / ref_sr:.1f}s)") | ||
|
|
||
| model: TTSGenerator | LatentTTSGenerator | ||
| if _is_latent_tts_model(model_path): | ||
| model = LatentTTSGenerator.load_model(model_path) | ||
| else: | ||
| model = TTSGenerator.load_model(model_path) | ||
|
|
||
| if isinstance(model, TTSGenerator): | ||
| if speaker_id is None: | ||
| speaker_id = model.default_speaker_id | ||
| if style is None: | ||
| style = model.default_style | ||
|
|
||
| _stop_word = "/stop" | ||
| while True: | ||
| user_text = console.input(f"[cyan]input text to generate speech({_stop_word} to exit)> [/cyan]") | ||
|
|
@@ -367,7 +422,7 @@ def tts( | |
| if user_text == "": | ||
| continue | ||
|
|
||
| user_message = TTSMessage(content=user_text, speaker_id="speaker:0", style="interleave") | ||
| user_message = TTSMessage(content=user_text, speaker_id=speaker_id, style=style, voice_prompt=voice_prompt) | ||
|
|
||
| tts_result = model.generate_speech([user_message]) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,10 +28,12 @@ | |
| GenerationConfig, | ||
| LanguageModel, | ||
| LanguageModelConfig, | ||
| LatentTTSGenerator, | ||
| LatentTTSGeneratorConfig, | ||
| TTSGenerator, | ||
| TTSGeneratorConfig, | ||
| ) | ||
| from lalamo.modules import Classifier, Decoder, LalamoModule, TTSModel | ||
| from lalamo.modules import Classifier, Decoder, LalamoModule, LatentTTSModel, TTSModel | ||
| from lalamo.modules.common import ShardingConfig, use_sharding | ||
| from lalamo.quantization import QuantizationMode | ||
| from lalamo.utils import process_chat_template | ||
|
|
@@ -90,7 +92,7 @@ class ModelMetadata: | |
| repo: str | ||
| use_cases: tuple[UseCase, ...] | ||
| model_type: ModelType | ||
| model_config: LanguageModelConfig | ClassifierModelConfig | TTSGeneratorConfig | ||
| model_config: LanguageModelConfig | ClassifierModelConfig | TTSGeneratorConfig | LatentTTSGeneratorConfig | ||
| grammar_start_tokens: tuple[str, ...] | ||
|
|
||
|
|
||
|
|
@@ -143,7 +145,7 @@ def download_config_file( | |
|
|
||
|
|
||
| class ImportResults(NamedTuple): | ||
| model: LanguageModel | ClassifierModel | TTSGenerator | ||
| model: LanguageModel | ClassifierModel | TTSGenerator | LatentTTSGenerator | ||
| metadata: ModelMetadata | ||
|
|
||
|
|
||
|
|
@@ -162,15 +164,20 @@ def _instantiate_tokenizer_from_model_spec( | |
| model_spec: ModelSpec, | ||
| output_dir: Path | str | None = None, | ||
| progress_callback: Callable[[StatusEvent], None] | None = None, | ||
| local_dir: Path | None = None, | ||
| ) -> Tokenizer: | ||
| if model_spec.vendor == "NVIDIA" and model_spec.family == "nanocodec": | ||
| # NOTE: once text decoder for Nanocodec is implemented - proper Tokenizer will hopefully become available | ||
| tokenizer = Tokenizer.from_str(dummy_char_level_tokenizer_config()) | ||
| else: | ||
| assert isinstance(model_spec.configs.tokenizer, FileSpec) | ||
| tokenizer_file = download_file(model_spec.configs.tokenizer, model_spec.repo, output_dir, progress_callback) | ||
| tokenizer = Tokenizer.from_file(str(tokenizer_file)) | ||
| return tokenizer | ||
| effective_local_dir = local_dir or model_spec.local_dir | ||
| match model_spec.configs.tokenizer: | ||
| case None: | ||
| return Tokenizer.from_str(dummy_char_level_tokenizer_config()) | ||
| case FileSpec() as file_spec: | ||
| if effective_local_dir is not None: | ||
| tokenizer_file = effective_local_dir / file_spec.filename | ||
| else: | ||
| tokenizer_file = download_file(file_spec, model_spec.repo, output_dir, progress_callback) | ||
| return Tokenizer.from_file(str(tokenizer_file)) | ||
|
Comment on lines
+99
to
+102
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's make sure to check we don't have a tokenizer only for nanocodec - so throw up here if the model is not nanocodec and has no tokenizer; also fail if its not a file spec |
||
| case str() as tokenizer_string: | ||
| return Tokenizer.from_str(tokenizer_string) | ||
|
|
||
knyazer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def import_message_processor( | ||
|
|
@@ -284,21 +291,41 @@ def _is_safe_to_extract(tar_item_info: TarInfo) -> bool: | |
| yield (weights_paths, config_path) | ||
|
|
||
|
|
||
| _WEIGHTS_EXTENSIONS: dict[WeightsType, str] = { | ||
| WeightsType.SAFETENSORS: ".safetensors", | ||
| WeightsType.TORCH: ".pth", | ||
| WeightsType.NEMO: ".nemo", | ||
| } | ||
|
|
||
|
|
||
knyazer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| @contextmanager | ||
| def _download_weights_and_config_files( | ||
| model_spec: ModelSpec, | ||
| progress_callback: Callable[[StatusEvent], None] | None = None, | ||
| ) -> Iterator[tuple[list[Path], Path]]: | ||
| local_dir: Path | None = None, | ||
| ) -> Iterator[tuple[list[Path], Path, list[Path]]]: | ||
| effective_local_dir = local_dir or model_spec.local_dir | ||
| if effective_local_dir is not None: | ||
| config_path = effective_local_dir / model_spec.configs.model_config.filename | ||
| ext = _WEIGHTS_EXTENSIONS[model_spec.weights_type] | ||
| weights_paths = sorted(effective_local_dir.glob(f"*{ext}")) | ||
| yield (weights_paths, config_path, []) | ||
| return | ||
|
|
||
| if model_spec.weights_type == WeightsType.NEMO: | ||
| (nemo_model_file,) = download_weights(model_spec, progress_callback=progress_callback) | ||
| with _unpack_nemo_model(nemo_model_file) as nemo_file_contents: | ||
| weights_paths, foreign_config_file_path = nemo_file_contents | ||
| yield (weights_paths, foreign_config_file_path) | ||
| yield (weights_paths, foreign_config_file_path, []) | ||
| else: | ||
| weights_paths = download_weights(model_spec, progress_callback=progress_callback) | ||
| foreign_config_file_path = download_config_file(model_spec) | ||
|
|
||
| yield (weights_paths, foreign_config_file_path) | ||
| extra_config_paths = [ | ||
| download_file(extra_config, model_spec.repo) for extra_config in model_spec.configs.extra_configs | ||
| ] | ||
knyazer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| yield (weights_paths, foreign_config_file_path, extra_config_paths) | ||
|
|
||
|
|
||
| def _load_main_processing_module( | ||
|
|
@@ -345,8 +372,8 @@ def _import_language_model( | |
| with _download_weights_and_config_files( | ||
| model_spec, | ||
| progress_callback=progress_callback, | ||
| ) as (model_weights_paths, config_path): | ||
| foreign_decoder_config = model_spec.config_type.from_json(config_path) | ||
| ) as (model_weights_paths, config_path, extra_config_paths): | ||
| foreign_decoder_config = model_spec.config_type.from_json(config_path, extra_config_paths) | ||
| assert isinstance(foreign_decoder_config, ForeignLMConfig) | ||
|
|
||
| if precision is None: | ||
|
|
@@ -402,8 +429,8 @@ def _import_classifier( | |
| with _download_weights_and_config_files( | ||
| model_spec, | ||
| progress_callback=progress_callback, | ||
| ) as (model_weights_paths, config_path): | ||
| foreign_classifier_config = model_spec.config_type.from_json(config_path) | ||
| ) as (model_weights_paths, config_path, extra_config_paths): | ||
| foreign_classifier_config = model_spec.config_type.from_json(config_path, extra_config_paths) | ||
| assert isinstance(foreign_classifier_config, ForeignClassifierConfig) | ||
|
|
||
| if precision is None: | ||
|
|
@@ -444,8 +471,8 @@ def _import_tts_model( | |
| with _download_weights_and_config_files( | ||
| model_spec, | ||
| progress_callback=progress_callback, | ||
| ) as (model_weights_paths, config_path): | ||
| foreign_tts_config = model_spec.config_type.from_json(config_path) | ||
| ) as (model_weights_paths, config_path, extra_config_paths): | ||
| foreign_tts_config = model_spec.config_type.from_json(config_path, extra_config_paths) | ||
| if precision is None: | ||
| precision = foreign_tts_config.default_precision | ||
| if model_spec.vendor == "FishAudio" and model_spec.family == "openaudio": | ||
|
|
@@ -508,6 +535,71 @@ def _import_tts_model( | |
| return (tts_generator, tts_generator_config) | ||
|
|
||
|
|
||
| def _import_latent_tts_model( | ||
| model_spec: ModelSpec, | ||
| *, | ||
| context_length: int | None = None, | ||
| precision: DTypeLike | None = None, | ||
| accumulation_precision: DTypeLike = jnp.float32, | ||
| progress_callback: Callable[[StatusEvent], None] | None = None, | ||
| local_dir: Path | None = None, | ||
| ) -> tuple[LatentTTSGenerator, LatentTTSGeneratorConfig]: | ||
| with _download_weights_and_config_files( | ||
| model_spec, | ||
| progress_callback=progress_callback, | ||
| local_dir=local_dir, | ||
| ) as (model_weights_paths, config_path, extra_config_paths): | ||
| foreign_config = model_spec.config_type.from_json(config_path, extra_config_paths) | ||
| if precision is None: | ||
| precision = foreign_config.default_precision | ||
|
|
||
| tokenizer = _instantiate_tokenizer_from_model_spec(model_spec, None, progress_callback, local_dir=local_dir) | ||
|
|
||
| latent_tts_model = _load_main_processing_module( | ||
| model_spec, | ||
| model_weights_paths, | ||
| precision, | ||
| foreign_config, | ||
| progress_callback, | ||
| context_length, | ||
| accumulation_precision, | ||
| ) | ||
|
|
||
| assert isinstance(latent_tts_model, LatentTTSModel) | ||
| if progress_callback is not None: | ||
| progress_callback(FinishedInitializingModelEvent()) | ||
|
|
||
| assert isinstance(model_spec.configs.chat_template, str) | ||
| tts_request_factory_config = TTSMessageProcessorConfig( | ||
| prompt_template=model_spec.configs.chat_template, | ||
| ) | ||
| message_processor = TTSMessageProcessor(tts_request_factory_config, tokenizer) | ||
|
|
||
| latent_tts_config = foreign_config.to_lalamo_config( | ||
| context_length=context_length, | ||
| activation_precision=precision, | ||
| accumulation_precision=precision, | ||
knyazer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| metadata_dict={}, | ||
| ) | ||
|
|
||
| if not hasattr(latent_tts_config, "default_generation_config"): | ||
| raise ValueError(f"{type(latent_tts_config).__name__} must implement default_generation_config()") | ||
| generation_config = latent_tts_config.default_generation_config() | ||
| generator_config = LatentTTSGeneratorConfig( | ||
| latent_tts_config=latent_tts_config, | ||
| message_processor_config=message_processor.config, | ||
| generation_config=generation_config, | ||
| ) | ||
|
|
||
| generator = LatentTTSGenerator( | ||
| config=generator_config, | ||
knyazer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| latent_tts_model=latent_tts_model, | ||
| message_processor=message_processor, | ||
| ) | ||
|
|
||
| return (generator, generator_config) | ||
|
|
||
|
|
||
| def import_model( | ||
| model_spec: ModelSpec | str, | ||
| *, | ||
|
|
@@ -516,7 +608,11 @@ def import_model( | |
| accumulation_precision: DTypeLike = jnp.float32, | ||
| progress_callback: Callable[[StatusEvent], None] | None = None, | ||
| sharding_config: ShardingConfig | None = None, | ||
| local_dir: Path | str | None = None, | ||
| ) -> ImportResults: | ||
| if isinstance(local_dir, str): | ||
| local_dir = Path(local_dir) | ||
|
|
||
knyazer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if isinstance(model_spec, str): | ||
| try: | ||
| model_spec = ModelRegistry.build().repo_to_model[model_spec] | ||
|
|
@@ -549,6 +645,15 @@ def import_model( | |
| accumulation_precision=accumulation_precision, | ||
| progress_callback=progress_callback, | ||
| ) | ||
| case ModelType.LATENT_TTS_MODEL: | ||
| model, config = _import_latent_tts_model( | ||
| model_spec, | ||
| context_length=context_length, | ||
| precision=precision, | ||
| accumulation_precision=accumulation_precision, | ||
| progress_callback=progress_callback, | ||
| local_dir=local_dir, | ||
| ) | ||
|
|
||
| metadata = ModelMetadata( | ||
| toolchain_version=LALAMO_VERSION, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
waveform: Float[Array, " audio_samples"]