Skip to content

Commit c0187ce

Browse files
ZihanJinmudler
authored andcommitted
Multilingual v2 update (resemble-ai#295)
* multilingual v2 vocab and russian stresser update * multilinugal tokenizer fix
1 parent 4db78fe commit c0187ce

File tree

4 files changed

+48
-13
lines changed

4 files changed

+48
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dependencies = [
2828
"spacy-pkuseg",
2929
"pykakasi==2.3.0",
3030
# "gradio==5.44.1",
31+
"russian-text-stresser @ git+https://github.com/Vuizur/add-stress-to-epub",
3132
]
3233

3334
[project.urls]

src/chatterbox/models/t3/modules/t3_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def n_channels(self):
2828

2929
@property
3030
def is_multilingual(self):
31-
return self.text_tokens_dict_size == 2352
31+
return self.text_tokens_dict_size == 2454
3232

3333
@classmethod
3434
def english_only(cls):
@@ -38,4 +38,4 @@ def english_only(cls):
3838
@classmethod
3939
def multilingual(cls):
4040
"""Create configuration for multilingual TTS model."""
41-
return cls(text_tokens_dict_size=2352)
41+
return cls(text_tokens_dict_size=2454)

src/chatterbox/models/tokenizers/tokenizer.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import logging
22
import json
3-
import re
43

54
import torch
65
from pathlib import Path
7-
from unicodedata import category
6+
from unicodedata import category, normalize
87
from tokenizers import Tokenizer
98
from huggingface_hub import hf_hub_download
109

@@ -33,7 +32,7 @@ def text_to_tokens(self, text: str):
3332
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
3433
return text_tokens
3534

36-
def encode( self, txt: str, verbose=False):
35+
def encode(self, txt: str):
3736
"""
3837
clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
3938
"""
@@ -46,8 +45,7 @@ def decode(self, seq):
4645
if isinstance(seq, torch.Tensor):
4746
seq = seq.cpu().numpy()
4847

49-
txt: str = self.tokenizer.decode(seq,
50-
skip_special_tokens=False)
48+
txt: str = self.tokenizer.decode(seq, skip_special_tokens=False)
5149
txt = txt.replace(' ', '')
5250
txt = txt.replace(SPACE, ' ')
5351
txt = txt.replace(EOT, '')
@@ -61,6 +59,7 @@ def decode(self, seq):
6159
# Global instances for optional dependencies
6260
_kakasi = None
6361
_dicta = None
62+
_russian_stresser = None
6463

6564

6665
def is_kanji(c: str) -> bool:
@@ -281,6 +280,25 @@ def __call__(self, text):
281280
return "".join(output)
282281

283282

283+
def add_russian_stress(text: str) -> str:
284+
"""Russian text normalization: adds stress marks to Russian text."""
285+
global _russian_stresser
286+
287+
try:
288+
if _russian_stresser is None:
289+
from russian_text_stresser.text_stresser import RussianTextStresser
290+
_russian_stresser = RussianTextStresser()
291+
292+
return _russian_stresser.stress_text(text)
293+
294+
except ImportError:
295+
logger.warning("russian_text_stresser not available - Russian stress labeling skipped")
296+
return text
297+
except Exception as e:
298+
logger.warning(f"Russian stress labeling failed: {e}")
299+
return text
300+
301+
284302
class MTLTokenizer:
285303
def __init__(self, vocab_file_path):
286304
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
@@ -293,12 +311,26 @@ def check_vocabset_sot_eot(self):
293311
assert SOT in voc
294312
assert EOT in voc
295313

296-
def text_to_tokens(self, text: str, language_id: str = None):
297-
text_tokens = self.encode(text, language_id=language_id)
314+
def preprocess_text(self, raw_text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
315+
"""
316+
Text preprocessor that handles lowercase conversion and NFKD normalization.
317+
"""
318+
preprocessed_text = raw_text
319+
if lowercase:
320+
preprocessed_text = preprocessed_text.lower()
321+
if nfkd_normalize:
322+
preprocessed_text = normalize("NFKD", preprocessed_text)
323+
324+
return preprocessed_text
325+
326+
def text_to_tokens(self, text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
327+
text_tokens = self.encode(text, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
298328
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
299329
return text_tokens
300330

301-
def encode(self, txt: str, language_id: str = None):
331+
def encode(self, txt: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
332+
txt = self.preprocess_text(txt, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
333+
302334
# Language-specific text processing
303335
if language_id == 'zh':
304336
txt = self.cangjie_converter(txt)
@@ -310,6 +342,8 @@ def encode(self, txt: str, language_id: str = None):
310342
txt = korean_normalize(txt)
311343
elif language_id == 'fr': # Author: Rouxin
312344
txt = decompose_french_text(txt)
345+
elif language_id == 'ru':
346+
txt = add_russian_stress(txt)
313347

314348
# Prepend language token
315349
if language_id:

src/chatterbox/mtl_tts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS':
168168
ve.to(device).eval()
169169

170170
t3 = T3(T3Config.multilingual())
171-
t3_state = load_safetensors(ckpt_dir / "t3_23lang.safetensors")
171+
t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
172172
if "model" in t3_state.keys():
173173
t3_state = t3_state["model"][0]
174174
t3.load_state_dict(t3_state)
@@ -181,7 +181,7 @@ def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS':
181181
s3gen.to(device).eval()
182182

183183
tokenizer = MTLTokenizer(
184-
str(ckpt_dir / "mtl_tokenizer.json")
184+
str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
185185
)
186186

187187
conds = None
@@ -197,7 +197,7 @@ def from_pretrained(cls, device: torch.device) -> 'ChatterboxMultilingualTTS':
197197
repo_id=REPO_ID,
198198
repo_type="model",
199199
revision="main",
200-
allow_patterns=["ve.pt", "t3_23lang.safetensors", "s3gen.pt", "mtl_tokenizer.json", "conds.pt", "Cangjie5_TC.json"],
200+
allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"],
201201
token=os.getenv("HF_TOKEN"),
202202
)
203203
)

0 commit comments

Comments
 (0)