forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 2
Add Accuracy Test #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
tlrmchlsmth
merged 5 commits into
tlrmchlsmth:nixl_integration
from
robertgshaw2-redhat:add-tests
Apr 30, 2025
Merged
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| #!/bin/bash | ||
|
|
||
| set -xe | ||
|
|
||
| # Model to run. | ||
| MODEL_NAME=Qwen/Qwen3-0.6B | ||
|
|
||
| # Trap the SIGINT signal (triggered by Ctrl+C) | ||
| trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT | ||
|
|
||
| # Cleanup function | ||
| cleanup() { | ||
| echo "Caught Ctrl+C, cleaning up..." | ||
| # Cleanup commands | ||
| pgrep python | xargs kill -9 | ||
| pkill -f python | ||
| echo "Cleanup complete. Exiting." | ||
| exit 0 | ||
| } | ||
|
|
||
| # Waits for vLLM to start. | ||
| wait_for_server() { | ||
| local port=$1 | ||
| timeout 1200 bash -c " | ||
| until curl -s localhost:${port}/v1/completions > /dev/null; do | ||
| sleep 1 | ||
| done" && return 0 || return 1 | ||
| } | ||
|
|
||
| # Prefill instance. | ||
| CUDA_VISIBLE_DEVICES=0 NIXL_ROLE="SENDER" vllm serve $MODEL_NAME \ | ||
| --port 8100 \ | ||
| --enforce-eager \ | ||
| --disable-log-requests \ | ||
| --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' & | ||
|
|
||
| # Decode instance. | ||
| CUDA_VISIBLE_DEVICES=1 NIXL_ROLE="RECVER" vllm serve $MODEL_NAME \ | ||
| --port 8200 \ | ||
| --enforce-eager \ | ||
| --disable-log-requests \ | ||
| --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' & | ||
|
|
||
| # wait until prefill and decode instances are ready | ||
| wait_for_server 8100 | ||
| wait_for_server 8200 | ||
|
|
||
| # Proxy server. | ||
| python toy_proxy_server.py --port 8192 & | ||
|
|
||
| # Run lm eval. | ||
| python3 -m pytest -s -x test_accuracy.py | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import lm_eval | ||
|
|
||
| MODEL_NAME = "Qwen/Qwen3-0.6B" | ||
| NUM_CONCURRENT = 100 | ||
| TASK = "gsm8k" | ||
| FILTER = "exact_match,strict-match" | ||
| RTOL = 0.03 | ||
| EXPECTED_VALUE = 0.41 | ||
|
|
||
|
|
||
| def test_accuracy(): | ||
| """Run the end to end accuracy test.""" | ||
|
|
||
| model_args = (f"model={MODEL_NAME}," | ||
| f"base_url=http://localhost:8192/v1/completions," | ||
| f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") | ||
|
|
||
| results = lm_eval.simple_evaluate( | ||
| model="local-completions", | ||
| model_args=model_args, | ||
| tasks=TASK, | ||
| ) | ||
|
|
||
| measured_value = results["results"][TASK][FILTER] | ||
| assert (measured_value - RTOL < EXPECTED_VALUE | ||
| and measured_value + RTOL > EXPECTED_VALUE | ||
| ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,145 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: this is needed so we can run in the CI (since the other one is in examples) |
||
|
|
||
| import argparse | ||
| import os | ||
| import uuid | ||
| from contextlib import asynccontextmanager | ||
|
|
||
| import httpx | ||
| from fastapi import FastAPI, Request | ||
| from fastapi.responses import StreamingResponse | ||
|
|
||
|
|
||
| @asynccontextmanager | ||
| async def lifespan(app: FastAPI): | ||
| """ | ||
| Lifespan context manager to handle startup and shutdown events. | ||
| """ | ||
| # Startup: Initialize clients | ||
| prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' | ||
| decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' | ||
|
|
||
| app.state.prefill_client = httpx.AsyncClient(timeout=None, | ||
| base_url=prefiller_base_url) | ||
| app.state.decode_client = httpx.AsyncClient(timeout=None, | ||
| base_url=decoder_base_url) | ||
|
|
||
| yield | ||
|
|
||
| # Shutdown: Close clients | ||
| await app.state.prefill_client.aclose() | ||
| await app.state.decode_client.aclose() | ||
|
|
||
|
|
||
| # Update FastAPI app initialization to use lifespan | ||
| app = FastAPI(lifespan=lifespan) | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser() | ||
|
|
||
| parser.add_argument("--port", type=int, default=8000) | ||
| parser.add_argument("--host", type=str, default="localhost") | ||
| parser.add_argument("--prefiller-host", type=str, default="localhost") | ||
| parser.add_argument("--prefiller-port", type=int, default=8100) | ||
| parser.add_argument("--decoder-host", type=str, default="localhost") | ||
| parser.add_argument("--decoder-port", type=int, default=8200) | ||
| args = parser.parse_args() | ||
| return args | ||
|
|
||
|
|
||
| # Initialize variables to hold the persistent clients | ||
| app.state.prefill_client = None | ||
| app.state.decode_client = None | ||
|
|
||
|
|
||
| async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, | ||
| req_data: dict, request_id: str): | ||
| """ | ||
| Send a request to a service using a persistent client. | ||
| """ | ||
| req_data = req_data.copy() | ||
| req_data['do_remote_decode'] = True | ||
| req_data["stream"] = False | ||
| headers = { | ||
| "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", | ||
| "X-Request-Id": request_id | ||
| } | ||
| response = await client.post(endpoint, json=req_data, headers=headers) | ||
| response.raise_for_status() | ||
|
|
||
| return response | ||
|
|
||
|
|
||
| async def stream_service_response(client: httpx.AsyncClient, endpoint: str, | ||
| req_data: dict, remote_block_ids: list[int], | ||
| remote_engine_id: str, request_id: str): | ||
| """ | ||
| Asynchronously stream the response from a service using a persistent client. | ||
| """ | ||
| headers = { | ||
| "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", | ||
| "X-Request-Id": request_id | ||
| } | ||
| req_data['do_remote_prefill'] = True | ||
| req_data["remote_block_ids"] = remote_block_ids | ||
| req_data['remote_engine_id'] = remote_engine_id | ||
| async with client.stream("POST", endpoint, json=req_data, | ||
| headers=headers) as response: | ||
| response.raise_for_status() | ||
| async for chunk in response.aiter_bytes(): | ||
| yield chunk | ||
|
|
||
|
|
||
| @app.post("/v1/completions") | ||
| async def handle_completions(request: Request): | ||
| try: | ||
| req_data = await request.json() | ||
|
|
||
| request_id = str(uuid.uuid4()) | ||
|
|
||
| # Send request to prefill service | ||
| response = await send_request_to_service(app.state.prefill_client, | ||
| "/completions", req_data, | ||
| request_id) | ||
|
|
||
| # Extract the needed fields | ||
| response_json = response.json() | ||
| remote_block_ids = response_json.get('remote_block_ids', []) | ||
| remote_engine_id = response_json.get('remote_engine_id', '') | ||
|
|
||
| # Add these to the request data for the decoder | ||
| req_data['remote_block_ids'] = remote_block_ids | ||
| req_data['remote_engine_id'] = remote_engine_id | ||
|
|
||
| # Stream response from decode service | ||
| async def generate_stream(): | ||
| async for chunk in stream_service_response( | ||
| app.state.decode_client, | ||
| "/completions", | ||
| req_data, | ||
| remote_block_ids=remote_block_ids, | ||
| remote_engine_id=remote_engine_id, | ||
| request_id=request_id): | ||
| yield chunk | ||
|
|
||
| return StreamingResponse(generate_stream(), | ||
| media_type="application/json") | ||
|
|
||
| except Exception as e: | ||
| import sys | ||
| import traceback | ||
| exc_info = sys.exc_info() | ||
| print("Error occurred in disagg prefill proxy server" | ||
| " - completions endpoint") | ||
| print(e) | ||
| print("".join(traceback.format_exception(*exc_info))) | ||
| raise | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| global global_args | ||
| global_args = parse_args() | ||
|
|
||
| import uvicorn | ||
| uvicorn.run(app, host=global_args.host, port=global_args.port) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the cleanup function is never used -- should we remove it? It looks like
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXITwill do the job (but we could think about adding a-9)