Skip to content
Merged
6 changes: 5 additions & 1 deletion open/text/embeddings/server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
if __name__ == "__main__":
app = create_app()

HOST = os.environ.get("HOST")
if HOST is None:
HOST = os.getenv("HOST", "0.0.0.0")

uvicorn.run(
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
app, host=HOST, port=int(os.getenv("PORT", 8000))
)
61 changes: 47 additions & 14 deletions open/text/embeddings/server/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from typing import List, Optional, Union
from starlette.concurrency import run_in_threadpool
from fastapi import FastAPI, APIRouter
Expand All @@ -12,6 +11,9 @@

from open.text.embeddings.server.gzip import GZipRequestMiddleware
from fastapi.middleware.gzip import GZipMiddleware
import pydantic
from transformers import AutoTokenizer

router = APIRouter()

DEFAULT_MODEL_NAME = "intfloat/e5-large-v2"
Expand Down Expand Up @@ -64,43 +66,58 @@ class CreateEmbeddingRequest(BaseModel):

class Embedding(BaseModel):
embedding: List[float]
object: str
index: int


class Usage(BaseModel):
prompt_tokens: int
total_tokens: int


class CreateEmbeddingResponse(BaseModel):
data: List[Embedding]
model: str
object: str
usage: Usage


embeddings = None

tokenizer = None

def initialize_embeddings():
global embeddings
global tokenizer

if "DEVICE" in os.environ:
device = os.environ["DEVICE"]
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model_name = os.environ["MODEL"]
model_name = os.environ.get("MODEL")
if model_name is None:
model_name = DEFAULT_MODEL_NAME
print("Loading model:", model_name)
normalize_embeddings = bool(os.environ.get("NORMALIZE_EMBEDDINGS", ""))
normalize_embeddings = bool(os.environ.get("NORMALIZE_EMBEDDINGS", "1"))
encode_kwargs = {
"normalize_embeddings": normalize_embeddings
}
print("Normalize embeddings:", normalize_embeddings)
tokenizer=AutoTokenizer.from_pretrained(model_name)
if "e5" in model_name:
embeddings = HuggingFaceInstructEmbeddings(model_name=model_name,
embed_instruction=E5_EMBED_INSTRUCTION,
query_instruction=E5_QUERY_INSTRUCTION,
encode_kwargs=encode_kwargs,
model_kwargs={"device": device})
elif model_name.startswith("BAAI/bge-") and model_name.endswith("-en"):
elif "bge-" in model_name and "-en" in model_name:
embeddings = HuggingFaceBgeEmbeddings(model_name=model_name,
query_instruction=BGE_EN_QUERY_INSTRUCTION,
encode_kwargs=encode_kwargs,
model_kwargs={"device": device})
elif model_name.startswith("BAAI/bge-") and model_name.endswith("-zh"):
elif "bge-" in model_name and "-zh" in model_name:
embeddings = HuggingFaceBgeEmbeddings(model_name=model_name,
query_instruction=BGE_ZH_QUERY_INSTRUCTION,
encode_kwargs=encode_kwargs,
Expand All @@ -113,22 +130,38 @@ def initialize_embeddings():

def _create_embedding(input: Union[str, List[str]]):
global embeddings

model_name = os.environ.get("MODEL")
if model_name is None:
model_name = DEFAULT_MODEL_NAME
model_name_short = model_name.split("/")[-1]
if isinstance(input, str):
return CreateEmbeddingResponse(data=[Embedding(embedding=embeddings.embed_query(input))])
tokens = tokenizer.tokenize(input)
return CreateEmbeddingResponse(data=[Embedding(embedding=embeddings.embed_query(input),
object="embedding", index=0)],
model=model_name_short, object='list',
usage=Usage(prompt_tokens=len(tokens), total_tokens=len(tokens)))
else:
data = [Embedding(embedding=embedding)
for embedding in embeddings.embed_documents(input)]
return CreateEmbeddingResponse(data=data)
data = []
total_tokens = 0
for i, embedding in enumerate(input):
data.append(Embedding(embedding=embeddings.embed_query(embedding), object="embedding", index=i))
total_tokens += len(tokenizer.tokenize(embedding))
return CreateEmbeddingResponse(data=data, model=model_name_short, object='list',
usage=Usage(prompt_tokens=total_tokens, total_tokens=total_tokens))


@router.post(
"/v1/embeddings",
response_model=CreateEmbeddingResponse,
)
async def create_embedding(
request: CreateEmbeddingRequest
request: CreateEmbeddingRequest
):
return await run_in_threadpool(
_create_embedding, **request.model_dump(exclude={"user", "model", "model_config"})
)
if pydantic.__version__ > '2.0.0':
return await run_in_threadpool(
_create_embedding, **request.model_dump(exclude={"user", "model", "model_config"})
)
else:
return await run_in_threadpool(
_create_embedding, **request.dict(exclude={"user", "model", "model_config"})
)
3 changes: 2 additions & 1 deletion server-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ fastapi
mangum
sentence_transformers
langchain
InstructorEmbedding
InstructorEmbedding
uvicorn