diff --git a/docs/reference/vllm.md b/docs/reference/vllm.md index 063699a6d..02711c7c0 100644 --- a/docs/reference/vllm.md +++ b/docs/reference/vllm.md @@ -45,6 +45,52 @@ curl http://127.0.0.1:8000/generate \ }' ``` +To start the vLLM OpenAI server use: + +```bash +python3 -m outlines.serve.openai_server +``` + +Example of openai_server running on `http://localhost:8000` using local GPTQ quantized model: + +```bash + python3 -m outlines.serve.openai_server --port 8000 --model ./Mistral-7B-Instruct-v0.2-GPTQ --dtype float16 +``` + +To use `v1/chat/completions` (or `v1/completions`) to generate an int for exemple use `{"type": "integer"}`: + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "./Mistral-7B-Instruct-v0.2-GPTQ", + "messages": [{"role": "user","content": "Hello!"}], + "schema": {"type": "integer"}, + "max_tokens": 5 + }' +``` +```bash +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "./Mistral-7B-Instruct-v0.2-GPTQ", + "prompt": "Hello!", + "schema": {"type": "integer"}, + "max_tokens": 5 + }' +``` +To use with Langchain or with other modules that use openai.OpenAI add the argument `extra_body`: + +```bash +from langchain_openai import ChatOpenAI + +llm = ChatOpenAI( openai_api_key = "EMPTY", + openai_api_base = "http://localhost:8000/v1", + model_name = "./Mistral-7B-Instruct-v0.2-GPTQ", + extra_body = { "schema": { "type": "integer" } }, + ) +``` + Instead of `curl`, you can also use the [requests][requests]{:target="_blank"} library from another python program. Please consult the [vLLM documentation][vllm]{:target="_blank"} for details on additional request parameters. You can also [read the code](https://github.com/outlines-dev/outlines/blob/main/outlines/serve/serve.py) in case you need to customize the solution to your needs. diff --git a/outlines/serve/openai_server.py b/outlines/serve/openai_server.py new file mode 100644 index 000000000..c7fdad988 --- /dev/null +++ b/outlines/serve/openai_server.py @@ -0,0 +1,836 @@ +# Copyright 2023 the vLLM developers + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +import codecs +import json +import time +from http import HTTPStatus +from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union + +import fastapi +import uvicorn +import vllm.model_executor.layers.sampler as sampler +from aioprometheus import MetricsMiddleware +from aioprometheus.asgi.starlette import metrics +from fastapi import Request +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response, StreamingResponse +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.metrics import add_global_metrics_labels +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + DeltaMessage, + ErrorResponse, + LogProbs, + ModelCard, + ModelList, + ModelPermission, + UsageInfo, +) +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils import random_uuid + +from outlines.serve.vllm import ( + JSONLogitsProcessor, + RegexLogitsProcessor, + _patched_apply_logits_processors, +) + +TIMEOUT_KEEP_ALIVE = 5 # seconds + +sampler._apply_logits_processors = _patched_apply_logits_processors + +logger = init_logger(__name__) +served_model = None +app = fastapi.FastAPI() +engine = None +response_role = None + + +def parse_args(): + parser = argparse.ArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server." + ) + parser.add_argument("--host", type=str, default=None, help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" + ) + parser.add_argument( + "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" + ) + parser.add_argument( + "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" + ) + parser.add_argument( + "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" + ) + parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. If not " + "specified, the model name will be the same as " + "the huggingface name.", + ) + parser.add_argument( + "--chat-template", + type=str, + default=None, + help="The file path to the chat template, " + "or the template in single-line form " + "for the specified model", + ) + parser.add_argument( + "--response-role", + type=str, + default="assistant", + help="The role name to return if " "`request.add_generation_prompt=true`.", + ) + parser.add_argument( + "--ssl-keyfile", + type=str, + default=None, + help="The file path to the SSL key file", + ) + parser.add_argument( + "--ssl-certfile", + type=str, + default=None, + help="The file path to the SSL cert file", + ) + + parser = AsyncEngineArgs.add_cli_args(parser) + return parser.parse_args() + + +app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics +app.add_route("/metrics", metrics) # Exposes HTTP metrics + + +def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse: + return JSONResponse( + ErrorResponse(message=message, type="invalid_request_error").dict(), + status_code=status_code.value, + ) + + +def load_chat_template(args, tokenizer): + if args.chat_template is not None: + try: + with open(args.chat_template, "r") as f: + chat_template = f.read() + except OSError: + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + chat_template = codecs.decode(args.chat_template, "unicode_escape") + + tokenizer.chat_template = chat_template + logger.info(f"Using supplied chat template:\n{tokenizer.chat_template}") + elif tokenizer.chat_template is not None: + logger.info(f"Using default chat template:\n{tokenizer.chat_template}") + else: + logger.warning("No chat template provided. Chat API will not work.") + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(_, exc): + return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) + + +async def check_model(request) -> Optional[JSONResponse]: + if request.model == served_model: + return + ret = create_error_response( + HTTPStatus.NOT_FOUND, + f"The model `{request.model}` does not exist.", + ) + return ret + + +async def check_length( + request: Union[ChatCompletionRequest, CompletionRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None, +) -> Tuple[List[int], Optional[JSONResponse]]: + assert not (prompt is None and prompt_ids is None) and not ( + prompt is not None and prompt_ids is not None + ), "Either prompt or prompt_ids should be provided." + input_ids = prompt_ids if prompt_ids is not None else tokenizer(prompt).input_ids + token_num = len(input_ids) + + if request.max_tokens is None: + request.max_tokens = max_model_len - token_num + if token_num + request.max_tokens > max_model_len: + return input_ids, create_error_response( + HTTPStatus.BAD_REQUEST, + f"This model's maximum context length is {max_model_len} tokens. " + f"However, you requested {request.max_tokens + token_num} tokens " + f"({token_num} in the messages, " + f"{request.max_tokens} in the completion). " + f"Please reduce the length of the messages or completion.", + ) + else: + return input_ids, None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.get("/v1/models") +async def show_available_models(): + """Show available models. Right now we only have one model.""" + model_cards = [ + ModelCard(id=served_model, root=served_model, permission=[ModelPermission()]) + ] + return ModelList(data=model_cards) + + +def create_logprobs( + token_ids: List[int], + top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, + num_output_top_logprobs: Optional[int] = None, + initial_text_offset: int = 0, +) -> LogProbs: + """Create OpenAI-style logprobs.""" + logprobs = LogProbs() + last_token_len = 0 + if num_output_top_logprobs: + logprobs.top_logprobs = [] + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is not None: + token_logprob = step_top_logprobs[token_id] + else: + token_logprob = None + token = tokenizer.convert_ids_to_tokens(token_id) + logprobs.tokens.append(token) + logprobs.token_logprobs.append(token_logprob) + if len(logprobs.text_offset) == 0: + logprobs.text_offset.append(initial_text_offset) + else: + logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) + last_token_len = len(token) + + if num_output_top_logprobs: + logprobs.top_logprobs.append( + { + tokenizer.convert_ids_to_tokens(i): p + for i, p in step_top_logprobs.items() + } + if step_top_logprobs + else None + ) + return logprobs + + +@app.post("/v1/chat/completions") +async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI ChatCompletion API. + + NOTE: Currently we do not support the following features: + - function_call (Users should implement this by themselves) + - logit_bias (to be supported by vLLM engine) + """ + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + if request.logit_bias is not None and len(request.logit_bias) > 0: + # TODO: support logit_bias in vLLM engine. + return create_error_response( + HTTPStatus.BAD_REQUEST, "logit_bias is not currently supported" + ) + + try: + prompt = tokenizer.apply_chat_template( + conversation=request.messages, + tokenize=False, + add_generation_prompt=request.add_generation_prompt, + ) + except Exception as e: + logger.error(f"Error in applying chat template from request: {str(e)}") + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + + token_ids, error_check_ret = await check_length(request, prompt=prompt) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + created_time = int(time.monotonic()) + chunk_object_type = "chat.completion.chunk" + + schema = request.json_schema + regex = request.regex_string + if schema is not None: + logits_processors = [JSONLogitsProcessor(schema, engine.engine)] + elif regex is not None: + logits_processors = [RegexLogitsProcessor(regex, engine.engine)] + else: + logits_processors = [] + + try: + spaces_between_special_tokens = request.spaces_between_special_tokens + sampling_params = SamplingParams( + n=request.n, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + repetition_penalty=request.repetition_penalty, + temperature=request.temperature, + top_p=request.top_p, + min_p=request.min_p, + stop=request.stop, + stop_token_ids=request.stop_token_ids, + max_tokens=request.max_tokens, + best_of=request.best_of, + top_k=request.top_k, + ignore_eos=request.ignore_eos, + use_beam_search=request.use_beam_search, + skip_special_tokens=request.skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + logits_processors=logits_processors, + ) + except ValueError as e: + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + + result_generator = engine.generate(prompt, sampling_params, request_id, token_ids) + + def get_role() -> str: + if request.add_generation_prompt: + return response_role + else: + return request.messages[-1]["role"] + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + # Send first response for each request.n (index) with the role + role = get_role() + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, delta=DeltaMessage(role=role), finish_reason=None + ) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name, + ) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the last message + if request.echo: + last_msg_content = "" + if ( + request.messages + and isinstance(request.messages, list) + and request.messages[-1].get("content") + and request.messages[-1].get("role") == role + ): + last_msg_content = request.messages[-1]["content"] + if last_msg_content: + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=last_msg_content), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name, + ) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + + # Send response for each token for each request.n (index) + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + finish_reason_sent = [False] * request.n + async for res in result_generator: + res: RequestOutput + for output in res.outputs: + i = output.index + + if finish_reason_sent[i]: + continue + + if output.finish_reason is None: + # Send token-by-token response for each request.n + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + else: + # Send the finish response for each request.n only once + prompt_tokens = len(res.prompt_token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=previous_num_tokens[i], + total_tokens=prompt_tokens + previous_num_tokens[i], + ) + choice_data = ChatCompletionResponseStreamChoice( + index=i, delta=[], finish_reason=output.finish_reason + ) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name, + ) + if final_usage is not None: + chunk.usage = final_usage + data = chunk.json( + exclude_unset=True, exclude_none=True, ensure_ascii=False + ) + yield f"data: {data}\n\n" + finish_reason_sent[i] = True + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def completion_full_generator(): + final_res: RequestOutput = None + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + return create_error_response( + HTTPStatus.BAD_REQUEST, "Client disconnected" + ) + final_res = res + assert final_res is not None + + choices = [] + role = get_role() + for output in final_res.outputs: + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=ChatMessage(role=role, content=output.text), + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + if request.echo: + last_msg_content = "" + if ( + request.messages + and isinstance(request.messages, list) + and request.messages[-1].get("content") + and request.messages[-1].get("role") == role + ): + last_msg_content = request.messages[-1]["content"] + + for choice in choices: + full_message = last_msg_content + choice.message.content + choice.message.content = full_message + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs + ) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + return response + + # Streaming response + return StreamingResponse( + completion_stream_generator(), media_type="text/event-stream" + ) + else: + return await completion_full_generator() + + +@app.post("/v1/completions") +async def create_completion(request: CompletionRequest, raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/completions/create + for the API specification. This API mimics the OpenAI Completion API. + + NOTE: Currently we do not support the following features: + - suffix (the language models we currently support do not support + suffix) + - logit_bias (to be supported by vLLM engine) + """ + + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + # OpenAI API supports echoing the prompt when max_tokens is 0. + echo_without_generation = request.echo and request.max_tokens == 0 + + if request.suffix is not None: + # The language models we currently support do not support suffix. + return create_error_response( + HTTPStatus.BAD_REQUEST, "suffix is not currently supported" + ) + + if request.logit_bias is not None and len(request.logit_bias) > 0: + # TODO: support logit_bias in vLLM engine. + return create_error_response( + HTTPStatus.BAD_REQUEST, "logit_bias is not currently supported" + ) + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + + use_token_ids = False + if isinstance(request.prompt, list): + if len(request.prompt) == 0: + return create_error_response( + HTTPStatus.BAD_REQUEST, "please provide at least one prompt" + ) + first_element = request.prompt[0] + if isinstance(first_element, int): + use_token_ids = True + prompt = request.prompt + elif isinstance(first_element, (str, list)): + # TODO: handles multiple prompt case in list[list[int]] + if len(request.prompt) > 1: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "multiple prompts in a batch is not currently supported", + ) + use_token_ids = not isinstance(first_element, str) + prompt = request.prompt[0] + else: + prompt = request.prompt + + if use_token_ids: + _, error_check_ret = await check_length(request, prompt_ids=prompt) + else: + token_ids, error_check_ret = await check_length(request, prompt=prompt) + if error_check_ret is not None: + return error_check_ret + + created_time = int(time.monotonic()) + + schema = request.json_schema + regex = request.regex_string + if schema is not None: + logits_processors = [JSONLogitsProcessor(schema, engine.engine)] + elif regex is not None: + logits_processors = [RegexLogitsProcessor(regex, engine.engine)] + else: + logits_processors = [] + + try: + spaces_between_special_tokens = request.spaces_between_special_tokens + sampling_params = SamplingParams( + n=request.n, + best_of=request.best_of, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + repetition_penalty=request.repetition_penalty, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + min_p=request.min_p, + stop=request.stop, + stop_token_ids=request.stop_token_ids, + ignore_eos=request.ignore_eos, + max_tokens=request.max_tokens if not echo_without_generation else 1, + logprobs=request.logprobs, + use_beam_search=request.use_beam_search, + prompt_logprobs=request.logprobs if request.echo else None, + skip_special_tokens=request.skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + logits_processors=logits_processors, + ) + except ValueError as e: + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + + if use_token_ids: + result_generator = engine.generate( + None, sampling_params, request_id, prompt_token_ids=prompt + ) + else: + result_generator = engine.generate( + prompt, sampling_params, request_id, token_ids + ) + + # Similar to the OpenAI API, when n != best_of, we do not stream the + # results. In addition, we do not stream the results when use beam search. + stream = ( + request.stream + and (request.best_of is None or request.n == request.best_of) + and not request.use_beam_search + ) + def create_stream_response_json( + index: int, + text: str, + logprobs: Optional[LogProbs] = None, + finish_reason: Optional[str] = None, + usage: Optional[UsageInfo] = None, + ) -> str: + choice_data = CompletionResponseStreamChoice( + index=index, + text=text, + logprobs=logprobs, + finish_reason=finish_reason, + ) + response = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + if usage is not None: + response.usage = usage + response_json = response.json(exclude_unset=True, ensure_ascii=False) + + return response_json + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + has_echoed = [False] * request.n + async for res in result_generator: + res: RequestOutput + for output in res.outputs: + i = output.index + delta_text = output.text[len(previous_texts[i]) :] + token_ids = output.token_ids[previous_num_tokens[i] :] + if request.logprobs is not None: + top_logprobs = output.logprobs[previous_num_tokens[i] :] + else: + top_logprobs = None + offsets = len(previous_texts[i]) + if request.echo and not has_echoed[i]: + if not echo_without_generation: + delta_text = res.prompt + delta_text + token_ids = res.prompt_token_ids + token_ids + if top_logprobs: + top_logprobs = res.prompt_logprobs + top_logprobs + else: # only just return the prompt + delta_text = res.prompt + token_ids = res.prompt_token_ids + if top_logprobs: + top_logprobs = res.prompt_logprobs + has_echoed[i] = True + if request.logprobs is not None: + logprobs = create_logprobs( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=offsets, + ) + else: + logprobs = None + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + finish_reason = output.finish_reason + response_json = create_stream_response_json( + index=i, + text=delta_text, + logprobs=logprobs, + finish_reason=finish_reason, + ) + yield f"data: {response_json}\n\n" + if output.finish_reason is not None: + logprobs = LogProbs() if request.logprobs is not None else None + prompt_tokens = len(res.prompt_token_ids) + completion_tokens = len(output.token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + response_json = create_stream_response_json( + index=i, + text="", + logprobs=logprobs, + finish_reason=output.finish_reason, + usage=final_usage, + ) + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + # Streaming response + if stream: + return StreamingResponse( + completion_stream_generator(), media_type="text/event-stream" + ) + + # Non-streaming response + final_res: RequestOutput = None + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + return create_error_response(HTTPStatus.BAD_REQUEST, "Client disconnected") + final_res = res + assert final_res is not None + choices = [] + prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs + prompt_text = final_res.prompt + for output in final_res.outputs: + if request.logprobs is not None: + if not echo_without_generation: + token_ids = output.token_ids + top_logprobs = output.logprobs + if request.echo: + token_ids = prompt_token_ids + token_ids + top_logprobs = prompt_logprobs + top_logprobs + else: + token_ids = prompt_token_ids + top_logprobs = prompt_logprobs + logprobs = create_logprobs( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None + if not echo_without_generation: + output_text = output.text + if request.echo: + output_text = prompt_text + output_text + else: + output_text = prompt_text + choice_data = CompletionResponseChoice( + index=output.index, + text=output_text, + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + if request.stream: + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + response_json = response.json(ensure_ascii=False) + + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse( + fake_stream_generator(), media_type="text/event-stream" + ) + + return response + + +if __name__ == "__main__": + args = parse_args() + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + logger.info(f"args: {args}") + + if args.served_model_name is not None: + served_model = args.served_model_name + else: + served_model = args.model + + response_role = args.response_role + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) + engine_model_config = asyncio.run(engine.get_model_config()) + max_model_len = engine_model_config.max_model_len + + # A separate tokenizer to map token IDs to strings. + tokenizer = get_tokenizer( + engine_model_config.tokenizer, + tokenizer_mode=engine_model_config.tokenizer_mode, + trust_remote_code=engine_model_config.trust_remote_code, + ) + load_chat_template(args, tokenizer) + + # Register labels for metrics + add_global_metrics_labels(model_name=engine_args.model) + + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + )