Skip to content
1 change: 1 addition & 0 deletions vllm/entrypoints/fast_sync_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
self.result_queue = result_queue
self.finish = False
self.need_restart = False
self.llm_engine: LLMEngine

def _add_request(
self,
Expand Down
225 changes: 214 additions & 11 deletions vllm/entrypoints/sync_openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,33 @@
import threading
import time
from contextlib import asynccontextmanager
from typing import Dict
from http import HTTPStatus
from typing import Dict, Iterable, List, Union, cast

import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import Mount
from openai.types.chat import ChatCompletionContentPartTextParam
from prometheus_client import make_asgi_app

import vllm
from vllm import FastSyncLLM as LLM
from vllm import envs
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.sync_openai.protocol import (CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
UsageInfo)
from vllm.entrypoints.openai.protocol import (
ChatCompletionContentPartParam, ChatCompletionMessageParam,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, CompletionRequest,
CompletionResponse, CompletionResponseChoice, DeltaMessage, ErrorResponse,
ModelCard, ModelList, ModelPermission, UsageInfo)
from vllm.entrypoints.openai.serving_chat import (ChatMessageParseResult,
ConversationMessage)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid

mp = multiprocessing.get_context(envs.VLLM_WORKER_MULTIPROC_METHOD)
Expand All @@ -41,14 +50,19 @@ class BackgroundRunner:

def __init__(self):
self.value = 0
self.engine_args = None
self.engine_args: EngineArgs
self.input_queue: multiprocessing.Queue = mp.Queue()
self.result_queue: multiprocessing.Queue = mp.Queue()
self.result_queues: Dict[str, asyncio.Queue] = {}
self.t: threading.Thread = threading.Thread(target=self.thread_proc)
self.loop = None
self.llm: LLM
self.proc: multiprocessing.Process
self.tokenizer = None
self.response_role: str

def set_response_role(self, role):
self.response_role = role

def set_engine_args(self, engine_args):
self.engine_args = engine_args
Expand All @@ -75,6 +89,7 @@ async def run_main(self):
input_queue=self.input_queue,
result_queue=self.result_queue,
)

self.loop = asyncio.get_event_loop()
self.proc = mp.Process(target=self.llm.run_engine)
self.t.start()
Expand Down Expand Up @@ -103,6 +118,15 @@ async def lifespan(app: FastAPI):
asyncio.create_task(runner.run_main())
await runner.result_queues["Ready"].get()
del runner.result_queues["Ready"]

tokenizer = get_tokenizer(
engine_args.tokenizer,
tokenizer_mode=engine_args.tokenizer_mode,
tokenizer_revision=engine_args.tokenizer_revision,
trust_remote_code=engine_args.trust_remote_code,
truncation_side="left")
runner.tokenizer = tokenizer

yield


Expand All @@ -115,6 +139,33 @@ async def lifespan(app: FastAPI):
app.routes.append(route)


@app.get("/v1/models")
async def show_available_models():
models = [
ModelCard(id=runner.engine_args.model,
root=runner.engine_args.model,
permission=[ModelPermission()])
]
model_list = ModelList(data=models)
return JSONResponse(content=model_list.model_dump())


@app.get("/version")
async def show_version():
ver = {"version": vllm.__version__}
return JSONResponse(content=ver)


async def _check_model(request: Union[CompletionRequest,
ChatCompletionRequest]):
model = request.model
if model != runner.engine_args.model:
return ErrorResponse(message=f"The model {model} does not exist.",
type="NotFoundError",
code=HTTPStatus.NOT_FOUND)
return None


async def completion_generator(model, result_queue, choices, created_time,
ids):
completed = 0
Expand All @@ -139,8 +190,9 @@ async def completion_generator(model, result_queue, choices, created_time,
res.usage = UsageInfo()
res.usage.completion_tokens = stats.get("tokens", 0)
res.usage.prompt_tokens = stats.get("prompt", 0)
res.usage.total_tokens = (res.usage.completion_tokens +
res.usage.prompt_tokens)
res.usage.total_tokens = (
res.usage.completion_tokens + # type: ignore
res.usage.prompt_tokens)
res.choices[0].finish_reason = stats["finish_reason"]
res.choices[0].stop_reason = stats["stop_reason"]
completed += 1
Expand All @@ -158,6 +210,10 @@ async def completion_generator(model, result_queue, choices, created_time,

@app.post("/v1/completions")
async def completions(request: CompletionRequest, raw_request: Request):
error_check_ret = await _check_model(request)
if error_check_ret is not None:
return JSONResponse(content=error_check_ret.model_dump(),
status_code=error_check_ret.code)
sampling_params = request.to_sampling_params()
ids, result_queue = await runner.add_request(request.prompt,
sampling_params)
Expand All @@ -179,8 +235,7 @@ async def completions(request: CompletionRequest, raw_request: Request):
created_time = int(time.time())
return StreamingResponse(content=completion_generator(
request.model, result_queue, choices, created_time, ids),
media_type="text/event-stream",
headers={"Access-Control-Allow-Origin": "*"})
media_type="text/event-stream")
while True:
request_id, token, stats = await result_queue.get()
choice_idx = choices[request_id]
Expand All @@ -200,6 +255,153 @@ async def completions(request: CompletionRequest, raw_request: Request):
return res


def parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
) -> ChatMessageParseResult:
texts: List[str] = []

for _, part in enumerate(parts):
part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]

texts.append(text)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")

messages = [ConversationMessage(role=role, content="\n".join(texts))]

return ChatMessageParseResult(messages=messages)


def parse_chat_message_content(
message: ChatCompletionMessageParam, ) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")

if content is None:
return ChatMessageParseResult(messages=[])
if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages)

return parse_chat_message_content_parts(role, content)


async def chat_completion_generator(model, result_queue, created_time, id):
try:
first_token = ChatCompletionStreamResponse(
id=id,
created=created_time,
model=model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role=runner.response_role),
logprobs=None,
finish_reason=None,
stop_reason=None)
],
usage=None)
response_json = first_token.model_dump_json(exclude_unset=True)
yield f"data: {response_json}\n\n"

while True:
request_id, token, stats = await result_queue.get()
assert request_id == id

res = ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=token),
logprobs=None,
finish_reason=None,
stop_reason=None)
],
usage=None)
if stats is not None:
res.usage = UsageInfo()
res.usage.completion_tokens = stats.get("tokens", 0)
res.usage.prompt_tokens = stats.get("prompt", 0)
res.usage.total_tokens = (
res.usage.completion_tokens + # type: ignore
res.usage.prompt_tokens)
res.choices[0].finish_reason = stats["finish_reason"]
res.choices[0].stop_reason = stats["stop_reason"]
response_json = res.model_dump_json(exclude_unset=True)
yield f"data: {response_json}\n\n"
if stats is not None:
runner.remove_result_queues([id])
break

yield "data: [DONE]\n\n"
except Exception as e:
logger.error("Error in completion_generator: %s", e)
return


@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest,
raw_request: Request):
error_check_ret = await _check_model(request)
if error_check_ret is not None:
return JSONResponse(content=error_check_ret.model_dump(),
status_code=error_check_ret.code)
sampling_params = request.to_sampling_params()
conversation: List[ConversationMessage] = []

res = ChatCompletionResponse(model=request.model,
choices=[],
usage=UsageInfo(prompt_tokens=0,
total_tokens=0,
completion_tokens=0))

for msg in request.messages:
parsed_msg = parse_chat_message_content(msg)
conversation.extend(parsed_msg.messages)

prompt = runner.tokenizer.apply_chat_template( # type: ignore
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
)

ids, result_queue = await runner.add_request(prompt, sampling_params)
assert len(ids) == 1

if request.stream:
created_time = int(time.time())
return StreamingResponse(content=chat_completion_generator(
request.model, result_queue, created_time, ids[0]),
media_type="text/event-stream")

res.choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role=runner.response_role, content=""),
finish_reason=None,
stop_reason=None))

while True:
_, token, stats = await result_queue.get()
res.choices[0].message.content += str(token)
if stats is not None:
res.usage.completion_tokens += stats["tokens"] # type: ignore
res.usage.prompt_tokens += stats["prompt"] # type: ignore
res.choices[0].finish_reason = stats["finish_reason"]
res.choices[0].stop_reason = stats["stop_reason"]
runner.remove_result_queues(ids)
break
res.usage.total_tokens = ( # type: ignore
res.usage.completion_tokens + res.usage.prompt_tokens) # type: ignore
return res


def parse_args():
parser = make_arg_parser()
return parser.parse_args()
Expand All @@ -209,6 +411,7 @@ def parse_args():
args = parse_args()
engine_args = EngineArgs.from_cli_args(args)
runner.set_engine_args(engine_args)
runner.set_response_role(args.response_role)

app.add_middleware(
CORSMiddleware,
Expand Down
Loading