Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions tests/v1/kv_connector/run_accuracy_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/bin/bash

set -xe

# Model to run.
MODEL_NAME=Qwen/Qwen3-0.6B

# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT

# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo "Cleanup complete. Exiting."
exit 0
}
Comment on lines +8 to +19
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the cleanup function is never used -- should we remove it? It looks like trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT will do the job (but we could think about adding a -9)


# Waits for vLLM to start.
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}

# Prefill instance.
CUDA_VISIBLE_DEVICES=0 NIXL_ROLE="SENDER" vllm serve $MODEL_NAME \
--port 8100 \
--enforce-eager \
--disable-log-requests \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' &

# Decode instance.
CUDA_VISIBLE_DEVICES=1 NIXL_ROLE="RECVER" vllm serve $MODEL_NAME \
--port 8200 \
--enforce-eager \
--disable-log-requests \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' &

# wait until prefill and decode instances are ready
wait_for_server 8100
wait_for_server 8200

# Proxy server.
python toy_proxy_server.py --port 8192 &

# Run lm eval.
python3 -m pytest -s -x test_accuracy.py
28 changes: 28 additions & 0 deletions tests/v1/kv_connector/test_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
import lm_eval

MODEL_NAME = "Qwen/Qwen3-0.6B"
NUM_CONCURRENT = 100
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.41


def test_accuracy():
"""Run the end to end accuracy test."""

model_args = (f"model={MODEL_NAME},"
f"base_url=http://localhost:8192/v1/completions,"
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")

results = lm_eval.simple_evaluate(
model="local-completions",
model_args=model_args,
tasks=TASK,
)

measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
145 changes: 145 additions & 0 deletions tests/v1/kv_connector/toy_proxy_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this is needed so we can run in the CI (since the other one is in examples)


import argparse
import os
import uuid
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse


@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize clients
prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1'
decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1'

app.state.prefill_client = httpx.AsyncClient(timeout=None,
base_url=prefiller_base_url)
app.state.decode_client = httpx.AsyncClient(timeout=None,
base_url=decoder_base_url)

yield

# Shutdown: Close clients
await app.state.prefill_client.aclose()
await app.state.decode_client.aclose()


# Update FastAPI app initialization to use lifespan
app = FastAPI(lifespan=lifespan)


def parse_args():
parser = argparse.ArgumentParser()

parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--prefiller-host", type=str, default="localhost")
parser.add_argument("--prefiller-port", type=int, default=8100)
parser.add_argument("--decoder-host", type=str, default="localhost")
parser.add_argument("--decoder-port", type=int, default=8200)
args = parser.parse_args()
return args


# Initialize variables to hold the persistent clients
app.state.prefill_client = None
app.state.decode_client = None


async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
req_data: dict, request_id: str):
"""
Send a request to a service using a persistent client.
"""
req_data = req_data.copy()
req_data['do_remote_decode'] = True
req_data["stream"] = False
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
response = await client.post(endpoint, json=req_data, headers=headers)
response.raise_for_status()

return response


async def stream_service_response(client: httpx.AsyncClient, endpoint: str,
req_data: dict, remote_block_ids: list[int],
remote_engine_id: str, request_id: str):
"""
Asynchronously stream the response from a service using a persistent client.
"""
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
req_data['do_remote_prefill'] = True
req_data["remote_block_ids"] = remote_block_ids
req_data['remote_engine_id'] = remote_engine_id
async with client.stream("POST", endpoint, json=req_data,
headers=headers) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk


@app.post("/v1/completions")
async def handle_completions(request: Request):
try:
req_data = await request.json()

request_id = str(uuid.uuid4())

# Send request to prefill service
response = await send_request_to_service(app.state.prefill_client,
"/completions", req_data,
request_id)

# Extract the needed fields
response_json = response.json()
remote_block_ids = response_json.get('remote_block_ids', [])
remote_engine_id = response_json.get('remote_engine_id', '')

# Add these to the request data for the decoder
req_data['remote_block_ids'] = remote_block_ids
req_data['remote_engine_id'] = remote_engine_id

# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(
app.state.decode_client,
"/completions",
req_data,
remote_block_ids=remote_block_ids,
remote_engine_id=remote_engine_id,
request_id=request_id):
yield chunk

return StreamingResponse(generate_stream(),
media_type="application/json")

except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
" - completions endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise


if __name__ == '__main__':
global global_args
global_args = parse_args()

import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)