-
Notifications
You must be signed in to change notification settings - Fork 5.1k
OAI Server Skeleton & Core Utility Endpoints #7179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a4fe5a1
e655e32
84dd577
23aa20b
ca6c963
6eb3425
affff1f
2dfeacf
4dedefe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,363 @@ | ||
| # Copyright 2023-2024 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """ | ||
| SGLang OpenAI-Compatible API Server. | ||
|
|
||
| This file implements OpenAI-compatible HTTP APIs for the inference engine via FastAPI. | ||
| """ | ||
|
|
||
| import argparse | ||
| import asyncio | ||
| import logging | ||
| import multiprocessing | ||
| import os | ||
| import threading | ||
| import time | ||
| from contextlib import asynccontextmanager | ||
| from typing import Callable, Dict, Optional | ||
|
|
||
| import numpy as np | ||
| import requests | ||
| import uvicorn | ||
| import uvloop | ||
| from fastapi import FastAPI, Request | ||
| from fastapi.middleware.cors import CORSMiddleware | ||
| from fastapi.responses import Response | ||
|
|
||
| from sglang.srt.disaggregation.utils import ( | ||
| FakeBootstrapHost, | ||
| register_disaggregation_server, | ||
| ) | ||
| from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses | ||
| from sglang.srt.managers.tokenizer_manager import TokenizerManager | ||
| from sglang.srt.metrics.func_timer import enable_func_timer | ||
| from sglang.srt.openai_api.protocol import ModelCard, ModelList | ||
| from sglang.srt.server_args import ServerArgs | ||
| from sglang.srt.utils import ( | ||
| add_prometheus_middleware, | ||
| delete_directory, | ||
| get_bool_env_var, | ||
| kill_process_tree, | ||
| set_uvicorn_logging_configs, | ||
| ) | ||
| from sglang.srt.warmup import execute_warmups | ||
| from sglang.utils import get_exception_traceback | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) | ||
|
|
||
|
|
||
| # Store global states | ||
| class AppState: | ||
| engine: Optional[Engine] = None | ||
| server_args: Optional[ServerArgs] = None | ||
| tokenizer_manager: Optional[TokenizerManager] = None | ||
| scheduler_info: Optional[Dict] = None | ||
|
|
||
|
|
||
| @asynccontextmanager | ||
| async def lifespan(app: FastAPI): | ||
| app.state.server_args.enable_metrics = True # By default, we enable metrics | ||
|
|
||
| server_args = app.state.server_args | ||
|
|
||
| # Initialize engine | ||
| logger.info(f"SGLang OpenAI server (PID: {os.getpid()}) is initializing...") | ||
|
|
||
| tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) | ||
| app.state.tokenizer_manager = tokenizer_manager | ||
| app.state.scheduler_info = scheduler_info | ||
|
|
||
| if server_args.enable_metrics: | ||
| add_prometheus_middleware(app) | ||
| enable_func_timer() | ||
|
|
||
| # Initialize engine state attribute to None for now | ||
| app.state.engine = None | ||
|
|
||
| if server_args.warmups is not None: | ||
| await execute_warmups( | ||
| server_args.warmups.split(","), app.state.tokenizer_manager | ||
| ) | ||
| logger.info("Warmup ended") | ||
|
|
||
| warmup_thread = getattr(app, "warmup_thread", None) | ||
| if warmup_thread is not None: | ||
| warmup_thread.start() | ||
|
|
||
| yield | ||
|
|
||
| # Lifespan shutdown | ||
| if hasattr(app.state, "engine") and app.state.engine is not None: | ||
| logger.info("SGLang engine is shutting down.") | ||
| # Add engine cleanup logic here when implemented | ||
|
|
||
|
|
||
| # Fast API app with CORS enabled | ||
| app = FastAPI( | ||
| lifespan=lifespan, | ||
| # TODO: check where /openai.json is created or why we use this | ||
| openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json", | ||
| ) | ||
| app.add_middleware( | ||
| CORSMiddleware, | ||
| allow_origins=["*"], | ||
| allow_credentials=True, | ||
| allow_methods=["*"], | ||
| allow_headers=["*"], | ||
| ) | ||
|
|
||
|
|
||
| @app.api_route("/health", methods=["GET"]) | ||
| async def health() -> Response: | ||
| """Health check. Used for readiness and liveness probes.""" | ||
| # In the future, this could check engine health more deeply | ||
| # For now, if the server is up, it's healthy. | ||
| return Response(status_code=200) | ||
|
|
||
|
|
||
| @app.api_route("/v1/models", methods=["GET"]) | ||
| async def show_models(): | ||
| """Show available models. Currently, it returns the served model name. | ||
|
|
||
| This endpoint is compatible with the OpenAI API standard. | ||
| """ | ||
| served_model_names = [app.state.tokenizer_manager.served_model_name] | ||
| model_cards = [] | ||
| for served_model_name in served_model_names: | ||
| model_cards.append( | ||
| ModelCard( | ||
| id=served_model_name, | ||
| root=served_model_name, | ||
| max_model_len=app.state.tokenizer_manager.model_config.context_len, | ||
| ) | ||
| ) | ||
| return ModelList(data=model_cards) | ||
|
|
||
|
|
||
| @app.get("/get_model_info") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just curious if we are serving multiple models, which model info will it return?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, it seems that a single instance of SGLang only supports serving one model at a time. Your suggestion is excellent. To ensure compatibility with the OpenAI API, I will likely add support for the |
||
| async def get_model_info(): | ||
| """Get the model information.""" | ||
| result = { | ||
| "model_path": app.state.tokenizer_manager.model_path, | ||
| "tokenizer_path": app.state.tokenizer_manager.server_args.tokenizer_path, | ||
| "is_generation": app.state.tokenizer_manager.is_generation, | ||
| } | ||
| return result | ||
|
|
||
|
|
||
| @app.post("/v1/completions") | ||
| async def openai_v1_completions(raw_request: Request): | ||
| pass | ||
|
|
||
|
|
||
| @app.post("/v1/chat/completions") | ||
| async def openai_v1_chat_completions(raw_request: Request): | ||
| pass | ||
|
|
||
|
|
||
| @app.post("/v1/embeddings") | ||
| async def openai_v1_embeddings(raw_request: Request): | ||
| pass | ||
|
|
||
|
|
||
| @app.post("/v1/score") | ||
| async def v1_score_request(raw_request: Request): | ||
| """Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation.""" | ||
| pass | ||
|
|
||
|
|
||
| # Additional API endpoints will be implemented in separate serving_*.py modules | ||
| # and mounted as APIRouters in future PRs | ||
|
|
||
|
|
||
| def _wait_and_warmup( | ||
| server_args: ServerArgs, | ||
| pipe_finish_writer: Optional[multiprocessing.connection.Connection], | ||
| image_token_text: str, | ||
| launch_callback: Optional[Callable[[], None]] = None, | ||
| ): | ||
| return | ||
| # TODO: Please wait until the /generate implementation is complete, | ||
| # or confirm if modifications are needed before removing this. | ||
|
|
||
| headers = {} | ||
| url = server_args.url() | ||
| if server_args.api_key: | ||
| headers["Authorization"] = f"Bearer {server_args.api_key}" | ||
|
|
||
| # Wait until the server is launched | ||
| success = False | ||
| for _ in range(120): | ||
| time.sleep(1) | ||
| try: | ||
| res = requests.get(url + "/get_model_info", timeout=5, headers=headers) | ||
| assert res.status_code == 200, f"{res=}, {res.text=}" | ||
| success = True | ||
| break | ||
| except (AssertionError, requests.exceptions.RequestException): | ||
| last_traceback = get_exception_traceback() | ||
| pass | ||
|
|
||
| if not success: | ||
| if pipe_finish_writer is not None: | ||
| pipe_finish_writer.send(last_traceback) | ||
| logger.error(f"Initialization failed. warmup error: {last_traceback}") | ||
| kill_process_tree(os.getpid()) | ||
| return | ||
|
|
||
| model_info = res.json() | ||
|
|
||
| # Send a warmup request | ||
| request_name = "/generate" if model_info["is_generation"] else "/encode" | ||
yhyang201 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # TODO: Replace with OpenAI API | ||
| max_new_tokens = 8 if model_info["is_generation"] else 1 | ||
| json_data = { | ||
| "sampling_params": { | ||
| "temperature": 0, | ||
| "max_new_tokens": max_new_tokens, | ||
| }, | ||
| } | ||
| if server_args.skip_tokenizer_init: | ||
| json_data["input_ids"] = [[10, 11, 12] for _ in range(server_args.dp_size)] | ||
| # TODO Workaround the bug that embedding errors for list of size 1 | ||
| if server_args.dp_size == 1: | ||
| json_data["input_ids"] = json_data["input_ids"][0] | ||
| else: | ||
| json_data["text"] = ["The capital city of France is"] * server_args.dp_size | ||
| # TODO Workaround the bug that embedding errors for list of size 1 | ||
| if server_args.dp_size == 1: | ||
| json_data["text"] = json_data["text"][0] | ||
|
|
||
| # Debug dumping | ||
| if server_args.debug_tensor_dump_input_file: | ||
| json_data.pop("text", None) | ||
| json_data["input_ids"] = np.load( | ||
| server_args.debug_tensor_dump_input_file | ||
| ).tolist() | ||
| json_data["sampling_params"]["max_new_tokens"] = 0 | ||
|
|
||
| try: | ||
| if server_args.disaggregation_mode == "null": | ||
| res = requests.post( | ||
| url + request_name, | ||
| json=json_data, | ||
| headers=headers, | ||
| timeout=600, | ||
| ) | ||
| assert res.status_code == 200, f"{res}" | ||
| else: | ||
| logger.info(f"Start of prefill warmup ...") | ||
| json_data = { | ||
| "sampling_params": { | ||
| "temperature": 0.0, | ||
| "max_new_tokens": 8, | ||
| "ignore_eos": True, | ||
| }, | ||
| "bootstrap_host": [FakeBootstrapHost] * server_args.dp_size, | ||
| # This is a hack to ensure fake transfer is enabled during prefill warmup | ||
| # ensure each dp rank has a unique bootstrap_room during prefill warmup | ||
| "bootstrap_room": [ | ||
| i * (2**63 // server_args.dp_size) + (i % server_args.tp_size) | ||
| for i in range(server_args.dp_size) | ||
| ], | ||
| "input_ids": [[0, 1, 2, 3]] * server_args.dp_size, | ||
| } | ||
| res = requests.post( | ||
| url + request_name, | ||
| json=json_data, | ||
| headers=headers, | ||
| timeout=1800, # because of deep gemm precache is very long if not precache. | ||
| ) | ||
| logger.info( | ||
| f"End of prefill warmup with status {res.status_code}, resp: {res.json()}" | ||
| ) | ||
|
|
||
| except Exception: | ||
| last_traceback = get_exception_traceback() | ||
| if pipe_finish_writer is not None: | ||
| pipe_finish_writer.send(last_traceback) | ||
| logger.error(f"Initialization failed. warmup error: {last_traceback}") | ||
| kill_process_tree(os.getpid()) | ||
| return | ||
|
|
||
| # Debug print | ||
| # logger.info(f"{res.json()=}") | ||
|
|
||
| logger.info("The server is fired up and ready to roll!") | ||
| if pipe_finish_writer is not None: | ||
| pipe_finish_writer.send("ready") | ||
|
|
||
| if server_args.delete_ckpt_after_loading: | ||
| delete_directory(server_args.model_path) | ||
|
|
||
| if server_args.debug_tensor_dump_input_file: | ||
| kill_process_tree(os.getpid()) | ||
|
|
||
| if server_args.pdlb_url is not None: | ||
| register_disaggregation_server( | ||
| server_args.disaggregation_mode, | ||
| server_args.port, | ||
| server_args.disaggregation_bootstrap_port, | ||
| server_args.pdlb_url, | ||
| ) | ||
|
|
||
| if launch_callback is not None: | ||
| launch_callback() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="SGLang OpenAI-Compatible API Server") | ||
| # Add arguments from ServerArgs. This allows reuse of existing CLI definitions. | ||
| ServerArgs.add_cli_args(parser) | ||
| # Potentially add server-specific arguments here in the future if needed | ||
|
|
||
| args = parser.parse_args() | ||
| server_args = ServerArgs.from_cli_args(args) | ||
|
|
||
| # Store server_args in app.state for access in lifespan and endpoints | ||
| app.state.server_args = server_args | ||
|
|
||
| # Configure logging | ||
| logging.basicConfig( | ||
| level=server_args.log_level.upper(), | ||
| format="%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s", | ||
| ) | ||
|
|
||
| # Send a warmup request - we will create the thread launch it | ||
| # in the lifespan after all other warmups have fired. | ||
| warmup_thread = threading.Thread( | ||
| target=_wait_and_warmup, | ||
| args=( | ||
| server_args, | ||
| None, | ||
| None, # Never used | ||
| None, | ||
| ), | ||
| ) | ||
| app.warmup_thread = warmup_thread | ||
|
|
||
| try: | ||
| # Start the server | ||
| set_uvicorn_logging_configs() | ||
| uvicorn.run( | ||
| app, | ||
| host=server_args.host, | ||
| port=server_args.port, | ||
| log_level=server_args.log_level.lower(), | ||
| timeout_keep_alive=60, # Increased keep-alive for potentially long requests | ||
| loop="uvloop", # Use uvloop for better performance if available | ||
| ) | ||
| finally: | ||
| warmup_thread.join() | ||
Uh oh!
There was an error while loading. Please reload this page.