Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions src/litserve/specs/openai_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,27 +135,38 @@ def setup(self, server: "LitServer"):
super().setup(server)

lit_api = server.lit_api
if inspect.isgeneratorfunction(lit_api.predict):
raise ValueError(
"You are using yield in your predict method, which is used for streaming.",
"OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ",
"is not a sequential operation.",
"Please consider replacing yield with return in predict.\n",
EMBEDDING_API_EXAMPLE,
)

is_encode_response_original = lit_api.encode_response.__code__ is LitAPI.encode_response.__code__
if not is_encode_response_original and inspect.isgeneratorfunction(lit_api.encode_response):
raise ValueError(
"You are using yield in your encode_response method, which is used for streaming.",
"OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ",
"is not a sequential operation.",
"Please consider replacing yield with return in encode_response.\n",
EMBEDDING_API_EXAMPLE,
)
if isinstance(lit_api, LitAPI):
self._check_lit_api(lit_api)
elif isinstance(lit_api, list):
for api in lit_api:
self._check_lit_api(api)

print("OpenAI Embedding Spec is ready.")

def _check_lit_api(self, api):
from litserve import LitAPI

if isinstance(api.spec, OpenAIEmbeddingSpec):
if inspect.isgeneratorfunction(api.predict):
raise ValueError(
"You are using yield in your predict method, which is used for streaming.",
"OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ",
"is not a sequential operation.",
"Please consider replacing yield with return in predict.\n",
EMBEDDING_API_EXAMPLE,
)

is_encode_response_original = api.encode_response.__code__ is LitAPI.encode_response.__code__
if not is_encode_response_original and inspect.isgeneratorfunction(api.encode_response):
raise ValueError(
"You are using yield in your encode_response method, which is used for streaming.",
"OpenAIEmbeddingSpec doesn't support streaming because producing embeddings ",
"is not a sequential operation.",
"Please consider replacing yield with return in encode_response.\n",
EMBEDDING_API_EXAMPLE,
)

def decode_request(self, request: EmbeddingRequest, context_kwargs: Optional[dict] = None) -> List[str]:
return request.input

Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_openai_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ async def test_openai_embedding_spec_with_single_input(openai_embedding_request_
assert len(resp.json()["data"][0]["embedding"]) == 768, "Embedding length should be 768"


@pytest.mark.asyncio
async def test_openai_embedding_spec_with_multi_endpoint(openai_embedding_request_data):
server = ls.LitServer([
TestEmbedAPI(spec=OpenAIEmbeddingSpec()),
])
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(
transport=ASGITransport(app=manager.app), base_url="http://test"
) as ac:
resp = await ac.post("/v2/embeddings", json=openai_embedding_request_data, timeout=10)
assert resp.status_code == 200, "Status code should be 200"
assert resp.json()["object"] == "list", "Object should be list"
assert resp.json()["data"][0]["index"] == 0, "Index should be 0"
assert len(resp.json()["data"]) == 1, "Length of data should be 1"
assert len(resp.json()["data"][0]["embedding"]) == 768, "Embedding length should be 768"


@pytest.mark.asyncio
async def test_openai_embedding_spec_with_multiple_inputs(openai_embedding_request_data_array):
spec = OpenAIEmbeddingSpec()
Expand Down
Loading