Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
317 changes: 235 additions & 82 deletions vllm/benchmarks/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Literal
from typing import Any, Literal, Optional
from multiprocessing import Process, Queue
import math
from setproctitle import setproctitle

import aiohttp
import numpy as np
Expand All @@ -57,6 +60,7 @@
shutil.which("gnuplot") is not None
)

INT_MAX = 10**10

class TaskType(Enum):
GENERATION = "generation"
Expand Down Expand Up @@ -483,6 +487,158 @@

return metrics, actual_output_lens

async def benchmark_core(
api_url: str,
model_id: str,
model_name: str,
max_concurrency: Optional[int],
input_requests: list[SampleRequest],
logprobs: Optional[int],
request_rate: float,
burstiness: float,
ignore_eos: bool,
lora_modules: Optional[Iterable[str]],
extra_headers: Optional[dict],
extra_body: Optional[dict],
request_func,
disable_tqdm: bool,
pidx: int,
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
ramp_up_start_rps: Optional[int] = None,
ramp_up_end_rps: Optional[int] = None,
output_q: Optional[Queue] = None,
pbar_q: Optional[Queue] = None,
):
setproctitle(f"vLLM Benchmark_core:{pidx:02d}")

# Create dedicated connector and session for each process
connector = aiohttp.TCPConnector(
limit=max_concurrency or 0,
limit_per_host=max_concurrency or 0,
ttl_dns_cache=300,
use_dns_cache=True,
keepalive_timeout=60,
enable_cleanup_closed=True,
force_close=False,
ssl=("https://" in api_url),
)

session = aiohttp.ClientSession(
connector=connector,
trust_env=True,
timeout=aiohttp.ClientTimeout(total=6 * 60 * 60),
)

# This can be used once the minimum Python version is 3.10 or higher,
# and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext())
semaphore = (asyncio.Semaphore(max_concurrency)
if max_concurrency else None)

async def limited_request_func(request_func_input, session, pbar):
if semaphore is None:
ret = await request_func(request_func_input=request_func_input,
session=session,
pbar=pbar)
if not disable_tqdm: pbar_q.put(1)

Check failure on line 544 in vllm/benchmarks/serve.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E701)

vllm/benchmarks/serve.py:544:32: E701 Multiple statements on one line (colon)
return ret
async with semaphore:
ret = await request_func(request_func_input=request_func_input,
session=session,
pbar=pbar)
if not disable_tqdm: pbar_q.put(1)

Check failure on line 550 in vllm/benchmarks/serve.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E701)

vllm/benchmarks/serve.py:550:32: E701 Multiple statements on one line (colon)
return ret

tasks: list[asyncio.Task] = []

async for request, _ in get_request(
input_requests, request_rate, burstiness,
ramp_up_strategy, ramp_up_start_rps, ramp_up_end_rps
):
prompt, prompt_len, output_len, mm_content, request_id = (
request.prompt,
request.prompt_len,
request.expected_output_len,
request.multi_modal_data,
request.request_id,
)
req_model_id, req_model_name = model_id, model_name
if lora_modules:
req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module

request_func_input = RequestFuncInput(
model=req_model_id,
model_name=req_model_name,
prompt=prompt,
api_url=api_url, # async for but it works serialized manner.
prompt_len=prompt_len,
output_len=output_len,
logprobs=logprobs,
multi_modal_content=mm_content,
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body,
request_id=request_id,
)
tasks.append(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input,
session=session,
pbar=None))) # Use external global pbar
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)

# clean-up & enqueue
await session.close()
output_q.put(outputs)

def async_worker_wrapper(
api_url: str,
model_id: str,
model_name: str,
max_concurrency: Optional[int],
input_requests: list[SampleRequest],
logprobs: Optional[int],
request_rate: float,
burstiness: float,
ignore_eos: bool,
lora_modules: Optional[Iterable[str]],
extra_headers: Optional[dict],
extra_body: Optional[dict],
request_func,
disable_tqdm: bool,
pidx: int,
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
ramp_up_start_rps: Optional[int] = None,
ramp_up_end_rps: Optional[int] = None,
output_q: Optional[Queue] = None,
pbar_q: Optional[Queue] = None,
):
asyncio.run(
benchmark_core(
api_url=api_url,
model_id=model_id,
model_name=model_name,
max_concurrency=max_concurrency,
input_requests=input_requests,
logprobs=logprobs,
request_rate=request_rate,
burstiness=burstiness,
ignore_eos=ignore_eos,
lora_modules=lora_modules,
extra_headers=extra_headers,
extra_body=extra_body,
request_func=request_func,
disable_tqdm=disable_tqdm,
pidx=pidx,
ramp_up_strategy=ramp_up_strategy,
ramp_up_start_rps=ramp_up_start_rps,
ramp_up_end_rps=ramp_up_end_rps,
output_q=output_q,
pbar_q=pbar_q,
)
)

async def benchmark(
task_type: TaskType,
Expand Down Expand Up @@ -513,6 +669,8 @@
ready_check_timeout_sec: int = 600,
warmup_time: float = 0.0,
cooldown_time: float = 0.0,
num_workers: int = 1,
max_connections_per_worker: int = 512,
):
try:
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
Expand Down Expand Up @@ -652,89 +810,74 @@

print(f"Burstiness factor: {burstiness} ({distribution})")
print(f"Maximum request concurrency: {max_concurrency}")
print(f"num_workers: {num_workers}")

pbar = None if disable_tqdm else tqdm(total=len(input_requests))

semaphore = (
asyncio.Semaphore(max_concurrency)
if max_concurrency
else contextlib.nullcontext()
)

async def limited_request_func(request_func_input, session, pbar):
async with semaphore:
return await request_func(
request_func_input=request_func_input, session=session, pbar=pbar
)

pbar_q = Queue()
benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = []

rps_change_events = []
last_int_rps = -1
if ramp_up_strategy is not None and ramp_up_start_rps is not None:
last_int_rps = ramp_up_start_rps
rps_change_events.append(
{
"rps": last_int_rps,
"timestamp": datetime.now().isoformat(),
}
)

async for request, current_request_rate in get_request(
input_requests,
request_rate,
burstiness,
ramp_up_strategy,
ramp_up_start_rps,
ramp_up_end_rps,
):
if ramp_up_strategy is not None:
current_int_rps = int(current_request_rate)
if current_int_rps > last_int_rps:
timestamp = datetime.now().isoformat()
for rps_val in range(last_int_rps + 1, current_int_rps + 1):
rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
last_int_rps = current_int_rps
prompt, prompt_len, output_len, mm_content, request_id = (
request.prompt,
request.prompt_len,
request.expected_output_len,
request.multi_modal_data,
request.request_id,
)
req_model_id, req_model_name = model_id, model_name
if lora_modules:
req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module

request_func_input = RequestFuncInput(
model=req_model_id,
model_name=req_model_name,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
logprobs=logprobs,
multi_modal_content=mm_content,
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body,
request_id=request_id,
)
tasks.append(
asyncio.create_task(
limited_request_func(
request_func_input=request_func_input, session=session, pbar=pbar
)
if num_workers > 0:
output_q = Queue()
procs = [
Process(
target=async_worker_wrapper,
kwargs={
"api_url": api_url,
"model_id": model_id,
"model_name": model_name,
"max_concurrency": max_concurrency/num_workers,
"input_requests": input_requests[
i*max_connections_per_worker:(i+1)*max_connections_per_worker],
"logprobs": logprobs,
"request_rate": request_rate/num_workers,
"burstiness": burstiness,
"ignore_eos": ignore_eos,
"lora_modules": lora_modules,
"extra_headers": extra_headers,
"extra_body": extra_body,
"request_func": request_func,
"disable_tqdm": disable_tqdm,
"pidx": i,
"ramp_up_strategy": ramp_up_strategy,
"ramp_up_start_rps":
ramp_up_start_rps/num_workers if ramp_up_start_rps is not None else ramp_up_start_rps,

Check failure on line 842 in vllm/benchmarks/serve.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/benchmarks/serve.py:842:89: E501 Line too long (110 > 88)
"ramp_up_end_rps":
ramp_up_end_rps/num_workers if ramp_up_end_rps is not None else ramp_up_end_rps,

Check failure on line 844 in vllm/benchmarks/serve.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/benchmarks/serve.py:844:89: E501 Line too long (104 > 88)
"output_q": output_q,
"pbar_q": pbar_q,
},
name=f"vllm_multiproc_bench_{i}",
)
)
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)

if pbar is not None:
pbar.close()
for i in range(num_workers)
]

# Start each procs
for p in procs: p.start()

Check failure on line 854 in vllm/benchmarks/serve.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E701)

vllm/benchmarks/serve.py:854:23: E701 Multiple statements on one line (colon)

# Aggregate outputs from all procs
# NOTE: Queue.get() is blocking operation
if not disable_tqdm:
_cnt = 0
num_inputs = len(input_requests)
with tqdm(total=num_inputs) as pbar:
while _cnt < num_inputs:
inc = pbar_q.get()
_cnt += 1
pbar.update(inc)

outputs: list[RequestFuncOutput] = []
num_done_procs = 0
while num_done_procs < num_workers:
outputs.extend(output_q.get())
num_done_procs += 1

# Join all procs
for p in procs: p.join()

Check failure on line 874 in vllm/benchmarks/serve.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E701)

vllm/benchmarks/serve.py:874:23: E701 Multiple statements on one line (colon)
else:
raise RuntimeError(
f"Invalid num_workers: {num_workers}. num_workers must be bigger than 0")

benchmark_duration = time.perf_counter() - benchmark_start_time
benchmark_end_time = time.perf_counter()
benchmark_duration = benchmark_end_time - benchmark_start_time

if task_type == TaskType.GENERATION:
metrics, actual_output_lens = calculate_metrics(
Expand Down Expand Up @@ -828,9 +971,6 @@
"errors": [output.error for output in outputs],
}

if rps_change_events:
result["rps_change_events"] = rps_change_events

def process_one_metric(
# E.g., "ttft"
metric_attribute_name: str,
Expand Down Expand Up @@ -960,13 +1100,13 @@
goodput_config_dict=goodput_config_dict,
is_trim=True,
)
print("{s:{c}^{n}}".format(s="Serving Benchmark Result after warmup before cooldown", n=50, c="="))

Check failure on line 1103 in vllm/benchmarks/serve.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/benchmarks/serve.py:1103:89: E501 Line too long (107 > 88)
print("{:<40} {:<10}".format("Warm-up Time:", warmup_time))
print("{:<40} {:<10}".format("Cool-down Time:", cooldown_time))
print("{:<40} {:<10}".format("Total counted tokens at filtering:", num_counted_tokens))

Check failure on line 1106 in vllm/benchmarks/serve.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/benchmarks/serve.py:1106:89: E501 Line too long (95 > 88)
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", t_duration))
if isinstance(metrics, BenchmarkMetrics):
print("{:<40} {:<10}".format("Total generated tokens:", t_metrics.total_output))

Check failure on line 1109 in vllm/benchmarks/serve.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/benchmarks/serve.py:1109:89: E501 Line too long (92 > 88)
if isinstance(metrics, BenchmarkMetrics):
print(
"{:<40} {:<10.2f}".format(
Expand All @@ -980,7 +1120,7 @@
"total_input_tokens": t_metrics.total_input,
"total_output_tokens": t_metrics.total_output,
"request_throughput": t_metrics.request_throughput,
"request_goodput": t_metrics.request_goodput if goodput_config_dict else None,

Check failure on line 1123 in vllm/benchmarks/serve.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/benchmarks/serve.py:1123:89: E501 Line too long (90 > 88)
"output_throughput": t_metrics.output_throughput,
"total_token_throughput": t_metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
Expand Down Expand Up @@ -1467,13 +1607,20 @@
default=0.0,
help="Cool-down time in seconds."
)
parser.add_argument(
"--max-connections-per-worker",
type=int,
default=INT_MAX,
help="Support multiprocess. num_workers: ceil(num_prompts / max_connections_per_worker)."
)


def main(args: argparse.Namespace) -> dict[str, Any]:
return asyncio.run(main_async(args))


async def main_async(args: argparse.Namespace) -> dict[str, Any]:
setproctitle("vLLM Benchmark")
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
Expand Down Expand Up @@ -1584,6 +1731,10 @@
sampling_params = {}
default_percentile_metrics = "e2el"

# Ensure _num_workers at least 1.
_num_workers = math.ceil(args.num_prompts/args.max_connections_per_worker)
assert (_num_workers > 0), f"Invalid num workers: {_num_workers}"

extra_body = args.extra_body or {}
extra_body = {**sampling_params, **extra_body}

Expand Down Expand Up @@ -1621,6 +1772,8 @@
ready_check_timeout_sec=args.ready_check_timeout_sec,
warmup_time=args.warmup_time,
cooldown_time=args.cooldown_time,
num_workers=_num_workers,
max_connections_per_worker=args.max_connections_per_worker,
)

# Save config and results to json
Expand Down