Skip to content

Commit d3b2082

Browse files
author
Varun Sundar Rabindranath
committed
add v1 lora kernels
1 parent f02a866 commit d3b2082

29 files changed

+32452
-362
lines changed

benchmarks/benchmark_throughput.py

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import argparse
44
import dataclasses
55
import json
6+
import pickle
67
import random
78
import time
89
from functools import cache
@@ -26,6 +27,9 @@
2627
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
2728
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
2829

30+
SAMPLING_TEMPERATURE = 0.0
31+
SAMPLING_TOP_P = 1.0
32+
2933

3034
@dataclasses.dataclass
3135
class SampleRequest:
@@ -166,6 +170,7 @@ def run_vllm(
166170
requests: List[SampleRequest],
167171
n: int,
168172
engine_args: EngineArgs,
173+
do_profile: bool,
169174
) -> float:
170175
from vllm import LLM, SamplingParams
171176
llm = LLM(**dataclasses.asdict(engine_args))
@@ -180,8 +185,8 @@ def run_vllm(
180185
sampling_params.append(
181186
SamplingParams(
182187
n=n,
183-
temperature=1.0,
184-
top_p=1.0,
188+
temperature=SAMPLING_TEMPERATURE,
189+
top_p=SAMPLING_TOP_P,
185190
ignore_eos=True,
186191
max_tokens=request.expected_output_len,
187192
))
@@ -191,13 +196,23 @@ def run_vllm(
191196

192197
use_beam_search = False
193198

199+
outputs = None
194200
if not use_beam_search:
195201
start = time.perf_counter()
196-
llm.generate(prompts,
197-
sampling_params,
198-
lora_request=lora_requests,
199-
use_tqdm=True)
202+
if do_profile:
203+
llm.start_profile()
204+
outputs = llm.generate(prompts,
205+
sampling_params,
206+
lora_request=lora_requests,
207+
use_tqdm=True)
200208
end = time.perf_counter()
209+
210+
if do_profile:
211+
llm.stop_profile()
212+
# it takes a while to generate the profile !!
213+
print("Called llm.stop_profile() ... Sleeping for 100s on client "
214+
"side for profile trace dump to finish !!")
215+
time.sleep(100)
201216
else:
202217
assert lora_requests is None, "BeamSearch API does not support LoRA"
203218
prompts = [request.prompt for request in requests]
@@ -214,14 +229,15 @@ def run_vllm(
214229
ignore_eos=True,
215230
))
216231
end = time.perf_counter()
217-
return end - start
232+
return end - start, outputs
218233

219234

220235
async def run_vllm_async(
221236
requests: List[SampleRequest],
222237
n: int,
223238
engine_args: AsyncEngineArgs,
224239
disable_frontend_multiprocessing: bool = False,
240+
do_profile: bool = False,
225241
) -> float:
226242
from vllm import SamplingParams
227243

@@ -239,14 +255,16 @@ async def run_vllm_async(
239255
sampling_params.append(
240256
SamplingParams(
241257
n=n,
242-
temperature=1.0,
243-
top_p=1.0,
258+
temperature=SAMPLING_TEMPERATURE,
259+
top_p=SAMPLING_TOP_P,
244260
ignore_eos=True,
245261
max_tokens=request.expected_output_len,
246262
))
247263
lora_requests.append(request.lora_request)
248264

249265
generators = []
266+
if do_profile:
267+
await llm.start_profile()
250268
start = time.perf_counter()
251269
for i, (prompt, sp,
252270
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
@@ -256,10 +274,25 @@ async def run_vllm_async(
256274
request_id=f"test{i}")
257275
generators.append(generator)
258276
all_gens = merge_async_iterators(*generators)
277+
outputs_dict = {}
259278
async for i, res in all_gens:
260-
pass
279+
outputs_dict[i] = res
280+
261281
end = time.perf_counter()
262-
return end - start
282+
elapsed = end - start
283+
284+
if do_profile:
285+
await llm.stop_profile()
286+
print("Called llm.stop_profile() ... Sleeping for 100s on client"
287+
"side for profile trace dump to finish !!")
288+
time.sleep(100)
289+
290+
num_prompts = len(prompts)
291+
outputs = []
292+
for i in range(num_prompts):
293+
outputs.append(outputs_dict[i])
294+
295+
return elapsed, outputs
263296

264297

265298
def run_hf(
@@ -392,16 +425,25 @@ def main(args: argparse.Namespace):
392425
for request in requests)
393426
if args.backend == "vllm":
394427
if args.async_engine:
395-
elapsed_time = uvloop.run(
428+
elapsed_time, outputs = uvloop.run(
396429
run_vllm_async(
397430
requests,
398431
args.n,
399432
AsyncEngineArgs.from_cli_args(args),
400433
args.disable_frontend_multiprocessing,
434+
do_profile=args.profile,
401435
))
402436
else:
403-
elapsed_time = run_vllm(requests, args.n,
404-
EngineArgs.from_cli_args(args))
437+
elapsed_time, outputs = run_vllm(requests,
438+
args.n,
439+
EngineArgs.from_cli_args(args),
440+
do_profile=args.profile)
441+
442+
if args.pickle_outputs:
443+
print("Pickling request outputs : ")
444+
with open("outputs.pkl", "wb+") as f:
445+
pickle.dump(outputs, f)
446+
405447
elif args.backend == "hf":
406448
assert args.tensor_parallel_size == 1
407449
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -491,6 +533,16 @@ def main(args: argparse.Namespace):
491533
help="Path to the lora adapters to use. This can be an absolute path, "
492534
"a relative path, or a Hugging Face model identifier.")
493535

536+
parser.add_argument("--profile",
537+
action='store_true',
538+
default=False,
539+
help="Profile the entire run")
540+
541+
parser.add_argument("--pickle-outputs",
542+
action="store_true",
543+
default=False,
544+
help="Pickle outputs got from benchmark")
545+
494546
parser = AsyncEngineArgs.add_cli_args(parser)
495547
args = parser.parse_args()
496548
if args.tokenizer is None:

0 commit comments

Comments
 (0)