From 636f80b91bafe6593ba8e339c00e7974feeb56f6 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 24 Jul 2025 15:57:58 +0530 Subject: [PATCH 1/5] feat: introduce _StopLoopError for graceful loop termination * Added a new exception class _StopLoopError to handle sentinel values in request processing loops. * Updated collate_requests and loop classes to raise _StopLoopError when a sentinel value is encountered, allowing for clean exit from loops. * Refactored request handling in SingleLoop and StreamingLoop to improve clarity and maintainability. --- src/litserve/loops/base.py | 16 ++++++++++++--- src/litserve/loops/simple_loops.py | 28 ++++++++++++++++++--------- src/litserve/loops/streaming_loops.py | 6 +++++- 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/src/litserve/loops/base.py b/src/litserve/loops/base.py index 4388efe8..cb3e8716 100644 --- a/src/litserve/loops/base.py +++ b/src/litserve/loops/base.py @@ -87,6 +87,10 @@ async def _async_inject_context(context: Union[List[dict], dict], func, *args, * return await _handle_async_function(func, *args, **kwargs) +class _StopLoopError(Exception): + pass + + def collate_requests( lit_api: LitAPI, request_queue: Queue, @@ -100,7 +104,10 @@ def collate_requests( if lit_api.batch_timeout == 0: while len(payloads) < lit_api.max_batch_size: try: - response_queue_id, uid, timestamp, x_enc = request_queue.get_nowait() + request_data = request_queue.get_nowait() + if request_data == (None, None, None, None): + raise _StopLoopError() + response_queue_id, uid, timestamp, x_enc = request_data if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: timed_out_uids.append((response_queue_id, uid)) else: @@ -115,7 +122,10 @@ def collate_requests( break try: - response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001)) + request_data = request_queue.get(timeout=min(remaining_time, 0.001)) + if request_data == (None, None, None, None): + raise _StopLoopError() + response_queue_id, uid, timestamp, x_enc = request_data if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: timed_out_uids.append((response_queue_id, uid)) else: @@ -264,7 +274,7 @@ def get_batch_requests( self, lit_api: LitAPI, request_queue: Queue, - ): + ) -> Tuple[List, List]: batches, timed_out_uids = collate_requests( lit_api, request_queue, diff --git a/src/litserve/loops/simple_loops.py b/src/litserve/loops/simple_loops.py index 3db0f11a..f9db22c5 100644 --- a/src/litserve/loops/simple_loops.py +++ b/src/litserve/loops/simple_loops.py @@ -21,7 +21,7 @@ from litserve import LitAPI from litserve.callbacks import CallbackRunner, EventTypes -from litserve.loops.base import DefaultLoop, _async_inject_context, _inject_context, collate_requests +from litserve.loops.base import DefaultLoop, _async_inject_context, _inject_context, _StopLoopError from litserve.specs.base import LitSpec from litserve.transport.base import MessageTransport from litserve.utils import LitAPIStatus, LoopResponseType, PickleableHTTPException @@ -41,7 +41,11 @@ def run_single_loop( lit_spec = lit_spec or lit_api.spec while True: try: - response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) + request_data = request_queue.get(timeout=1.0) + if request_data == (None, None, None, None): + logger.debug("Received sentinel value, stopping loop") + return + response_queue_id, uid, timestamp, x_enc = request_data except (Empty, ValueError): continue except KeyboardInterrupt: # pragma: no cover @@ -211,9 +215,11 @@ async def process_requests(): pending_tasks = set() while True: try: - response_queue_id, uid, timestamp, x_enc = await event_loop.run_in_executor( - None, request_queue.get, 1.0 - ) + request_data = await event_loop.run_in_executor(None, request_queue.get, 1.0) + if request_data == (None, None, None, None): + logger.debug("Received sentinel value, stopping loop") + return + response_queue_id, uid, timestamp, x_enc = request_data except (Empty, ValueError): continue except KeyboardInterrupt: @@ -291,10 +297,14 @@ def run_batched_loop( ): lit_spec = lit_api.spec while True: - batches, timed_out_uids = collate_requests( - lit_api, - request_queue, - ) + try: + batches, timed_out_uids = self.get_batch_requests( + lit_api, + request_queue, + ) + except _StopLoopError: + logger.debug("Received sentinel value, stopping loop") + return for response_queue_id, uid in timed_out_uids: logger.error( diff --git a/src/litserve/loops/streaming_loops.py b/src/litserve/loops/streaming_loops.py index 57d951bb..6179572e 100644 --- a/src/litserve/loops/streaming_loops.py +++ b/src/litserve/loops/streaming_loops.py @@ -41,7 +41,11 @@ def run_streaming_loop( lit_spec = lit_api.spec while True: try: - response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) + request_data = request_queue.get(timeout=1.0) + if request_data == (None, None, None, None): + logger.debug("Received sentinel value, stopping loop") + return + response_queue_id, uid, timestamp, x_enc = request_data logger.debug("uid=%s", uid) except (Empty, ValueError): continue From 8d44a1b1d10514ad2d5cf5a3b53447ad09a29267 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 24 Jul 2025 16:13:00 +0530 Subject: [PATCH 2/5] add test --- src/litserve/loops/base.py | 4 +++- tests/unit/test_batch.py | 11 ++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/litserve/loops/base.py b/src/litserve/loops/base.py index cb3e8716..b981a16d 100644 --- a/src/litserve/loops/base.py +++ b/src/litserve/loops/base.py @@ -88,7 +88,9 @@ async def _async_inject_context(context: Union[List[dict], dict], func, *args, * class _StopLoopError(Exception): - pass + def __init__(self, message: str = "Received sentinel value, stopping loop"): + self.message = message + super().__init__(self.message) def collate_requests( diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index b306f0b7..96e1b0ad 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -26,7 +26,7 @@ import litserve as ls from litserve import LitAPI, LitServer from litserve.callbacks import CallbackRunner -from litserve.loops.base import collate_requests +from litserve.loops.base import _StopLoopError, collate_requests from litserve.loops.simple_loops import BatchedLoop from litserve.utils import wrap_litserve_start @@ -235,6 +235,15 @@ def test_collate_requests(batch_timeout, batch_size): assert len(timed_out_uids) == 0, "No timed out uids" +def test_collate_requests_sentinel(): + api = ls.test_examples.SimpleBatchedAPI(max_batch_size=2, batch_timeout=0) + api.request_timeout = 5 + request_queue = Queue() + request_queue.put((None, None, None, None)) + with pytest.raises(_StopLoopError, match="Received sentinel value, stopping loop"): + collate_requests(api, request_queue) + + class BatchSizeMismatchAPI(SimpleBatchLitAPI): def predict(self, x): assert len(x) == 2, "Expected two concurrent inputs to be batched" From fef68ae5185b0d72a29f7c948b45c6599df5d73c Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 24 Jul 2025 16:22:05 +0530 Subject: [PATCH 3/5] apply code suggestion --- src/litserve/loops/base.py | 9 ++++++--- src/litserve/loops/simple_loops.py | 6 +++--- src/litserve/loops/streaming_loops.py | 12 +++++++----- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/litserve/loops/base.py b/src/litserve/loops/base.py index b981a16d..5ccedc19 100644 --- a/src/litserve/loops/base.py +++ b/src/litserve/loops/base.py @@ -37,6 +37,9 @@ # renamed in PR: https://github.com/encode/starlette/pull/2780 MultiPartParser.spool_max_size = sys.maxsize +_DEFAULT_STOP_LOOP_MESSAGE = "Received sentinel value, stopping loop" +_SENTINEL_VALUE = (None, None, None, None) + def _inject_context(context: Union[List[dict], dict], func, *args, **kwargs): sig = inspect.signature(func) @@ -88,7 +91,7 @@ async def _async_inject_context(context: Union[List[dict], dict], func, *args, * class _StopLoopError(Exception): - def __init__(self, message: str = "Received sentinel value, stopping loop"): + def __init__(self, message: str = _DEFAULT_STOP_LOOP_MESSAGE): self.message = message super().__init__(self.message) @@ -107,7 +110,7 @@ def collate_requests( while len(payloads) < lit_api.max_batch_size: try: request_data = request_queue.get_nowait() - if request_data == (None, None, None, None): + if request_data == _SENTINEL_VALUE: raise _StopLoopError() response_queue_id, uid, timestamp, x_enc = request_data if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: @@ -125,7 +128,7 @@ def collate_requests( try: request_data = request_queue.get(timeout=min(remaining_time, 0.001)) - if request_data == (None, None, None, None): + if request_data == _SENTINEL_VALUE: raise _StopLoopError() response_queue_id, uid, timestamp, x_enc = request_data if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: diff --git a/src/litserve/loops/simple_loops.py b/src/litserve/loops/simple_loops.py index f9db22c5..2d566119 100644 --- a/src/litserve/loops/simple_loops.py +++ b/src/litserve/loops/simple_loops.py @@ -21,7 +21,7 @@ from litserve import LitAPI from litserve.callbacks import CallbackRunner, EventTypes -from litserve.loops.base import DefaultLoop, _async_inject_context, _inject_context, _StopLoopError +from litserve.loops.base import _SENTINEL_VALUE, DefaultLoop, _async_inject_context, _inject_context, _StopLoopError from litserve.specs.base import LitSpec from litserve.transport.base import MessageTransport from litserve.utils import LitAPIStatus, LoopResponseType, PickleableHTTPException @@ -42,7 +42,7 @@ def run_single_loop( while True: try: request_data = request_queue.get(timeout=1.0) - if request_data == (None, None, None, None): + if request_data == _SENTINEL_VALUE: logger.debug("Received sentinel value, stopping loop") return response_queue_id, uid, timestamp, x_enc = request_data @@ -216,7 +216,7 @@ async def process_requests(): while True: try: request_data = await event_loop.run_in_executor(None, request_queue.get, 1.0) - if request_data == (None, None, None, None): + if request_data == _SENTINEL_VALUE: logger.debug("Received sentinel value, stopping loop") return response_queue_id, uid, timestamp, x_enc = request_data diff --git a/src/litserve/loops/streaming_loops.py b/src/litserve/loops/streaming_loops.py index 6179572e..22e34376 100644 --- a/src/litserve/loops/streaming_loops.py +++ b/src/litserve/loops/streaming_loops.py @@ -21,7 +21,7 @@ from litserve import LitAPI from litserve.callbacks import CallbackRunner, EventTypes -from litserve.loops.base import DefaultLoop, _async_inject_context, _inject_context, collate_requests +from litserve.loops.base import _SENTINEL_VALUE, DefaultLoop, _async_inject_context, _inject_context, collate_requests from litserve.specs.base import LitSpec from litserve.transport.base import MessageTransport from litserve.utils import LitAPIStatus, LoopResponseType, PickleableHTTPException @@ -42,7 +42,7 @@ def run_streaming_loop( while True: try: request_data = request_queue.get(timeout=1.0) - if request_data == (None, None, None, None): + if request_data == _SENTINEL_VALUE: logger.debug("Received sentinel value, stopping loop") return response_queue_id, uid, timestamp, x_enc = request_data @@ -210,9 +210,11 @@ async def process_requests(): while True: try: - response_queue_id, uid, timestamp, x_enc = await event_loop.run_in_executor( - None, request_queue.get, 1.0 - ) + request_data = await event_loop.run_in_executor(None, request_queue.get, 1.0) + if request_data == _SENTINEL_VALUE: + logger.debug("Received sentinel value, stopping loop") + return + response_queue_id, uid, timestamp, x_enc = request_data logger.debug("uid=%s", uid) except (Empty, ValueError): continue From 5ba8dde6a5a3eade3f8dcb9f043bc9263a46d90e Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 24 Jul 2025 17:21:10 +0530 Subject: [PATCH 4/5] refactor: enhance test_batched_loop with FakeTransport * Introduced FakeTransport class to simulate message transport in tests. * Updated test_batched_loop to utilize FakeTransport for response handling. * Improved assertions to verify the correctness of responses from the batched loop. --- tests/unit/test_batch.py | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 96e1b0ad..8312044e 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -14,7 +14,7 @@ import asyncio import time from queue import Queue -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest import torch @@ -28,7 +28,8 @@ from litserve.callbacks import CallbackRunner from litserve.loops.base import _StopLoopError, collate_requests from litserve.loops.simple_loops import BatchedLoop -from litserve.utils import wrap_litserve_start +from litserve.transport.base import MessageTransport +from litserve.utils import LoopResponseType, wrap_litserve_start NOOP_CB_RUNNER = CallbackRunner() @@ -184,11 +185,23 @@ def put(self, *args, block=True, timeout=None): raise StopIteration("exit loop") +class FakeTransport(MessageTransport): + def __init__(self): + self.responses = [] + + async def areceive(self, **kwargs) -> dict: + raise NotImplementedError("This is a fake transport") + + def send(self, response, consumer_id: int): + self.responses.append(response) + + def test_batched_loop(): requests_queue = Queue() response_queue_id = 0 requests_queue.put((response_queue_id, "uuid-1234", time.monotonic(), {"input": 4.0})) requests_queue.put((response_queue_id, "uuid-1235", time.monotonic(), {"input": 5.0})) + requests_queue.put((None, None, None, None)) lit_api_mock = MagicMock() lit_api_mock.request_timeout = 2 @@ -201,13 +214,17 @@ def test_batched_loop(): lit_api_mock.encode_response = MagicMock(side_effect=lambda x: {"output": x}) loop = BatchedLoop() - with patch("pickle.dumps", side_effect=StopIteration("exit loop")), pytest.raises(StopIteration, match="exit loop"): - loop.run_batched_loop( - lit_api_mock, - requests_queue, - [FakeResponseQueue()], - callback_runner=NOOP_CB_RUNNER, - ) + transport = FakeTransport() + loop.run_batched_loop( + lit_api_mock, + requests_queue, + transport=transport, + callback_runner=NOOP_CB_RUNNER, + ) + + assert len(transport.responses) == 2, "response queue should have 2 responses" + assert transport.responses[0] == ("uuid-1234", ({"output": 16.0}, "OK", LoopResponseType.REGULAR)) + assert transport.responses[1] == ("uuid-1235", ({"output": 25.0}, "OK", LoopResponseType.REGULAR)) lit_api_mock.batch.assert_called_once() lit_api_mock.batch.assert_called_once_with([4.0, 5.0]) From c0c786793a0a7b078acd0faae8425de82f7df404 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 24 Jul 2025 17:22:40 +0530 Subject: [PATCH 5/5] const _SENTINEL_VALUE --- tests/unit/test_batch.py | 6 +++--- tests/unit/test_loops.py | 28 +++++++++++++++++----------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 8312044e..2d418ba6 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -26,7 +26,7 @@ import litserve as ls from litserve import LitAPI, LitServer from litserve.callbacks import CallbackRunner -from litserve.loops.base import _StopLoopError, collate_requests +from litserve.loops.base import _SENTINEL_VALUE, _StopLoopError, collate_requests from litserve.loops.simple_loops import BatchedLoop from litserve.transport.base import MessageTransport from litserve.utils import LoopResponseType, wrap_litserve_start @@ -201,7 +201,7 @@ def test_batched_loop(): response_queue_id = 0 requests_queue.put((response_queue_id, "uuid-1234", time.monotonic(), {"input": 4.0})) requests_queue.put((response_queue_id, "uuid-1235", time.monotonic(), {"input": 5.0})) - requests_queue.put((None, None, None, None)) + requests_queue.put(_SENTINEL_VALUE) lit_api_mock = MagicMock() lit_api_mock.request_timeout = 2 @@ -256,7 +256,7 @@ def test_collate_requests_sentinel(): api = ls.test_examples.SimpleBatchedAPI(max_batch_size=2, batch_timeout=0) api.request_timeout = 5 request_queue = Queue() - request_queue.put((None, None, None, None)) + request_queue.put(_SENTINEL_VALUE) with pytest.raises(_StopLoopError, match="Received sentinel value, stopping loop"): collate_requests(api, request_queue) diff --git a/tests/unit/test_loops.py b/tests/unit/test_loops.py index 9109e1b4..c17695cb 100644 --- a/tests/unit/test_loops.py +++ b/tests/unit/test_loops.py @@ -33,7 +33,13 @@ from litserve import LitAPI from litserve.callbacks import CallbackRunner from litserve.loops import BatchedStreamingLoop, LitLoop, Output, StreamingLoop, inference_worker -from litserve.loops.base import DefaultLoop, _async_inject_context, _handle_async_function, _sync_fn_to_async_fn +from litserve.loops.base import ( + _SENTINEL_VALUE, + DefaultLoop, + _async_inject_context, + _handle_async_function, + _sync_fn_to_async_fn, +) from litserve.loops.continuous_batching_loop import ( ContinuousBatchingLoop, notify_timed_out_requests, @@ -82,7 +88,7 @@ def get(self, timeout=None): raise Empty # Simulate queue being empty after sentinel item = super().get(timeout=timeout) # Sentinel: (None, None, None, None) - if item == (None, None, None, None): + if item == _SENTINEL_VALUE: self._sentinel_seen = True raise KeyboardInterrupt # Triggers loop exit in your code return item @@ -107,7 +113,7 @@ def async_loop_args(): requests_queue = TestQueue() requests_queue.put((0, "uuid-123", time.monotonic(), {"input": 1})) requests_queue.put((1, "uuid-234", time.monotonic(), {"input": 2})) - requests_queue.put((None, None, None, None)) + requests_queue.put(_SENTINEL_VALUE) lit_api = AsyncTestLitAPI() return lit_api, requests_queue @@ -263,7 +269,7 @@ async def test_streaming_loop_process_streaming_request(mock_transport): def test_run_streaming_loop_with_async(mock_transport, monkeypatch): requests_queue = TestQueue() requests_queue.put((0, "uuid-123", time.monotonic(), {"input": 5})) - requests_queue.put((None, None, None, None)) # Sentinel to stop the loop + requests_queue.put(_SENTINEL_VALUE) # Sentinel to stop the loop lit_api = AsyncTestStreamLitAPI() loop = StreamingLoop() @@ -412,7 +418,7 @@ async def test_run_single_loop(mock_transport): time.sleep(1) # Stop the loop by putting a sentinel value in the queue - request_queue.put((None, None, None, None)) + request_queue.put(_SENTINEL_VALUE) loop_thread.join() response = await transport.areceive(consumer_id=0) @@ -445,7 +451,7 @@ async def test_run_single_loop_timeout(): assert response.status_code == 504 assert "Request UUID-001 was waiting in the queue for too long" in stream.getvalue() - request_queue.put((None, None, None, None)) + request_queue.put(_SENTINEL_VALUE) loop_thread.join() @@ -481,7 +487,7 @@ async def test_run_batched_loop(): actual = await transport.areceive(0, timeout=10) assert actual == expected, f"Expected {expected}, got {actual}" - request_queue.put((None, None, None, None)) + request_queue.put(_SENTINEL_VALUE) loop_thread.join() @@ -524,7 +530,7 @@ async def test_run_batched_loop_timeout(mock_transport): _, (response2, _, _) = await transport.areceive(consumer_id=0, timeout=10) assert response2 == {"output": 25.0} - request_queue.put((None, None, None, None)) + request_queue.put(_SENTINEL_VALUE) loop_thread.join() @@ -548,7 +554,7 @@ async def test_run_streaming_loop(mock_transport): time.sleep(1) # Stop the loop by putting a sentinel value in the queue - request_queue.put((None, None, None, None)) + request_queue.put(_SENTINEL_VALUE) loop_thread.join() for i in range(3): @@ -579,7 +585,7 @@ async def test_run_streaming_loop_timeout(mock_transport): time.sleep(1) # Stop the loop by putting a sentinel value in the queue - request_queue.put((None, None, None, None)) + request_queue.put(_SENTINEL_VALUE) loop_thread.join() assert "Request UUID-001 was waiting in the queue for too long" in stream.getvalue() @@ -617,7 +623,7 @@ def off_test_run_batched_streaming_loop(openai_request_data): time.sleep(1) # Stop the loop by putting a sentinel value in the queue - request_queue.put((None, None, None, None)) + request_queue.put(_SENTINEL_VALUE) loop_thread.join() response = response_queues[0].get(timeout=5)[1]