diff --git a/README.md b/README.md index d6651ef7..82f9f991 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,8 @@ If you like the model but need to scale or tune it for higher accuracy, check ou - Watermarked outputs - Easy voice conversion script - [Outperforms ElevenLabs](https://podonos.com/resembleai/chatterbox) +- Smart artifact cleaning with pause protection +- Support for custom pause tags `[pause:xx]` # Supported Languages Arabic (ar) • Danish (da) • German (de) • Greek (el) • English (en) • Spanish (es) • Finnish (fi) • French (fr) • Hebrew (he) • Hindi (hi) • Italian (it) • Japanese (ja) • Korean (ko) • Malay (ms) • Dutch (nl) • Norwegian (no) • Polish (pl) • Portuguese (pt) • Russian (ru) • Swedish (sv) • Swahili (sw) • Turkish (tr) • Chinese (zh) @@ -84,6 +86,15 @@ ta.save("test-chinese.wav", wav_chinese, model.sr) AUDIO_PROMPT_PATH = "YOUR_FILE.wav" wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH) ta.save("test-2.wav", wav, model.sr) + +# Using artifact cleaning feature +wav = model.generate( + text="Hello[pause:0.5s]world!", # Support [pause:xx] tag for adding pauses + use_auto_editor=True, # Enable artifact cleaning + ae_threshold=0.06, # Volume threshold (0-1) + ae_margin=0.2 # Boundary protection time (seconds) +) +ta.save("test-3.wav", wav, model.sr) ``` See `example_tts.py` and `example_vc.py` for more examples. diff --git a/pyproject.toml b/pyproject.toml index 780ec361..5be8a6fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ authors = [ ] dependencies = [ "numpy>=1.24.0,<1.26.0", + "resampy==0.4.3", "librosa==0.11.0", "s3tokenizer", "torch==2.6.0", @@ -23,7 +24,7 @@ dependencies = [ "pykakasi==2.3.0", "gradio==5.44.1", "russian-text-stresser @ git+https://github.com/Vuizur/add-stress-to-epub", - + "auto-editor>=27.0.0" ] [project.urls] @@ -36,3 +37,10 @@ build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] where = ["src"] + +[dependency-groups] +dev = [ + "gradio>=4.44.1", + "resemble-perth>=1.0.1", + "setuptools>=80.9.0", +] diff --git a/src/chatterbox/__init__.py b/src/chatterbox/__init__.py index 190cfbf2..a0fc1513 100644 --- a/src/chatterbox/__init__.py +++ b/src/chatterbox/__init__.py @@ -8,4 +8,5 @@ from .tts import ChatterboxTTS from .vc import ChatterboxVC -from .mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES \ No newline at end of file +from .mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES +from .text_utils import split_text_into_segments, split_by_word_boundary, merge_short_sentences diff --git a/src/chatterbox/text_utils.py b/src/chatterbox/text_utils.py new file mode 100644 index 00000000..77db8aaa --- /dev/null +++ b/src/chatterbox/text_utils.py @@ -0,0 +1,359 @@ +""" +Text processing utility functions + +Contains long text splitting, sentence merging and other text processing functions +Support for multiple languages including English, Chinese, Japanese, Korean, etc. +""" + +import re +from typing import List, Optional, Tuple +import logging + + +def detect_language(text: str) -> str: + """ + Simple language detection based on character patterns + + Parameters: + text: Text to detect language for + + Returns: + Language code: 'zh' (Chinese), 'ja' (Japanese), 'ko' (Korean), 'en' (English/others) + """ + # Count different character types + chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text)) # CJK Unified Ideographs + japanese_chars = len(re.findall(r'[\u3040-\u309f\u30a0-\u30ff]', text)) # Hiragana + Katakana + korean_chars = len(re.findall(r'[\uac00-\ud7af]', text)) # Hangul + + total_chars = len(re.sub(r'\s+', '', text)) + + if total_chars == 0: + return 'en' + + # Calculate ratios + chinese_ratio = chinese_chars / total_chars + japanese_ratio = japanese_chars / total_chars + korean_ratio = korean_chars / total_chars + + # Determine primary language + if chinese_ratio > 0.3: + return 'zh' + elif japanese_ratio > 0.3: + return 'ja' + elif korean_ratio > 0.3: + return 'ko' + else: + return 'en' + + +def get_sentence_separators(lang: str) -> str: + """ + Get sentence separator pattern for different languages + + Parameters: + lang: Language code + + Returns: + Regex pattern for sentence separators + """ + if lang == 'zh': + # Chinese punctuation + return r'(?<=[。!?;])\s*' + elif lang == 'ja': + # Japanese punctuation + return r'(?<=[。!?])\s*' + elif lang == 'ko': + # Korean punctuation + return r'(?<=[\.!?。!?])\s*' + else: + # English and other languages + return r'(?<=[.!?])\s+' + + +def get_punctuation_pattern(lang: str) -> str: + """ + Get punctuation pattern for word boundary splitting + + Parameters: + lang: Language code + + Returns: + Regex pattern for punctuation marks + """ + if lang == 'zh': + return r'(?<=[。!?,;:、])\s*' + elif lang == 'ja': + return r'(?<=[。!?、,])\s*' + elif lang == 'ko': + return r'(?<=[\.!?,;。!?,;])\s*' + else: + return r'(?<=[.!?,;:])\s+' + + +def split_by_word_boundary(text: str, max_len: int, lang: Optional[str] = None) -> List[str]: + """ + Split text by word boundaries to ensure words are not broken in the middle + Support for multiple languages + + Parameters: + text: Text to split + max_len: Maximum length of each segment + lang: Language code (auto-detect if None) + + Returns: + List of text segments + """ + if lang is None: + lang = detect_language(text) + + # First try to split at punctuation marks + punct_pattern = get_punctuation_pattern(lang) + punct_splits = re.split(punct_pattern, text) + + segments = [] + for split in punct_splits: + if len(split) <= max_len: + segments.append(split) + continue + + # If the segment after punctuation splitting is still too long + if lang in ['zh', 'ja', 'ko']: + # For CJK languages, split by character for fine-grained control + segments.extend(_split_cjk_text(split, max_len)) + else: + # For space-separated languages, split by word boundaries + segments.extend(_split_spaced_text(split, max_len)) + + return segments + + +def _split_cjk_text(text: str, max_len: int) -> List[str]: + """ + Split CJK text by characters while trying to preserve semantic units + + Parameters: + text: CJK text to split + max_len: Maximum length of each segment + + Returns: + List of text segments + """ + segments = [] + current = "" + + # Try to break at common phrase boundaries + phrase_boundaries = r'[,、;:]' # Common CJK phrase separators + + i = 0 + while i < len(text): + char = text[i] + + # Check if adding this character would exceed max length + if len(current + char) > max_len: + if current: + segments.append(current) + current = char + else: + # Single character exceeds max length, add it anyway + segments.append(char) + current = "" + else: + current += char + + # Try to break at phrase boundaries if we're getting close to max length + if (len(current) > max_len * 0.8 and + re.search(phrase_boundaries, char) and + i < len(text) - 1): + segments.append(current) + current = "" + + i += 1 + + # Add the remaining part + if current: + segments.append(current) + + return segments + + +def _split_spaced_text(text: str, max_len: int) -> List[str]: + """ + Split space-separated text by word boundaries + + Parameters: + text: Text to split + max_len: Maximum length of each segment + + Returns: + List of text segments + """ + segments = [] + words = text.split() + current = "" + + for word in words: + # If the word itself exceeds maximum length, try to keep the complete word + if len(word) > max_len: + if current: + segments.append(current) + segments.append(word) + current = "" + continue + + # Check if adding this word would exceed maximum length + potential_segment = (current + " " + word).strip() if current else word + if len(potential_segment) > max_len: + if current: + segments.append(current) + current = word + else: + current = potential_segment + + # Add the remaining part + if current: + segments.append(current) + + return segments + + +def merge_short_sentences(sentences: List[str], max_length: int, min_length: int = 20, lang: Optional[str] = None) -> List[str]: + """ + Merge short sentences to the next sentence, ensuring not to exceed maximum length limit + Support for multiple languages + + Parameters: + sentences: List of sentences + max_length: Maximum length limit after merging + min_length: Length threshold for short sentences + lang: Language code (auto-detect if None) + + Returns: + List of merged sentences + """ + if not sentences: + return [] + + # Auto-detect language from first non-empty sentence + if lang is None: + for sentence in sentences: + if sentence.strip(): + lang = detect_language(sentence) + break + else: + lang = 'en' + + # Adjust min_length for different languages + # CJK languages are more information-dense + if lang in ['zh', 'ja', 'ko']: + min_length = max(10, min_length // 2) + + result = [] + i = 0 + + while i < len(sentences): + current = sentences[i].strip() + + # Skip empty sentences + if not current: + i += 1 + continue + + # If current sentence length is greater than or equal to min_length, add directly to result + if len(current) >= min_length: + result.append(current) + i += 1 + continue + + # Current sentence is short, try to merge with subsequent sentences + merged = current + j = i + 1 + while j < len(sentences) and len(merged) < min_length: + next_sentence = sentences[j].strip() + if not next_sentence: + j += 1 + continue + + # Check if merging would exceed maximum length + separator = "" if lang in ['zh', 'ja', 'ko'] else " " + potential_merge = merged + separator + next_sentence + if len(potential_merge) <= max_length: + merged = potential_merge + j += 1 + else: + # If merging would exceed maximum length, stop merging + break + + # Add merged result + if merged: + result.append(merged) + + # Update index + i = j if j > i else i + 1 + + return result + + +def split_text_into_segments(text: str, max_length: int = 300, logger: Optional[logging.Logger] = None) -> List[str]: + """ + Split text into segments suitable for TTS processing + Support for multiple languages including English, Chinese, Japanese, Korean + + Parameters: + text: Text to split + max_length: Maximum character count for each segment + logger: Optional logger + + Returns: + List of split segments + """ + if not text or not text.strip(): + return [] + + text = text.strip() + + # Auto-detect language + lang = detect_language(text) + + # If text length doesn't exceed maximum length, return directly + if len(text) <= max_length: + return [text] + + if logger: + logger.debug(f"Text length {len(text)} exceeds maximum limit {max_length}, starting split (language: {lang})") + + # First split by paragraphs (double newlines) + paragraphs = [p.strip() for p in re.split(r'\n\s*\n', text) if p.strip()] + + segments = [] + for paragraph in paragraphs: + if len(paragraph) <= max_length: + segments.append(paragraph) + continue + + # Paragraph too long, split further + # 1. Try to split by sentences using language-appropriate patterns + sentence_pattern = get_sentence_separators(lang) + sentences = re.split(sentence_pattern, paragraph) + sentences = [s.strip() for s in sentences if s.strip()] + + # 2. Merge short sentences to avoid producing too many fragments + # Adjust min_length based on language characteristics + min_length = 30 if lang == 'en' else 15 # CJK languages are more dense + sentences = merge_short_sentences(sentences, max_length, min_length=min_length, lang=lang) + + # 3. Handle sentences that are still too long + for sentence in sentences: + if len(sentence) <= max_length: + segments.append(sentence) + else: + # Use word boundary splitting with language support + word_segments = split_by_word_boundary(sentence, max_length, lang=lang) + segments.extend(word_segments) + + # Clean empty paragraphs + segments = [s.strip() for s in segments if s.strip()] + + if logger: + logger.debug(f"Text splitting completed, total {len(segments)} segments (language: {lang})") + + return segments \ No newline at end of file diff --git a/src/chatterbox/tts.py b/src/chatterbox/tts.py index 6d9b5ad5..b629fa2d 100644 --- a/src/chatterbox/tts.py +++ b/src/chatterbox/tts.py @@ -1,5 +1,12 @@ from dataclasses import dataclass from pathlib import Path +import tempfile +import subprocess +import os +import threading +import queue +from concurrent.futures import ThreadPoolExecutor, as_completed + import librosa import torch @@ -7,6 +14,9 @@ import torch.nn.functional as F from huggingface_hub import hf_hub_download from safetensors.torch import load_file +import re +import numpy as np +import torchaudio from .models.t3 import T3 from .models.s3tokenizer import S3_SR, drop_invalid_tokens @@ -14,6 +24,7 @@ from .models.tokenizers import EnTokenizer from .models.voice_encoder import VoiceEncoder from .models.t3.modules.cond_enc import T3Cond +from .text_utils import split_text_into_segments REPO_ID = "ResembleAI/chatterbox" @@ -215,6 +226,12 @@ def generate( exaggeration=0.5, cfg_weight=0.5, temperature=0.8, + use_auto_editor=False, + ae_threshold=0.06, + ae_margin=0.2, + disable_watermark=False, + max_segment_length=300, + max_workers=None, # None means auto-detect optimal value ): if audio_prompt_path: self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration) @@ -230,6 +247,290 @@ def generate( emotion_adv=exaggeration * torch.ones(1, 1, 1), ).to(device=self.device) + # Long text processing: automatically determine if text needs to be split based on length + if len(text) > max_segment_length: + # Use async batch generation + if max_workers is None: + max_workers = 3 # Simple default value + return self._generate_long_text_async( + text, + max_segment_length=max_segment_length, + cfg_weight=cfg_weight, + temperature=temperature, + repetition_penalty=repetition_penalty, + min_p=min_p, + top_p=top_p, + use_auto_editor=use_auto_editor, + ae_threshold=ae_threshold, + ae_margin=ae_margin, + disable_watermark=disable_watermark, + max_workers=max_workers + ) + + # Parse pause tags BEFORE applying punc_norm to preserve the tags + segments = parse_pause_tags(text) + + # Single segment processing (simplified logic for single text without pauses) + if len(segments) == 1 and segments[0][1] == 0.0: # Single text, no pause + text_segment = segments[0][0] + segment_audio = self._generate_single_segment( + text_segment, cfg_weight, temperature, repetition_penalty, min_p, top_p, disable_watermark + ) + + # Clean artifacts (if enabled) + if use_auto_editor: + # Save temporary audio file + import tempfile + with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: + temp_audio_path = temp_file.name + torchaudio.save(temp_audio_path, segment_audio, self.sr) + + # Clean artifacts + cleaned_audio_path = self._clean_artifacts(temp_audio_path, ae_threshold, ae_margin) + + # Load cleaned audio + if cleaned_audio_path != temp_audio_path: + try: + cleaned_audio, _ = torchaudio.load(cleaned_audio_path) + # Clean temporary files + if os.path.exists(temp_audio_path): + os.unlink(temp_audio_path) + if os.path.exists(cleaned_audio_path): + os.unlink(cleaned_audio_path) + return cleaned_audio + except Exception as e: + print(f"[WARNING] Unable to load cleaned audio: {e}") + # Clean temporary files + if os.path.exists(temp_audio_path): + os.unlink(temp_audio_path) + if os.path.exists(cleaned_audio_path): + os.unlink(cleaned_audio_path) + else: + # Cleaning failed, use original audio + if os.path.exists(temp_audio_path): + os.unlink(temp_audio_path) + + return segment_audio + + # Process text with pauses - generate and clean each segment first, then add pauses + audio_segments = [] + temp_files_to_cleanup = [] + + try: + for text_segment, pause_duration in segments: + if text_segment.strip(): # Non-empty text segment + # 1. Generate audio segment + segment_audio = self._generate_single_segment( + text_segment, cfg_weight, temperature, repetition_penalty, min_p, top_p, disable_watermark + ) + + # 2. Clean artifacts for this segment + if use_auto_editor: + # Save temporary audio file + import tempfile + with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: + temp_audio_path = temp_file.name + temp_files_to_cleanup.append(temp_audio_path) + torchaudio.save(temp_audio_path, segment_audio, self.sr) + + # Clean artifacts + cleaned_audio_path = self._clean_artifacts(temp_audio_path, ae_threshold, ae_margin) + if cleaned_audio_path != temp_audio_path: + temp_files_to_cleanup.append(cleaned_audio_path) + + # Load cleaned audio + try: + if cleaned_audio_path != temp_audio_path: + cleaned_audio, _ = torchaudio.load(cleaned_audio_path) + segment_audio = cleaned_audio + except Exception as e: + print(f"[WARNING] Unable to load cleaned audio segment: {e}") + # Continue using original audio segment + + audio_segments.append(segment_audio.squeeze(0)) + + # 3. Add pause (after artifact cleaning) + if pause_duration > 0: + silence = create_silence(pause_duration, self.sr) + audio_segments.append(silence.squeeze(0)) + + # 4. Concatenate all audio segments + if audio_segments: + final_audio = torch.cat(audio_segments, dim=0) + return final_audio.unsqueeze(0) + else: + # If no valid audio segments, return brief silence + return create_silence(0.1, self.sr) + + finally: + # Clean up all temporary files + for temp_file in temp_files_to_cleanup: + if os.path.exists(temp_file): + try: + os.unlink(temp_file) + except: + pass # Ignore cleanup errors + + def _generate_long_text_async( + self, + text, + max_segment_length=300, + cfg_weight=0.5, + temperature=0.8, + repetition_penalty=1.2, + min_p=0.05, + top_p=1.0, + use_auto_editor=False, + ae_threshold=0.06, + ae_margin=0.2, + disable_watermark=False, + max_workers=3 + ): + """ + Async generation of long text audio - core optimization only keeps async batch generation + """ + # Split text into short paragraphs + text_segments = split_text_into_segments(text, max_segment_length) + + if not text_segments: + return create_silence(0.1, self.sr) + + # Preprocess text segments, preserve pause tag information + all_segments = [] # [(text, pause_duration), ...] + for segment_text in text_segments: + # Process pause tags, preserve pause information + segments = parse_pause_tags(segment_text) + all_segments.extend(segments) + + # Separate text and pause information + text_parts = [] + pause_info = [] + for text_part, pause_duration in all_segments: + if text_part.strip(): + text_parts.append(text_part.strip()) + pause_info.append(pause_duration) + + # Core: async batch generation + audio_segments = self._generate_segments_async( + text_parts, cfg_weight, temperature, + repetition_penalty, min_p, top_p, disable_watermark, max_workers + ) + + # Apply audio cleaning (if enabled) + if use_auto_editor: + audio_segments = self._clean_audio_segments_batch( + audio_segments, ae_threshold, ae_margin + ) + + # Merge audio, including pauses + final_audio_parts = [] + for i, audio in enumerate(audio_segments): + if audio is not None: + final_audio_parts.append(audio.squeeze(0)) + # Add pause (if any) + if i < len(pause_info) and pause_info[i] > 0: + silence = create_silence(pause_info[i], self.sr) + final_audio_parts.append(silence.squeeze(0)) + + if final_audio_parts: + final_audio = torch.cat(final_audio_parts, dim=0) + return final_audio.unsqueeze(0) + else: + return create_silence(0.1, self.sr) + + + + + + + + + + def _generate_segments_async(self, text_list, cfg_weight, temperature, repetition_penalty, min_p, top_p, disable_watermark, max_workers): + """Async parallel generation of text segments - core performance optimization""" + audio_results = [None] * len(text_list) + result_queue = queue.Queue() + + def generate_worker(index, text): + try: + if text and text.strip(): + # Preprocess text + text = punc_norm(text.strip()) + # Generate audio + audio = self._generate_single_segment( + text, cfg_weight, temperature, repetition_penalty, min_p, top_p, disable_watermark + ) + else: + audio = create_silence(0.1, self.sr) + result_queue.put((index, audio, None)) + except Exception as e: + result_queue.put((index, None, str(e))) + + # Concurrent execution + with ThreadPoolExecutor(max_workers=min(max_workers, len(text_list))) as executor: + futures = [executor.submit(generate_worker, i, text) for i, text in enumerate(text_list)] + + # Wait for completion + for future in as_completed(futures): + future.result() + + # Collect results + while not result_queue.empty(): + index, audio, error = result_queue.get() + if error: + audio_results[index] = create_silence(0.1, self.sr) + else: + audio_results[index] = audio + + return audio_results + + def _clean_audio_segments_batch(self, audio_segments, ae_threshold, ae_margin): + """Batch clean artifacts from audio segments""" + cleaned_segments = [] + temp_files_to_cleanup = [] + + try: + for i, audio in enumerate(audio_segments): + if audio is not None: + # Save temporary audio file + import tempfile + with tempfile.NamedTemporaryFile(suffix=f'_segment_{i}.wav', delete=False) as temp_file: + temp_audio_path = temp_file.name + temp_files_to_cleanup.append(temp_audio_path) + torchaudio.save(temp_audio_path, audio, self.sr) + + # Clean artifacts + cleaned_audio_path = self._clean_artifacts(temp_audio_path, ae_threshold, ae_margin) + if cleaned_audio_path != temp_audio_path: + temp_files_to_cleanup.append(cleaned_audio_path) + + # Load cleaned audio + try: + if cleaned_audio_path != temp_audio_path: + cleaned_audio, _ = torchaudio.load(cleaned_audio_path) + cleaned_segments.append(cleaned_audio) + else: + # Cleaning failed, use original audio + cleaned_segments.append(audio) + except Exception as e: + print(f"[WARNING] Unable to load cleaned audio segment {i}: {e}") + cleaned_segments.append(audio) + else: + cleaned_segments.append(audio) + + return cleaned_segments + + finally: + # Clean up all temporary files + for temp_file in temp_files_to_cleanup: + if os.path.exists(temp_file): + try: + os.unlink(temp_file) + except: + pass # Ignore cleanup errors + + def _generate_single_segment(self, text, cfg_weight, temperature, repetition_penalty=1.2, min_p=0.05, top_p=1.0, disable_watermark=False): + """Generate audio for a single text segment""" # Norm and tokenize text text = punc_norm(text) text_tokens = self.tokenizer.text_to_tokens(text).to(self.device) @@ -268,5 +569,134 @@ def generate( ref_dict=self.conds.gen, ) wav = wav.squeeze(0).detach().cpu().numpy() - watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr) - return torch.from_numpy(watermarked_wav).unsqueeze(0) \ No newline at end of file + + if not disable_watermark: + watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr) + return torch.from_numpy(watermarked_wav).unsqueeze(0) + else: + return torch.from_numpy(wav).unsqueeze(0) + + def _clean_artifacts(self, audio_path: str, threshold: float = 0.06, margin: float = 0.2) -> str: + """ + Clean artifacts from audio using auto-editor + + Args: + audio_path: Path to input audio file + threshold: Volume threshold, values below this are considered silence/artifacts + margin: Boundary protection time in seconds + + Returns: + Path to cleaned audio file + """ + import subprocess + import tempfile + import os + + # Create output file + output_file = tempfile.NamedTemporaryFile(suffix='_cleaned.wav', delete=False) + output_file.close() + + try: + # Build auto-editor command (adapted for 28.0.0 version) + cmd = [ + "auto-editor", + audio_path, + "--edit", f"audio:threshold={threshold}", + "--margin", f"{margin}s", + "--output-file", output_file.name + ] + + print(f"[INFO] Cleaning artifacts: {' '.join(cmd)}") + + # Execute auto-editor + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=60, + check=True + ) + + if os.path.exists(output_file.name) and os.path.getsize(output_file.name) > 0: + print(f"[INFO] Artifact cleaning completed: {output_file.name}") + return output_file.name + else: + raise RuntimeError("auto-editor did not generate a valid output file") + + except subprocess.CalledProcessError as e: + print(f"[ERROR] auto-editor execution failed: {e}") + print(f"[ERROR] stderr: {e.stderr}") + print(f"[ERROR] stdout: {e.stdout}") + # Clean up failed output file + if os.path.exists(output_file.name): + os.unlink(output_file.name) + return audio_path # Return original file + + except Exception as e: + print(f"[ERROR] Exception occurred during artifact cleaning: {e}") + # Clean up failed output file + if os.path.exists(output_file.name): + os.unlink(output_file.name) + return audio_path # Return original file + + +def parse_pause_tags(text: str): + """ + Parse pause tags in text and return text segments with corresponding pause durations + + Args: + text: Text containing pause tags like "Hello[pause:0.5s]world[pause:1.0s]end" + + Returns: + segments: [(text_segment, pause_duration), ...] + Example: [("Hello", 0.5), ("world", 1.0), ("end", 0.0)] + """ + if not text: + return [("", 0.0)] + + # Regular expression to match pause tags + pause_pattern = r'\[pause:([\d.]+)s\]' + + segments = [] + last_end = 0 + + # Find all pause tags + for match in re.finditer(pause_pattern, text): + # Extract text before the pause tag + text_segment = text[last_end:match.start()].strip() + if text_segment: + segments.append((text_segment, 0.0)) + + # Extract pause duration + pause_duration = float(match.group(1)) + # Ensure pause duration is a multiple of 0.1s + pause_duration = round(pause_duration / 0.1) * 0.1 + segments.append(("", pause_duration)) + + last_end = match.end() + + # Add the final text segment + final_text = text[last_end:].strip() + if final_text: + segments.append((final_text, 0.0)) + + # If no segments found, return original text + if not segments: + segments = [(text, 0.0)] + + return segments + + +def create_silence(duration_seconds: float, sample_rate: int) -> torch.Tensor: + """ + Create silence of specified duration + + Args: + duration_seconds: Duration of silence in seconds + sample_rate: Sample rate + + Returns: + Silent audio tensor + """ + num_samples = int(duration_seconds * sample_rate) + return torch.zeros(1, num_samples) \ No newline at end of file