Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/litserve/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,7 @@ def _inject_context(context: Union[List[dict], dict], func, *args, **kwargs):
return func(*args, **kwargs)


async def _async_inject_context(context: Union[List[dict], dict], func, *args, **kwargs):
sig = inspect.signature(func)

# Determine if we need to inject context
if "context" in sig.parameters:
kwargs["context"] = context

async def _handle_async_function(func, *args, **kwargs):
# Call the function based on its type
if inspect.isasyncgenfunction(func):
# Async generator - return directly (don't await)
Expand All @@ -69,6 +63,16 @@ async def _async_inject_context(context: Union[List[dict], dict], func, *args, *
return result


async def _async_inject_context(context: Union[List[dict], dict], func, *args, **kwargs):
sig = inspect.signature(func)

# Determine if we need to inject context
if "context" in sig.parameters:
kwargs["context"] = context

return await _handle_async_function(func, *args, **kwargs)


def collate_requests(
lit_api: LitAPI,
request_queue: Queue,
Expand Down
2 changes: 1 addition & 1 deletion src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ async def handle_request(self, request, request_type) -> Response:
class RegularRequestHandler(BaseRequestHandler):
async def handle_request(self, request, request_type) -> Response:
try:
logger.info(f"Handling request: {request}")
logger.debug(f"Handling request: {request}")
# Prepare request
payload = await self._prepare_request(request, request_type)

Expand Down
26 changes: 26 additions & 0 deletions tests/integration/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
from asgi_lifespan import LifespanManager
from httpx import ASGITransport, AsyncClient

import litserve as ls
from litserve.utils import wrap_litserve_start


class MinimalAsyncAPI(ls.LitAPI):
def setup(self, device):
self.model = None

async def predict(self, x):
y = x["input"] ** 2
return {"output": y}


@pytest.mark.asyncio
async def test_async_api():
server = ls.LitServer(MinimalAsyncAPI(enable_async=True))
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(
transport=ASGITransport(app=manager.app), base_url="http://test"
) as ac:
response = await ac.post("/predict", json={"input": 2})
assert response.json() == {"output": 4}
31 changes: 29 additions & 2 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import threading
import time
from queue import Empty, Queue
from typing import Dict, List, Optional
from typing import AsyncGenerator, Dict, List, Optional
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -33,7 +33,7 @@
from litserve import LitAPI
from litserve.callbacks import CallbackRunner
from litserve.loops import BatchedStreamingLoop, LitLoop, Output, StreamingLoop, inference_worker
from litserve.loops.base import DefaultLoop
from litserve.loops.base import DefaultLoop, _async_inject_context, _handle_async_function
from litserve.loops.continuous_batching_loop import (
ContinuousBatchingLoop,
notify_timed_out_requests,
Expand Down Expand Up @@ -918,3 +918,30 @@ async def test_continuous_batching_run(continuous_batching_setup):
assert o == ""
assert status == LitAPIStatus.FINISH_STREAMING
assert response_type == LoopResponseType.STREAMING


@pytest.mark.asyncio
async def test_handle_async_function():
async def async_func():
return "async"

def sync_func():
return "sync"

async def async_gen():
for i in range(3):
yield i

assert await _handle_async_function(async_func) == "async"
assert await _handle_async_function(sync_func) == "sync"
async_gen = await _handle_async_function(async_gen)
assert isinstance(async_gen, AsyncGenerator)


@pytest.mark.asyncio
async def test_async_inject_context():
async def async_func(x, context=0):
return x * context["a"]

context = {"a": 1}
assert await _async_inject_context(context, async_func, 2) == 2
Loading