2424from transformers import AutoTokenizer , AutoModel , set_seed , TextIteratorStreamer , StoppingCriteriaList , StopStringCriteria
2525from transformers import AutoProcessor , MusicgenForConditionalGeneration
2626from scipy .io import wavfile
27-
27+ import outetts
2828
2929_ONE_DAY_IN_SECONDS = 60 * 60 * 24
3030
@@ -87,6 +87,7 @@ def LoadModel(self, request, context):
8787
8888 self .CUDA = torch .cuda .is_available ()
8989 self .OV = False
90+ self .OuteTTS = False
9091
9192 device_map = "cpu"
9293
@@ -195,7 +196,45 @@ def LoadModel(self, request, context):
195196 self .OV = True
196197 elif request .Type == "MusicgenForConditionalGeneration" :
197198 self .processor = AutoProcessor .from_pretrained (model_name )
198- self .model = MusicgenForConditionalGeneration .from_pretrained (model_name )
199+ self .model = MusicgenForConditionalGeneration .from_pretrained (model_name )
200+ elif request .Type == "OuteTTS" :
201+ options = request .Options
202+ MODELNAME = "OuteAI/OuteTTS-0.3-1B"
203+ TOKENIZER = "OuteAI/OuteTTS-0.3-1B"
204+ VERSION = "0.3"
205+ SPEAKER = "en_male_1"
206+ for opt in options :
207+ if opt .startswith ("tokenizer:" ):
208+ TOKENIZER = opt .split (":" )[1 ]
209+ break
210+ if opt .startswith ("version:" ):
211+ VERSION = opt .split (":" )[1 ]
212+ break
213+ if opt .startswith ("speaker:" ):
214+ SPEAKER = opt .split (":" )[1 ]
215+ break
216+
217+ if model_name != "" :
218+ MODELNAME = model_name
219+
220+ # Configure the model
221+ model_config = outetts .HFModelConfig_v2 (
222+ model_path = MODELNAME ,
223+ tokenizer_path = TOKENIZER
224+ )
225+ # Initialize the interface
226+ self .interface = outetts .InterfaceHF (model_version = VERSION , cfg = model_config )
227+ self .OuteTTS = True
228+
229+ self .interface .print_default_speakers ()
230+ if request .AudioPath :
231+ if os .path .isabs (request .AudioPath ):
232+ self .AudioPath = request .AudioPath
233+ else :
234+ self .AudioPath = os .path .join (request .ModelPath , request .AudioPath )
235+ self .speaker = self .interface .create_speaker (audio_path = self .AudioPath )
236+ else :
237+ self .speaker = self .interface .load_default_speaker (name = SPEAKER )
199238 else :
200239 print ("Automodel" , file = sys .stderr )
201240 self .model = AutoModel .from_pretrained (model_name ,
@@ -206,7 +245,7 @@ def LoadModel(self, request, context):
206245 torch_dtype = compute )
207246 if request .ContextSize > 0 :
208247 self .max_tokens = request .ContextSize
209- elif request . Type != "MusicgenForConditionalGeneration" :
248+ elif hasattr ( self . model , 'config' ) and hasattr ( self . model . config , 'max_position_embeddings' ) :
210249 self .max_tokens = self .model .config .max_position_embeddings
211250 else :
212251 self .max_tokens = 512
@@ -445,9 +484,30 @@ def SoundGeneration(self, request, context):
445484 return backend_pb2 .Result (success = False , message = f"Unexpected { err = } , { type (err )= } " )
446485 return backend_pb2 .Result (success = True )
447486
487+ def OuteTTS (self , request , context ):
488+ try :
489+ print ("[OuteTTS] generating TTS" , file = sys .stderr )
490+ gen_cfg = outetts .GenerationConfig (
491+ text = "Speech synthesis is the artificial production of human speech." ,
492+ temperature = 0.1 ,
493+ repetition_penalty = 1.1 ,
494+ max_length = self .max_tokens ,
495+ speaker = self .speaker ,
496+ # voice_characteristics="upbeat enthusiasm, friendliness, clarity, professionalism, and trustworthiness"
497+ )
498+ output = self .interface .generate (config = gen_cfg )
499+ print ("[OuteTTS] Generated TTS" , file = sys .stderr )
500+ output .save (request .dst )
501+ print ("[OuteTTS] TTS done" , file = sys .stderr )
502+ except Exception as err :
503+ return backend_pb2 .Result (success = False , message = f"Unexpected { err = } , { type (err )= } " )
504+ return backend_pb2 .Result (success = True )
448505
449506# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
450507 def TTS (self , request , context ):
508+ if self .OuteTTS :
509+ return self .OuteTTS (request , context )
510+
451511 model_name = request .model
452512 try :
453513 if self .processor is None :
@@ -463,7 +523,7 @@ def TTS(self, request, context):
463523 padding = True ,
464524 return_tensors = "pt" ,
465525 )
466- tokens = 512 # No good place to set the "length" in TTS, so use 10s as a sane default
526+ tokens = self . max_tokens # No good place to set the "length" in TTS, so use 10s as a sane default
467527 audio_values = self .model .generate (** inputs , max_new_tokens = tokens )
468528 print ("[transformers-musicgen] TTS generated!" , file = sys .stderr )
469529 sampling_rate = self .model .config .audio_encoder .sampling_rate
0 commit comments