@@ -22,6 +22,7 @@ class RequestFuncInput:
2222 prompt_len : int
2323 output_len : int
2424 model : str
25+ model_name : str = None
2526 best_of : int = 1
2627 logprobs : Optional [int ] = None
2728 extra_body : Optional [dict ] = None
@@ -43,8 +44,8 @@ class RequestFuncOutput:
4344
4445
4546async def async_request_tgi (
46- request_func_input : RequestFuncInput ,
47- pbar : Optional [tqdm ] = None ,
47+ request_func_input : RequestFuncInput ,
48+ pbar : Optional [tqdm ] = None ,
4849) -> RequestFuncOutput :
4950 api_url = request_func_input .api_url
5051 assert api_url .endswith ("generate_stream" )
@@ -78,7 +79,7 @@ async def async_request_tgi(
7879 continue
7980 chunk_bytes = chunk_bytes .decode ("utf-8" )
8081
81- #NOTE: Sometimes TGI returns a ping response without
82+ # NOTE: Sometimes TGI returns a ping response without
8283 # any data, we should skip it.
8384 if chunk_bytes .startswith (":" ):
8485 continue
@@ -115,8 +116,8 @@ async def async_request_tgi(
115116
116117
117118async def async_request_trt_llm (
118- request_func_input : RequestFuncInput ,
119- pbar : Optional [tqdm ] = None ,
119+ request_func_input : RequestFuncInput ,
120+ pbar : Optional [tqdm ] = None ,
120121) -> RequestFuncOutput :
121122 api_url = request_func_input .api_url
122123 assert api_url .endswith ("generate_stream" )
@@ -182,8 +183,8 @@ async def async_request_trt_llm(
182183
183184
184185async def async_request_deepspeed_mii (
185- request_func_input : RequestFuncInput ,
186- pbar : Optional [tqdm ] = None ,
186+ request_func_input : RequestFuncInput ,
187+ pbar : Optional [tqdm ] = None ,
187188) -> RequestFuncOutput :
188189 async with aiohttp .ClientSession (timeout = AIOHTTP_TIMEOUT ) as session :
189190 assert request_func_input .best_of == 1
@@ -225,8 +226,8 @@ async def async_request_deepspeed_mii(
225226
226227
227228async def async_request_openai_completions (
228- request_func_input : RequestFuncInput ,
229- pbar : Optional [tqdm ] = None ,
229+ request_func_input : RequestFuncInput ,
230+ pbar : Optional [tqdm ] = None ,
230231) -> RequestFuncOutput :
231232 api_url = request_func_input .api_url
232233 assert api_url .endswith (
@@ -235,7 +236,8 @@ async def async_request_openai_completions(
235236
236237 async with aiohttp .ClientSession (timeout = AIOHTTP_TIMEOUT ) as session :
237238 payload = {
238- "model" : request_func_input .model ,
239+ "model" : request_func_input .model_name \
240+ if request_func_input .model_name else request_func_input .model ,
239241 "prompt" : request_func_input .prompt ,
240242 "temperature" : 0.0 ,
241243 "best_of" : request_func_input .best_of ,
@@ -315,8 +317,8 @@ async def async_request_openai_completions(
315317
316318
317319async def async_request_openai_chat_completions (
318- request_func_input : RequestFuncInput ,
319- pbar : Optional [tqdm ] = None ,
320+ request_func_input : RequestFuncInput ,
321+ pbar : Optional [tqdm ] = None ,
320322) -> RequestFuncOutput :
321323 api_url = request_func_input .api_url
322324 assert api_url .endswith (
@@ -328,7 +330,8 @@ async def async_request_openai_chat_completions(
328330 if request_func_input .multi_modal_content :
329331 content .append (request_func_input .multi_modal_content )
330332 payload = {
331- "model" : request_func_input .model ,
333+ "model" : request_func_input .model_name \
334+ if request_func_input .model_name else request_func_input .model ,
332335 "messages" : [
333336 {
334337 "role" : "user" ,
@@ -417,10 +420,10 @@ def get_model(pretrained_model_name_or_path: str) -> str:
417420
418421
419422def get_tokenizer (
420- pretrained_model_name_or_path : str ,
421- tokenizer_mode : str = "auto" ,
422- trust_remote_code : bool = False ,
423- ** kwargs ,
423+ pretrained_model_name_or_path : str ,
424+ tokenizer_mode : str = "auto" ,
425+ trust_remote_code : bool = False ,
426+ ** kwargs ,
424427) -> Union [PreTrainedTokenizer , PreTrainedTokenizerFast ]:
425428 if pretrained_model_name_or_path is not None and not os .path .exists (
426429 pretrained_model_name_or_path ):
0 commit comments