diff --git a/tests/v1/kv_connector/run_accuracy_test.sh b/tests/v1/kv_connector/run_accuracy_test.sh new file mode 100644 index 000000000000..0aab60e4adca --- /dev/null +++ b/tests/v1/kv_connector/run_accuracy_test.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +set -xe + +# Model to run. +MODEL_NAME=Qwen/Qwen3-0.6B + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 + pkill -f python + echo "Cleanup complete. Exiting." + exit 0 +} + +# Waits for vLLM to start. +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Prefill instance. +CUDA_VISIBLE_DEVICES=0 NIXL_ROLE="SENDER" vllm serve $MODEL_NAME \ + --port 8100 \ + --enforce-eager \ + --disable-log-requests \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' & + +# Decode instance. +CUDA_VISIBLE_DEVICES=1 NIXL_ROLE="RECVER" vllm serve $MODEL_NAME \ + --port 8200 \ + --enforce-eager \ + --disable-log-requests \ + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' & + +# wait until prefill and decode instances are ready +wait_for_server 8100 +wait_for_server 8200 + +# Proxy server. +python toy_proxy_server.py --port 8192 & + +# Run lm eval. +python3 -m pytest -s -x test_accuracy.py diff --git a/tests/v1/kv_connector/test_accuracy.py b/tests/v1/kv_connector/test_accuracy.py new file mode 100644 index 000000000000..60878a664eb9 --- /dev/null +++ b/tests/v1/kv_connector/test_accuracy.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +import lm_eval + +MODEL_NAME = "Qwen/Qwen3-0.6B" +NUM_CONCURRENT = 100 +TASK = "gsm8k" +FILTER = "exact_match,strict-match" +RTOL = 0.03 +EXPECTED_VALUE = 0.41 + + +def test_accuracy(): + """Run the end to end accuracy test.""" + + model_args = (f"model={MODEL_NAME}," + f"base_url=http://localhost:8192/v1/completions," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/v1/kv_connector/toy_proxy_server.py b/tests/v1/kv_connector/toy_proxy_server.py new file mode 100644 index 000000000000..89e3c4493fdb --- /dev/null +++ b/tests/v1/kv_connector/toy_proxy_server.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import uuid +from contextlib import asynccontextmanager + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize clients + prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' + decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' + + app.state.prefill_client = httpx.AsyncClient(timeout=None, + base_url=prefiller_base_url) + app.state.decode_client = httpx.AsyncClient(timeout=None, + base_url=decoder_base_url) + + yield + + # Shutdown: Close clients + await app.state.prefill_client.aclose() + await app.state.decode_client.aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--prefiller-host", type=str, default="localhost") + parser.add_argument("--prefiller-port", type=int, default=8100) + parser.add_argument("--decoder-host", type=str, default="localhost") + parser.add_argument("--decoder-port", type=int, default=8200) + args = parser.parse_args() + return args + + +# Initialize variables to hold the persistent clients +app.state.prefill_client = None +app.state.decode_client = None + + +async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, + req_data: dict, request_id: str): + """ + Send a request to a service using a persistent client. + """ + req_data = req_data.copy() + req_data['do_remote_decode'] = True + req_data["stream"] = False + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + response = await client.post(endpoint, json=req_data, headers=headers) + response.raise_for_status() + + return response + + +async def stream_service_response(client: httpx.AsyncClient, endpoint: str, + req_data: dict, remote_block_ids: list[int], + remote_engine_id: str, request_id: str): + """ + Asynchronously stream the response from a service using a persistent client. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + req_data['do_remote_prefill'] = True + req_data["remote_block_ids"] = remote_block_ids + req_data['remote_engine_id'] = remote_engine_id + async with client.stream("POST", endpoint, json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + try: + req_data = await request.json() + + request_id = str(uuid.uuid4()) + + # Send request to prefill service + response = await send_request_to_service(app.state.prefill_client, + "/completions", req_data, + request_id) + + # Extract the needed fields + response_json = response.json() + remote_block_ids = response_json.get('remote_block_ids', []) + remote_engine_id = response_json.get('remote_engine_id', '') + + # Add these to the request data for the decoder + req_data['remote_block_ids'] = remote_block_ids + req_data['remote_engine_id'] = remote_engine_id + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response( + app.state.decode_client, + "/completions", + req_data, + remote_block_ids=remote_block_ids, + remote_engine_id=remote_engine_id, + request_id=request_id): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port)