diff --git a/src/litserve/api.py b/src/litserve/api.py index 8d4e9879..cf5671e6 100644 --- a/src/litserve/api.py +++ b/src/litserve/api.py @@ -89,7 +89,7 @@ def __init__( "but the max_batch_size parameter was not set." ) - self.api_path = api_path + self._api_path = api_path self.stream = stream self._loop = loop self._spec = spec @@ -128,6 +128,7 @@ async def predict(self, x, **kwargs): @abstractmethod def setup(self, device): """Setup the model so it can be called in `predict`.""" + pass def decode_request(self, request, **kwargs): """Convert the request payload to your model input.""" @@ -210,7 +211,8 @@ def device(self): def device(self, value): self._device = value - def pre_setup(self, spec: Optional[LitSpec]): + def pre_setup(self, spec: Optional[LitSpec] = None): + spec = spec or self._spec if self.stream: self._default_unbatch = self._unbatch_stream else: @@ -274,3 +276,13 @@ def spec(self): @spec.setter def spec(self, value: LitSpec): self._spec = value + + @property + def api_path(self): + if self._spec: + return self._spec.api_path + return self._api_path + + @api_path.setter + def api_path(self, value: str): + self._api_path = value diff --git a/src/litserve/loggers.py b/src/litserve/loggers.py index 821ba5c0..dbf91680 100644 --- a/src/litserve/loggers.py +++ b/src/litserve/loggers.py @@ -142,7 +142,7 @@ def _process_logger_queue(logger_proxies: List[_LoggerProxy], queue): @functools.cache # Run once per LitServer instance def run(self, lit_server: "LitServer"): queue = lit_server.logger_queue - lit_server.lit_api.set_logger_queue(queue) + lit_server.litapi_connector.set_logger_queue(queue) # Disconnect the logger connector from the LitServer to avoid pickling issues self._lit_server = None diff --git a/src/litserve/loops/base.py b/src/litserve/loops/base.py index a0efda04..add93827 100644 --- a/src/litserve/loops/base.py +++ b/src/litserve/loops/base.py @@ -147,7 +147,7 @@ def run( """ - def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]): + def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec] = None): pass async def schedule_task( @@ -162,15 +162,14 @@ async def schedule_task( def __call__( self, lit_api: LitAPI, - lit_spec: Optional[LitSpec], device: str, worker_id: int, request_queue: Queue, transport: MessageTransport, - stream: bool, workers_setup_status: Dict[int, str], callback_runner: CallbackRunner, ): + lit_spec = lit_api.spec if asyncio.iscoroutinefunction(self.run): event_loop = asyncio.new_event_loop() @@ -182,12 +181,10 @@ async def _wrapper(): try: await self.run( lit_api, - lit_spec, device, worker_id, request_queue, transport, - stream, workers_setup_status, callback_runner, ) @@ -200,12 +197,10 @@ async def _wrapper(): while True: self.run( lit_api, - lit_spec, device, worker_id, request_queue, transport, - stream, workers_setup_status, callback_runner, ) @@ -213,12 +208,10 @@ async def _wrapper(): def run( self, lit_api: LitAPI, - lit_spec: Optional[LitSpec], device: str, worker_id: int, request_queue: Queue, transport: MessageTransport, - stream: bool, workers_setup_status: Dict[int, str], callback_runner: CallbackRunner, ): @@ -273,14 +266,13 @@ def put_error_response( class DefaultLoop(LitLoop): - def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]): + def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec] = None): # we will sanitize regularly if no spec # in case, we have spec then: # case 1: spec implements a streaming API # Case 2: spec implements a non-streaming API - if spec: + if lit_api.spec: # TODO: Implement sanitization - lit_api._spec = spec return original = lit_api.unbatch.__code__ is LitAPI.unbatch.__code__ diff --git a/src/litserve/loops/continuous_batching_loop.py b/src/litserve/loops/continuous_batching_loop.py index 0ecfad0b..1be8d228 100644 --- a/src/litserve/loops/continuous_batching_loop.py +++ b/src/litserve/loops/continuous_batching_loop.py @@ -68,7 +68,7 @@ def __init__(self, max_sequence_length: int = 2048): self.max_sequence_length = max_sequence_length self.response_queue_ids: Dict[str, int] = {} # uid -> response_queue_id - def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]): + def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec] = None): """Check if the lit_api has the necessary methods and if streaming is enabled.""" if not lit_api.stream: raise ValueError( @@ -180,16 +180,15 @@ async def step( async def run( self, lit_api: LitAPI, - lit_spec: Optional[LitSpec], device: str, worker_id: int, request_queue: Queue, transport: MessageTransport, - stream: bool, workers_setup_status: Dict[int, str], callback_runner: CallbackRunner, ): """Main loop that processes batches of requests.""" + lit_spec = lit_api.spec try: prev_outputs = None while lit_api.has_active_requests(): diff --git a/src/litserve/loops/loops.py b/src/litserve/loops/loops.py index 5fbd3978..8920bb7e 100644 --- a/src/litserve/loops/loops.py +++ b/src/litserve/loops/loops.py @@ -13,14 +13,13 @@ # limitations under the License. import logging from queue import Queue -from typing import Dict, Optional, Union +from typing import Dict from litserve import LitAPI from litserve.callbacks import CallbackRunner, EventTypes -from litserve.loops.base import _BaseLoop +from litserve.loops.base import LitLoop, _BaseLoop from litserve.loops.simple_loops import BatchedLoop, SingleLoop from litserve.loops.streaming_loops import BatchedStreamingLoop, StreamingLoop -from litserve.specs.base import LitSpec from litserve.transport.base import MessageTransport from litserve.utils import WorkerSetupStatus @@ -61,30 +60,34 @@ def get_default_loop(stream: bool, max_batch_size: int, enable_async: bool = Fal def inference_worker( lit_api: LitAPI, - lit_spec: Optional[LitSpec], device: str, worker_id: int, request_queue: Queue, transport: MessageTransport, - stream: bool, workers_setup_status: Dict[int, str], callback_runner: CallbackRunner, - loop: Union[str, _BaseLoop], ): + print("workers_setup_status", workers_setup_status) + lit_spec = lit_api.spec + loop: LitLoop = lit_api.loop + stream = lit_api.stream + + endpoint = lit_api.api_path.split("/")[-1] + callback_runner.trigger_event(EventTypes.BEFORE_SETUP.value, lit_api=lit_api) try: lit_api.setup(device) except Exception: logger.exception(f"Error setting up worker {worker_id}.") - workers_setup_status[worker_id] = WorkerSetupStatus.ERROR + workers_setup_status[f"{endpoint}_{worker_id}"] = WorkerSetupStatus.ERROR return lit_api.device = device callback_runner.trigger_event(EventTypes.AFTER_SETUP.value, lit_api=lit_api) - print(f"Setup complete for worker {worker_id}.") + print(f"Setup complete for worker {f'{endpoint}_{worker_id}'}.") if workers_setup_status: - workers_setup_status[worker_id] = WorkerSetupStatus.READY + workers_setup_status[f"{endpoint}_{worker_id}"] = WorkerSetupStatus.READY if lit_spec: logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec") @@ -94,12 +97,10 @@ def inference_worker( loop( lit_api, - lit_spec, device, worker_id, request_queue, transport, - stream, workers_setup_status, callback_runner, ) diff --git a/src/litserve/loops/simple_loops.py b/src/litserve/loops/simple_loops.py index 650342f9..4c6000a7 100644 --- a/src/litserve/loops/simple_loops.py +++ b/src/litserve/loops/simple_loops.py @@ -33,11 +33,12 @@ class SingleLoop(DefaultLoop): def run_single_loop( self, lit_api: LitAPI, - lit_spec: Optional[LitSpec], request_queue: Queue, transport: MessageTransport, callback_runner: CallbackRunner, + lit_spec: Optional[LitSpec] = None, ): + lit_spec = lit_spec or lit_api.spec while True: try: response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) @@ -125,10 +126,11 @@ async def _process_single_request( self, request, lit_api: LitAPI, - lit_spec: Optional[LitSpec], transport: MessageTransport, callback_runner: CallbackRunner, + lit_spec: Optional[LitSpec] = None, ): + lit_spec = lit_spec or lit_api.spec response_queue_id, uid, timestamp, x_enc = request try: context = {} @@ -191,7 +193,6 @@ async def _process_single_request( def _run_single_loop_with_async( self, lit_api: LitAPI, - lit_spec: Optional[LitSpec], request_queue: Queue, transport: MessageTransport, callback_runner: CallbackRunner, @@ -232,7 +233,6 @@ async def process_requests(): self._process_single_request( (response_queue_id, uid, timestamp, x_enc), lit_api, - lit_spec, transport, callback_runner, ), @@ -255,30 +255,31 @@ async def process_requests(): def __call__( self, lit_api: LitAPI, - lit_spec: Optional[LitSpec], device: str, worker_id: int, request_queue: Queue, transport: MessageTransport, - stream: bool, workers_setup_status: Dict[int, str], callback_runner: CallbackRunner, + lit_spec: Optional[LitSpec] = None, + stream: bool = False, ): if lit_api.enable_async: - self._run_single_loop_with_async(lit_api, lit_spec, request_queue, transport, callback_runner) + self._run_single_loop_with_async(lit_api, request_queue, transport, callback_runner) else: - self.run_single_loop(lit_api, lit_spec, request_queue, transport, callback_runner) + self.run_single_loop(lit_api, request_queue, transport, callback_runner) class BatchedLoop(DefaultLoop): def run_batched_loop( self, lit_api: LitAPI, - lit_spec: LitSpec, request_queue: Queue, transport: MessageTransport, callback_runner: CallbackRunner, + lit_spec: Optional[LitSpec] = None, ): + lit_spec = lit_api.spec while True: batches, timed_out_uids = collate_requests( lit_api, @@ -368,18 +369,17 @@ def run_batched_loop( def __call__( self, lit_api: LitAPI, - lit_spec: Optional[LitSpec], device: str, worker_id: int, request_queue: Queue, transport: MessageTransport, - stream: bool, workers_setup_status: Dict[int, str], callback_runner: CallbackRunner, + lit_spec: Optional[LitSpec] = None, + stream: bool = False, ): self.run_batched_loop( lit_api, - lit_spec, request_queue, transport, callback_runner, diff --git a/src/litserve/loops/streaming_loops.py b/src/litserve/loops/streaming_loops.py index de3e295d..1f3d78c5 100644 --- a/src/litserve/loops/streaming_loops.py +++ b/src/litserve/loops/streaming_loops.py @@ -33,11 +33,12 @@ class StreamingLoop(DefaultLoop): def run_streaming_loop( self, lit_api: LitAPI, - lit_spec: LitSpec, request_queue: Queue, transport: MessageTransport, callback_runner: CallbackRunner, + lit_spec: Optional[LitSpec] = None, ): + lit_spec = lit_api.spec while True: try: response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) @@ -113,10 +114,11 @@ async def _process_streaming_request( self, request, lit_api: LitAPI, - lit_spec: Optional[LitSpec], transport: MessageTransport, callback_runner: CallbackRunner, + lit_spec: Optional[LitSpec] = None, ): + lit_spec = lit_api.spec response_queue_id, uid, timestamp, x_enc = request try: context = {} @@ -179,7 +181,6 @@ async def _process_streaming_request( def run_streaming_loop_async( self, lit_api: LitAPI, - lit_spec: Optional[LitSpec], request_queue: Queue, transport: MessageTransport, callback_runner: CallbackRunner, @@ -214,7 +215,6 @@ async def process_requests(): self._process_streaming_request( (response_queue_id, uid, timestamp, x_enc), lit_api, - lit_spec, transport, callback_runner, ), @@ -233,30 +233,29 @@ async def process_requests(): def __call__( self, lit_api: LitAPI, - lit_spec: Optional[LitSpec], device: str, worker_id: int, request_queue: Queue, transport: MessageTransport, - stream: bool, workers_setup_status: Dict[int, str], callback_runner: CallbackRunner, ): if lit_api.enable_async: - self.run_streaming_loop_async(lit_api, lit_spec, request_queue, transport, callback_runner) + self.run_streaming_loop_async(lit_api, request_queue, transport, callback_runner) else: - self.run_streaming_loop(lit_api, lit_spec, request_queue, transport, callback_runner) + self.run_streaming_loop(lit_api, request_queue, transport, callback_runner) class BatchedStreamingLoop(DefaultLoop): def run_batched_streaming_loop( self, lit_api: LitAPI, - lit_spec: LitSpec, request_queue: Queue, transport: MessageTransport, callback_runner: CallbackRunner, + lit_spec: Optional[LitSpec] = None, ): + lit_spec = lit_api.spec while True: batches, timed_out_uids = collate_requests( lit_api, @@ -338,18 +337,15 @@ def run_batched_streaming_loop( def __call__( self, lit_api: LitAPI, - lit_spec: Optional[LitSpec], device: str, worker_id: int, request_queue: Queue, transport: MessageTransport, - stream: bool, workers_setup_status: Dict[int, str], callback_runner: CallbackRunner, ): self.run_batched_streaming_loop( lit_api, - lit_spec, request_queue, transport, callback_runner, diff --git a/src/litserve/server.py b/src/litserve/server.py index 865a5bd1..63fa1701 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -27,8 +27,9 @@ from collections import deque from contextlib import asynccontextmanager from multiprocessing.context import Process +from queue import Queue from threading import Thread -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import uvicorn import uvicorn.server @@ -132,10 +133,91 @@ def _migration_warning(feature_name): ) +class _LitAPIConnector: + """A helper class to manage one or more `LitAPI` instances. + + This class provides utilities for performing setup tasks, managing request + and batch timeouts, and interacting with `LitAPI` instances in a unified way. + It ensures that all `LitAPI` instances are properly initialized and configured + before use. + + Attributes: + lit_apis (List[LitAPI]): A list of `LitAPI` instances managed by this connector. + + Methods: + pre_setup(): Calls the `pre_setup` method on all managed `LitAPI` instances. + set_request_timeout(timeout): Sets the request timeout for all `LitAPI` instances + and validates that batch timeouts are within acceptable limits. + __iter__(): Allows iteration over the managed `LitAPI` instances. + any_stream(): Checks if any of the `LitAPI` instances have streaming enabled. + set_logger_queue(queue): Sets a logger queue for all `LitAPI` instances. + + """ + + def __init__(self, lit_apis: Union[LitAPI, Iterable[LitAPI]]): + if isinstance(lit_apis, LitAPI): + self.lit_apis = [lit_apis] + elif isinstance(lit_apis, Iterable): + self.lit_apis = list(lit_apis) + if not self.lit_apis: # Check if the iterable is empty + raise ValueError("lit_apis must not be an empty iterable") + self._detect_path_collision() + self._check_mixed_streaming_configuration() + else: + raise ValueError(f"lit_apis must be a LitAPI or an iterable of LitAPI, but got {type(lit_apis)}") + + def _check_mixed_streaming_configuration(self): + """Ensure consistent streaming configuration across all endpoints. + + Streaming must be either enabled for all endpoints or disabled for all. Mixing streaming and non-streaming + endpoints is currently not supported. + + """ + streams_enabled = [api.stream for api in self.lit_apis] + if any(streams_enabled) and not all(streams_enabled): + raise ValueError( + "Inconsistent streaming configuration: all endpoints must either " + "enable streaming or disable it. " + "Mixed configurations are not supported yet." + ) + + def _detect_path_collision(self): + paths = {"/health": "LitServe healthcheck", "/info": "LitServe info"} + for lit_api in self.lit_apis: + if lit_api.api_path in paths: + raise ValueError(f"api_path {lit_api.api_path} is already in use by {paths[lit_api.api_path]}") + paths[lit_api.api_path] = lit_api + + def pre_setup(self): + for lit_api in self.lit_apis: + lit_api.pre_setup() + # Ideally LitAPI should not know about LitLoop + # LitLoop can keep litapi as a class variable + lit_api.loop.pre_setup(lit_api) + + def set_request_timeout(self, timeout: float): + for lit_api in self.lit_apis: + lit_api.request_timeout = timeout + + for lit_api in self.lit_apis: + if lit_api.batch_timeout > timeout and timeout not in (False, -1): + raise ValueError("batch_timeout must be less than request_timeout") + + def __iter__(self): + return iter(self.lit_apis) + + def any_stream(self): + return any(lit_api.stream for lit_api in self.lit_apis) + + def set_logger_queue(self, queue: Queue): + for lit_api in self.lit_apis: + lit_api.set_logger_queue(queue) + + class LitServer: def __init__( self, - lit_api: LitAPI, + lit_api: Union[LitAPI, List[LitAPI]], accelerator: str = "auto", devices: Union[str, int] = "auto", workers_per_device: int = 1, @@ -224,8 +306,8 @@ def __init__( lit_api.stream = spec.stream # pre setup - lit_api.pre_setup(spec=spec) - lit_api.loop.pre_setup(lit_api, spec=spec) + self.litapi_connector = _LitAPIConnector(lit_api) + self.litapi_connector.pre_setup() if api_path and not api_path.startswith("/"): raise ValueError( @@ -249,16 +331,6 @@ def __init__( except (TypeError, ValueError): raise ValueError("model_metadata must be JSON serializable.") - # Check if the batch and unbatch methods are overridden in the lit_api instance - batch_overridden = lit_api.batch.__code__ is not LitAPI.batch.__code__ - unbatch_overridden = lit_api.unbatch.__code__ is not LitAPI.unbatch.__code__ - - if batch_overridden and unbatch_overridden and lit_api.max_batch_size == 1: - warnings.warn( - "The LitServer has both batch and unbatch methods implemented, " - "but the max_batch_size parameter was not set." - ) - if sys.platform == "win32" and fast_queue: warnings.warn("ZMQ is not supported on Windows with LitServe. Disabling ZMQ.") fast_queue = False @@ -267,17 +339,13 @@ def __init__( self.info_path = info_path self.track_requests = track_requests self.timeout = timeout - # TODO: Connector - lit_api.request_timeout = timeout - if lit_api.batch_timeout > timeout and timeout not in (False, -1): - raise ValueError("batch_timeout must be less than request_timeout") + self.litapi_connector.set_request_timeout(timeout) self.app = FastAPI(lifespan=self.lifespan) self.app.response_queue_id = None self.response_queue_id = None self.response_buffer = {} # gzip does not play nicely with streaming, see https://github.com/tiangolo/fastapi/discussions/8448 - # TODO: Connector - if not lit_api.stream: + if not self.litapi_connector.any_stream(): middlewares.append((GZipMiddleware, {"minimum_size": 1000})) if max_payload_size is not None: middlewares.append((MaxSizeMiddleware, {"max_size": max_payload_size})) @@ -293,20 +361,7 @@ def __init__( self._callback_runner = CallbackRunner(callbacks) self.use_zmq = fast_queue self.transport_config = None - - # specs = spec if spec is not None else [] - # self._specs = specs if isinstance(specs, Sequence) else [specs] - - decode_request_signature = inspect.signature(lit_api.decode_request) - encode_response_signature = inspect.signature(lit_api.encode_response) - - self.request_type = decode_request_signature.parameters["request"].annotation - if self.request_type == decode_request_signature.empty: - self.request_type = Request - - self.response_type = encode_response_signature.return_annotation - if self.response_type == encode_response_signature.empty: - self.response_type = Response + self.litapi_request_queues = {} accelerator = self._connector.accelerator devices = self._connector.devices @@ -322,51 +377,40 @@ def __init__( self.transport_config = TransportConfig(transport_config="zmq" if self.use_zmq else "mp") self.register_endpoints() - def launch_inference_worker(self, num_uvicorn_servers: int): - self.transport_config.num_consumers = num_uvicorn_servers - manager = self.transport_config.manager = mp.Manager() - self._transport = create_transport_from_config(self.transport_config) - self.workers_setup_status = manager.dict() - self.request_queue = manager.Queue() - if self._logger_connector._loggers: - self.logger_queue = manager.Queue() - - self._logger_connector.run(self) - - specs = [self.lit_api.spec] if self.lit_api.spec else [] + def launch_inference_worker(self, lit_api: LitAPI): + specs = [lit_api.spec] if lit_api.spec else [] for spec in specs: # Objects of Server class are referenced (not copied) logging.debug(f"shallow copy for Server is created for for spec {spec}") server_copy = copy.copy(self) - del server_copy.app, server_copy.transport_config + del server_copy.app, server_copy.transport_config, server_copy.litapi_connector + print(self.litapi_request_queues) spec.setup(server_copy) process_list = [] + endpoint = lit_api.api_path.split("/")[-1] for worker_id, device in enumerate(self.inference_workers): if len(device) == 1: device = device[0] - self.workers_setup_status[worker_id] = WorkerSetupStatus.STARTING + self.workers_setup_status[f"{endpoint}_{worker_id}"] = WorkerSetupStatus.STARTING ctx = mp.get_context("spawn") process = ctx.Process( target=inference_worker, args=( - self.lit_api, - self.lit_api.spec, + lit_api, device, worker_id, - self.request_queue, + self._get_request_queue(lit_api.api_path), self._transport, - self.lit_api.stream, self.workers_setup_status, self._callback_runner, - self.lit_api.loop, ), ) process.start() process_list.append(process) - return manager, process_list + return process_list @asynccontextmanager async def lifespan(self, app: FastAPI): @@ -383,7 +427,7 @@ async def lifespan(self, app: FastAPI): future = response_queue_to_buffer( transport, self.response_buffer, - self.lit_api.stream, + self.litapi_connector.any_stream(), app.response_queue_id, ) task = loop.create_task(future, name=f"response_queue_to_buffer-{app.response_queue_id}") @@ -404,7 +448,8 @@ def device_identifiers(self, accelerator, device): return [f"{accelerator}:{el}" for el in device] return [f"{accelerator}:{device}"] - async def data_streamer(self, q: deque, data_available: asyncio.Event, send_status: bool = False): + @staticmethod + async def data_streamer(q: deque, data_available: asyncio.Event, send_status: bool = False): while True: await data_available.wait() while len(q) > 0: @@ -432,9 +477,7 @@ def active_requests(self): return sum(counter.value for counter in self.active_counters) return None - def register_endpoints(self): - """Register endpoint routes for the FastAPI app and setup middlewares.""" - self._callback_runner.trigger_event(EventTypes.ON_SERVER_START.value, litserver=self) + def _register_internal_endpoints(self): workers_ready = False @self.app.get("/", dependencies=[Depends(self.setup_auth())]) @@ -469,7 +512,31 @@ async def info(request: Request) -> Response: } ) - async def predict(request: self.request_type) -> self.response_type: + def register_endpoints(self): + self._register_internal_endpoints() + for lit_api in self.litapi_connector: + decode_request_signature = inspect.signature(lit_api.decode_request) + encode_response_signature = inspect.signature(lit_api.encode_response) + + request_type = decode_request_signature.parameters["request"].annotation + if request_type == decode_request_signature.empty: + request_type = Request + + response_type = encode_response_signature.return_annotation + if response_type == encode_response_signature.empty: + response_type = Response + self._register_api_endpoints(lit_api, request_type, response_type) + + def _get_request_queue(self, api_path: str): + return self.litapi_request_queues[api_path] + + def _register_api_endpoints(self, lit_api: LitAPI, request_type, response_type): + """Register endpoint routes for the FastAPI app and setup middlewares.""" + + self._callback_runner.trigger_event(EventTypes.ON_SERVER_START.value, litserver=self) + + async def predict(request: request_type) -> response_type: + request_queue = self._get_request_queue(lit_api.api_path) self._callback_runner.trigger_event( EventTypes.ON_REQUEST.value, active_requests=self.active_requests, @@ -482,7 +549,7 @@ async def predict(request: self.request_type) -> self.response_type: logger.debug(f"Received request uid={uid}") payload = request - if self.request_type == Request: + if request_type == Request: if request.headers["Content-Type"] == "application/x-www-form-urlencoded" or request.headers[ "Content-Type" ].startswith("multipart/form-data"): @@ -490,7 +557,7 @@ async def predict(request: self.request_type) -> self.response_type: else: payload = await request.json() - self.request_queue.put((response_queue_id, uid, time.monotonic(), payload)) + request_queue.put((response_queue_id, uid, time.monotonic(), payload)) await event.wait() response, status = self.response_buffer.pop(uid) @@ -503,7 +570,8 @@ async def predict(request: self.request_type) -> self.response_type: self._callback_runner.trigger_event(EventTypes.ON_RESPONSE.value, litserver=self) return response - async def stream_predict(request: self.request_type) -> self.response_type: + async def stream_predict(request: request_type) -> response_type: + request_queue = self._get_request_queue(lit_api.api_path) self._callback_runner.trigger_event( EventTypes.ON_REQUEST.value, active_requests=self.active_requests, @@ -517,9 +585,9 @@ async def stream_predict(request: self.request_type) -> self.response_type: logger.debug(f"Received request uid={uid}") payload = request - if self.request_type == Request: + if request_type == Request: payload = await request.json() - self.request_queue.put((response_queue_id, uid, time.monotonic(), payload)) + request_queue.put((response_queue_id, uid, time.monotonic(), payload)) response = call_after_stream( self.data_streamer(q, data_available=event), @@ -529,11 +597,11 @@ async def stream_predict(request: self.request_type) -> self.response_type: ) return StreamingResponse(response) - if not self.lit_api.spec: - stream = self.lit_api.stream + if not lit_api.spec: + stream = lit_api.stream # In the future we might want to differentiate endpoints for streaming vs non-streaming # For now we allow either one or the other - endpoint = self.lit_api.api_path + endpoint = lit_api.api_path methods = ["POST"] self.app.add_api_route( endpoint, @@ -542,7 +610,7 @@ async def stream_predict(request: self.request_type) -> self.response_type: dependencies=[Depends(self.setup_auth())], ) - specs = [self.lit_api.spec] if self.lit_api.spec else [] + specs = [lit_api.spec] if lit_api.spec else [] for spec in specs: spec: LitSpec # TODO check that path is not clashing @@ -581,6 +649,21 @@ def verify_worker_status(self): time.sleep(0.05) logger.debug("One or more workers are ready to serve requests") + def _init_manager(self, num_api_servers: int): + manager = self.transport_config.manager = mp.Manager() + self.transport_config.num_consumers = num_api_servers + self.workers_setup_status = manager.dict() + + # create request queues for each unique lit_api api_path + for lit_api in self.litapi_connector: + self.litapi_request_queues[lit_api.api_path] = manager.Queue() + + if self._logger_connector._loggers: + self.logger_queue = manager.Queue() + self._logger_connector.run(self) + self._transport = create_transport_from_config(self.transport_config) + return manager + def run( self, host: str = "0.0.0.0", @@ -624,7 +707,12 @@ def run( elif api_server_worker_type is None: api_server_worker_type = "process" - manager, inference_workers = self.launch_inference_worker(num_api_servers) + manager = self._init_manager(num_api_servers) + self._logger_connector.run(self) + inference_workers = [] + for lit_api in self.litapi_connector: + _inference_workers = self.launch_inference_worker(lit_api) + inference_workers.extend(_inference_workers) self.verify_worker_status() try: @@ -658,8 +746,10 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_w workers = [] for response_queue_id in range(num_uvicorn_servers): self.app.response_queue_id = response_queue_id - if self.lit_api.spec: - self.lit_api.spec.response_queue_id = response_queue_id + for lit_api in self.litapi_connector: + if lit_api.spec: + lit_api.spec.response_queue_id = response_queue_id + app: FastAPI = copy.copy(self.app) self._prepare_app_run(app) diff --git a/src/litserve/specs/base.py b/src/litserve/specs/base.py index 5c4e81ef..b958139c 100644 --- a/src/litserve/specs/base.py +++ b/src/litserve/specs/base.py @@ -23,9 +23,12 @@ class LitSpec: def __init__(self): self._endpoints = [] - + self.api_path = None self._server: LitServer = None self._max_batch_size = 1 + self.response_buffer = None + self.request_queue = None + self.response_queue_id = None @property def stream(self): @@ -35,7 +38,10 @@ def pre_setup(self, lit_api: "LitAPI"): pass def setup(self, server: "LitServer"): - self._server = server + """This method is called by the server to connect the spec to the server.""" + self.response_buffer = server.response_buffer + self.request_queue = server._get_request_queue(self.api_path) + self.data_streamer = server.data_streamer def add_endpoint(self, path: str, endpoint: Callable, methods: List[str]): """Register an endpoint in the spec.""" diff --git a/src/litserve/specs/openai.py b/src/litserve/specs/openai.py index 4a4f0158..166d8d7a 100644 --- a/src/litserve/specs/openai.py +++ b/src/litserve/specs/openai.py @@ -331,6 +331,7 @@ def __init__( ): super().__init__() # register the endpoint + self.api_path = "/v1/chat/completions" # default api path self.add_endpoint("/v1/chat/completions", self.chat_completion, ["POST"]) self.add_endpoint("/v1/chat/completions", self.options_chat_completions, ["OPTIONS"]) @@ -436,7 +437,7 @@ def encode_response( async def get_from_queues(self, uids) -> List[AsyncGenerator]: choice_pipes = [] for uid, q, event in zip(uids, self.queues, self.events): - data = self._server.data_streamer(q, event, send_status=True) + data = self.data_streamer(q, event, send_status=True) choice_pipes.append(data) return choice_pipes @@ -454,8 +455,8 @@ async def chat_completion(self, request: ChatCompletionRequest, background_tasks request_el.n = 1 q = deque() event = asyncio.Event() - self._server.response_buffer[uid] = (q, event) - self._server.request_queue.put((response_queue_id, uid, time.monotonic(), request_el)) + self.response_buffer[uid] = (q, event) + self.request_queue.put((response_queue_id, uid, time.monotonic(), request_el)) self.queues.append(q) self.events.append(event) diff --git a/src/litserve/specs/openai_embedding.py b/src/litserve/specs/openai_embedding.py index 9100403e..4cacf61c 100644 --- a/src/litserve/specs/openai_embedding.py +++ b/src/litserve/specs/openai_embedding.py @@ -125,6 +125,7 @@ class OpenAIEmbeddingSpec(LitSpec): def __init__(self): super().__init__() # register the endpoint + self.api_path = "/v1/embeddings" # default api path self.add_endpoint("/v1/embeddings", self.embeddings_endpoint, ["POST"]) self.add_endpoint("/v1/embeddings", self.options_embeddings, ["GET"]) @@ -133,7 +134,7 @@ def setup(self, server: "LitServer"): super().setup(server) - lit_api = self._server.lit_api + lit_api = server.lit_api if inspect.isgeneratorfunction(lit_api.predict): raise ValueError( "You are using yield in your predict method, which is used for streaming.", @@ -245,12 +246,12 @@ async def embeddings_endpoint(self, request: EmbeddingRequest) -> EmbeddingRespo logger.debug("Received embedding request: %s", request) uid = uuid.uuid4() event = asyncio.Event() - self._server.response_buffer[uid] = event + self.response_buffer[uid] = event - self._server.request_queue.put_nowait((response_queue_id, uid, time.monotonic(), request.model_copy())) + self.request_queue.put_nowait((response_queue_id, uid, time.monotonic(), request.model_copy())) await event.wait() - response, status = self._server.response_buffer.pop(uid) + response, status = self.response_buffer.pop(uid) if status == LitAPIStatus.ERROR and isinstance(response, HTTPException): logger.error("Error in embedding request: %s", response) diff --git a/src/litserve/utils.py b/src/litserve/utils.py index d78467e7..2036ed86 100644 --- a/src/litserve/utils.py +++ b/src/litserve/utils.py @@ -64,10 +64,16 @@ async def azip(*async_iterables): @contextmanager def wrap_litserve_start(server: "LitServer"): + """Pytest utility to start the server in a context manager.""" server.app.response_queue_id = 0 - if server.lit_api._spec: - server.lit_api._spec.response_queue_id = 0 - manager, processes = server.launch_inference_worker(num_uvicorn_servers=1) + for lit_api in server.litapi_connector: + if lit_api.spec: + lit_api.spec.response_queue_id = 0 + + manager = server._init_manager(1) + processes = [] + for lit_api in server.litapi_connector: + processes.extend(server.launch_inference_worker(lit_api)) server._prepare_app_run(server.app) try: yield server diff --git a/tests/test_batch.py b/tests/test_batch.py index b1069f31..690f0841 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -203,7 +203,6 @@ def test_batched_loop(): loop = BatchedLoop() with patch("pickle.dumps", side_effect=StopIteration("exit loop")), pytest.raises(StopIteration, match="exit loop"): loop.run_batched_loop( - lit_api_mock, lit_api_mock, requests_queue, [FakeResponseQueue()], diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index c99e7598..e8143e4d 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -234,7 +234,7 @@ def test_start_server(mock_uvicon): def server_for_api_worker_test(simple_litapi): server = ls.LitServer(simple_litapi, devices=1) server.verify_worker_status = MagicMock() - server.launch_inference_worker = MagicMock(return_value=[MagicMock(), [MagicMock()]]) + server.launch_inference_worker = MagicMock(return_value=[MagicMock()]) server._start_server = MagicMock() server._transport = MagicMock() return server @@ -246,7 +246,7 @@ def test_server_run_with_api_server_worker_type(mock_uvicorn, server_for_api_wor server = server_for_api_worker_test server.run(api_server_worker_type="process", num_api_servers=10) - server.launch_inference_worker.assert_called_with(10) + server.launch_inference_worker.assert_called_with(server.lit_api) @pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix") @@ -258,7 +258,7 @@ def test_server_run_with_process_api_worker( server = server_for_api_worker_test server.run(api_server_worker_type=api_server_worker_type, num_api_workers=num_api_workers) - server.launch_inference_worker.assert_called_with(num_api_workers) + server.launch_inference_worker.assert_called_with(server.lit_api) actual = server._start_server.call_args assert actual[0][4] == "process", "Server should run in process mode" mock_uvicorn.Config.assert_called() @@ -269,7 +269,7 @@ def test_server_run_with_process_api_worker( def test_server_run_with_thread_api_worker(mock_uvicorn, server_for_api_worker_test): server = server_for_api_worker_test server.run(api_server_worker_type="thread") - server.launch_inference_worker.assert_called_with(1) + server.launch_inference_worker.assert_called_with(server.lit_api) assert server._start_server.call_args[0][4] == "thread", "Server should run in thread mode" mock_uvicorn.Config.assert_called() @@ -291,7 +291,7 @@ def test_server_run_windows(mock_uvicorn): api = ls.test_examples.SimpleLitAPI() server = ls.LitServer(api) server.verify_worker_status = MagicMock() - server.launch_inference_worker = MagicMock(return_value=[MagicMock(), [MagicMock()]]) + server.launch_inference_worker = MagicMock(return_value=[MagicMock()]) server._transport = MagicMock() server._start_server = MagicMock() @@ -306,14 +306,14 @@ def test_server_terminate(): server._transport = MagicMock() with ( + patch("litserve.server.LitServer._init_manager", return_value=MagicMock()) as mock_init_manager, patch("litserve.server.LitServer._start_server", side_effect=Exception("mocked error")) as mock_start, - patch( - "litserve.server.LitServer.launch_inference_worker", return_value=(MagicMock(), [MagicMock()]) - ) as mock_launch, + patch("litserve.server.LitServer.launch_inference_worker", return_value=([MagicMock()])) as mock_launch, ): with pytest.raises(Exception, match="mocked error"): server.run(port=8001) + mock_init_manager.assert_called() mock_launch.assert_called() mock_start.assert_called() server._transport.close.assert_called() diff --git a/tests/test_litapi.py b/tests/test_litapi.py index 170585fb..4a6b4f16 100644 --- a/tests/test_litapi.py +++ b/tests/test_litapi.py @@ -277,7 +277,9 @@ def test_log(): api.request_timeout = 30 assert api._logger_queue is None, "Logger queue should be None" server = ls.LitServer(api, loggers=TestLogger()) - server.launch_inference_worker(1) + server._init_manager(1) + server._logger_connector.run(server) + server.launch_inference_worker(api) api.log("time", 0.1) assert server.logger_queue.get() == ("time", 0.1) diff --git a/tests/test_loops.py b/tests/test_loops.py index fd2b095d..c8c6ff8b 100644 --- a/tests/test_loops.py +++ b/tests/test_loops.py @@ -127,7 +127,7 @@ def test_single_loop(loop_args): lit_loop = SingleLoop() with pytest.raises(StopIteration, match="exit loop"): - lit_loop.run_single_loop(lit_api_mock, None, requests_queue, transport, callback_runner=NOOP_CB_RUNNER) + lit_loop.run_single_loop(lit_api_mock, requests_queue, transport, callback_runner=NOOP_CB_RUNNER) @pytest.mark.asyncio @@ -140,7 +140,6 @@ async def test_single_loop_process_single_async_request(async_loop_args, mock_tr await loop._process_single_request( request, lit_api_mock, - None, mock_transport, NOOP_CB_RUNNER, ) @@ -162,7 +161,7 @@ def test_run_single_loop_with_async(async_loop_args, monkeypatch): import contextlib with contextlib.suppress(KeyboardInterrupt): - loop._run_single_loop_with_async(lit_api_mock, None, requests_queue, mock_transport, NOOP_CB_RUNNER) + loop._run_single_loop_with_async(lit_api_mock, requests_queue, mock_transport, NOOP_CB_RUNNER) response = asyncio.get_event_loop().run_until_complete(mock_transport.areceive(consumer_id=0)) assert response == ("uuid-123", ({"output": 1}, ls.utils.LitAPIStatus.OK)) @@ -211,7 +210,6 @@ def fake_encode(output): lit_loop = StreamingLoop() with pytest.raises(StopIteration, match="exit loop"): lit_loop.run_streaming_loop( - fake_stream_api, fake_stream_api, requests_queue, transport, @@ -278,7 +276,6 @@ def fake_encode(output_iter): transport = FakeBatchStreamTransport(num_streamed_outputs) with pytest.raises(StopIteration, match="finish streaming"): lit_loop.run_batched_streaming_loop( - fake_stream_api, fake_stream_api, requests_queue, transport=transport, @@ -295,14 +292,18 @@ def test_inference_worker(mock_single_loop, mock_batched_loop): lit_api_mock.max_batch_size = 2 lit_api_mock.batch_timeout = 0 lit_api_mock.enable_async = False + lit_api_mock.stream = False + lit_api_mock.api_path = "/predict" + lit_api_mock.loop = "auto" inference_worker( lit_api_mock, - *[MagicMock()] * 5, - stream=False, + "cpu", + 0, + MagicMock(), + MagicMock(), workers_setup_status={}, callback_runner=NOOP_CB_RUNNER, - loop="auto", ) mock_batched_loop.assert_called_once() @@ -310,14 +311,18 @@ def test_inference_worker(mock_single_loop, mock_batched_loop): lit_api_mock.max_batch_size = 1 lit_api_mock.batch_timeout = 0 lit_api_mock.enable_async = False + lit_api_mock.stream = False + lit_api_mock.api_path = "/predict" + lit_api_mock.loop = "auto" inference_worker( lit_api_mock, - *[MagicMock()] * 5, - stream=False, + "cpu", + 0, + MagicMock(), + MagicMock(), workers_setup_status={}, callback_runner=NOOP_CB_RUNNER, - loop="auto", ) mock_single_loop.assert_called_once() @@ -335,7 +340,7 @@ async def test_run_single_loop(mock_transport): # Run the loop in a separate thread to allow it to be stopped lit_loop = SingleLoop() loop_thread = threading.Thread( - target=lit_loop.run_single_loop, args=(lit_api, None, request_queue, transport, NOOP_CB_RUNNER) + target=lit_loop.run_single_loop, args=(lit_api, request_queue, transport, NOOP_CB_RUNNER) ) loop_thread.start() @@ -367,7 +372,7 @@ async def test_run_single_loop_timeout(): lit_loop = SingleLoop() loop_thread = threading.Thread( - target=lit_loop.run_single_loop, args=(lit_api, None, request_queue, transport, NOOP_CB_RUNNER) + target=lit_loop.run_single_loop, args=(lit_api, request_queue, transport, NOOP_CB_RUNNER) ) loop_thread.start() @@ -399,7 +404,7 @@ async def test_run_batched_loop(): lit_loop = BatchedLoop() loop_thread = threading.Thread( target=lit_loop.run_batched_loop, - args=(lit_api, None, request_queue, transport, NOOP_CB_RUNNER), + args=(lit_api, request_queue, transport, NOOP_CB_RUNNER), ) loop_thread.start() @@ -442,7 +447,7 @@ async def test_run_batched_loop_timeout(mock_transport): lit_loop = BatchedLoop() loop_thread = threading.Thread( target=lit_loop.run_batched_loop, - args=(lit_api, None, request_queue, transport, NOOP_CB_RUNNER), + args=(lit_api, request_queue, transport, NOOP_CB_RUNNER), ) loop_thread.start() @@ -471,7 +476,7 @@ async def test_run_streaming_loop(mock_transport): # Run the loop in a separate thread to allow it to be stopped lit_loop = StreamingLoop() loop_thread = threading.Thread( - target=lit_loop.run_streaming_loop, args=(lit_api, None, request_queue, mock_transport, NOOP_CB_RUNNER) + target=lit_loop.run_streaming_loop, args=(lit_api, request_queue, mock_transport, NOOP_CB_RUNNER) ) loop_thread.start() @@ -502,7 +507,7 @@ async def test_run_streaming_loop_timeout(mock_transport): # Run the loop in a separate thread to allow it to be stopped lit_loop = StreamingLoop() loop_thread = threading.Thread( - target=lit_loop.run_streaming_loop, args=(lit_api, None, request_queue, mock_transport, NOOP_CB_RUNNER) + target=lit_loop.run_streaming_loop, args=(lit_api, request_queue, mock_transport, NOOP_CB_RUNNER) ) loop_thread.start() @@ -559,24 +564,20 @@ class TestLoop(LitLoop): def __call__( self, lit_api: LitAPI, - lit_spec: Optional[LitSpec], device: str, worker_id: int, request_queue: Queue, transport: MessageTransport, - stream: bool, workers_setup_status: Dict[int, str], callback_runner: CallbackRunner, ): try: self.run( lit_api, - lit_spec, device, worker_id, request_queue, transport, - stream, workers_setup_status, callback_runner, ) @@ -586,12 +587,10 @@ def __call__( def run( self, lit_api: LitAPI, - lit_spec: Optional[LitSpec], device: str, worker_id: int, request_queue: Queue, transport: MessageTransport, - stream: bool, workers_setup_status: Dict[int, str], callback_runner: CallbackRunner, ): @@ -617,7 +616,7 @@ async def test_custom_loop(mock_transport): request_queue = Queue() request_queue.put((0, "UUID-001", time.monotonic(), {"input": 4.0})) - loop(lit_api, None, "cpu", 0, request_queue, mock_transport, False, {}, NOOP_CB_RUNNER) + loop(lit_api, "cpu", 0, request_queue, mock_transport, {}, NOOP_CB_RUNNER) response = await mock_transport.areceive(0) assert response[0] == "UUID-001" assert response[1][0] == {"output": 16.0} @@ -837,7 +836,7 @@ async def test_continuous_batching_run(continuous_batching_setup): response_queue_id, uid, _, input = (0, "UUID-001", time.monotonic(), {"input": "Hello"}) lit_loop.add_request(uid, input, lit_api, None) lit_loop.response_queue_ids[uid] = response_queue_id - await lit_loop.run(lit_api, None, "cpu", 0, request_queue, mock_transport, True, {}, NOOP_CB_RUNNER) + await lit_loop.run(lit_api, "cpu", 0, request_queue, mock_transport, {}, NOOP_CB_RUNNER) results = [] for i in range(5): diff --git a/tests/test_multiple_endpoints.py b/tests/test_multiple_endpoints.py new file mode 100644 index 00000000..19ea91b8 --- /dev/null +++ b/tests/test_multiple_endpoints.py @@ -0,0 +1,70 @@ +import pytest +from asgi_lifespan import LifespanManager +from httpx import ASGITransport, AsyncClient + +import litserve as ls +from litserve.utils import wrap_litserve_start + + +class InferencePipeline(ls.LitAPI): + def __init__(self, name=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name + + def setup(self, device): + self.model = lambda x: x**2 + + def decode_request(self, request): + return request["input"] + + def predict(self, x): + return self.model(x) + + def encode_response(self, output): + return {"output": output, "name": self.name} + + +@pytest.mark.asyncio +async def test_multiple_endpoints(): + api1 = InferencePipeline(name="api1", api_path="/api1") + api2 = InferencePipeline(name="api2", api_path="/api2") + server = ls.LitServer([api1, api2]) + + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient( + transport=ASGITransport(app=manager.app), base_url="http://test" + ) as ac: + resp = await ac.post("/api1", json={"input": 2.0}, timeout=10) + assert resp.status_code == 200, "Server response should be 200 (OK)" + assert resp.json()["output"] == 4.0, "output from Identity server must be same as input" + assert resp.json()["name"] == "api1", "name from Identity server must be same as input" + + resp = await ac.post("/api2", json={"input": 5.0}, timeout=10) + assert resp.status_code == 200, "Server response should be 200 (OK)" + assert resp.json()["output"] == 25.0, "output from Identity server must be same as input" + assert resp.json()["name"] == "api2", "name from Identity server must be same as input" + + +def test_multiple_endpoints_with_same_path(): + api1 = InferencePipeline(name="api1", api_path="/api1") + api2 = InferencePipeline(name="api2", api_path="/api1") + with pytest.raises(ValueError, match="api_path /api1 is already in use by"): + ls.LitServer([api1, api2]) + + +def test_reserved_paths(): + api1 = InferencePipeline(name="api1", api_path="/health") + api2 = InferencePipeline(name="api2", api_path="/info") + with pytest.raises(ValueError, match="api_path /health is already in use by LitServe healthcheck"): + ls.LitServer([api1, api2]) + + +def test_check_mixed_streaming_configuration(): + api1 = InferencePipeline(name="api1", api_path="/api1", stream=True) + api2 = InferencePipeline(name="api2", api_path="/api2", stream=False) + with pytest.raises( + ValueError, + match="Inconsistent streaming configuration: all endpoints must either enable streaming or disable it. " + "Mixed configurations are not supported yet.", + ): + ls.LitServer([api1, api2]) diff --git a/tests/test_openai_embedding.py b/tests/test_openai_embedding.py index 7306fb30..00a93496 100644 --- a/tests/test_openai_embedding.py +++ b/tests/test_openai_embedding.py @@ -37,7 +37,7 @@ @pytest.mark.asyncio async def test_openai_embedding_spec_with_single_input(openai_embedding_request_data): spec = OpenAIEmbeddingSpec() - server = ls.LitServer(TestEmbedAPI(), spec=spec) + server = ls.LitServer(TestEmbedAPI(spec=spec)) with wrap_litserve_start(server) as server: async with LifespanManager(server.app) as manager, AsyncClient( @@ -54,8 +54,7 @@ async def test_openai_embedding_spec_with_single_input(openai_embedding_request_ @pytest.mark.asyncio async def test_openai_embedding_spec_with_multiple_inputs(openai_embedding_request_data_array): spec = OpenAIEmbeddingSpec() - server = ls.LitServer(TestEmbedAPI(), spec=spec) - + server = ls.LitServer(TestEmbedAPI(spec=spec)) with wrap_litserve_start(server) as server: async with LifespanManager(server.app) as manager, AsyncClient( transport=ASGITransport(app=manager.app), base_url="http://test" @@ -73,7 +72,7 @@ async def test_openai_embedding_spec_with_multiple_inputs(openai_embedding_reque @pytest.mark.asyncio async def test_openai_embedding_spec_with_usage(openai_embedding_request_data): spec = OpenAIEmbeddingSpec() - server = ls.LitServer(TestEmbedAPIWithUsage(), spec=spec) + server = ls.LitServer(TestEmbedAPIWithUsage(spec=spec)) with wrap_litserve_start(server) as server: async with LifespanManager(server.app) as manager, AsyncClient(