Skip to content

Commit 25ef3db

Browse files
bhimrazypre-commit-ci[bot]Borda
authored
refactor: validation logic to pre_setup method in Embed Spec, where there is access to correct api instance (#573)
* refactor: move validation to presetup from setup where there is access to correct api instance * fix: improve validation tests for yield usage in OpenAIEmbeddingSpec * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 9aa4dd5 commit 25ef3db

File tree

2 files changed

+26
-46
lines changed

2 files changed

+26
-46
lines changed

src/litserve/specs/openai_embedding.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import numpy as np
3333
import torch
3434

35-
from litserve import LitServer
35+
from litserve import LitAPI, LitServer
3636

3737

3838
class EmbeddingRequest(BaseModel):
@@ -129,44 +129,32 @@ def __init__(self):
129129
self.add_endpoint("/v1/embeddings", self.embeddings_endpoint, ["POST"])
130130
self.add_endpoint("/v1/embeddings", self.options_embeddings, ["GET"])
131131

132-
def setup(self, server: "LitServer"):
132+
def pre_setup(self, lit_api: "LitAPI"):
133133
from litserve import LitAPI
134134

135-
super().setup(server)
136-
137-
lit_api = server.lit_api
135+
if inspect.isgeneratorfunction(lit_api.predict):
136+
raise ValueError(
137+
"You are using yield in your predict method, which is used for streaming.",
138+
"OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ",
139+
"is not a sequential operation.",
140+
"Please consider replacing yield with return in predict.\n",
141+
EMBEDDING_API_EXAMPLE,
142+
)
138143

139-
if isinstance(lit_api, LitAPI):
140-
self._check_lit_api(lit_api)
141-
elif isinstance(lit_api, list):
142-
for api in lit_api:
143-
self._check_lit_api(api)
144+
is_encode_response_original = lit_api.encode_response.__code__ is LitAPI.encode_response.__code__
145+
if not is_encode_response_original and inspect.isgeneratorfunction(lit_api.encode_response):
146+
raise ValueError(
147+
"You are using yield in your encode_response method, which is used for streaming.",
148+
"OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ",
149+
"is not a sequential operation.",
150+
"Please consider replacing yield with return in encode_response.\n",
151+
EMBEDDING_API_EXAMPLE,
152+
)
144153

154+
def setup(self, server: "LitServer"):
155+
super().setup(server)
145156
print("OpenAI Embedding Spec is ready.")
146157

147-
def _check_lit_api(self, api):
148-
from litserve import LitAPI
149-
150-
if isinstance(api.spec, OpenAIEmbeddingSpec):
151-
if inspect.isgeneratorfunction(api.predict):
152-
raise ValueError(
153-
"You are using yield in your predict method, which is used for streaming.",
154-
"OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ",
155-
"is not a sequential operation.",
156-
"Please consider replacing yield with return in predict.\n",
157-
EMBEDDING_API_EXAMPLE,
158-
)
159-
160-
is_encode_response_original = api.encode_response.__code__ is LitAPI.encode_response.__code__
161-
if not is_encode_response_original and inspect.isgeneratorfunction(api.encode_response):
162-
raise ValueError(
163-
"You are using yield in your encode_response method, which is used for streaming.",
164-
"OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ",
165-
"is not a sequential operation.",
166-
"Please consider replacing yield with return in encode_response.\n",
167-
EMBEDDING_API_EXAMPLE,
168-
)
169-
170158
def decode_request(self, request: EmbeddingRequest, context_kwargs: Optional[dict] = None) -> List[str]:
171159
return request.input
172160

tests/unit/test_openai_embedding.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,19 +107,11 @@ async def test_openai_embedding_spec_with_usage(openai_embedding_request_data):
107107

108108
@pytest.mark.asyncio
109109
async def test_openai_embedding_spec_validation(openai_request_data):
110-
server = ls.LitServer(TestEmbedAPIWithYieldPredict(), spec=OpenAIEmbeddingSpec())
111-
with pytest.raises(ValueError, match="You are using yield in your predict method"), wrap_litserve_start(
112-
server
113-
) as server:
114-
async with LifespanManager(server.app):
115-
pass
116-
117-
server = ls.LitServer(TestEmbedAPIWithYieldEncodeResponse(), spec=OpenAIEmbeddingSpec())
118-
with pytest.raises(ValueError, match="You are using yield in your encode_response method"), wrap_litserve_start(
119-
server
120-
) as server:
121-
async with LifespanManager(server.app):
122-
pass
110+
with pytest.raises(ValueError, match="You are using yield in your predict method"):
111+
ls.LitServer(TestEmbedAPIWithYieldPredict(), spec=OpenAIEmbeddingSpec())
112+
113+
with pytest.raises(ValueError, match="You are using yield in your encode_response method"):
114+
ls.LitServer(TestEmbedAPIWithYieldEncodeResponse(), spec=OpenAIEmbeddingSpec())
123115

124116

125117
@pytest.mark.asyncio

0 commit comments

Comments
 (0)