Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import multiprocessing as mp
import os
import pickle
import secrets
import sys
import threading
Expand Down Expand Up @@ -317,25 +318,36 @@ async def handle_request(self, request, request_type) -> Response:
response, status = self.server.response_buffer.pop(uid)

if status == LitAPIStatus.ERROR:
await self._handle_error_response(response)
self._handle_error_response(response)

# Trigger callback
self.server._callback_runner.trigger_event(EventTypes.ON_RESPONSE.value, litserver=self.server)

return response

except HTTPException as e:
raise e
raise e from None

except Exception as e:
logger.exception(f"Error handling request: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
logger.error(f"Unhandled exception: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error") from e

@staticmethod
def _handle_error_response(response):
"""Raise HTTPException as is and rest as 500 after logging the error."""
try:
if isinstance(response, bytes):
response = pickle.loads(response)
raise HTTPException(status_code=response.status_code, detail=response.detail)
except Exception as e:
logger.debug(f"couldn't unpickle error response {e}")

async def _handle_error_response(self, response):
logger.error("Error in request: %s", response)
if isinstance(response, HTTPException):
raise response

if isinstance(response, Exception):
logger.error(f"Error while handling request: {response}")

raise HTTPException(status_code=500, detail="Internal server error")


Expand Down Expand Up @@ -800,7 +812,7 @@ def launch_inference_worker(self, lit_api: LitAPI):
self.workers_setup_status,
self._callback_runner,
),
name=f"lit-inference-{endpoint}_{worker_id}",
name="inference-worker",
)
process.start()
process_list.append(process)
Expand Down Expand Up @@ -1363,9 +1375,9 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_w
server = uvicorn.Server(config=uvicorn_config)
if uvicorn_worker_type == "process":
ctx = mp.get_context("fork")
w = ctx.Process(target=server.run, args=(sockets,), name=f"lit-uvicorn-{response_queue_id}")
w = ctx.Process(target=server.run, args=(sockets,), name=f"LitServer-{response_queue_id}")
elif uvicorn_worker_type == "thread":
w = threading.Thread(target=server.run, args=(sockets,), name=f"lit-uvicorn-{response_queue_id}")
w = threading.Thread(target=server.run, args=(sockets,), name=f"LitServer-{response_queue_id}")
else:
raise ValueError("Invalid value for api_server_worker_type. Must be 'process' or 'thread'")
w.start()
Expand Down
173 changes: 102 additions & 71 deletions src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,26 @@ async def encode_response(self, output):
""" # noqa: E501


def _openai_format_error(error: Exception):
if isinstance(error, HTTPException):
return "data: " + json.dumps({
"error": {
"message": error.detail,
"type": "internal",
"param": None,
"code": "internal_error",
}
})
return "data: " + json.dumps({
"error": {
"message": "Internal server error",
"type": "internal",
"param": None,
"code": "internal_error",
}
})


class OpenAISpec(LitSpec):
def __init__(
self,
Expand Down Expand Up @@ -486,82 +506,93 @@ async def chat_completion(self, request: ChatCompletionRequest, background_tasks
return await response_task

async def streaming_completion(self, request: ChatCompletionRequest, pipe_responses: List):
model = request.model
usage_info = None
async for streaming_response in azip(*pipe_responses):
choices = []
usage_infos = []
# iterate over n choices
for i, (response, status) in enumerate(streaming_response):
if status == LitAPIStatus.ERROR and isinstance(response, HTTPException):
raise response
elif status == LitAPIStatus.ERROR:
logger.error("Error in streaming response: %s", response)
raise HTTPException(status_code=500)
encoded_response = json.loads(response)
logger.debug(encoded_response)
chat_msg = ChoiceDelta(**encoded_response)
usage_infos.append(UsageInfo(**encoded_response))
choice = ChatCompletionStreamingChoice(
index=i, delta=chat_msg, system_fingerprint="", finish_reason=None
)
try:
model = request.model
usage_info = None
async for streaming_response in azip(*pipe_responses):
choices = []
usage_infos = []
# iterate over n choices
for i, (response, status) in enumerate(streaming_response):
if status == LitAPIStatus.ERROR and isinstance(response, HTTPException):
raise response
elif status == LitAPIStatus.ERROR:
logger.error("Error in streaming response: %s", response)
raise HTTPException(status_code=500)
encoded_response = json.loads(response)
logger.debug(encoded_response)
chat_msg = ChoiceDelta(**encoded_response)
usage_infos.append(UsageInfo(**encoded_response))
choice = ChatCompletionStreamingChoice(
index=i, delta=chat_msg, system_fingerprint="", finish_reason=None
)

choices.append(choice)
choices.append(choice)

# Only use the last item from encode_response
usage_info = sum(usage_infos)
chunk = ChatCompletionChunk(model=model, choices=choices, usage=None)
logger.debug(chunk)
yield f"data: {chunk.model_dump_json(by_alias=True)}\n\n"

# Only use the last item from encode_response
usage_info = sum(usage_infos)
chunk = ChatCompletionChunk(model=model, choices=choices, usage=None)
logger.debug(chunk)
yield f"data: {chunk.model_dump_json(by_alias=True)}\n\n"

choices = [
ChatCompletionStreamingChoice(
index=i,
delta=ChoiceDelta(),
finish_reason="stop",
choices = [
ChatCompletionStreamingChoice(
index=i,
delta=ChoiceDelta(),
finish_reason="stop",
)
for i in range(request.n)
]
last_chunk = ChatCompletionChunk(
model=model,
choices=choices,
usage=usage_info,
)
for i in range(request.n)
]
last_chunk = ChatCompletionChunk(
model=model,
choices=choices,
usage=usage_info,
)
yield f"data: {last_chunk.model_dump_json(by_alias=True)}\n\n"
yield "data: [DONE]\n\n"
yield f"data: {last_chunk.model_dump_json(by_alias=True)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error("Error in streaming response: %s", e, exc_info=True)
yield _openai_format_error(e)
return

async def non_streaming_completion(self, request: ChatCompletionRequest, generator_list: List[AsyncGenerator]):
model = request.model
usage_infos = []
choices = []
# iterate over n choices
for i, streaming_response in enumerate(generator_list):
msgs = []
tool_calls = None
usage = None
async for response, status in streaming_response:
if status == LitAPIStatus.ERROR and isinstance(response, HTTPException):
raise response
if status == LitAPIStatus.ERROR:
logger.error("Error in OpenAI non-streaming response: %s", response)
raise HTTPException(status_code=500)

# data from LitAPI.encode_response
encoded_response = json.loads(response)
logger.debug(encoded_response)
chat_msg = ChatMessage(**encoded_response)
usage = UsageInfo(**encoded_response)
usage_infos.append(usage) # Aggregate usage info across all choices
msgs.append(chat_msg.content)
if chat_msg.tool_calls:
tool_calls = chat_msg.tool_calls

content = "".join(msg for msg in msgs if msg is not None)
msg = {"role": "assistant", "content": content, "tool_calls": tool_calls}
choice = ChatCompletionResponseChoice(index=i, message=msg, finish_reason="stop")
choices.append(choice)

return ChatCompletionResponse(model=model, choices=choices, usage=sum(usage_infos))
try:
model = request.model
usage_infos = []
choices = []
# iterate over n choices
for i, streaming_response in enumerate(generator_list):
msgs = []
tool_calls = None
usage = None
async for response, status in streaming_response:
if status == LitAPIStatus.ERROR and isinstance(response, HTTPException):
raise response
if status == LitAPIStatus.ERROR:
logger.error("Error in OpenAI non-streaming response: %s", response)
raise HTTPException(status_code=500)

# data from LitAPI.encode_response
encoded_response = json.loads(response)
logger.debug(encoded_response)
chat_msg = ChatMessage(**encoded_response)
usage = UsageInfo(**encoded_response)
usage_infos.append(usage) # Aggregate usage info across all choices
msgs.append(chat_msg.content)
if chat_msg.tool_calls:
tool_calls = chat_msg.tool_calls

content = "".join(msg for msg in msgs if msg is not None)
msg = {"role": "assistant", "content": content, "tool_calls": tool_calls}
choice = ChatCompletionResponseChoice(index=i, message=msg, finish_reason="stop")
choices.append(choice)

return ChatCompletionResponse(model=model, choices=choices, usage=sum(usage_infos))
except HTTPException as e:
raise e
except Exception as e:
logger.error("Error in non-streaming response: %s", e, exc_info=True)
raise HTTPException(status_code=500)


class _AsyncOpenAISpecWrapper(_AsyncSpecWrapper):
Expand Down
6 changes: 5 additions & 1 deletion src/litserve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@

logger = logging.getLogger(__name__)

_DEFAULT_LOG_FORMAT = (
"%(asctime)s - %(processName)s[%(process)d] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s"
)


class LitAPIStatus:
OK = "OK"
Expand Down Expand Up @@ -128,7 +132,7 @@ def _get_default_handler(stream, format):

def configure_logging(
level: Union[str, int] = logging.INFO,
format: str = "%(processName)s[%(process)d] - %(name)s - %(levelname)s - %(message)s",
format: str = _DEFAULT_LOG_FORMAT,
stream: TextIO = sys.stdout,
use_rich: bool = False,
):
Expand Down
14 changes: 13 additions & 1 deletion tests/unit/test_request_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from fastapi import Request
from fastapi import HTTPException, Request

from litserve.server import BaseRequestHandler, RegularRequestHandler
from litserve.test_examples import SimpleLitAPI
Expand Down Expand Up @@ -92,3 +92,15 @@ async def test_request_handler_streaming(mock_event, mock_lit_api):
response = await handler.handle_request(mock_request, Request)
assert mock_server.request_queue.qsize() == 1
assert response == "test-response"


def test_regular_handler_error_response():
with pytest.raises(HTTPException) as e:
RegularRequestHandler._handle_error_response(HTTPException(status_code=500, detail="test error response"))
assert e.value.status_code == 500
assert e.value.detail == "test error response"

with pytest.raises(HTTPException) as e:
RegularRequestHandler._handle_error_response(Exception("test exception"))
assert e.value.status_code == 500
assert e.value.detail == "Internal server error"
4 changes: 2 additions & 2 deletions tests/unit/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,13 @@ def predict(self, prompt):

@pytest.mark.asyncio
async def test_fail_http(openai_request_data):
server = ls.LitServer(WrongLitAPI(), spec=ls.OpenAISpec())
server = ls.LitServer(WrongLitAPI(spec=ls.OpenAISpec()))
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(
transport=ASGITransport(app=manager.app), base_url="http://test"
) as ac:
res = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10)
assert res.status_code == 501, "Server raises 501 error"
assert res.status_code == 501, f"Server raises 501 error: {res.content}"
assert res.text == '{"detail":"test LitAPI.predict error"}'


Expand Down
Loading