Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions comps/cores/mega/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ async def execute(
headers={"Content-type": "application/json", "Authorization": f"Bearer {access_token}"},
proxies={"http": None},
stream=True,
timeout=1000,
timeout=2000,
)
else:
response = requests.post(
Expand All @@ -285,7 +285,7 @@ async def execute(
},
proxies={"http": None},
stream=True,
timeout=1000,
timeout=2000,
)

downstream = runtime_graph.downstream(cur_node)
Expand Down Expand Up @@ -317,6 +317,7 @@ def generate():
"Authorization": f"Bearer {access_token}",
},
proxies={"http": None},
timeout=2000,
)
else:
res = requests.post(
Expand All @@ -326,6 +327,7 @@ def generate():
"Content-type": "application/json",
},
proxies={"http": None},
timeout=2000,
)
res_json = res.json()
if "text" in res_json:
Expand Down
51 changes: 25 additions & 26 deletions comps/embeddings/src/integrations/ovms.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import os
from typing import List, Union

import aiohttp
import requests
from huggingface_hub import AsyncInferenceClient

from comps import CustomLogger, OpeaComponent, OpeaComponentRegistry, ServiceType
from comps.cores.mega.utils import get_access_token
Expand All @@ -32,24 +30,11 @@ class OpeaOVMSEmbedding(OpeaComponent):
def __init__(self, name: str, description: str, config: dict = None):
super().__init__(name, ServiceType.EMBEDDING.name.lower(), description, config)
self.base_url = os.getenv("OVMS_EMBEDDING_ENDPOINT", "http://localhost:8080")
self.client = self._initialize_client()

health_status = self.check_health()
if not health_status:
logger.error("OpeaOVMSEmbedding health check failed.")

def _initialize_client(self) -> AsyncInferenceClient:
"""Initializes the AsyncInferenceClient."""
access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
headers = {"Authorization": f"Bearer {access_token}"} if access_token else {}
return AsyncInferenceClient(
model=MODEL_ID,
token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
headers=headers,
)

async def invoke(self, input: EmbeddingRequest) -> EmbeddingResponse:
"""Invokes the embedding service to generate embeddings for the provided input.

Expand All @@ -69,17 +54,31 @@ async def invoke(self, input: EmbeddingRequest) -> EmbeddingResponse:
raise ValueError("Invalid input format: Only string or list of strings are supported.")
else:
raise TypeError("Unsupported input type: input must be a string or list of strings.")
response = await self.client.post(
json={
"input": texts,
"encoding_format": input.encoding_format,
"model": self.client.model,
"user": input.user,
},
model=f"{self.base_url}/v3/embeddings",
task="text-embedding",
# Build headers
headers = {"Content-Type": "application/json"}
access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
embeddings = json.loads(response.decode())
if access_token:
headers["Authorization"] = f"Bearer {access_token}"

# Compose request
payload = {
"input": texts,
"encoding_format": input.encoding_format,
"model": MODEL_ID,
"user": input.user,
}

# Send async POST request using aiohttp
url = f"{self.base_url}/v3/embeddings"
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload) as resp:
if resp.status != 200:
logger.error(f"Embedding service error: {resp.status} - {await resp.text()}")
raise RuntimeError(f"Failed to fetch embeddings: HTTP {resp.status}")
embeddings = await resp.json()

return EmbeddingResponse(**embeddings)

def check_health(self) -> bool:
Expand Down
18 changes: 9 additions & 9 deletions comps/embeddings/src/integrations/tei.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from comps import CustomLogger, OpeaComponent, OpeaComponentRegistry, ServiceType
from comps.cores.mega.utils import get_access_token
from comps.cores.proto.api_protocol import EmbeddingRequest, EmbeddingResponse
from comps.cores.proto.api_protocol import EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData

logger = CustomLogger("opea_tei_embedding")
logflag = os.getenv("LOGFLAG", False)
Expand Down Expand Up @@ -44,7 +44,7 @@ def _initialize_client(self) -> AsyncInferenceClient:
)
headers = {"Authorization": f"Bearer {access_token}"} if access_token else {}
return AsyncInferenceClient(
model=f"{self.base_url}/v1/embeddings",
model=f"{self.base_url}/embed",
token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
headers=headers,
)
Expand All @@ -68,13 +68,13 @@ async def invoke(self, input: EmbeddingRequest) -> EmbeddingResponse:
raise ValueError("Invalid input format: Only string or list of strings are supported.")
else:
raise TypeError("Unsupported input type: input must be a string or list of strings.")
response = await self.client.post(
json={"input": texts, "encoding_format": input.encoding_format, "model": input.model, "user": input.user},
model=f"{self.base_url}/v1/embeddings",
task="text-embedding",
)
embeddings = json.loads(response.decode())
return EmbeddingResponse(**embeddings)
# feature_extraction return np.ndarray
response = await self.client.feature_extraction(text=texts, model=f"{self.base_url}/embed")
# Convert np.ndarray to a list of lists (embedding)
data = [EmbeddingResponseData(index=i, embedding=embedding.tolist()) for i, embedding in enumerate(response)]
# Construct the EmbeddingResponse
response = EmbeddingResponse(data=data)
return response

def check_health(self) -> bool:
"""Checks the health of the embedding service.
Expand Down
1 change: 1 addition & 0 deletions comps/rerankings/src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ aiohttp
docarray[full]
fastapi
httpx
huggingface-hub==0.30.2
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-sdk
Expand Down
Loading