Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
39b9995
start
aniketmaurya May 25, 2025
ad6b344
update
aniketmaurya May 25, 2025
60098f5
Add LitAPIV2 class and refactor LitServerV2 initialization
aniketmaurya May 25, 2025
21654f4
Refactor LitAPI and LitServer initialization for improved configuration
aniketmaurya May 25, 2025
539d459
Refactor LitAPI loop handling and update test assertions
aniketmaurya May 25, 2025
08024dc
Implement validation for api_path in LitServer initialization
aniketmaurya May 25, 2025
900ad8a
Merge branch 'main' into aniket/multiple-endpoints
aniketmaurya May 25, 2025
4c9f3ed
update
aniketmaurya May 25, 2025
6ef9477
Refactor LitServer and LitAPI initialization for deprecation handling
aniketmaurya May 25, 2025
429bddc
fix
aniketmaurya May 25, 2025
9be1b99
Remove debug print statements from api.py and server.py; add test for…
aniketmaurya May 25, 2025
5bf18f8
Refactor LitServer to utilize _LitAPIConnector for improved API manag…
aniketmaurya May 25, 2025
66000aa
Merge branch 'main' into endpoint-2
aniketmaurya May 26, 2025
8fe8d20
Refactor pre_setup methods in LitAPI and loops for improved handling …
aniketmaurya May 26, 2025
319955c
Enhance worker setup in LitServer for improved inference handling
aniketmaurya May 26, 2025
00ba243
fixes
aniketmaurya May 26, 2025
12464b8
Refactor LitServer and BatchedLoop for improved worker management and…
aniketmaurya May 26, 2025
e81bd53
update
aniketmaurya May 26, 2025
0f8a6e8
Update test_openai_embedding.py to pass spec directly to TestEmbedAPI…
aniketmaurya May 26, 2025
fe522bd
Refactor LitServer and utility functions for improved worker manageme…
aniketmaurya May 27, 2025
6d25c15
Refactor data_streamer method in LitServer for improved accessibility
aniketmaurya May 27, 2025
c6b91bd
fix
aniketmaurya May 27, 2025
d42c5ce
fix tests
aniketmaurya May 27, 2025
f0a3261
fix test
aniketmaurya May 27, 2025
01f59a8
update
aniketmaurya May 27, 2025
7943cc7
update
aniketmaurya May 27, 2025
0055ae6
Apply suggestions from code review
aniketmaurya May 27, 2025
8170b3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2025
58a04ad
fix
aniketmaurya May 27, 2025
91dd7e1
Request queue for each LitAPI
aniketmaurya May 27, 2025
0afc18b
Refactor LitAPI and LitServer for improved API path handling
aniketmaurya May 27, 2025
0d41e21
fix test
aniketmaurya May 27, 2025
5f2a386
fix
aniketmaurya May 27, 2025
ceb1b7b
fix windows
aniketmaurya May 27, 2025
81fff9a
Implement path collision detection in LitServer
aniketmaurya May 27, 2025
630cffe
Add mixed streaming configuration check in LitServer
aniketmaurya May 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
"but the max_batch_size parameter was not set."
)

self.api_path = api_path
self._api_path = api_path
self.stream = stream
self._loop = loop
self._spec = spec
Expand Down Expand Up @@ -128,6 +128,7 @@ async def predict(self, x, **kwargs):
@abstractmethod
def setup(self, device):
"""Setup the model so it can be called in `predict`."""
pass

def decode_request(self, request, **kwargs):
"""Convert the request payload to your model input."""
Expand Down Expand Up @@ -210,7 +211,8 @@ def device(self):
def device(self, value):
self._device = value

def pre_setup(self, spec: Optional[LitSpec]):
def pre_setup(self, spec: Optional[LitSpec] = None):
spec = spec or self._spec
if self.stream:
self._default_unbatch = self._unbatch_stream
else:
Expand Down Expand Up @@ -274,3 +276,13 @@ def spec(self):
@spec.setter
def spec(self, value: LitSpec):
self._spec = value

@property
def api_path(self):
if self._spec:
return self._spec.api_path
return self._api_path

@api_path.setter
def api_path(self, value: str):
self._api_path = value
2 changes: 1 addition & 1 deletion src/litserve/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _process_logger_queue(logger_proxies: List[_LoggerProxy], queue):
@functools.cache # Run once per LitServer instance
def run(self, lit_server: "LitServer"):
queue = lit_server.logger_queue
lit_server.lit_api.set_logger_queue(queue)
lit_server.litapi_connector.set_logger_queue(queue)

# Disconnect the logger connector from the LitServer to avoid pickling issues
self._lit_server = None
Expand Down
16 changes: 4 additions & 12 deletions src/litserve/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def run(

"""

def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec] = None):
pass

async def schedule_task(
Expand All @@ -162,15 +162,14 @@ async def schedule_task(
def __call__(
self,
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
device: str,
worker_id: int,
request_queue: Queue,
transport: MessageTransport,
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
):
lit_spec = lit_api.spec
if asyncio.iscoroutinefunction(self.run):
event_loop = asyncio.new_event_loop()

Expand All @@ -182,12 +181,10 @@ async def _wrapper():
try:
await self.run(
lit_api,
lit_spec,
device,
worker_id,
request_queue,
transport,
stream,
workers_setup_status,
callback_runner,
)
Expand All @@ -200,25 +197,21 @@ async def _wrapper():
while True:
self.run(
lit_api,
lit_spec,
device,
worker_id,
request_queue,
transport,
stream,
workers_setup_status,
callback_runner,
)

def run(
self,
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
device: str,
worker_id: int,
request_queue: Queue,
transport: MessageTransport,
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
):
Expand Down Expand Up @@ -273,14 +266,13 @@ def put_error_response(


class DefaultLoop(LitLoop):
def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec] = None):
# we will sanitize regularly if no spec
# in case, we have spec then:
# case 1: spec implements a streaming API
# Case 2: spec implements a non-streaming API
if spec:
if lit_api.spec:
# TODO: Implement sanitization
lit_api._spec = spec
return

original = lit_api.unbatch.__code__ is LitAPI.unbatch.__code__
Expand Down
5 changes: 2 additions & 3 deletions src/litserve/loops/continuous_batching_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, max_sequence_length: int = 2048):
self.max_sequence_length = max_sequence_length
self.response_queue_ids: Dict[str, int] = {} # uid -> response_queue_id

def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec] = None):
"""Check if the lit_api has the necessary methods and if streaming is enabled."""
if not lit_api.stream:
raise ValueError(
Expand Down Expand Up @@ -180,16 +180,15 @@ async def step(
async def run(
self,
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
device: str,
worker_id: int,
request_queue: Queue,
transport: MessageTransport,
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
):
"""Main loop that processes batches of requests."""
lit_spec = lit_api.spec
try:
prev_outputs = None
while lit_api.has_active_requests():
Expand Down
23 changes: 12 additions & 11 deletions src/litserve/loops/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
# limitations under the License.
import logging
from queue import Queue
from typing import Dict, Optional, Union
from typing import Dict

from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.loops.base import _BaseLoop
from litserve.loops.base import LitLoop, _BaseLoop
from litserve.loops.simple_loops import BatchedLoop, SingleLoop
from litserve.loops.streaming_loops import BatchedStreamingLoop, StreamingLoop
from litserve.specs.base import LitSpec
from litserve.transport.base import MessageTransport
from litserve.utils import WorkerSetupStatus

Expand Down Expand Up @@ -61,30 +60,34 @@ def get_default_loop(stream: bool, max_batch_size: int, enable_async: bool = Fal

def inference_worker(
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
device: str,
worker_id: int,
request_queue: Queue,
transport: MessageTransport,
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
loop: Union[str, _BaseLoop],
):
print("workers_setup_status", workers_setup_status)
lit_spec = lit_api.spec
loop: LitLoop = lit_api.loop
stream = lit_api.stream

endpoint = lit_api.api_path.split("/")[-1]

callback_runner.trigger_event(EventTypes.BEFORE_SETUP.value, lit_api=lit_api)
try:
lit_api.setup(device)
except Exception:
logger.exception(f"Error setting up worker {worker_id}.")
workers_setup_status[worker_id] = WorkerSetupStatus.ERROR
workers_setup_status[f"{endpoint}_{worker_id}"] = WorkerSetupStatus.ERROR
return
lit_api.device = device
callback_runner.trigger_event(EventTypes.AFTER_SETUP.value, lit_api=lit_api)

print(f"Setup complete for worker {worker_id}.")
print(f"Setup complete for worker {f'{endpoint}_{worker_id}'}.")

if workers_setup_status:
workers_setup_status[worker_id] = WorkerSetupStatus.READY
workers_setup_status[f"{endpoint}_{worker_id}"] = WorkerSetupStatus.READY

if lit_spec:
logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec")
Expand All @@ -94,12 +97,10 @@ def inference_worker(

loop(
lit_api,
lit_spec,
device,
worker_id,
request_queue,
transport,
stream,
workers_setup_status,
callback_runner,
)
24 changes: 12 additions & 12 deletions src/litserve/loops/simple_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ class SingleLoop(DefaultLoop):
def run_single_loop(
self,
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
request_queue: Queue,
transport: MessageTransport,
callback_runner: CallbackRunner,
lit_spec: Optional[LitSpec] = None,
):
lit_spec = lit_spec or lit_api.spec
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
Expand Down Expand Up @@ -125,10 +126,11 @@ async def _process_single_request(
self,
request,
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
transport: MessageTransport,
callback_runner: CallbackRunner,
lit_spec: Optional[LitSpec] = None,
):
lit_spec = lit_spec or lit_api.spec
response_queue_id, uid, timestamp, x_enc = request
try:
context = {}
Expand Down Expand Up @@ -191,7 +193,6 @@ async def _process_single_request(
def _run_single_loop_with_async(
self,
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
request_queue: Queue,
transport: MessageTransport,
callback_runner: CallbackRunner,
Expand Down Expand Up @@ -232,7 +233,6 @@ async def process_requests():
self._process_single_request(
(response_queue_id, uid, timestamp, x_enc),
lit_api,
lit_spec,
transport,
callback_runner,
),
Expand All @@ -255,30 +255,31 @@ async def process_requests():
def __call__(
self,
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
device: str,
worker_id: int,
request_queue: Queue,
transport: MessageTransport,
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
lit_spec: Optional[LitSpec] = None,
stream: bool = False,
):
if lit_api.enable_async:
self._run_single_loop_with_async(lit_api, lit_spec, request_queue, transport, callback_runner)
self._run_single_loop_with_async(lit_api, request_queue, transport, callback_runner)
else:
self.run_single_loop(lit_api, lit_spec, request_queue, transport, callback_runner)
self.run_single_loop(lit_api, request_queue, transport, callback_runner)


class BatchedLoop(DefaultLoop):
def run_batched_loop(
self,
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
transport: MessageTransport,
callback_runner: CallbackRunner,
lit_spec: Optional[LitSpec] = None,
):
lit_spec = lit_api.spec
while True:
batches, timed_out_uids = collate_requests(
lit_api,
Expand Down Expand Up @@ -368,18 +369,17 @@ def run_batched_loop(
def __call__(
self,
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
device: str,
worker_id: int,
request_queue: Queue,
transport: MessageTransport,
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
lit_spec: Optional[LitSpec] = None,
stream: bool = False,
):
self.run_batched_loop(
lit_api,
lit_spec,
request_queue,
transport,
callback_runner,
Expand Down
Loading