Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import multiprocessing as mp
import os
import pickle
import secrets
import sys
import threading
Expand Down Expand Up @@ -325,17 +326,23 @@ async def handle_request(self, request, request_type) -> Response:
return response

except HTTPException as e:
raise e
raise e from None

except Exception as e:
logger.exception(f"Error handling request: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
logger.error(f"Unhandled exception: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error") from e

async def _handle_error_response(self, response):
logger.error("Error in request: %s", response)
"""Raise HTTPException as is and rest as 500 after logging the error."""
if isinstance(response, bytes):
response = pickle.loads(response)

if isinstance(response, HTTPException):
raise response

if isinstance(response, Exception):
logger.error(f"Error while handling request: {response}")

raise HTTPException(status_code=500, detail="Internal server error")


Expand Down Expand Up @@ -800,7 +807,7 @@ def launch_inference_worker(self, lit_api: LitAPI):
self.workers_setup_status,
self._callback_runner,
),
name=f"lit-inference-{endpoint}_{worker_id}",
name="inference-worker",
)
process.start()
process_list.append(process)
Expand Down Expand Up @@ -1363,9 +1370,9 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_w
server = uvicorn.Server(config=uvicorn_config)
if uvicorn_worker_type == "process":
ctx = mp.get_context("fork")
w = ctx.Process(target=server.run, args=(sockets,), name=f"lit-uvicorn-{response_queue_id}")
w = ctx.Process(target=server.run, args=(sockets,), name=f"LitServer-{response_queue_id}")
elif uvicorn_worker_type == "thread":
w = threading.Thread(target=server.run, args=(sockets,), name=f"lit-uvicorn-{response_queue_id}")
w = threading.Thread(target=server.run, args=(sockets,), name=f"LitServer-{response_queue_id}")
else:
raise ValueError("Invalid value for api_server_worker_type. Must be 'process' or 'thread'")
w.start()
Expand Down
171 changes: 100 additions & 71 deletions src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,26 @@ async def encode_response(self, output):
""" # noqa: E501


def _openai_format_error(error: Exception):
if isinstance(error, HTTPException):
return "data: " + json.dumps({
"error": {
"message": error.detail,
"type": "internal",
"param": None,
"code": "internal_error",
}
})
return "data: " + json.dumps({
"error": {
"message": "Internal server error",
"type": "internal",
"param": None,
"code": "internal_error",
}
})


class OpenAISpec(LitSpec):
def __init__(
self,
Expand Down Expand Up @@ -486,82 +506,91 @@ async def chat_completion(self, request: ChatCompletionRequest, background_tasks
return await response_task

async def streaming_completion(self, request: ChatCompletionRequest, pipe_responses: List):
model = request.model
usage_info = None
async for streaming_response in azip(*pipe_responses):
choices = []
usage_infos = []
# iterate over n choices
for i, (response, status) in enumerate(streaming_response):
if status == LitAPIStatus.ERROR and isinstance(response, HTTPException):
raise response
elif status == LitAPIStatus.ERROR:
logger.error("Error in streaming response: %s", response)
raise HTTPException(status_code=500)
encoded_response = json.loads(response)
logger.debug(encoded_response)
chat_msg = ChoiceDelta(**encoded_response)
usage_infos.append(UsageInfo(**encoded_response))
choice = ChatCompletionStreamingChoice(
index=i, delta=chat_msg, system_fingerprint="", finish_reason=None
)
try:
model = request.model
usage_info = None
async for streaming_response in azip(*pipe_responses):
choices = []
usage_infos = []
# iterate over n choices
for i, (response, status) in enumerate(streaming_response):
if status == LitAPIStatus.ERROR and isinstance(response, HTTPException):
raise response
elif status == LitAPIStatus.ERROR:
logger.error("Error in streaming response: %s", response)
raise HTTPException(status_code=500)
encoded_response = json.loads(response)
logger.debug(encoded_response)
chat_msg = ChoiceDelta(**encoded_response)
usage_infos.append(UsageInfo(**encoded_response))
choice = ChatCompletionStreamingChoice(
index=i, delta=chat_msg, system_fingerprint="", finish_reason=None
)

choices.append(choice)
choices.append(choice)

# Only use the last item from encode_response
usage_info = sum(usage_infos)
chunk = ChatCompletionChunk(model=model, choices=choices, usage=None)
logger.debug(chunk)
yield f"data: {chunk.model_dump_json(by_alias=True)}\n\n"

# Only use the last item from encode_response
usage_info = sum(usage_infos)
chunk = ChatCompletionChunk(model=model, choices=choices, usage=None)
logger.debug(chunk)
yield f"data: {chunk.model_dump_json(by_alias=True)}\n\n"

choices = [
ChatCompletionStreamingChoice(
index=i,
delta=ChoiceDelta(),
finish_reason="stop",
choices = [
ChatCompletionStreamingChoice(
index=i,
delta=ChoiceDelta(),
finish_reason="stop",
)
for i in range(request.n)
]
last_chunk = ChatCompletionChunk(
model=model,
choices=choices,
usage=usage_info,
)
for i in range(request.n)
]
last_chunk = ChatCompletionChunk(
model=model,
choices=choices,
usage=usage_info,
)
yield f"data: {last_chunk.model_dump_json(by_alias=True)}\n\n"
yield "data: [DONE]\n\n"
yield f"data: {last_chunk.model_dump_json(by_alias=True)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error("Error in streaming response: %s", e, exc_info=True)
yield _openai_format_error(e)
return

async def non_streaming_completion(self, request: ChatCompletionRequest, generator_list: List[AsyncGenerator]):
model = request.model
usage_infos = []
choices = []
# iterate over n choices
for i, streaming_response in enumerate(generator_list):
msgs = []
tool_calls = None
usage = None
async for response, status in streaming_response:
if status == LitAPIStatus.ERROR and isinstance(response, HTTPException):
raise response
if status == LitAPIStatus.ERROR:
logger.error("Error in OpenAI non-streaming response: %s", response)
raise HTTPException(status_code=500)

# data from LitAPI.encode_response
encoded_response = json.loads(response)
logger.debug(encoded_response)
chat_msg = ChatMessage(**encoded_response)
usage = UsageInfo(**encoded_response)
usage_infos.append(usage) # Aggregate usage info across all choices
msgs.append(chat_msg.content)
if chat_msg.tool_calls:
tool_calls = chat_msg.tool_calls

content = "".join(msg for msg in msgs if msg is not None)
msg = {"role": "assistant", "content": content, "tool_calls": tool_calls}
choice = ChatCompletionResponseChoice(index=i, message=msg, finish_reason="stop")
choices.append(choice)

return ChatCompletionResponse(model=model, choices=choices, usage=sum(usage_infos))
try:
model = request.model
usage_infos = []
choices = []
# iterate over n choices
for i, streaming_response in enumerate(generator_list):
msgs = []
tool_calls = None
usage = None
async for response, status in streaming_response:
if status == LitAPIStatus.ERROR and isinstance(response, HTTPException):
raise response
if status == LitAPIStatus.ERROR:
logger.error("Error in OpenAI non-streaming response: %s", response)
raise HTTPException(status_code=500)

# data from LitAPI.encode_response
encoded_response = json.loads(response)
logger.debug(encoded_response)
chat_msg = ChatMessage(**encoded_response)
usage = UsageInfo(**encoded_response)
usage_infos.append(usage) # Aggregate usage info across all choices
msgs.append(chat_msg.content)
if chat_msg.tool_calls:
tool_calls = chat_msg.tool_calls

content = "".join(msg for msg in msgs if msg is not None)
msg = {"role": "assistant", "content": content, "tool_calls": tool_calls}
choice = ChatCompletionResponseChoice(index=i, message=msg, finish_reason="stop")
choices.append(choice)

return ChatCompletionResponse(model=model, choices=choices, usage=sum(usage_infos))
except Exception as e:
logger.error("Error in non-streaming response: %s", e, exc_info=True)
raise HTTPException(status_code=500)


class _AsyncOpenAISpecWrapper(_AsyncSpecWrapper):
Expand Down
6 changes: 5 additions & 1 deletion src/litserve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@

logger = logging.getLogger(__name__)

_DEFAULT_LOG_FORMAT = (
"%(asctime)s - %(processName)s[%(process)d] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s"
)


class LitAPIStatus:
OK = "OK"
Expand Down Expand Up @@ -128,7 +132,7 @@ def _get_default_handler(stream, format):

def configure_logging(
level: Union[str, int] = logging.INFO,
format: str = "%(processName)s[%(process)d] - %(name)s - %(levelname)s - %(message)s",
format: str = _DEFAULT_LOG_FORMAT,
stream: TextIO = sys.stdout,
use_rich: bool = False,
):
Expand Down
Loading