Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 47 additions & 24 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,30 +192,53 @@ def __init__(

def _validate_async_methods(self):
"""Validate that async methods are properly implemented when enable_async is True."""
if self.enable_async:
# check if LitAPI methods are coroutines or async generators
for method in ["decode_request", "predict", "encode_response"]:
method_obj = getattr(self, method)
if not (asyncio.iscoroutinefunction(method_obj) or inspect.isasyncgenfunction(method_obj)):
raise ValueError("""LitAPI(enable_async=True) requires all methods to be coroutines.

Please either set enable_async=False or implement the following methods as coroutines:
Example:
class MyLitAPI(LitAPI):
async def decode_request(self, request, **kwargs):
return request
async def predict(self, x, **kwargs):
return x
async def encode_response(self, output, **kwargs):
return output

Streaming example:
class MyStreamingAPI(LitAPI):
async def predict(self, x, **kwargs):
for i in range(10):
await asyncio.sleep(0.1) # simulate async work
yield f"Token {i}: {x}"
""")
if not self.enable_async:
return

# Define validation rules for each method
validation_rules = {
"decode_request": {
"required_types": [asyncio.iscoroutinefunction, inspect.isasyncgenfunction],
"error_type": "warning",
"message": "should be an async function or async generator when enable_async=True",
},
"encode_response": {
"required_types": [asyncio.iscoroutinefunction, inspect.isasyncgenfunction],
"error_type": "warning",
"message": "should be an async function or async generator when enable_async=True",
},
"predict": {
"required_types": [inspect.isasyncgenfunction, asyncio.iscoroutinefunction],
"error_type": "error",
"message": "must be an async generator or async function when enable_async=True",
},
}

errors = []
warnings_list = []

for method_name, rules in validation_rules.items():
method_obj = getattr(self, method_name)

# Check if method satisfies any of the required types
is_valid = any(check_func(method_obj) for check_func in rules["required_types"])

if not is_valid:
message = f"{method_name} {rules['message']}"

if rules["error_type"] == "error":
errors.append(message)
else:
warnings_list.append(message)

# Emit warnings
for warning_msg in warnings_list:
warnings.warn(f"{warning_msg}. LitServe will asyncify the method.", UserWarning)

# Raise errors if any
if errors:
error_msg = "Async validation failed:\n" + "\n".join(f"- {err}" for err in errors)
raise ValueError(error_msg)

@abstractmethod
def setup(self, device):
Expand Down
22 changes: 16 additions & 6 deletions src/litserve/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,24 @@ def _inject_context(context: Union[List[dict], dict], func, *args, **kwargs):

async def _async_inject_context(context: Union[List[dict], dict], func, *args, **kwargs):
sig = inspect.signature(func)
is_async_gen = inspect.isasyncgenfunction(func)

# Determine if we need to inject context
if "context" in sig.parameters:
result = (
await func(*args, **kwargs, context=context) if not is_async_gen else func(*args, **kwargs, context=context)
)
else:
result = await func(*args, **kwargs) if not is_async_gen else func(*args, **kwargs)
kwargs["context"] = context

# Call the function based on its type
if inspect.isasyncgenfunction(func):
# Async generator - return directly (don't await)
return func(*args, **kwargs)
if asyncio.iscoroutinefunction(func):
# Async function - await the result
return await func(*args, **kwargs)
# Sync function - call directly, then await if result is awaitable
result = func(*args, **kwargs)

# Check if the result is awaitable (coroutine)
if asyncio.iscoroutine(result):
return await result

return result

Expand Down
28 changes: 14 additions & 14 deletions src/litserve/loops/streaming_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,17 @@ async def _process_streaming_request(

# When using async, predict should return an async generator
# and encode_response should handle async generators
async for item in y_gen:
# For each item from predict, pass to encode_response
# The _async_inject_context already handles async generators correctly
enc_result = await _async_inject_context(
context,
lit_api.encode_response,
[item], # Wrap in list since encode_response expects an iterable
)
# The _async_inject_context already handles async generators correctly
enc_result = await _async_inject_context(
context,
lit_api.encode_response,
y_gen,
)

# encode_response should also return an async generator
async for y_enc in enc_result:
y_enc = lit_api.format_encoded_response(y_enc)
self.put_response(
transport, response_queue_id, uid, y_enc, LitAPIStatus.OK, LoopResponseType.STREAMING
)
# encode_response should also return an async generator
async for y_enc in enc_result:
y_enc = lit_api.format_encoded_response(y_enc)
self.put_response(transport, response_queue_id, uid, y_enc, LitAPIStatus.OK, LoopResponseType.STREAMING)

self.put_response(
transport, response_queue_id, uid, "", LitAPIStatus.FINISH_STREAMING, LoopResponseType.STREAMING
Expand Down Expand Up @@ -200,6 +196,10 @@ def run_streaming_loop_async(
transport: MessageTransport,
callback_runner: CallbackRunner,
):
if lit_api.spec:
# wrap the default implementation of the spec in an async spec wrapper
lit_api.spec = lit_api.spec.as_async()

async def process_requests():
event_loop = asyncio.get_running_loop()
pending_tasks = set()
Expand Down
20 changes: 19 additions & 1 deletion src/litserve/specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from typing import TYPE_CHECKING, Callable, List
from typing import TYPE_CHECKING, AsyncGenerator, Callable, Generator, List, Optional, Union

if TYPE_CHECKING:
from litserve import LitAPI, LitServer
Expand Down Expand Up @@ -64,3 +64,21 @@ def encode_response(self, output, meta_kwargs):

"""
pass

def as_async(self) -> "_AsyncSpecWrapper":
return _AsyncSpecWrapper(self)


class _AsyncSpecWrapper:
def __init__(self, spec: LitSpec):
self._spec = spec

def __getattr__(self, name):
# Delegate all other attributes/methods to the wrapped spec
return getattr(self._spec, name)

async def decode_request(self, request, context_kwargs: Optional[dict] = None):
return self._spec.decode_request(request, context_kwargs)

async def encode_response(self, output: Union[Generator, AsyncGenerator], context_kwargs: Optional[dict] = None):
return self._spec.encode_response(output, context_kwargs)
43 changes: 31 additions & 12 deletions src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field

from litserve.specs.base import LitSpec
from litserve.specs.base import LitSpec, _AsyncSpecWrapper
from litserve.utils import LitAPIStatus, azip

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -264,13 +264,14 @@ def encode_response(self, output):
```
"""

ASYNC_LITAPI_VALIDATION_MSG = """LitAPI.decode_request, LitAPI.predict, and LitAPI.encode_response must all be async
coroutines (use 'async def') while using the OpenAISpec with async enabled in LitAPI.
ASYNC_LITAPI_VALIDATION_MSG = """Error: {}

Additionally, LitAPI.predict and LitAPI.encode_response must be async generators (use 'yield' or 'yield from' inside
an 'async def' function).
`enable_async` is set but LitAPI method is not async. To use async with OpenAISpec, you need to make the following changes:

- LitAPI.decode_request can be a regular function or an async function.
- LitAPI.predict must be an async generator (use 'yield' or 'yield from' inside an 'async def' function).
- LitAPI.encode_response can be a regular function or an async generator.

Error: {}

Please follow the examples below for guidance on how to use the spec in async mode:

Expand Down Expand Up @@ -323,7 +324,7 @@ async def encode_response(self, output):
async for out in output:
yield ChatMessage(role="assistant", content=out)
```
"""
""" # noqa: E501


class OpenAISpec(LitSpec):
Expand All @@ -346,14 +347,23 @@ def pre_setup(self, lit_api: "LitAPI"):
is_encode_response_original = lit_api.encode_response.__code__ is LitAPI.encode_response.__code__

if lit_api.enable_async:
# warning for decode_request and encode_response
if not asyncio.iscoroutinefunction(lit_api.decode_request):
raise ValueError(ASYNC_LITAPI_VALIDATION_MSG.format("decode_request is not a coroutine"))
logger.info("decode_request is not a coroutine function. LitServe will asyncify it.")
if not inspect.isasyncgenfunction(lit_api.encode_response):
logger.info("encode_response is not an async generator. LitServe will asyncify it.")

if not inspect.isasyncgenfunction(lit_api.predict):
raise ValueError(ASYNC_LITAPI_VALIDATION_MSG.format("predict is not a generator"))

if not inspect.isasyncgenfunction(lit_api.encode_response):
raise ValueError(ASYNC_LITAPI_VALIDATION_MSG.format("encode_response is not a generator"))
raise ValueError(ASYNC_LITAPI_VALIDATION_MSG.format("predict must be an async generator"))

if (
not is_encode_response_original
and not inspect.isgeneratorfunction(lit_api.encode_response)
and not inspect.isasyncgenfunction(lit_api.encode_response)
):
raise ValueError(
ASYNC_LITAPI_VALIDATION_MSG.format("encode_response is neither a generator nor an async generator")
)

else:
for method in ["decode_request", "predict", "encode_response"]:
Expand All @@ -375,6 +385,9 @@ def setup(self, server: "LitServer"):
super().setup(server)
print("OpenAI spec setup complete")

def as_async(self) -> "_AsyncOpenAISpecWrapper":
return _AsyncOpenAISpecWrapper(self)

def populate_context(self, context, request):
data = request.dict()
data.pop("messages")
Expand Down Expand Up @@ -550,3 +563,9 @@ async def non_streaming_completion(self, request: ChatCompletionRequest, generat
choices.append(choice)

return ChatCompletionResponse(model=model, choices=choices, usage=sum(usage_infos))


class _AsyncOpenAISpecWrapper(_AsyncSpecWrapper):
async def encode_response(self, output_generator: AsyncGenerator, context_kwargs: Optional[dict] = None):
async for output in output_generator:
yield self._spec._encode_response(output)
2 changes: 1 addition & 1 deletion tests/e2e/default_async_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def predict(self, x):
yield self.model(i)

async def encode_response(self, output):
for out in output:
async for out in output:
yield {"output": out}


Expand Down
2 changes: 1 addition & 1 deletion tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ async def predict(self, x):
yield self.model(i)

async def encode_response(self, output_stream):
for output in output_stream:
async for output in output_stream:
yield {"output": output}


Expand Down
4 changes: 3 additions & 1 deletion tests/test_litapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,5 +285,7 @@ def test_log():


def test_enable_async_not_set():
with pytest.raises(ValueError, match=r"LitAPI\(enable_async=True\) requires all methods to be coroutines\."):
with pytest.raises(
ValueError, match=r"predict must be an async generator or async function when enable_async=True"
):
ls.test_examples.SimpleLitAPI(enable_async=True)
2 changes: 1 addition & 1 deletion tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ async def predict(self, x):
yield {"output": i}

async def encode_response(self, output):
for out in output:
async for out in output:
yield out["output"]


Expand Down
23 changes: 4 additions & 19 deletions tests/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,30 +370,21 @@ class IncorrectDecodeAsyncAPI(IncorrectAsyncAPI):
def decode_request(self, request):
return request

def _validate_async_methods(self):
return None


class IncorrectEncodeAsyncAPI(IncorrectAsyncAPI):
async def predict(self, x):
yield "This is a generated output"


@pytest.mark.asyncio
def test_openai_spec_asyncapi_decode_request_validation():
with pytest.raises(ValueError, match="decode_request is not a coroutine"):
ls.LitServer(IncorrectDecodeAsyncAPI(enable_async=True), spec=OpenAISpec())


@pytest.mark.asyncio
def test_openai_spec_asyncapi_predict_validation():
with pytest.raises(ValueError, match="predict is not a generator"):
async def test_openai_spec_asyncapi_predict_validation():
with pytest.raises(ValueError, match="predict must be an async generator"):
ls.LitServer(IncorrectAsyncAPI(enable_async=True), spec=OpenAISpec())


@pytest.mark.asyncio
def test_openai_spec_asyncapi_encode_response_validation():
with pytest.raises(ValueError, match="encode_response is not a generator"):
with pytest.raises(ValueError, match="encode_response is neither a generator nor an async generator"):
ls.LitServer(IncorrectEncodeAsyncAPI(enable_async=True), spec=OpenAISpec())


Expand All @@ -414,12 +405,6 @@ async def encode_response(self, output):
yield {"role": "assistant", "content": output}


@pytest.mark.asyncio
def test_openai_asyncapi_decode_not_implemented():
with pytest.raises(ValueError, match=r"LitAPI\(enable_async=True\) requires all methods to be coroutines\."):
ls.LitServer(DecodeNotImplementedAsyncOpenAILitAPI(enable_async=True), spec=OpenAISpec())


class AsyncOpenAILitAPI(ls.LitAPI):
def setup(self, device):
self.model = None
Expand All @@ -433,7 +418,7 @@ async def predict(self, x):
yield token

async def encode_response(self, output_stream, context):
for output in output_stream:
async for output in output_stream:
yield {"role": "assistant", "content": output}


Expand Down
Loading