-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Closed
Description
Want CLI?
Running this as Python in the main folder while had the conda env active - produces more accurate speech patterns for some reason for me
import os
import torch
import numpy as np
import soundfile as sf
import logging
from datetime import datetime
from cli.SparkTTS import SparkTTS
def generate_tts_audio(
text,
model_dir="pretrained_models/Spark-TTS-0.5B",
device="cuda:0",
prompt_speech_path=None,
prompt_text=None,
gender=None,
pitch=None,
speed=None,
save_dir="example/results",
segmentation_threshold=150 #Do not go above this if you want to crash or you have better GPU
):
"""
Generates TTS audio from input text, splitting into segments if necessary.
Args:
text (str): Input text for speech synthesis.
model_dir (str): Path to the model directory.
device (str): Device identifier (e.g., "cuda:0" or "cpu").
prompt_speech_path (str, optional): Path to prompt audio for cloning.
prompt_text (str, optional): Transcript of prompt audio.
gender (str, optional): Gender parameter ("male"/"female").
pitch (str, optional): Pitch parameter (e.g., "moderate").
speed (str, optional): Speed parameter (e.g., "moderate").
save_dir (str): Directory where generated audio will be saved.
segmentation_threshold (int): Maximum number of words per segment.
Returns:
str: The unique file path where the generated audio is saved.
"""
logging.info("Initializing TTS model...")
device = torch.device(device)
model = SparkTTS(model_dir, device)
# Ensure the save directory exists.
os.makedirs(save_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
save_path = os.path.join(save_dir, f"{timestamp}.wav")
# Check if the text is too long.
words = text.split()
if len(words) > segmentation_threshold:
logging.info("Input text exceeds threshold; splitting into segments...")
segments = [' '.join(words[i:i+segmentation_threshold]) for i in range(0, len(words), segmentation_threshold)]
wavs = []
for seg in segments:
with torch.no_grad():
wav = model.inference(
seg,
prompt_speech_path,
prompt_text=prompt_text,
gender=gender,
pitch=pitch,
speed=speed
)
wavs.append(wav)
final_wav = np.concatenate(wavs, axis=0)
else:
with torch.no_grad():
final_wav = model.inference(
text,
prompt_speech_path,
prompt_text=prompt_text,
gender=gender,
pitch=pitch,
speed=speed
)
# Save the generated audio.
sf.write(save_path, final_wav, samplerate=16000)
logging.info(f"Audio saved at: {save_path}")
return save_path
# Example usage:
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
# Sample input (feel free to adjust)
sample_text = (
"The mind that opens to a new idea never returns to its original size. "
"Hellstrom’s Hive: Chapter 1 – The Awakening. Mara Vance stirred from a deep, dreamless sleep, "
"her consciousness surfacing like a diver breaking through the ocean's surface. "
"A dim, amber light filtered through her closed eyelids, warm and pulsing softly. "
"She hesitated to open her eyes, savoring the fleeting peace before reality set in. "
"A cool, earthy scent filled her nostrils—damp soil mingled with something sweet and metallic. "
"The air was thick, almost humid, carrying with it a faint vibration that resonated in her bones. "
"It wasn't just a sound; it was a presence. "
"Her eyelids fluttered open. Above her stretched a ceiling unlike any she'd seen—organic and alive, "
"composed of interwoven tendrils that glowed with the same amber light. They pulsated gently, "
"like the breathing of some colossal creature. Shadows danced between the strands, creating shifting patterns."
)
# Call the function (adjust parameters as needed)
output_file = generate_tts_audio(
sample_text,
gender="male",
pitch="moderate",
speed="moderate"
)
print("Generated audio file:", output_file)Better GUI
And GUI if someone wants one - it's Light weight - same as if you run trough CLI - at least on 3060 it runs normal - combines text - but it will crash if you place a ton of text unfortunately
it requires just to
pip install pyside6 🔹 Buttons?!?
- Text Input: The big text box where you enter text to be converted into speech.
- Load Voice Sample: Loads a voice sample (MP3/WAV) for RVC-like functionality, allowing voice transformation.
- Reset Voice Sample: Clears the loaded voice sample, letting you switch back to gender-based synthesis without restarting the app.
- Gender Selection Dropdown:
- If using Spark-TTS, select "Male" or "Female" for a generated voice.
- If left on "Auto," Spark-TTS will fail.
- Takes a few seconds to generate before synthesis starts.
- Generate Speech: Starts generating speech based on the entered text and selected parameters.
- Play: Plays the last generated audio file.
- Stop: Stops playback.
- Save Audio: Saves the last generated audio to a file.
- Word Count: That thing that count words.
😎
import sys
import os
import time
import torch
import shutil
import numpy as np
import soundfile as sf
from PySide6.QtWidgets import (
QApplication, QWidget, QVBoxLayout, QPushButton, QLabel,
QTextEdit, QSlider, QFileDialog, QComboBox, QHBoxLayout
)
from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtMultimedia import QMediaPlayer, QAudioOutput
from PySide6.QtGui import QPainter, QColor, QPen, QIcon
from cli.SparkTTS import SparkTTS
# --- Worker Thread for TTS Generation (with segmentation support) ---
class TTSWorker(QThread):
result_ready = Signal(object, float) # Emits (final result, generation_time)
progress_update = Signal(int, int) # Emits (current_segment, total_segments)
def __init__(self, model, text, voice_sample, gender, pitch, speed):
"""
text: Either a string or a list of strings (segments).
"""
super().__init__()
self.model = model
self.text = text
self.voice_sample = voice_sample
self.gender = gender
self.pitch = pitch
self.speed = speed
def run(self):
start = time.time()
try:
results = []
if isinstance(self.text, list):
total = len(self.text)
for i, segment in enumerate(self.text):
with torch.no_grad():
wav = self.model.inference(
segment,
prompt_speech_path=self.voice_sample,
gender=self.gender,
pitch=self.pitch,
speed=self.speed
)
results.append(wav)
self.progress_update.emit(i + 1, total)
final_wav = np.concatenate(results, axis=0)
else:
with torch.no_grad():
final_wav = self.model.inference(
self.text,
prompt_speech_path=self.voice_sample,
gender=self.gender,
pitch=self.pitch,
speed=self.speed
)
self.progress_update.emit(1, 1)
elapsed = time.time() - start
self.result_ready.emit(final_wav, elapsed)
except Exception as e:
self.result_ready.emit(e, 0)
# --- Waveform Visualization Widget ---
class WaveformWidget(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.progress = 0.0 # Range: 0.0 to 1.0
def set_progress(self, progress):
self.progress = progress
self.update()
def paintEvent(self, event):
painter = QPainter(self)
painter.fillRect(self.rect(), QColor("black"))
pen = QPen(QColor("green"))
pen.setWidth(5)
painter.setPen(pen)
painter.drawLine(0, self.height() // 2, int(self.width() * self.progress), self.height() // 2)
# --- Main Application Class ---
class SparkTTSApp(QWidget):
def __init__(self, model, device):
super().__init__()
self.model = model
self.voice_sample = None
self.current_audio_file = None
self.total_duration = 0
self.init_ui()
self.status_label.setText(f"Model loaded on {device}")
# Set app icon if available.
icon_path = "src/logo.webp"
if os.path.exists(icon_path):
self.setWindowIcon(QIcon(icon_path)) # Set app icon if found.
# Initialize audio player and output.
self.audio_player = QMediaPlayer()
self.audio_output = QAudioOutput()
self.audio_player.setAudioOutput(self.audio_output)
self.audio_player.positionChanged.connect(self.on_position_changed)
self.audio_player.durationChanged.connect(self.on_duration_changed)
def init_ui(self):
self.setWindowTitle("Spark-TTS GUI")
self.setMinimumSize(600, 400)
main_layout = QVBoxLayout()
main_layout.setContentsMargins(15, 15, 15, 15)
# Text input.
self.text_input = QTextEdit()
self.text_input.setPlaceholderText("Enter text for speech synthesis...")
main_layout.addWidget(self.text_input)
self.word_count_label = QLabel("Word Count: 0")
main_layout.addWidget(self.word_count_label)
self.text_input.textChanged.connect(self.update_word_count)
btn_layout = QHBoxLayout()
self.voice_btn = QPushButton("Load Voice Sample")
self.voice_btn.clicked.connect(self.select_voice_sample)
self.reset_voice_btn = QPushButton("Reset Voice Sample")
self.reset_voice_btn.clicked.connect(self.reset_voice_sample)
self.generate_btn = QPushButton("Generate Speech")
self.generate_btn.clicked.connect(self.run_synthesis)
btn_layout.addWidget(self.voice_btn)
btn_layout.addWidget(self.reset_voice_btn)
btn_layout.addWidget(self.generate_btn)
main_layout.addLayout(btn_layout)
# Controls Layout (only Gender, Pitch, and Speed).
controls_layout = QHBoxLayout()
self.gender_selector = QComboBox()
self.gender_selector.addItems(["Auto", "Male", "Female"])
controls_layout.addWidget(QLabel("Gender:"))
controls_layout.addWidget(self.gender_selector)
self.pitch_slider, pitch_layout = self.create_slider_with_value("Pitch")
controls_layout.addLayout(pitch_layout)
self.speed_slider, speed_layout = self.create_slider_with_value("Speed")
controls_layout.addLayout(speed_layout)
main_layout.addLayout(controls_layout)
# Audio controls layout.
audio_controls = QHBoxLayout()
self.play_btn = QPushButton("Play")
self.play_btn.clicked.connect(self.play_audio)
self.stop_btn = QPushButton("Stop")
self.stop_btn.clicked.connect(self.stop_audio)
self.save_btn = QPushButton("Save Audio")
self.save_btn.clicked.connect(self.save_audio)
audio_controls.addWidget(self.play_btn)
audio_controls.addWidget(self.stop_btn)
audio_controls.addWidget(self.save_btn)
main_layout.addLayout(audio_controls)
# Status bar.
self.status_label = QLabel("Ready")
self.status_label.setAlignment(Qt.AlignCenter)
main_layout.addWidget(self.status_label)
# Waveform visualization widget.
self.waveform = WaveformWidget()
main_layout.addWidget(self.waveform)
self.setLayout(main_layout)
def create_slider_with_value(self, label_text):
from PySide6.QtWidgets import QVBoxLayout
layout = QVBoxLayout()
label = QLabel(label_text)
slider = QSlider(Qt.Horizontal)
slider.setRange(0, 4)
slider.setValue(2)
value_label = QLabel("2")
slider.valueChanged.connect(lambda val: value_label.setText(str(val)))
layout.addWidget(label)
layout.addWidget(slider)
layout.addWidget(value_label)
# Descriptive text under the slider.
desc_label = QLabel(f"Adjust {label_text.lower()} level")
layout.addWidget(desc_label)
return slider, layout
def update_word_count(self):
"""Updates the word count dynamically as the user types."""
text = self.text_input.toPlainText().strip()
word_count = len(text.split()) if text else 0
self.word_count_label.setText(f"Word Count: {word_count}")
def reset_voice_sample(self):
"""Clears the loaded voice sample and restores gender selection."""
self.voice_sample = None
self.gender_selector.setEnabled(True)
self.status_label.setText("Voice sample cleared. You can now use gender selection.")
def select_voice_sample(self):
file_path, _ = QFileDialog.getOpenFileName(
self, "Select Voice Sample", "", "Audio Files (*.wav *.mp3)"
)
if file_path:
self.voice_sample = file_path
self.status_label.setText(f"Loaded voice sample: {os.path.basename(file_path)}")
def save_audio(self):
if not (self.current_audio_file and os.path.exists(self.current_audio_file)):
self.status_label.setText("No audio to save!")
return
save_path, _ = QFileDialog.getSaveFileName(
self, "Save Audio", "", "WAV Files (*.wav)"
)
if save_path:
shutil.copy(self.current_audio_file, save_path)
self.status_label.setText(f"Audio saved to: {os.path.basename(save_path)}")
def play_audio(self):
if self.current_audio_file and os.path.exists(self.current_audio_file):
self.audio_player.setSource(self.current_audio_file)
self.audio_player.play()
def stop_audio(self):
self.audio_player.stop()
def on_duration_changed(self, duration):
self.total_duration = duration
def on_position_changed(self, position):
if self.total_duration > 0:
progress = position / self.total_duration
self.waveform.set_progress(progress)
def run_synthesis(self):
text = self.text_input.toPlainText().strip()
if not text:
self.status_label.setText("Please enter some text!")
return
# Segmentation: Limit each segment to 150 words.
segmentation_threshold = 150
words = text.split()
if len(words) > segmentation_threshold:
text_to_process = [
' '.join(words[i:i + segmentation_threshold])
for i in range(0, len(words), segmentation_threshold)
]
self.status_label.setText("Text too long: processing segments...")
else:
text_to_process = text
# Determine parameters based on whether a voice sample is loaded.
if self.voice_sample is not None:
prompt = self.voice_sample
gender = None
pitch = None
speed = None
else:
prompt = None
gender = self.gender_selector.currentText().lower()
gender = None if gender == "auto" else gender
pitch_map = ["very_low", "low", "moderate", "high", "very_high"]
speed_map = ["very_low", "low", "moderate", "high", "very_high"]
pitch = pitch_map[self.pitch_slider.value()]
speed = speed_map[self.speed_slider.value()]
self.generate_btn.setEnabled(False)
self.status_label.setText("Generating speech...")
self.worker = TTSWorker(self.model, text_to_process, prompt, gender, pitch, speed)
self.worker.progress_update.connect(self.on_generation_progress)
self.worker.result_ready.connect(self.on_generation_complete)
self.worker.start()
def on_generation_progress(self, current, total):
self.status_label.setText(f"Generating segment {current} / {total}...")
def on_generation_complete(self, result, elapsed):
if isinstance(result, Exception):
self.status_label.setText(f"Error: {result}")
else:
filename = f"output_{int(time.time())}.wav"
sf.write(filename, result, samplerate=16000)
self.current_audio_file = filename
self.status_label.setText(f"Generated in {elapsed:.1f}s | Saved to {filename}")
self.generate_btn.setEnabled(True)
if __name__ == "__main__":
app = QApplication(sys.argv)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SparkTTS("pretrained_models/Spark-TTS-0.5B", device=device)
window = SparkTTSApp(model, device.type.upper())
window.show()
sys.exit(app.exec())I've got side-tracked - and I'm updating it to look better - but this version works - it's slower than the CLI, but hey you type it speaks.
Cheers.
xinshengwang, wangxinFyfting, Paegasus, f5qiang, buhuipao and 3 more
Metadata
Metadata
Assignees
Labels
No labels



