Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
fastapi >=0.100
uvicorn[standard] >=0.29.0
pyzmq >=22.0.0
fpdb >=0.0.0.dev2
4 changes: 3 additions & 1 deletion src/litserve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from litserve.loggers import Logger
from litserve.server import LitServer, Request, Response
from litserve.specs import OpenAIEmbeddingSpec, OpenAISpec
from litserve.utils import configure_logging
from litserve.utils import configure_logging, set_trace, set_trace_if_debug

configure_logging()

Expand All @@ -32,4 +32,6 @@
"test_examples",
"Callback",
"Logger",
"set_trace",
"set_trace_if_debug",
]
12 changes: 12 additions & 0 deletions src/litserve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import TYPE_CHECKING, AsyncIterator

from fastapi import HTTPException
from fpdb import ForkedPdb

if TYPE_CHECKING:
from litserve.server import LitServer
Expand Down Expand Up @@ -151,3 +152,14 @@ def generate_random_zmq_address(temp_dir="/tmp"):
unique_name = f"zmq-{uuid.uuid4().hex}.ipc"
ipc_path = os.path.join(temp_dir, unique_name)
return f"ipc://{ipc_path}"


def set_trace():
"""Set a tracepoint in the code."""
ForkedPdb().set_trace()


def set_trace_if_debug(debug_env_var="LITSERVE_DEBUG", debug_env_var_value="1"):
"""Set a tracepoint in the code if the environment variable LITSERVE_DEBUG is set."""
if os.environ.get(debug_env_var) == debug_env_var_value:
set_trace()
18 changes: 17 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import pickle
import sys
from unittest import mock
from unittest.mock import MagicMock

import pytest
from fastapi import HTTPException

from litserve.utils import call_after_stream, dump_exception, generate_random_zmq_address
from litserve.utils import call_after_stream, dump_exception, generate_random_zmq_address, set_trace_if_debug


def test_dump_exception():
Expand Down Expand Up @@ -50,3 +51,18 @@ def test_generate_random_zmq_address_non_windows(tmpdir):
# Verify the path exists within the specified temp_dir
assert os.path.commonpath([temp_dir, address1[6:]]) == temp_dir
assert os.path.commonpath([temp_dir, address2[6:]]) == temp_dir


@mock.patch("litserve.utils.set_trace")
def test_set_trace_if_debug(mock_set_trace):
# mock environ
with mock.patch("litserve.utils.os.environ", {"LITSERVE_DEBUG": "1"}):
set_trace_if_debug()
mock_set_trace.assert_called_once()


@mock.patch("litserve.utils.ForkedPdb")
def test_set_trace_if_debug_not_set(mock_forked_pdb):
with mock.patch("litserve.utils.os.environ", {"LITSERVE_DEBUG": "0"}):
set_trace_if_debug()
mock_forked_pdb.assert_not_called()