11import logging
22import json
3- import re
43
54import torch
65from pathlib import Path
7- from unicodedata import category
6+ from unicodedata import category , normalize
87from tokenizers import Tokenizer
98from huggingface_hub import hf_hub_download
109
@@ -33,7 +32,7 @@ def text_to_tokens(self, text: str):
3332 text_tokens = torch .IntTensor (text_tokens ).unsqueeze (0 )
3433 return text_tokens
3534
36- def encode ( self , txt : str , verbose = False ):
35+ def encode (self , txt : str ):
3736 """
3837 clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
3938 """
@@ -46,8 +45,7 @@ def decode(self, seq):
4645 if isinstance (seq , torch .Tensor ):
4746 seq = seq .cpu ().numpy ()
4847
49- txt : str = self .tokenizer .decode (seq ,
50- skip_special_tokens = False )
48+ txt : str = self .tokenizer .decode (seq , skip_special_tokens = False )
5149 txt = txt .replace (' ' , '' )
5250 txt = txt .replace (SPACE , ' ' )
5351 txt = txt .replace (EOT , '' )
@@ -61,6 +59,7 @@ def decode(self, seq):
6159# Global instances for optional dependencies
6260_kakasi = None
6361_dicta = None
62+ _russian_stresser = None
6463
6564
6665def is_kanji (c : str ) -> bool :
@@ -281,6 +280,25 @@ def __call__(self, text):
281280 return "" .join (output )
282281
283282
283+ def add_russian_stress (text : str ) -> str :
284+ """Russian text normalization: adds stress marks to Russian text."""
285+ global _russian_stresser
286+
287+ try :
288+ if _russian_stresser is None :
289+ from russian_text_stresser .text_stresser import RussianTextStresser
290+ _russian_stresser = RussianTextStresser ()
291+
292+ return _russian_stresser .stress_text (text )
293+
294+ except ImportError :
295+ logger .warning ("russian_text_stresser not available - Russian stress labeling skipped" )
296+ return text
297+ except Exception as e :
298+ logger .warning (f"Russian stress labeling failed: { e } " )
299+ return text
300+
301+
284302class MTLTokenizer :
285303 def __init__ (self , vocab_file_path ):
286304 self .tokenizer : Tokenizer = Tokenizer .from_file (vocab_file_path )
@@ -293,12 +311,26 @@ def check_vocabset_sot_eot(self):
293311 assert SOT in voc
294312 assert EOT in voc
295313
296- def text_to_tokens (self , text : str , language_id : str = None ):
297- text_tokens = self .encode (text , language_id = language_id )
314+ def preprocess_text (self , raw_text : str , language_id : str = None , lowercase : bool = True , nfkd_normalize : bool = True ):
315+ """
316+ Text preprocessor that handles lowercase conversion and NFKD normalization.
317+ """
318+ preprocessed_text = raw_text
319+ if lowercase :
320+ preprocessed_text = preprocessed_text .lower ()
321+ if nfkd_normalize :
322+ preprocessed_text = normalize ("NFKD" , preprocessed_text )
323+
324+ return preprocessed_text
325+
326+ def text_to_tokens (self , text : str , language_id : str = None , lowercase : bool = True , nfkd_normalize : bool = True ):
327+ text_tokens = self .encode (text , language_id = language_id , lowercase = lowercase , nfkd_normalize = nfkd_normalize )
298328 text_tokens = torch .IntTensor (text_tokens ).unsqueeze (0 )
299329 return text_tokens
300330
301- def encode (self , txt : str , language_id : str = None ):
331+ def encode (self , txt : str , language_id : str = None , lowercase : bool = True , nfkd_normalize : bool = True ):
332+ txt = self .preprocess_text (txt , language_id = language_id , lowercase = lowercase , nfkd_normalize = nfkd_normalize )
333+
302334 # Language-specific text processing
303335 if language_id == 'zh' :
304336 txt = self .cangjie_converter (txt )
@@ -310,6 +342,8 @@ def encode(self, txt: str, language_id: str = None):
310342 txt = korean_normalize (txt )
311343 elif language_id == 'fr' : # Author: Rouxin
312344 txt = decompose_french_text (txt )
345+ elif language_id == 'ru' :
346+ txt = add_russian_stress (txt )
313347
314348 # Prepend language token
315349 if language_id :
0 commit comments