Skip to content

Commit 83bf6cd

Browse files
daniel-saliblulmer
authored andcommitted
[Frontend] track server_load (vllm-project#13950)
Signed-off-by: Louis Ulmer <[email protected]>
1 parent c89fd01 commit 83bf6cd

File tree

4 files changed

+131
-4
lines changed

4 files changed

+131
-4
lines changed

tests/entrypoints/openai/test_basic.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,51 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer):
171171
extra_headers={
172172
"Content-Type": "application/x-www-form-urlencoded"
173173
})
174+
175+
176+
@pytest.mark.parametrize(
177+
"server_args",
178+
[
179+
pytest.param(["--enable-server-load-tracking"],
180+
id="enable-server-load-tracking")
181+
],
182+
indirect=True,
183+
)
184+
@pytest.mark.asyncio
185+
async def test_server_load(server: RemoteOpenAIServer):
186+
# Check initial server load
187+
response = requests.get(server.url_for("load"))
188+
assert response.status_code == HTTPStatus.OK
189+
assert response.json().get("server_load") == 0
190+
191+
def make_long_completion_request():
192+
return requests.post(
193+
server.url_for("v1/completions"),
194+
headers={"Content-Type": "application/json"},
195+
json={
196+
"prompt": "Give me a long story",
197+
"max_tokens": 1000,
198+
"temperature": 0,
199+
},
200+
)
201+
202+
# Start the completion request in a background thread.
203+
completion_future = asyncio.create_task(
204+
asyncio.to_thread(make_long_completion_request))
205+
206+
# Give a short delay to ensure the request has started.
207+
await asyncio.sleep(0.1)
208+
209+
# Check server load while the completion request is running.
210+
response = requests.get(server.url_for("load"))
211+
assert response.status_code == HTTPStatus.OK
212+
assert response.json().get("server_load") == 1
213+
214+
# Wait for the completion request to finish.
215+
await completion_future
216+
await asyncio.sleep(0.1)
217+
218+
# Check server load after the completion request has finished.
219+
response = requests.get(server.url_for("load"))
220+
assert response.status_code == HTTPStatus.OK
221+
assert response.json().get("server_load") == 0

vllm/entrypoints/openai/api_server.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
from vllm.entrypoints.openai.serving_transcription import (
8181
OpenAIServingTranscription)
8282
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
83-
from vllm.entrypoints.utils import with_cancellation
83+
from vllm.entrypoints.utils import load_aware_call, with_cancellation
8484
from vllm.logger import init_logger
8585
from vllm.usage.usage_lib import UsageContext
8686
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
@@ -347,6 +347,24 @@ async def health(raw_request: Request) -> Response:
347347
return Response(status_code=200)
348348

349349

350+
@router.get("/load")
351+
async def get_server_load_metrics(request: Request):
352+
# This endpoint returns the current server load metrics.
353+
# It tracks requests utilizing the GPU from the following routes:
354+
# - /v1/chat/completions
355+
# - /v1/completions
356+
# - /v1/audio/transcriptions
357+
# - /v1/embeddings
358+
# - /pooling
359+
# - /score
360+
# - /v1/score
361+
# - /rerank
362+
# - /v1/rerank
363+
# - /v2/rerank
364+
return JSONResponse(
365+
content={'server_load': request.app.state.server_load_metrics})
366+
367+
350368
@router.api_route("/ping", methods=["GET", "POST"])
351369
async def ping(raw_request: Request) -> Response:
352370
"""Ping check. Endpoint required for SageMaker"""
@@ -400,6 +418,7 @@ async def show_version():
400418
@router.post("/v1/chat/completions",
401419
dependencies=[Depends(validate_json_request)])
402420
@with_cancellation
421+
@load_aware_call
403422
async def create_chat_completion(request: ChatCompletionRequest,
404423
raw_request: Request):
405424
handler = chat(raw_request)
@@ -421,6 +440,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
421440

422441
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
423442
@with_cancellation
443+
@load_aware_call
424444
async def create_completion(request: CompletionRequest, raw_request: Request):
425445
handler = completion(raw_request)
426446
if handler is None:
@@ -439,6 +459,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
439459

440460
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
441461
@with_cancellation
462+
@load_aware_call
442463
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
443464
handler = embedding(raw_request)
444465
if handler is None:
@@ -485,6 +506,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
485506

486507
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
487508
@with_cancellation
509+
@load_aware_call
488510
async def create_pooling(request: PoolingRequest, raw_request: Request):
489511
handler = pooling(raw_request)
490512
if handler is None:
@@ -503,6 +525,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
503525

504526
@router.post("/score", dependencies=[Depends(validate_json_request)])
505527
@with_cancellation
528+
@load_aware_call
506529
async def create_score(request: ScoreRequest, raw_request: Request):
507530
handler = score(raw_request)
508531
if handler is None:
@@ -521,6 +544,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
521544

522545
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
523546
@with_cancellation
547+
@load_aware_call
524548
async def create_score_v1(request: ScoreRequest, raw_request: Request):
525549
logger.warning(
526550
"To indicate that Score API is not part of standard OpenAI API, we "
@@ -531,10 +555,10 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
531555

532556
@router.post("/v1/audio/transcriptions")
533557
@with_cancellation
558+
@load_aware_call
534559
async def create_transcriptions(request: Annotated[TranscriptionRequest,
535560
Form()],
536561
raw_request: Request):
537-
538562
handler = transcription(raw_request)
539563
if handler is None:
540564
return base(raw_request).create_error_response(
@@ -556,6 +580,7 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
556580

557581
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
558582
@with_cancellation
583+
@load_aware_call
559584
async def do_rerank(request: RerankRequest, raw_request: Request):
560585
handler = rerank(raw_request)
561586
if handler is None:
@@ -894,6 +919,9 @@ async def init_app_state(
894919
) if model_config.runner_type == "transcription" else None
895920
state.task = model_config.task
896921

922+
state.enable_server_load_tracking = args.enable_server_load_tracking
923+
state.server_load_metrics = 0
924+
897925

898926
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
899927
family = socket.AF_INET

vllm/entrypoints/openai/cli_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
257257
action='store_true',
258258
default=False,
259259
help="If set to True, enable prompt_tokens_details in usage.")
260+
parser.add_argument(
261+
"--enable-server-load-tracking",
262+
action='store_true',
263+
default=False,
264+
help=
265+
"If set to True, enable tracking server_load_metrics in the app state."
266+
)
260267

261268
return parser
262269

vllm/entrypoints/utils.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import functools
55

66
from fastapi import Request
7+
from fastapi.responses import JSONResponse, StreamingResponse
8+
from starlette.background import BackgroundTask, BackgroundTasks
79

810

911
async def listen_for_disconnect(request: Request) -> None:
@@ -17,9 +19,9 @@ async def listen_for_disconnect(request: Request) -> None:
1719
def with_cancellation(handler_func):
1820
"""Decorator that allows a route handler to be cancelled by client
1921
disconnections.
20-
22+
2123
This does _not_ use request.is_disconnected, which does not work with
22-
middleware. Instead this follows the pattern from
24+
middleware. Instead this follows the pattern from
2325
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
2426
to wait for an http disconnect message, and the other to do the work that we
2527
want done. When the first task finishes, the other is cancelled.
@@ -57,3 +59,45 @@ async def wrapper(*args, **kwargs):
5759
return None
5860

5961
return wrapper
62+
63+
64+
def decrement_server_load(request: Request):
65+
request.app.state.server_load_metrics -= 1
66+
67+
68+
def load_aware_call(func):
69+
70+
@functools.wraps(func)
71+
async def wrapper(*args, raw_request: Request, **kwargs):
72+
if not raw_request.app.state.enable_server_load_tracking:
73+
return await func(*args, raw_request=raw_request, **kwargs)
74+
75+
raw_request.app.state.server_load_metrics += 1
76+
try:
77+
response = await func(*args, raw_request=raw_request, **kwargs)
78+
except Exception:
79+
raw_request.app.state.server_load_metrics -= 1
80+
raise
81+
82+
if isinstance(response, (JSONResponse, StreamingResponse)):
83+
if response.background is None:
84+
response.background = BackgroundTask(decrement_server_load,
85+
raw_request)
86+
elif isinstance(response.background, BackgroundTasks):
87+
response.background.add_task(decrement_server_load,
88+
raw_request)
89+
elif isinstance(response.background, BackgroundTask):
90+
# Convert the single BackgroundTask to BackgroundTasks
91+
# and chain the decrement_server_load task to it
92+
tasks = BackgroundTasks()
93+
tasks.add_task(response.background.func,
94+
*response.background.args,
95+
**response.background.kwargs)
96+
tasks.add_task(decrement_server_load, raw_request)
97+
response.background = tasks
98+
else:
99+
raw_request.app.state.server_load_metrics -= 1
100+
101+
return response
102+
103+
return wrapper

0 commit comments

Comments
 (0)