Skip to content

Examples: GUI and "Advanced" CLI #10

@AcTePuKc

Description

@AcTePuKc

Want CLI?

Running this as Python in the main folder while had the conda env active - produces more accurate speech patterns for some reason for me

import os
import torch
import numpy as np
import soundfile as sf
import logging
from datetime import datetime
from cli.SparkTTS import SparkTTS

def generate_tts_audio(
    text,
    model_dir="pretrained_models/Spark-TTS-0.5B",
    device="cuda:0",
    prompt_speech_path=None,
    prompt_text=None,
    gender=None,
    pitch=None,
    speed=None,
    save_dir="example/results",
    segmentation_threshold=150 #Do not go above this if you want to crash or you have better GPU
):
    """
    Generates TTS audio from input text, splitting into segments if necessary.

    Args:
        text (str): Input text for speech synthesis.
        model_dir (str): Path to the model directory.
        device (str): Device identifier (e.g., "cuda:0" or "cpu").
        prompt_speech_path (str, optional): Path to prompt audio for cloning.
        prompt_text (str, optional): Transcript of prompt audio.
        gender (str, optional): Gender parameter ("male"/"female").
        pitch (str, optional): Pitch parameter (e.g., "moderate").
        speed (str, optional): Speed parameter (e.g., "moderate").
        save_dir (str): Directory where generated audio will be saved.
        segmentation_threshold (int): Maximum number of words per segment.

    Returns:
        str: The unique file path where the generated audio is saved.
    """
    logging.info("Initializing TTS model...")
    device = torch.device(device)
    model = SparkTTS(model_dir, device)
    
    # Ensure the save directory exists.
    os.makedirs(save_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    save_path = os.path.join(save_dir, f"{timestamp}.wav")

    # Check if the text is too long.
    words = text.split()
    if len(words) > segmentation_threshold:
        logging.info("Input text exceeds threshold; splitting into segments...")
        segments = [' '.join(words[i:i+segmentation_threshold]) for i in range(0, len(words), segmentation_threshold)]
        wavs = []
        for seg in segments:
            with torch.no_grad():
                wav = model.inference(
                    seg,
                    prompt_speech_path,
                    prompt_text=prompt_text,
                    gender=gender,
                    pitch=pitch,
                    speed=speed
                )
            wavs.append(wav)
        final_wav = np.concatenate(wavs, axis=0)
    else:
        with torch.no_grad():
            final_wav = model.inference(
                text,
                prompt_speech_path,
                prompt_text=prompt_text,
                gender=gender,
                pitch=pitch,
                speed=speed
            )
    
    # Save the generated audio.
    sf.write(save_path, final_wav, samplerate=16000)
    logging.info(f"Audio saved at: {save_path}")
    return save_path

# Example usage:
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
    
    # Sample input (feel free to adjust)
    sample_text = (
        "The mind that opens to a new idea never returns to its original size. "
        "Hellstrom’s Hive: Chapter 1 – The Awakening. Mara Vance stirred from a deep, dreamless sleep, "
        "her consciousness surfacing like a diver breaking through the ocean's surface. "
        "A dim, amber light filtered through her closed eyelids, warm and pulsing softly. "
        "She hesitated to open her eyes, savoring the fleeting peace before reality set in. "
        "A cool, earthy scent filled her nostrils—damp soil mingled with something sweet and metallic. "
        "The air was thick, almost humid, carrying with it a faint vibration that resonated in her bones. "
        "It wasn't just a sound; it was a presence. "
        "Her eyelids fluttered open. Above her stretched a ceiling unlike any she'd seen—organic and alive, "
        "composed of interwoven tendrils that glowed with the same amber light. They pulsated gently, "
        "like the breathing of some colossal creature. Shadows danced between the strands, creating shifting patterns."
    )
    
    # Call the function (adjust parameters as needed)
    output_file = generate_tts_audio(
        sample_text,
        gender="male",
        pitch="moderate",
        speed="moderate"
    )
    print("Generated audio file:", output_file)

Better GUI

And GUI if someone wants one - it's Light weight - same as if you run trough CLI - at least on 3060 it runs normal - combines text - but it will crash if you place a ton of text unfortunately
it requires just to

 pip install pyside6 

🔹 Buttons?!?

  • Text Input: The big text box where you enter text to be converted into speech.
  • Load Voice Sample: Loads a voice sample (MP3/WAV) for RVC-like functionality, allowing voice transformation.
  • Reset Voice Sample: Clears the loaded voice sample, letting you switch back to gender-based synthesis without restarting the app.
  • Gender Selection Dropdown:
    • If using Spark-TTS, select "Male" or "Female" for a generated voice.
    • If left on "Auto," Spark-TTS will fail.
    • Takes a few seconds to generate before synthesis starts.
  • Generate Speech: Starts generating speech based on the entered text and selected parameters.
  • Play: Plays the last generated audio file.
  • Stop: Stops playback.
  • Save Audio: Saves the last generated audio to a file.
  • Word Count: That thing that count words.

😎

Image

Image

Image

Image

import sys
import os
import time
import torch
import shutil
import numpy as np
import soundfile as sf
from PySide6.QtWidgets import (
    QApplication, QWidget, QVBoxLayout, QPushButton, QLabel,
    QTextEdit, QSlider, QFileDialog, QComboBox, QHBoxLayout
)
from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtMultimedia import QMediaPlayer, QAudioOutput
from PySide6.QtGui import QPainter, QColor, QPen, QIcon
from cli.SparkTTS import SparkTTS

# --- Worker Thread for TTS Generation (with segmentation support) ---
class TTSWorker(QThread):
    result_ready = Signal(object, float)  # Emits (final result, generation_time)
    progress_update = Signal(int, int)      # Emits (current_segment, total_segments)

    def __init__(self, model, text, voice_sample, gender, pitch, speed):
        """
        text: Either a string or a list of strings (segments).
        """
        super().__init__()
        self.model = model
        self.text = text
        self.voice_sample = voice_sample
        self.gender = gender
        self.pitch = pitch
        self.speed = speed
        

    def run(self):
        start = time.time()
        try:
            results = []
            if isinstance(self.text, list):
                total = len(self.text)
                for i, segment in enumerate(self.text):
                    with torch.no_grad():
                        wav = self.model.inference(
                            segment,
                            prompt_speech_path=self.voice_sample,
                            gender=self.gender,
                            pitch=self.pitch,
                            speed=self.speed
                        )
                    results.append(wav)
                    self.progress_update.emit(i + 1, total)
                final_wav = np.concatenate(results, axis=0)
            else:
                with torch.no_grad():
                    final_wav = self.model.inference(
                        self.text,
                        prompt_speech_path=self.voice_sample,
                        gender=self.gender,
                        pitch=self.pitch,
                        speed=self.speed
                    )
                self.progress_update.emit(1, 1)
            elapsed = time.time() - start
            self.result_ready.emit(final_wav, elapsed)
        except Exception as e:
            self.result_ready.emit(e, 0)

# --- Waveform Visualization Widget ---
class WaveformWidget(QWidget):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.progress = 0.0  # Range: 0.0 to 1.0

    def set_progress(self, progress):
        self.progress = progress
        self.update()

    def paintEvent(self, event):
        painter = QPainter(self)
        painter.fillRect(self.rect(), QColor("black"))
        pen = QPen(QColor("green"))
        pen.setWidth(5)
        painter.setPen(pen)
        painter.drawLine(0, self.height() // 2, int(self.width() * self.progress), self.height() // 2)

# --- Main Application Class ---
class SparkTTSApp(QWidget):
    def __init__(self, model, device):
        super().__init__()
        self.model = model
        self.voice_sample = None
        self.current_audio_file = None
        self.total_duration = 0
        self.init_ui()
        self.status_label.setText(f"Model loaded on {device}")

        # Set app icon if available.
        icon_path = "src/logo.webp"
        if os.path.exists(icon_path):
            self.setWindowIcon(QIcon(icon_path))  # Set app icon if found.

        # Initialize audio player and output.    
        self.audio_player = QMediaPlayer()
        self.audio_output = QAudioOutput()
        self.audio_player.setAudioOutput(self.audio_output)
        self.audio_player.positionChanged.connect(self.on_position_changed)
        self.audio_player.durationChanged.connect(self.on_duration_changed)

    def init_ui(self):
        self.setWindowTitle("Spark-TTS GUI")
        self.setMinimumSize(600, 400)
        main_layout = QVBoxLayout()
        main_layout.setContentsMargins(15, 15, 15, 15)
        
        # Text input.
        self.text_input = QTextEdit()
        self.text_input.setPlaceholderText("Enter text for speech synthesis...")
        main_layout.addWidget(self.text_input)

        self.word_count_label = QLabel("Word Count: 0")
        main_layout.addWidget(self.word_count_label)

        self.text_input.textChanged.connect(self.update_word_count)

        
        
        btn_layout = QHBoxLayout()
        self.voice_btn = QPushButton("Load Voice Sample")
        self.voice_btn.clicked.connect(self.select_voice_sample)
        self.reset_voice_btn = QPushButton("Reset Voice Sample")
        self.reset_voice_btn.clicked.connect(self.reset_voice_sample)
        self.generate_btn = QPushButton("Generate Speech")
        self.generate_btn.clicked.connect(self.run_synthesis)
        btn_layout.addWidget(self.voice_btn)
        btn_layout.addWidget(self.reset_voice_btn)
        btn_layout.addWidget(self.generate_btn)
        main_layout.addLayout(btn_layout)
        
        # Controls Layout (only Gender, Pitch, and Speed).
        controls_layout = QHBoxLayout()
        self.gender_selector = QComboBox()
        self.gender_selector.addItems(["Auto", "Male", "Female"])
        controls_layout.addWidget(QLabel("Gender:"))
        controls_layout.addWidget(self.gender_selector)
        
        self.pitch_slider, pitch_layout = self.create_slider_with_value("Pitch")
        controls_layout.addLayout(pitch_layout)
        
        self.speed_slider, speed_layout = self.create_slider_with_value("Speed")
        controls_layout.addLayout(speed_layout)
        
        main_layout.addLayout(controls_layout)
        
        # Audio controls layout.
        audio_controls = QHBoxLayout()
        self.play_btn = QPushButton("Play")
        self.play_btn.clicked.connect(self.play_audio)
        self.stop_btn = QPushButton("Stop")
        self.stop_btn.clicked.connect(self.stop_audio)
        self.save_btn = QPushButton("Save Audio")
        self.save_btn.clicked.connect(self.save_audio)
        audio_controls.addWidget(self.play_btn)
        audio_controls.addWidget(self.stop_btn)
        audio_controls.addWidget(self.save_btn)
        main_layout.addLayout(audio_controls)
        
        # Status bar.
        self.status_label = QLabel("Ready")
        self.status_label.setAlignment(Qt.AlignCenter)
        main_layout.addWidget(self.status_label)
        
        # Waveform visualization widget.
        self.waveform = WaveformWidget()
        main_layout.addWidget(self.waveform)
        
        self.setLayout(main_layout)

    def create_slider_with_value(self, label_text):
        from PySide6.QtWidgets import QVBoxLayout
        layout = QVBoxLayout()
        label = QLabel(label_text)
        slider = QSlider(Qt.Horizontal)
        slider.setRange(0, 4)
        slider.setValue(2)
        value_label = QLabel("2")
        slider.valueChanged.connect(lambda val: value_label.setText(str(val)))
        layout.addWidget(label)
        layout.addWidget(slider)
        layout.addWidget(value_label)
        # Descriptive text under the slider.
        desc_label = QLabel(f"Adjust {label_text.lower()} level")
        layout.addWidget(desc_label)
        return slider, layout
    
    def update_word_count(self):
        """Updates the word count dynamically as the user types."""
        text = self.text_input.toPlainText().strip()
        word_count = len(text.split()) if text else 0
        self.word_count_label.setText(f"Word Count: {word_count}")

    def reset_voice_sample(self):
        """Clears the loaded voice sample and restores gender selection."""
        self.voice_sample = None
        self.gender_selector.setEnabled(True)
        self.status_label.setText("Voice sample cleared. You can now use gender selection.")

    def select_voice_sample(self):
        file_path, _ = QFileDialog.getOpenFileName(
            self, "Select Voice Sample", "", "Audio Files (*.wav *.mp3)"
        )
        if file_path:
            self.voice_sample = file_path
            self.status_label.setText(f"Loaded voice sample: {os.path.basename(file_path)}")

    def save_audio(self):
        if not (self.current_audio_file and os.path.exists(self.current_audio_file)):
            self.status_label.setText("No audio to save!")
            return
        save_path, _ = QFileDialog.getSaveFileName(
            self, "Save Audio", "", "WAV Files (*.wav)"
        )
        if save_path:
            shutil.copy(self.current_audio_file, save_path)
            self.status_label.setText(f"Audio saved to: {os.path.basename(save_path)}")

    def play_audio(self):
        if self.current_audio_file and os.path.exists(self.current_audio_file):
            self.audio_player.setSource(self.current_audio_file)
            self.audio_player.play()

    def stop_audio(self):
        self.audio_player.stop()

    def on_duration_changed(self, duration):
        self.total_duration = duration

    def on_position_changed(self, position):
        if self.total_duration > 0:
            progress = position / self.total_duration
            self.waveform.set_progress(progress)

    def run_synthesis(self):
        text = self.text_input.toPlainText().strip()
        if not text:
            self.status_label.setText("Please enter some text!")
            return

        # Segmentation: Limit each segment to 150 words.
        segmentation_threshold = 150
        words = text.split()
        if len(words) > segmentation_threshold:
            text_to_process = [
                ' '.join(words[i:i + segmentation_threshold])
                for i in range(0, len(words), segmentation_threshold)
            ]
            self.status_label.setText("Text too long: processing segments...")
        else:
            text_to_process = text

        # Determine parameters based on whether a voice sample is loaded.
        if self.voice_sample is not None:
            prompt = self.voice_sample
            gender = None
            pitch = None
            speed = None
        else:
            prompt = None
            gender = self.gender_selector.currentText().lower()
            gender = None if gender == "auto" else gender
            pitch_map = ["very_low", "low", "moderate", "high", "very_high"]
            speed_map = ["very_low", "low", "moderate", "high", "very_high"]
            pitch = pitch_map[self.pitch_slider.value()]
            speed = speed_map[self.speed_slider.value()]

        self.generate_btn.setEnabled(False)
        self.status_label.setText("Generating speech...")

        self.worker = TTSWorker(self.model, text_to_process, prompt, gender, pitch, speed)
        self.worker.progress_update.connect(self.on_generation_progress)
        self.worker.result_ready.connect(self.on_generation_complete)
        self.worker.start()

    def on_generation_progress(self, current, total):
        self.status_label.setText(f"Generating segment {current} / {total}...")

    def on_generation_complete(self, result, elapsed):
        if isinstance(result, Exception):
            self.status_label.setText(f"Error: {result}")
        else:
            filename = f"output_{int(time.time())}.wav"
            sf.write(filename, result, samplerate=16000)
            self.current_audio_file = filename
            self.status_label.setText(f"Generated in {elapsed:.1f}s | Saved to {filename}")
        self.generate_btn.setEnabled(True)

if __name__ == "__main__":
    app = QApplication(sys.argv)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SparkTTS("pretrained_models/Spark-TTS-0.5B", device=device)
    window = SparkTTSApp(model, device.type.upper())
    window.show()
    sys.exit(app.exec())

I've got side-tracked - and I'm updating it to look better - but this version works - it's slower than the CLI, but hey you type it speaks.

Cheers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions