33import argparse
44import dataclasses
55import json
6+ import pickle
67import random
78import time
89from functools import cache
2627from vllm .transformers_utils .tokenizer import AnyTokenizer , get_lora_tokenizer
2728from vllm .utils import FlexibleArgumentParser , merge_async_iterators
2829
30+ SAMPLING_TEMPERATURE = 0.0
31+ SAMPLING_TOP_P = 1.0
32+
2933
3034@dataclasses .dataclass
3135class SampleRequest :
@@ -166,6 +170,7 @@ def run_vllm(
166170 requests : List [SampleRequest ],
167171 n : int ,
168172 engine_args : EngineArgs ,
173+ do_profile : bool ,
169174) -> float :
170175 from vllm import LLM , SamplingParams
171176 llm = LLM (** dataclasses .asdict (engine_args ))
@@ -180,8 +185,8 @@ def run_vllm(
180185 sampling_params .append (
181186 SamplingParams (
182187 n = n ,
183- temperature = 1.0 ,
184- top_p = 1.0 ,
188+ temperature = SAMPLING_TEMPERATURE ,
189+ top_p = SAMPLING_TOP_P ,
185190 ignore_eos = True ,
186191 max_tokens = request .expected_output_len ,
187192 ))
@@ -191,13 +196,23 @@ def run_vllm(
191196
192197 use_beam_search = False
193198
199+ outputs = None
194200 if not use_beam_search :
195201 start = time .perf_counter ()
196- llm .generate (prompts ,
197- sampling_params ,
198- lora_request = lora_requests ,
199- use_tqdm = True )
202+ if do_profile :
203+ llm .start_profile ()
204+ outputs = llm .generate (prompts ,
205+ sampling_params ,
206+ lora_request = lora_requests ,
207+ use_tqdm = True )
200208 end = time .perf_counter ()
209+
210+ if do_profile :
211+ llm .stop_profile ()
212+ # it takes a while to generate the profile !!
213+ print ("Called llm.stop_profile() ... Sleeping for 100s on client "
214+ "side for profile trace dump to finish !!" )
215+ time .sleep (100 )
201216 else :
202217 assert lora_requests is None , "BeamSearch API does not support LoRA"
203218 prompts = [request .prompt for request in requests ]
@@ -214,14 +229,15 @@ def run_vllm(
214229 ignore_eos = True ,
215230 ))
216231 end = time .perf_counter ()
217- return end - start
232+ return end - start , outputs
218233
219234
220235async def run_vllm_async (
221236 requests : List [SampleRequest ],
222237 n : int ,
223238 engine_args : AsyncEngineArgs ,
224239 disable_frontend_multiprocessing : bool = False ,
240+ do_profile : bool = False ,
225241) -> float :
226242 from vllm import SamplingParams
227243
@@ -239,14 +255,16 @@ async def run_vllm_async(
239255 sampling_params .append (
240256 SamplingParams (
241257 n = n ,
242- temperature = 1.0 ,
243- top_p = 1.0 ,
258+ temperature = SAMPLING_TEMPERATURE ,
259+ top_p = SAMPLING_TOP_P ,
244260 ignore_eos = True ,
245261 max_tokens = request .expected_output_len ,
246262 ))
247263 lora_requests .append (request .lora_request )
248264
249265 generators = []
266+ if do_profile :
267+ await llm .start_profile ()
250268 start = time .perf_counter ()
251269 for i , (prompt , sp ,
252270 lr ) in enumerate (zip (prompts , sampling_params , lora_requests )):
@@ -256,10 +274,25 @@ async def run_vllm_async(
256274 request_id = f"test{ i } " )
257275 generators .append (generator )
258276 all_gens = merge_async_iterators (* generators )
277+ outputs_dict = {}
259278 async for i , res in all_gens :
260- pass
279+ outputs_dict [i ] = res
280+
261281 end = time .perf_counter ()
262- return end - start
282+ elapsed = end - start
283+
284+ if do_profile :
285+ await llm .stop_profile ()
286+ print ("Called llm.stop_profile() ... Sleeping for 100s on client"
287+ "side for profile trace dump to finish !!" )
288+ time .sleep (100 )
289+
290+ num_prompts = len (prompts )
291+ outputs = []
292+ for i in range (num_prompts ):
293+ outputs .append (outputs_dict [i ])
294+
295+ return elapsed , outputs
263296
264297
265298def run_hf (
@@ -392,16 +425,25 @@ def main(args: argparse.Namespace):
392425 for request in requests )
393426 if args .backend == "vllm" :
394427 if args .async_engine :
395- elapsed_time = uvloop .run (
428+ elapsed_time , outputs = uvloop .run (
396429 run_vllm_async (
397430 requests ,
398431 args .n ,
399432 AsyncEngineArgs .from_cli_args (args ),
400433 args .disable_frontend_multiprocessing ,
434+ do_profile = args .profile ,
401435 ))
402436 else :
403- elapsed_time = run_vllm (requests , args .n ,
404- EngineArgs .from_cli_args (args ))
437+ elapsed_time , outputs = run_vllm (requests ,
438+ args .n ,
439+ EngineArgs .from_cli_args (args ),
440+ do_profile = args .profile )
441+
442+ if args .pickle_outputs :
443+ print ("Pickling request outputs : " )
444+ with open ("outputs.pkl" , "wb+" ) as f :
445+ pickle .dump (outputs , f )
446+
405447 elif args .backend == "hf" :
406448 assert args .tensor_parallel_size == 1
407449 elapsed_time = run_hf (requests , args .model , tokenizer , args .n ,
@@ -491,6 +533,16 @@ def main(args: argparse.Namespace):
491533 help = "Path to the lora adapters to use. This can be an absolute path, "
492534 "a relative path, or a Hugging Face model identifier." )
493535
536+ parser .add_argument ("--profile" ,
537+ action = 'store_true' ,
538+ default = False ,
539+ help = "Profile the entire run" )
540+
541+ parser .add_argument ("--pickle-outputs" ,
542+ action = "store_true" ,
543+ default = False ,
544+ help = "Pickle outputs got from benchmark" )
545+
494546 parser = AsyncEngineArgs .add_cli_args (parser )
495547 args = parser .parse_args ()
496548 if args .tokenizer is None :
0 commit comments