-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[Feature] vLLM CLI for serving and querying OpenAI compatible server #5090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
533dfa2
fa90277
b6f06fa
3c09138
01b0fef
8d13d0a
e4004e9
d9606e4
60d58cb
dd031b5
fdea667
1979d18
73ed451
5aa70b6
0aff304
1e4e891
1c617b9
5c8250b
09103b6
ae60142
807d97f
09aa92f
f9dde03
00f84dd
6f60716
cbd8d8e
310f473
4913116
edef04f
9e19be7
563ec6d
824b5d9
3dd1b75
e93d59a
53b6d1e
8cf2257
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import argparse | ||
| import asyncio | ||
| import importlib | ||
| import inspect | ||
|
|
@@ -7,7 +8,7 @@ | |
|
|
||
| import fastapi | ||
| import uvicorn | ||
| from fastapi import Request | ||
| from fastapi import APIRouter, Request | ||
| from fastapi.exceptions import RequestValidationError | ||
| from fastapi.middleware.cors import CORSMiddleware | ||
| from fastapi.responses import JSONResponse, Response, StreamingResponse | ||
|
|
@@ -24,8 +25,11 @@ | |
| from vllm.logger import init_logger | ||
| from vllm.usage.usage_lib import UsageContext | ||
|
|
||
|
|
||
| TIMEOUT_KEEP_ALIVE = 5 # seconds | ||
|
|
||
| engine: AsyncLLMEngine = None | ||
| engine_args: AsyncEngineArgs = None | ||
| openai_serving_chat: OpenAIServingChat = None | ||
| openai_serving_completion: OpenAIServingCompletion = None | ||
| logger = init_logger(__name__) | ||
|
|
@@ -45,45 +49,33 @@ async def _force_log(): | |
| yield | ||
|
|
||
|
|
||
| app = fastapi.FastAPI(lifespan=lifespan) | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = make_arg_parser() | ||
| return parser.parse_args() | ||
|
|
||
| router = APIRouter() | ||
|
|
||
| # Add prometheus asgi middleware to route /metrics requests | ||
| metrics_app = make_asgi_app() | ||
| app.mount("/metrics", metrics_app) | ||
|
|
||
| router.mount("/metrics", metrics_app) | ||
|
|
||
| @app.exception_handler(RequestValidationError) | ||
| async def validation_exception_handler(_, exc): | ||
| err = openai_serving_chat.create_error_response(message=str(exc)) | ||
| return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) | ||
|
|
||
|
|
||
| @app.get("/health") | ||
| @router.get("/health") | ||
| async def health() -> Response: | ||
| """Health check.""" | ||
| await openai_serving_chat.engine.check_health() | ||
| return Response(status_code=200) | ||
|
|
||
|
|
||
| @app.get("/v1/models") | ||
| @router.get("/v1/models") | ||
| async def show_available_models(): | ||
| models = await openai_serving_chat.show_available_models() | ||
| return JSONResponse(content=models.model_dump()) | ||
|
|
||
|
|
||
| @app.get("/version") | ||
| @router.get("/version") | ||
| async def show_version(): | ||
| ver = {"version": vllm.__version__} | ||
| return JSONResponse(content=ver) | ||
|
|
||
|
|
||
| @app.post("/v1/chat/completions") | ||
| @router.post("/v1/chat/completions") | ||
| async def create_chat_completion(request: ChatCompletionRequest, | ||
| raw_request: Request): | ||
| generator = await openai_serving_chat.create_chat_completion( | ||
|
|
@@ -98,7 +90,7 @@ async def create_chat_completion(request: ChatCompletionRequest, | |
| return JSONResponse(content=generator.model_dump()) | ||
|
|
||
|
|
||
| @app.post("/v1/completions") | ||
| @router.post("/v1/completions") | ||
| async def create_completion(request: CompletionRequest, raw_request: Request): | ||
| generator = await openai_serving_completion.create_completion( | ||
| request, raw_request) | ||
|
|
@@ -112,8 +104,10 @@ async def create_completion(request: CompletionRequest, raw_request: Request): | |
| return JSONResponse(content=generator.model_dump()) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_args() | ||
| def build_app(args): | ||
| app = fastapi.FastAPI(lifespan=lifespan) | ||
| app.include_router(router) | ||
| app.root_path = args.root_path | ||
|
|
||
| app.add_middleware( | ||
| CORSMiddleware, | ||
|
|
@@ -123,6 +117,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request): | |
| allow_headers=args.allowed_headers, | ||
| ) | ||
|
|
||
| @app.exception_handler(RequestValidationError) | ||
| async def validation_exception_handler(_, exc): | ||
| err = openai_serving_chat.create_error_response(message=str(exc)) | ||
| return JSONResponse(err.model_dump(), | ||
| status_code=HTTPStatus.BAD_REQUEST) | ||
|
|
||
| if token := os.environ.get("VLLM_API_KEY") or args.api_key: | ||
|
|
||
| @app.middleware("http") | ||
|
|
@@ -146,13 +146,21 @@ async def authentication(request: Request, call_next): | |
| raise ValueError(f"Invalid middleware {middleware}. " | ||
| f"Must be a function or a class.") | ||
|
|
||
| return app | ||
|
|
||
|
|
||
| def run_server(args): | ||
| app = build_app(args) | ||
|
|
||
| logger.info(f"vLLM API server version {vllm.__version__}") | ||
| logger.info(f"args: {args}") | ||
|
|
||
| if args.served_model_name is not None: | ||
| served_model_names = args.served_model_name | ||
| else: | ||
| served_model_names = [args.model] | ||
| served_model_names = [args.model_tag] | ||
|
|
||
| global engine_args, engine, openai_serving_chat, openai_serving_completion | ||
| engine_args = AsyncEngineArgs.from_cli_args(args) | ||
| engine = AsyncLLMEngine.from_engine_args( | ||
| engine_args, usage_context=UsageContext.OPENAI_API_SERVER) | ||
|
|
@@ -163,7 +171,6 @@ async def authentication(request: Request, call_next): | |
| openai_serving_completion = OpenAIServingCompletion( | ||
| engine, served_model_names, args.lora_modules) | ||
|
|
||
| app.root_path = args.root_path | ||
| uvicorn.run(app, | ||
| host=args.host, | ||
| port=args.port, | ||
|
|
@@ -173,3 +180,13 @@ async def authentication(request: Request, call_next): | |
| ssl_certfile=args.ssl_certfile, | ||
| ssl_ca_certs=args.ssl_ca_certs, | ||
| ssl_cert_reqs=args.ssl_cert_reqs) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| # NOTE(simon): | ||
| # This section should be in sync with vllm/scripts.py for CLI entrypoints. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this note true? They seem to be different? (also in this case, should we have a common main method to share?)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In sync in their usage of |
||
| parser = argparse.ArgumentParser( | ||
| description="vLLM OpenAI-Compatible RESTful API server.") | ||
| parser = make_arg_parser(parser) | ||
| args = parser.parse_args() | ||
| run_server(args) | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to make
engineas an optional arg to this function?This can help external applications reuse the llm engine and attach other API interfaces (like grpc) to the same llm engine. To be used with the other suggestion of changing line 204 to:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, this would be useful.