Skip to content

Commit 66659aa

Browse files
authored
Merge pull request #4 from Dstack-TEE/feat/add_proxy_metrics
feat: add proxy metrics
2 parents 61943c1 + c0ef050 commit 66659aa

File tree

5 files changed

+108
-26
lines changed

5 files changed

+108
-26
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ dstack-sdk = "^0.5.0"
1717
cryptography = "^43.0.1"
1818
redis = "^5.2.1"
1919
nv-ppcie-verifier = "^1.5.0"
20+
prometheus-fastapi-instrumentator = "^7.0.0"
21+
prometheus-client = "^0.21.1"
2022

2123
[tool.poetry.group.dev.dependencies]
2224
pytest = "^8.3.4"

src/app/api/v1/openai.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from app.cache.cache import cache
2222
from app.logger import log
23+
from app.metrics import get_proxy_metrics
2324
from app.quote.quote import (
2425
ECDSA,
2526
ED25519,
@@ -65,7 +66,6 @@ async def stream_vllm_response(
6566
request_body: bytes,
6667
modified_request_body: bytes,
6768
request_hash: Optional[str] = None,
68-
requested_model: Optional[str] = None,
6969
):
7070
"""
7171
Handle streaming vllm request
@@ -75,7 +75,6 @@ async def stream_vllm_response(
7575
request_hash: Optional hash from request header (X-Request-Hash). Used by trusted clients to provide
7676
pre-calculated request hash, avoiding redundant hash computation. Falls back to
7777
calculating hash from request_body if not provided
78-
requested_model: The model name requested by the client
7978
Returns:
8079
A streaming response
8180
"""
@@ -103,11 +102,6 @@ async def generate_stream(response):
103102
# Extract the cache key (data.id) from the first chunk
104103
if not chat_id:
105104
chat_id = chunk_data.get("id")
106-
107-
# Override the model name if requested_model is provided
108-
if requested_model and "model" in chunk_data:
109-
chunk_data["model"] = requested_model
110-
final_chunk = f"data: {json.dumps(chunk_data)}\n"
111105

112106
except Exception as e:
113107
error_message = f"Failed to parse chunk: {e}\n The original data is: {data}"
@@ -148,6 +142,7 @@ async def generate_stream(response):
148142
generate_stream(response),
149143
background=BackgroundTasks([response.aclose, client.aclose]),
150144
media_type="text/event-stream",
145+
headers={"X-Accel-Buffering": "no"},
151146
)
152147

153148

@@ -157,7 +152,6 @@ async def non_stream_vllm_response(
157152
request_body: bytes,
158153
modified_request_body: bytes,
159154
request_hash: Optional[str] = None,
160-
requested_model: Optional[str] = None,
161155
):
162156
"""
163157
Handle non-streaming responses
@@ -167,7 +161,6 @@ async def non_stream_vllm_response(
167161
request_hash: Optional hash from request header (X-Request-Hash). Used by trusted clients to provide
168162
pre-calculated request hash, avoiding redundant hash computation. Falls back to
169163
calculating hash from request_body if not provided
170-
requested_model: The model name requested by the client
171164
Returns:
172165
The response data
173166
"""
@@ -186,10 +179,6 @@ async def non_stream_vllm_response(
186179
raise HTTPException(status_code=response.status_code, detail=response.text)
187180

188181
response_data = response.json()
189-
190-
# Override the model name if requested_model is provided
191-
if requested_model and "model" in response_data:
192-
response_data["model"] = requested_model
193182

194183
# Cache the request-response pair using the chat ID
195184
chat_id = response_data.get("id")
@@ -270,18 +259,16 @@ async def chat_completions(
270259
is_stream = modified_json.get(
271260
"stream", False
272261
) # Default to non-streaming if not specified
273-
requested_model = modified_json.get("model")
274-
275262
modified_request_body = json.dumps(modified_json).encode("utf-8")
276263
if is_stream:
277264
# Create a streaming response
278265
return await stream_vllm_response(
279-
VLLM_URL, request_body, modified_request_body, x_request_hash, requested_model
266+
VLLM_URL, request_body, modified_request_body, x_request_hash
280267
)
281268
else:
282269
# Handle non-streaming response
283270
response_data = await non_stream_vllm_response(
284-
VLLM_URL, request_body, modified_request_body, x_request_hash, requested_model
271+
VLLM_URL, request_body, modified_request_body, x_request_hash
285272
)
286273
return JSONResponse(content=response_data)
287274

@@ -301,18 +288,16 @@ async def completions(
301288
is_stream = modified_json.get(
302289
"stream", False
303290
) # Default to non-streaming if not specified
304-
requested_model = modified_json.get("model")
305-
306291
modified_request_body = json.dumps(modified_json).encode("utf-8")
307292
if is_stream:
308293
# Create a streaming response
309294
return await stream_vllm_response(
310-
VLLM_COMPLETIONS_URL, request_body, modified_request_body, x_request_hash, requested_model
295+
VLLM_COMPLETIONS_URL, request_body, modified_request_body, x_request_hash
311296
)
312297
else:
313298
# Handle non-streaming response
314299
response_data = await non_stream_vllm_response(
315-
VLLM_COMPLETIONS_URL, request_body, modified_request_body, x_request_hash, requested_model
300+
VLLM_COMPLETIONS_URL, request_body, modified_request_body, x_request_hash
316301
)
317302
return JSONResponse(content=response_data)
318303

@@ -355,11 +340,25 @@ async def signature(request: Request, chat_id: str, signing_algo: str = None):
355340
# Metrics of vLLM instance
356341
@router.get("/metrics")
357342
async def metrics(request: Request):
358-
async with httpx.AsyncClient(timeout=httpx.Timeout(TIMEOUT)) as client:
359-
response = await client.get(VLLM_METRICS_URL)
360-
if response.status_code != 200:
361-
raise HTTPException(status_code=response.status_code, detail=response.text)
362-
return PlainTextResponse(response.text)
343+
# Get local metrics from the proxy
344+
local_metrics = get_proxy_metrics()
345+
346+
# Fetch metrics from the vLLM backend
347+
try:
348+
async with httpx.AsyncClient(timeout=httpx.Timeout(TIMEOUT)) as client:
349+
response = await client.get(VLLM_METRICS_URL)
350+
if response.status_code == 200:
351+
remote_metrics = response.text
352+
else:
353+
log.warning(f"Failed to fetch vLLM metrics: {response.status_code}")
354+
remote_metrics = f"# Failed to fetch vLLM metrics: {response.status_code}"
355+
except Exception as e:
356+
log.error(f"Error fetching vLLM metrics: {e}")
357+
remote_metrics = f"# Error fetching vLLM metrics: {e}"
358+
359+
# Combine both and return
360+
combined_metrics = f"{local_metrics}\n\n# --- vLLM Backend Metrics ---\n\n{remote_metrics}"
361+
return PlainTextResponse(combined_metrics)
363362

364363

365364
@router.get("/models")

src/app/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33
from .api import router as api_router
44
from .api.response.response import ok, error, http_exception
55
from .logger import log
6+
from .metrics import vllm_proxy_errors_total
7+
from prometheus_fastapi_instrumentator import Instrumentator
68

79
app = FastAPI()
810
app.include_router(api_router)
911

12+
# Initialize Prometheus Instrumentator
13+
Instrumentator().instrument(app).expose(app, endpoint="/local-metrics", include_in_schema=False)
14+
1015

1116
@app.get("/")
1217
async def root():
@@ -25,6 +30,7 @@ async def global_exception_handler(request: Request, exc: Exception):
2530
return http_exception(exc.status_code, exc.detail)
2631

2732
log.error(f"Unhandled exception: {exc}")
33+
vllm_proxy_errors_total.labels(type=type(exc).__name__).inc()
2834
return error(
2935
status_code=500,
3036
message=str(exc),

src/app/metrics.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from prometheus_client import Counter, REGISTRY, generate_latest
2+
3+
# Custom metrics
4+
# By default, metrics are registered to REGISTRY
5+
vllm_proxy_errors_total = Counter(
6+
"vllm_proxy_errors_total",
7+
"Total number of unhandled exceptions in the vLLM proxy",
8+
["type"]
9+
)
10+
11+
def get_proxy_metrics() -> str:
12+
"""
13+
Get the current proxy metrics from the prometheus-client registry.
14+
"""
15+
return generate_latest(REGISTRY).decode("utf-8")
16+

tests/app/test_metrics.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
import httpx
3+
from fastapi.testclient import TestClient
4+
from unittest.mock import patch
5+
6+
# Standard test setup
7+
from tests.app.test_helpers import setup_test_environment, TEST_AUTH_HEADER
8+
setup_test_environment()
9+
10+
import sys
11+
sys.modules["app.quote.quote"] = __import__("tests.app.mock_quote", fromlist=[""])
12+
13+
from app.main import app
14+
from app.api.v1.openai import VLLM_METRICS_URL
15+
16+
client = TestClient(app)
17+
18+
@pytest.mark.asyncio
19+
@pytest.mark.respx
20+
async def test_metrics_endpoint_combined(respx_mock):
21+
# Mock the vLLM metrics endpoint
22+
vllm_metrics_content = "# HELP vllm_some_metric\n# TYPE vllm_some_metric counter\nvllm_some_metric 1.0"
23+
respx_mock.get(VLLM_METRICS_URL).mock(
24+
return_value=httpx.Response(200, text=vllm_metrics_content)
25+
)
26+
27+
# Make request to the proxy's metrics endpoint
28+
response = client.get("/v1/metrics")
29+
30+
assert response.status_code == 200
31+
content = response.text
32+
33+
# Check if local metrics are present (e.g., from prometheus-fastapi-instrumentator or our custom ones)
34+
assert "vllm_proxy_errors_total" in content
35+
assert "http_requests_total" in content
36+
37+
# Check if vLLM metrics are present
38+
assert "vllm_some_metric" in content
39+
assert "vLLM Backend Metrics" in content
40+
41+
@pytest.mark.asyncio
42+
@pytest.mark.respx
43+
async def test_metrics_endpoint_vllm_fail(respx_mock):
44+
# Mock the vLLM metrics endpoint to fail
45+
respx_mock.get(VLLM_METRICS_URL).mock(
46+
return_value=httpx.Response(500, text="Internal Server Error")
47+
)
48+
49+
# Make request to the proxy's metrics endpoint
50+
response = client.get("/v1/metrics")
51+
52+
assert response.status_code == 200
53+
content = response.text
54+
55+
# Local metrics should still be there
56+
assert "vllm_proxy_errors_total" in content
57+
58+
# Should contain error message about vLLM metrics
59+
assert "Failed to fetch vLLM metrics: 500" in content

0 commit comments

Comments
 (0)