Skip to content
16 changes: 10 additions & 6 deletions open/text/embeddings/server/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from typing import List, Optional, Union
from starlette.concurrency import run_in_threadpool
from fastapi import FastAPI, APIRouter
Expand Down Expand Up @@ -66,10 +65,12 @@ class CreateEmbeddingRequest(BaseModel):
class Embedding(BaseModel):
embedding: List[float]


class Usage(BaseModel):
prompt_tokens: int
total_tokens: int


class CreateEmbeddingResponse(BaseModel):
data: List[Embedding]
model: str
Expand Down Expand Up @@ -127,25 +128,28 @@ def _create_embedding(input: Union[str, List[str]]):
model_name = DEFAULT_MODEL_NAME
model_name_short = model_name.split("/")[-1]
if isinstance(input, str):
return CreateEmbeddingResponse(data=[Embedding(embedding=embeddings.embed_query(input))],model=model_name_short,object='list',usage=Usage(prompt_tokens=5,total_tokens=5))
return CreateEmbeddingResponse(data=[Embedding(embedding=embeddings.embed_query(input))],
model=model_name_short, object='list',
usage=Usage(prompt_tokens=5, total_tokens=5))
else:
data = [Embedding(embedding=embedding)
for embedding in embeddings.embed_documents(input)]
return CreateEmbeddingResponse(data=[Embedding(embedding=embeddings.embed_query(input))],model=model_name_short,object='list',usage=Usage(prompt_tokens=5,total_tokens=5))
return CreateEmbeddingResponse(data=data, model=model_name_short, object='list',
usage=Usage(prompt_tokens=5, total_tokens=5))


@router.post(
"/v1/embeddings",
response_model=CreateEmbeddingResponse,
)
async def create_embedding(
request: CreateEmbeddingRequest
request: CreateEmbeddingRequest
):
if pydantic.__version__ > '2.0.0' :
if pydantic.__version__ > '2.0.0':
return await run_in_threadpool(
_create_embedding, **request.model_dump(exclude={"user", "model", "model_config"})
)
else :
else:
return await run_in_threadpool(
_create_embedding, **request.dict(exclude={"user", "model", "model_config"})
)