From 2b8e763f8f47fc239feacb542e5fe664fe6058eb Mon Sep 17 00:00:00 2001 From: Jimin Park Date: Mon, 8 Dec 2025 04:32:02 +0000 Subject: [PATCH] Support multi-process benchmark --- vllm/benchmarks/serve.py | 317 +++++++++++++++++++++++++++++---------- 1 file changed, 235 insertions(+), 82 deletions(-) diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index e8704fe23c75..f54ddee52fb6 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -31,7 +31,10 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Literal +from typing import Any, Literal, Optional +from multiprocessing import Process, Queue +import math +from setproctitle import setproctitle import aiohttp import numpy as np @@ -57,6 +60,7 @@ shutil.which("gnuplot") is not None ) +INT_MAX = 10**10 class TaskType(Enum): GENERATION = "generation" @@ -483,6 +487,158 @@ def calculate_metrics( return metrics, actual_output_lens +async def benchmark_core( + api_url: str, + model_id: str, + model_name: str, + max_concurrency: Optional[int], + input_requests: list[SampleRequest], + logprobs: Optional[int], + request_rate: float, + burstiness: float, + ignore_eos: bool, + lora_modules: Optional[Iterable[str]], + extra_headers: Optional[dict], + extra_body: Optional[dict], + request_func, + disable_tqdm: bool, + pidx: int, + ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, + ramp_up_start_rps: Optional[int] = None, + ramp_up_end_rps: Optional[int] = None, + output_q: Optional[Queue] = None, + pbar_q: Optional[Queue] = None, +): + setproctitle(f"vLLM Benchmark_core:{pidx:02d}") + + # Create dedicated connector and session for each process + connector = aiohttp.TCPConnector( + limit=max_concurrency or 0, + limit_per_host=max_concurrency or 0, + ttl_dns_cache=300, + use_dns_cache=True, + keepalive_timeout=60, + enable_cleanup_closed=True, + force_close=False, + ssl=("https://" in api_url), + ) + + session = aiohttp.ClientSession( + connector=connector, + trust_env=True, + timeout=aiohttp.ClientTimeout(total=6 * 60 * 60), + ) + + # This can be used once the minimum Python version is 3.10 or higher, + # and it will simplify the code in limited_request_func. + # semaphore = (asyncio.Semaphore(max_concurrency) + # if max_concurrency else contextlib.nullcontext()) + semaphore = (asyncio.Semaphore(max_concurrency) + if max_concurrency else None) + + async def limited_request_func(request_func_input, session, pbar): + if semaphore is None: + ret = await request_func(request_func_input=request_func_input, + session=session, + pbar=pbar) + if not disable_tqdm: pbar_q.put(1) + return ret + async with semaphore: + ret = await request_func(request_func_input=request_func_input, + session=session, + pbar=pbar) + if not disable_tqdm: pbar_q.put(1) + return ret + + tasks: list[asyncio.Task] = [] + + async for request, _ in get_request( + input_requests, request_rate, burstiness, + ramp_up_strategy, ramp_up_start_rps, ramp_up_end_rps + ): + prompt, prompt_len, output_len, mm_content, request_id = ( + request.prompt, + request.prompt_len, + request.expected_output_len, + request.multi_modal_data, + request.request_id, + ) + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput( + model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, # async for but it works serialized manner. + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + multi_modal_content=mm_content, + ignore_eos=ignore_eos, + extra_headers=extra_headers, + extra_body=extra_body, + request_id=request_id, + ) + tasks.append( + asyncio.create_task( + limited_request_func(request_func_input=request_func_input, + session=session, + pbar=None))) # Use external global pbar + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) + + # clean-up & enqueue + await session.close() + output_q.put(outputs) + +def async_worker_wrapper( + api_url: str, + model_id: str, + model_name: str, + max_concurrency: Optional[int], + input_requests: list[SampleRequest], + logprobs: Optional[int], + request_rate: float, + burstiness: float, + ignore_eos: bool, + lora_modules: Optional[Iterable[str]], + extra_headers: Optional[dict], + extra_body: Optional[dict], + request_func, + disable_tqdm: bool, + pidx: int, + ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, + ramp_up_start_rps: Optional[int] = None, + ramp_up_end_rps: Optional[int] = None, + output_q: Optional[Queue] = None, + pbar_q: Optional[Queue] = None, +): + asyncio.run( + benchmark_core( + api_url=api_url, + model_id=model_id, + model_name=model_name, + max_concurrency=max_concurrency, + input_requests=input_requests, + logprobs=logprobs, + request_rate=request_rate, + burstiness=burstiness, + ignore_eos=ignore_eos, + lora_modules=lora_modules, + extra_headers=extra_headers, + extra_body=extra_body, + request_func=request_func, + disable_tqdm=disable_tqdm, + pidx=pidx, + ramp_up_strategy=ramp_up_strategy, + ramp_up_start_rps=ramp_up_start_rps, + ramp_up_end_rps=ramp_up_end_rps, + output_q=output_q, + pbar_q=pbar_q, + ) + ) async def benchmark( task_type: TaskType, @@ -513,6 +669,8 @@ async def benchmark( ready_check_timeout_sec: int = 600, warmup_time: float = 0.0, cooldown_time: float = 0.0, + num_workers: int = 1, + max_connections_per_worker: int = 512, ): try: request_func = ASYNC_REQUEST_FUNCS[endpoint_type] @@ -652,89 +810,74 @@ async def warmup_limited_request_func(): print(f"Burstiness factor: {burstiness} ({distribution})") print(f"Maximum request concurrency: {max_concurrency}") + print(f"num_workers: {num_workers}") - pbar = None if disable_tqdm else tqdm(total=len(input_requests)) - - semaphore = ( - asyncio.Semaphore(max_concurrency) - if max_concurrency - else contextlib.nullcontext() - ) - - async def limited_request_func(request_func_input, session, pbar): - async with semaphore: - return await request_func( - request_func_input=request_func_input, session=session, pbar=pbar - ) - + pbar_q = Queue() benchmark_start_time = time.perf_counter() - tasks: list[asyncio.Task] = [] - - rps_change_events = [] - last_int_rps = -1 - if ramp_up_strategy is not None and ramp_up_start_rps is not None: - last_int_rps = ramp_up_start_rps - rps_change_events.append( - { - "rps": last_int_rps, - "timestamp": datetime.now().isoformat(), - } - ) - - async for request, current_request_rate in get_request( - input_requests, - request_rate, - burstiness, - ramp_up_strategy, - ramp_up_start_rps, - ramp_up_end_rps, - ): - if ramp_up_strategy is not None: - current_int_rps = int(current_request_rate) - if current_int_rps > last_int_rps: - timestamp = datetime.now().isoformat() - for rps_val in range(last_int_rps + 1, current_int_rps + 1): - rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) - last_int_rps = current_int_rps - prompt, prompt_len, output_len, mm_content, request_id = ( - request.prompt, - request.prompt_len, - request.expected_output_len, - request.multi_modal_data, - request.request_id, - ) - req_model_id, req_model_name = model_id, model_name - if lora_modules: - req_lora_module = next(lora_modules) - req_model_id, req_model_name = req_lora_module, req_lora_module - request_func_input = RequestFuncInput( - model=req_model_id, - model_name=req_model_name, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - logprobs=logprobs, - multi_modal_content=mm_content, - ignore_eos=ignore_eos, - extra_headers=extra_headers, - extra_body=extra_body, - request_id=request_id, - ) - tasks.append( - asyncio.create_task( - limited_request_func( - request_func_input=request_func_input, session=session, pbar=pbar - ) + if num_workers > 0: + output_q = Queue() + procs = [ + Process( + target=async_worker_wrapper, + kwargs={ + "api_url": api_url, + "model_id": model_id, + "model_name": model_name, + "max_concurrency": max_concurrency/num_workers, + "input_requests": input_requests[ + i*max_connections_per_worker:(i+1)*max_connections_per_worker], + "logprobs": logprobs, + "request_rate": request_rate/num_workers, + "burstiness": burstiness, + "ignore_eos": ignore_eos, + "lora_modules": lora_modules, + "extra_headers": extra_headers, + "extra_body": extra_body, + "request_func": request_func, + "disable_tqdm": disable_tqdm, + "pidx": i, + "ramp_up_strategy": ramp_up_strategy, + "ramp_up_start_rps": + ramp_up_start_rps/num_workers if ramp_up_start_rps is not None else ramp_up_start_rps, + "ramp_up_end_rps": + ramp_up_end_rps/num_workers if ramp_up_end_rps is not None else ramp_up_end_rps, + "output_q": output_q, + "pbar_q": pbar_q, + }, + name=f"vllm_multiproc_bench_{i}", ) - ) - outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) - - if pbar is not None: - pbar.close() + for i in range(num_workers) + ] + + # Start each procs + for p in procs: p.start() + + # Aggregate outputs from all procs + # NOTE: Queue.get() is blocking operation + if not disable_tqdm: + _cnt = 0 + num_inputs = len(input_requests) + with tqdm(total=num_inputs) as pbar: + while _cnt < num_inputs: + inc = pbar_q.get() + _cnt += 1 + pbar.update(inc) + + outputs: list[RequestFuncOutput] = [] + num_done_procs = 0 + while num_done_procs < num_workers: + outputs.extend(output_q.get()) + num_done_procs += 1 + + # Join all procs + for p in procs: p.join() + else: + raise RuntimeError( + f"Invalid num_workers: {num_workers}. num_workers must be bigger than 0") - benchmark_duration = time.perf_counter() - benchmark_start_time + benchmark_end_time = time.perf_counter() + benchmark_duration = benchmark_end_time - benchmark_start_time if task_type == TaskType.GENERATION: metrics, actual_output_lens = calculate_metrics( @@ -828,9 +971,6 @@ async def limited_request_func(request_func_input, session, pbar): "errors": [output.error for output in outputs], } - if rps_change_events: - result["rps_change_events"] = rps_change_events - def process_one_metric( # E.g., "ttft" metric_attribute_name: str, @@ -1467,6 +1607,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=0.0, help="Cool-down time in seconds." ) + parser.add_argument( + "--max-connections-per-worker", + type=int, + default=INT_MAX, + help="Support multiprocess. num_workers: ceil(num_prompts / max_connections_per_worker)." + ) def main(args: argparse.Namespace) -> dict[str, Any]: @@ -1474,6 +1620,7 @@ def main(args: argparse.Namespace) -> dict[str, Any]: async def main_async(args: argparse.Namespace) -> dict[str, Any]: + setproctitle("vLLM Benchmark") print(args) random.seed(args.seed) np.random.seed(args.seed) @@ -1584,6 +1731,10 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: sampling_params = {} default_percentile_metrics = "e2el" + # Ensure _num_workers at least 1. + _num_workers = math.ceil(args.num_prompts/args.max_connections_per_worker) + assert (_num_workers > 0), f"Invalid num workers: {_num_workers}" + extra_body = args.extra_body or {} extra_body = {**sampling_params, **extra_body} @@ -1621,6 +1772,8 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: ready_check_timeout_sec=args.ready_check_timeout_sec, warmup_time=args.warmup_time, cooldown_time=args.cooldown_time, + num_workers=_num_workers, + max_connections_per_worker=args.max_connections_per_worker, ) # Save config and results to json