Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/litserve/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
# renamed in PR: https://github.com/encode/starlette/pull/2780
MultiPartParser.spool_max_size = sys.maxsize

_DEFAULT_STOP_LOOP_MESSAGE = "Received sentinel value, stopping loop"
_SENTINEL_VALUE = (None, None, None, None)


def _inject_context(context: Union[List[dict], dict], func, *args, **kwargs):
sig = inspect.signature(func)
Expand Down Expand Up @@ -87,6 +90,12 @@ async def _async_inject_context(context: Union[List[dict], dict], func, *args, *
return await _handle_async_function(func, *args, **kwargs)


class _StopLoopError(Exception):
def __init__(self, message: str = _DEFAULT_STOP_LOOP_MESSAGE):
self.message = message
super().__init__(self.message)


def collate_requests(
lit_api: LitAPI,
request_queue: Queue,
Expand All @@ -100,7 +109,10 @@ def collate_requests(
if lit_api.batch_timeout == 0:
while len(payloads) < lit_api.max_batch_size:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get_nowait()
request_data = request_queue.get_nowait()
if request_data == _SENTINEL_VALUE:
raise _StopLoopError()
response_queue_id, uid, timestamp, x_enc = request_data
if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout:
timed_out_uids.append((response_queue_id, uid))
else:
Expand All @@ -115,7 +127,10 @@ def collate_requests(
break

try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001))
request_data = request_queue.get(timeout=min(remaining_time, 0.001))
if request_data == _SENTINEL_VALUE:
raise _StopLoopError()
response_queue_id, uid, timestamp, x_enc = request_data
if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout:
timed_out_uids.append((response_queue_id, uid))
else:
Expand Down Expand Up @@ -264,7 +279,7 @@ def get_batch_requests(
self,
lit_api: LitAPI,
request_queue: Queue,
):
) -> Tuple[List, List]:
batches, timed_out_uids = collate_requests(
lit_api,
request_queue,
Expand Down
28 changes: 19 additions & 9 deletions src/litserve/loops/simple_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.loops.base import DefaultLoop, _async_inject_context, _inject_context, collate_requests
from litserve.loops.base import _SENTINEL_VALUE, DefaultLoop, _async_inject_context, _inject_context, _StopLoopError
from litserve.specs.base import LitSpec
from litserve.transport.base import MessageTransport
from litserve.utils import LitAPIStatus, LoopResponseType, PickleableHTTPException
Expand All @@ -41,7 +41,11 @@ def run_single_loop(
lit_spec = lit_spec or lit_api.spec
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
request_data = request_queue.get(timeout=1.0)
if request_data == _SENTINEL_VALUE:
logger.debug("Received sentinel value, stopping loop")
return
response_queue_id, uid, timestamp, x_enc = request_data
except (Empty, ValueError):
continue
except KeyboardInterrupt: # pragma: no cover
Expand Down Expand Up @@ -211,9 +215,11 @@ async def process_requests():
pending_tasks = set()
while True:
try:
response_queue_id, uid, timestamp, x_enc = await event_loop.run_in_executor(
None, request_queue.get, 1.0
)
request_data = await event_loop.run_in_executor(None, request_queue.get, 1.0)
if request_data == _SENTINEL_VALUE:
logger.debug("Received sentinel value, stopping loop")
return
response_queue_id, uid, timestamp, x_enc = request_data
except (Empty, ValueError):
continue
except KeyboardInterrupt:
Expand Down Expand Up @@ -291,10 +297,14 @@ def run_batched_loop(
):
lit_spec = lit_api.spec
while True:
batches, timed_out_uids = collate_requests(
lit_api,
request_queue,
)
try:
batches, timed_out_uids = self.get_batch_requests(
lit_api,
request_queue,
)
except _StopLoopError:
logger.debug("Received sentinel value, stopping loop")
return

for response_queue_id, uid in timed_out_uids:
logger.error(
Expand Down
16 changes: 11 additions & 5 deletions src/litserve/loops/streaming_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.loops.base import DefaultLoop, _async_inject_context, _inject_context, collate_requests
from litserve.loops.base import _SENTINEL_VALUE, DefaultLoop, _async_inject_context, _inject_context, collate_requests
from litserve.specs.base import LitSpec
from litserve.transport.base import MessageTransport
from litserve.utils import LitAPIStatus, LoopResponseType, PickleableHTTPException
Expand All @@ -41,7 +41,11 @@ def run_streaming_loop(
lit_spec = lit_api.spec
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
request_data = request_queue.get(timeout=1.0)
if request_data == _SENTINEL_VALUE:
logger.debug("Received sentinel value, stopping loop")
return
response_queue_id, uid, timestamp, x_enc = request_data
logger.debug("uid=%s", uid)
except (Empty, ValueError):
continue
Expand Down Expand Up @@ -206,9 +210,11 @@ async def process_requests():

while True:
try:
response_queue_id, uid, timestamp, x_enc = await event_loop.run_in_executor(
None, request_queue.get, 1.0
)
request_data = await event_loop.run_in_executor(None, request_queue.get, 1.0)
if request_data == _SENTINEL_VALUE:
logger.debug("Received sentinel value, stopping loop")
return
response_queue_id, uid, timestamp, x_enc = request_data
logger.debug("uid=%s", uid)
except (Empty, ValueError):
continue
Expand Down
46 changes: 36 additions & 10 deletions tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import asyncio
import time
from queue import Queue
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock

import pytest
import torch
Expand All @@ -26,9 +26,10 @@
import litserve as ls
from litserve import LitAPI, LitServer
from litserve.callbacks import CallbackRunner
from litserve.loops.base import collate_requests
from litserve.loops.base import _SENTINEL_VALUE, _StopLoopError, collate_requests
from litserve.loops.simple_loops import BatchedLoop
from litserve.utils import wrap_litserve_start
from litserve.transport.base import MessageTransport
from litserve.utils import LoopResponseType, wrap_litserve_start

NOOP_CB_RUNNER = CallbackRunner()

Expand Down Expand Up @@ -184,11 +185,23 @@ def put(self, *args, block=True, timeout=None):
raise StopIteration("exit loop")


class FakeTransport(MessageTransport):
def __init__(self):
self.responses = []

async def areceive(self, **kwargs) -> dict:
raise NotImplementedError("This is a fake transport")

def send(self, response, consumer_id: int):
self.responses.append(response)


def test_batched_loop():
requests_queue = Queue()
response_queue_id = 0
requests_queue.put((response_queue_id, "uuid-1234", time.monotonic(), {"input": 4.0}))
requests_queue.put((response_queue_id, "uuid-1235", time.monotonic(), {"input": 5.0}))
requests_queue.put(_SENTINEL_VALUE)

lit_api_mock = MagicMock()
lit_api_mock.request_timeout = 2
Expand All @@ -201,13 +214,17 @@ def test_batched_loop():
lit_api_mock.encode_response = MagicMock(side_effect=lambda x: {"output": x})

loop = BatchedLoop()
with patch("pickle.dumps", side_effect=StopIteration("exit loop")), pytest.raises(StopIteration, match="exit loop"):
loop.run_batched_loop(
lit_api_mock,
requests_queue,
[FakeResponseQueue()],
callback_runner=NOOP_CB_RUNNER,
)
transport = FakeTransport()
loop.run_batched_loop(
lit_api_mock,
requests_queue,
transport=transport,
callback_runner=NOOP_CB_RUNNER,
)

assert len(transport.responses) == 2, "response queue should have 2 responses"
assert transport.responses[0] == ("uuid-1234", ({"output": 16.0}, "OK", LoopResponseType.REGULAR))
assert transport.responses[1] == ("uuid-1235", ({"output": 25.0}, "OK", LoopResponseType.REGULAR))

lit_api_mock.batch.assert_called_once()
lit_api_mock.batch.assert_called_once_with([4.0, 5.0])
Expand Down Expand Up @@ -235,6 +252,15 @@ def test_collate_requests(batch_timeout, batch_size):
assert len(timed_out_uids) == 0, "No timed out uids"


def test_collate_requests_sentinel():
api = ls.test_examples.SimpleBatchedAPI(max_batch_size=2, batch_timeout=0)
api.request_timeout = 5
request_queue = Queue()
request_queue.put(_SENTINEL_VALUE)
with pytest.raises(_StopLoopError, match="Received sentinel value, stopping loop"):
collate_requests(api, request_queue)


class BatchSizeMismatchAPI(SimpleBatchLitAPI):
def predict(self, x):
assert len(x) == 2, "Expected two concurrent inputs to be batched"
Expand Down
28 changes: 17 additions & 11 deletions tests/unit/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@
from litserve import LitAPI
from litserve.callbacks import CallbackRunner
from litserve.loops import BatchedStreamingLoop, LitLoop, Output, StreamingLoop, inference_worker
from litserve.loops.base import DefaultLoop, _async_inject_context, _handle_async_function, _sync_fn_to_async_fn
from litserve.loops.base import (
_SENTINEL_VALUE,
DefaultLoop,
_async_inject_context,
_handle_async_function,
_sync_fn_to_async_fn,
)
from litserve.loops.continuous_batching_loop import (
ContinuousBatchingLoop,
notify_timed_out_requests,
Expand Down Expand Up @@ -82,7 +88,7 @@ def get(self, timeout=None):
raise Empty # Simulate queue being empty after sentinel
item = super().get(timeout=timeout)
# Sentinel: (None, None, None, None)
if item == (None, None, None, None):
if item == _SENTINEL_VALUE:
self._sentinel_seen = True
raise KeyboardInterrupt # Triggers loop exit in your code
return item
Expand All @@ -107,7 +113,7 @@ def async_loop_args():
requests_queue = TestQueue()
requests_queue.put((0, "uuid-123", time.monotonic(), {"input": 1}))
requests_queue.put((1, "uuid-234", time.monotonic(), {"input": 2}))
requests_queue.put((None, None, None, None))
requests_queue.put(_SENTINEL_VALUE)

lit_api = AsyncTestLitAPI()
return lit_api, requests_queue
Expand Down Expand Up @@ -263,7 +269,7 @@ async def test_streaming_loop_process_streaming_request(mock_transport):
def test_run_streaming_loop_with_async(mock_transport, monkeypatch):
requests_queue = TestQueue()
requests_queue.put((0, "uuid-123", time.monotonic(), {"input": 5}))
requests_queue.put((None, None, None, None)) # Sentinel to stop the loop
requests_queue.put(_SENTINEL_VALUE) # Sentinel to stop the loop

lit_api = AsyncTestStreamLitAPI()
loop = StreamingLoop()
Expand Down Expand Up @@ -412,7 +418,7 @@ async def test_run_single_loop(mock_transport):
time.sleep(1)

# Stop the loop by putting a sentinel value in the queue
request_queue.put((None, None, None, None))
request_queue.put(_SENTINEL_VALUE)
loop_thread.join()

response = await transport.areceive(consumer_id=0)
Expand Down Expand Up @@ -445,7 +451,7 @@ async def test_run_single_loop_timeout():
assert response.status_code == 504
assert "Request UUID-001 was waiting in the queue for too long" in stream.getvalue()

request_queue.put((None, None, None, None))
request_queue.put(_SENTINEL_VALUE)
loop_thread.join()


Expand Down Expand Up @@ -481,7 +487,7 @@ async def test_run_batched_loop():
actual = await transport.areceive(0, timeout=10)
assert actual == expected, f"Expected {expected}, got {actual}"

request_queue.put((None, None, None, None))
request_queue.put(_SENTINEL_VALUE)
loop_thread.join()


Expand Down Expand Up @@ -524,7 +530,7 @@ async def test_run_batched_loop_timeout(mock_transport):
_, (response2, _, _) = await transport.areceive(consumer_id=0, timeout=10)
assert response2 == {"output": 25.0}

request_queue.put((None, None, None, None))
request_queue.put(_SENTINEL_VALUE)
loop_thread.join()


Expand All @@ -548,7 +554,7 @@ async def test_run_streaming_loop(mock_transport):
time.sleep(1)

# Stop the loop by putting a sentinel value in the queue
request_queue.put((None, None, None, None))
request_queue.put(_SENTINEL_VALUE)
loop_thread.join()

for i in range(3):
Expand Down Expand Up @@ -579,7 +585,7 @@ async def test_run_streaming_loop_timeout(mock_transport):
time.sleep(1)

# Stop the loop by putting a sentinel value in the queue
request_queue.put((None, None, None, None))
request_queue.put(_SENTINEL_VALUE)
loop_thread.join()

assert "Request UUID-001 was waiting in the queue for too long" in stream.getvalue()
Expand Down Expand Up @@ -617,7 +623,7 @@ def off_test_run_batched_streaming_loop(openai_request_data):
time.sleep(1)

# Stop the loop by putting a sentinel value in the queue
request_queue.put((None, None, None, None))
request_queue.put(_SENTINEL_VALUE)
loop_thread.join()

response = response_queues[0].get(timeout=5)[1]
Expand Down
Loading