Skip to content

Commit 1493751

Browse files
committed
add benchmark doc and scripts
Signed-off-by: wangli <[email protected]>
1 parent 78530c0 commit 1493751

File tree

9 files changed

+2474
-0
lines changed

9 files changed

+2474
-0
lines changed

benchmarks/backend_request_func.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import json
4+
import os
5+
import sys
6+
import time
7+
import traceback
8+
from dataclasses import dataclass, field
9+
from typing import List, Optional, Union
10+
11+
import aiohttp
12+
import huggingface_hub.constants
13+
from tqdm.asyncio import tqdm
14+
from transformers import (AutoTokenizer, PreTrainedTokenizer,
15+
PreTrainedTokenizerFast)
16+
17+
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
18+
19+
20+
@dataclass
21+
class RequestFuncInput:
22+
prompt: str
23+
api_url: str
24+
prompt_len: int
25+
output_len: int
26+
model: str
27+
model_name: Optional[str] = None
28+
best_of: int = 1
29+
logprobs: Optional[int] = None
30+
extra_body: Optional[dict] = None
31+
multi_modal_content: Optional[dict] = None
32+
ignore_eos: bool = False
33+
34+
35+
@dataclass
36+
class RequestFuncOutput:
37+
generated_text: str = ""
38+
success: bool = False
39+
latency: float = 0.0
40+
output_tokens: int = 0
41+
ttft: float = 0.0 # Time to first token
42+
itl: List[float] = field(
43+
default_factory=list) # List of inter-token latencies
44+
tpot: float = 0.0 # avg next-token latencies
45+
prompt_len: int = 0
46+
error: str = ""
47+
48+
49+
async def async_request_openai_completions(
50+
request_func_input: RequestFuncInput,
51+
pbar: Optional[tqdm] = None,
52+
) -> RequestFuncOutput:
53+
api_url = request_func_input.api_url
54+
assert api_url.endswith(
55+
("completions", "profile")
56+
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
57+
58+
async with aiohttp.ClientSession(trust_env=True,
59+
timeout=AIOHTTP_TIMEOUT) as session:
60+
payload = {
61+
"model": request_func_input.model_name \
62+
if request_func_input.model_name else request_func_input.model,
63+
"prompt": request_func_input.prompt,
64+
"temperature": 0.0,
65+
"best_of": request_func_input.best_of,
66+
"max_tokens": request_func_input.output_len,
67+
"logprobs": request_func_input.logprobs,
68+
"stream": True,
69+
"stream_options": {
70+
"include_usage": True,
71+
},
72+
}
73+
if request_func_input.ignore_eos:
74+
payload["ignore_eos"] = request_func_input.ignore_eos
75+
if request_func_input.extra_body:
76+
payload.update(request_func_input.extra_body)
77+
headers = {
78+
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
79+
}
80+
81+
output = RequestFuncOutput()
82+
output.prompt_len = request_func_input.prompt_len
83+
84+
generated_text = ""
85+
st = time.perf_counter()
86+
most_recent_timestamp = st
87+
try:
88+
async with session.post(url=api_url, json=payload,
89+
headers=headers) as response:
90+
if response.status == 200:
91+
first_chunk_received = False
92+
async for chunk_bytes in response.content:
93+
chunk_bytes = chunk_bytes.strip()
94+
if not chunk_bytes:
95+
continue
96+
97+
chunk = chunk_bytes.decode("utf-8").removeprefix(
98+
"data: ")
99+
if chunk != "[DONE]":
100+
data = json.loads(chunk)
101+
102+
# NOTE: Some completion API might have a last
103+
# usage summary response without a token so we
104+
# want to check a token was generated
105+
if choices := data.get("choices"):
106+
# Note that text could be empty here
107+
# e.g. for special tokens
108+
text = choices[0].get("text")
109+
timestamp = time.perf_counter()
110+
# First token
111+
if not first_chunk_received:
112+
first_chunk_received = True
113+
ttft = time.perf_counter() - st
114+
output.ttft = ttft
115+
116+
# Decoding phase
117+
else:
118+
output.itl.append(timestamp -
119+
most_recent_timestamp)
120+
121+
most_recent_timestamp = timestamp
122+
generated_text += text or ""
123+
elif usage := data.get("usage"):
124+
output.output_tokens = usage.get(
125+
"completion_tokens")
126+
if first_chunk_received:
127+
output.success = True
128+
else:
129+
output.success = False
130+
output.error = (
131+
"Never received a valid chunk to calculate TTFT."
132+
"This response will be marked as failed!")
133+
output.generated_text = generated_text
134+
output.latency = most_recent_timestamp - st
135+
else:
136+
output.error = response.reason or ""
137+
output.success = False
138+
except Exception:
139+
output.success = False
140+
exc_info = sys.exc_info()
141+
output.error = "".join(traceback.format_exception(*exc_info))
142+
143+
if pbar:
144+
pbar.update(1)
145+
return output
146+
147+
148+
def get_model(pretrained_model_name_or_path: str) -> str:
149+
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
150+
from modelscope import snapshot_download
151+
152+
model_path = snapshot_download(
153+
model_id=pretrained_model_name_or_path,
154+
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
155+
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
156+
157+
return model_path
158+
return pretrained_model_name_or_path
159+
160+
def get_tokenizer(
161+
pretrained_model_name_or_path: str,
162+
tokenizer_mode: str = "auto",
163+
trust_remote_code: bool = False,
164+
**kwargs,
165+
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
166+
if pretrained_model_name_or_path is not None and not os.path.exists(
167+
pretrained_model_name_or_path):
168+
pretrained_model_name_or_path = get_model(
169+
pretrained_model_name_or_path)
170+
if tokenizer_mode == "slow":
171+
if kwargs.get("use_fast", False):
172+
raise ValueError(
173+
"Cannot use the fast tokenizer in slow tokenizer mode.")
174+
kwargs["use_fast"] = False
175+
if tokenizer_mode == "mistral":
176+
try:
177+
from vllm.transformers_utils.tokenizer import MistralTokenizer
178+
except ImportError as e:
179+
raise ImportError("MistralTokenizer requires vllm package.\n"
180+
"Please install it with `pip install vllm` "
181+
"to use mistral tokenizer mode.") from e
182+
return MistralTokenizer.from_pretrained(
183+
str(pretrained_model_name_or_path))
184+
else:
185+
return AutoTokenizer.from_pretrained(
186+
pretrained_model_name_or_path,
187+
trust_remote_code=trust_remote_code,
188+
**kwargs,
189+
)
190+
191+
ASYNC_REQUEST_FUNCS = {
192+
"vllm": async_request_openai_completions,
193+
}

benchmarks/benchmark_latency.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Benchmark the latency of processing a single batch of requests."""
3+
import argparse
4+
import dataclasses
5+
import json
6+
import time
7+
from pathlib import Path
8+
from typing import List, Optional
9+
10+
import numpy as np
11+
import torch
12+
from tqdm import tqdm
13+
14+
from vllm import LLM, SamplingParams
15+
from vllm.engine.arg_utils import EngineArgs
16+
from vllm.inputs import PromptType
17+
from vllm.sampling_params import BeamSearchParams
18+
from vllm.utils import FlexibleArgumentParser
19+
20+
21+
def main(args: argparse.Namespace):
22+
print(args)
23+
24+
engine_args = EngineArgs.from_cli_args(args)
25+
26+
# NOTE(woosuk): If the request cannot be processed in a single batch,
27+
# the engine will automatically process the request in multiple batches.
28+
llm = LLM(**dataclasses.asdict(engine_args))
29+
30+
sampling_params = SamplingParams(
31+
n=args.n,
32+
temperature=1.0,
33+
top_p=1.0,
34+
ignore_eos=True,
35+
max_tokens=args.output_len,
36+
)
37+
print(sampling_params)
38+
dummy_prompt_token_ids = np.random.randint(10000,
39+
size=(args.batch_size,
40+
args.input_len))
41+
dummy_prompts: List[PromptType] = [{
42+
"prompt_token_ids": batch
43+
} for batch in dummy_prompt_token_ids.tolist()]
44+
45+
def llm_generate():
46+
if not args.use_beam_search:
47+
llm.generate(dummy_prompts,
48+
sampling_params=sampling_params,
49+
use_tqdm=False)
50+
else:
51+
llm.beam_search(
52+
dummy_prompts,
53+
BeamSearchParams(
54+
beam_width=args.n,
55+
max_tokens=args.output_len,
56+
ignore_eos=True,
57+
))
58+
59+
def run_to_completion(profile_dir: Optional[str] = None):
60+
if profile_dir:
61+
with torch.profiler.profile(
62+
activities=[
63+
torch.profiler.ProfilerActivity.CPU,
64+
torch.profiler.ProfilerActivity.CUDA,
65+
],
66+
on_trace_ready=torch.profiler.tensorboard_trace_handler(
67+
str(profile_dir))) as p:
68+
llm_generate()
69+
print(p.key_averages().table(sort_by="self_cuda_time_total"))
70+
else:
71+
start_time = time.perf_counter()
72+
llm_generate()
73+
end_time = time.perf_counter()
74+
latency = end_time - start_time
75+
return latency
76+
77+
print("Warming up...")
78+
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
79+
run_to_completion(profile_dir=None)
80+
81+
if args.profile:
82+
profile_dir = args.profile_result_dir
83+
if not profile_dir:
84+
profile_dir = Path(
85+
"."
86+
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
87+
print(f"Profiling (results will be saved to '{profile_dir}')...")
88+
run_to_completion(profile_dir=profile_dir)
89+
return
90+
91+
# Benchmark.
92+
latencies = []
93+
94+
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
95+
latencies.append(run_to_completion(profile_dir=None))
96+
latencies = np.array(latencies)
97+
percentages = [10, 25, 50, 75, 90, 99]
98+
percentiles = np.percentile(latencies, percentages)
99+
print(f'Avg latency: {np.mean(latencies)} seconds')
100+
for percentage, percentile in zip(percentages, percentiles):
101+
print(f'{percentage}% percentile latency: {percentile} seconds')
102+
103+
# Output JSON results if specified
104+
if args.output_json:
105+
results = {
106+
"avg_latency": np.mean(latencies),
107+
"latencies": latencies.tolist(),
108+
"percentiles": dict(zip(percentages, percentiles.tolist())),
109+
}
110+
with open(args.output_json, "w") as f:
111+
json.dump(results, f, indent=4)
112+
113+
114+
if __name__ == '__main__':
115+
parser = FlexibleArgumentParser(
116+
description='Benchmark the latency of processing a single batch of '
117+
'requests till completion.')
118+
parser.add_argument('--input-len', type=int, default=32)
119+
parser.add_argument('--output-len', type=int, default=128)
120+
parser.add_argument('--batch-size', type=int, default=8)
121+
parser.add_argument('--n',
122+
type=int,
123+
default=1,
124+
help='Number of generated sequences per prompt.')
125+
parser.add_argument('--use-beam-search', action='store_true')
126+
parser.add_argument('--num-iters-warmup',
127+
type=int,
128+
default=10,
129+
help='Number of iterations to run for warmup.')
130+
parser.add_argument('--num-iters',
131+
type=int,
132+
default=30,
133+
help='Number of iterations to run.')
134+
parser.add_argument(
135+
'--profile',
136+
action='store_true',
137+
help='profile the generation process of a single batch')
138+
parser.add_argument(
139+
'--profile-result-dir',
140+
type=str,
141+
default=None,
142+
help=('path to save the pytorch profiler output. Can be visualized '
143+
'with ui.perfetto.dev or Tensorboard.'))
144+
parser.add_argument(
145+
'--output-json',
146+
type=str,
147+
default=None,
148+
help='Path to save the latency results in JSON format.')
149+
150+
parser = EngineArgs.add_cli_args(parser)
151+
args = parser.parse_args()
152+
main(args)

0 commit comments

Comments
 (0)