Skip to content

Commit 2259b9c

Browse files
Fix huggingface_hub API upgrade issue (#1691)
* Fix huggingfacehub API upgrade issue Signed-off-by: lvliang-intel <[email protected]>
1 parent 28a2820 commit 2259b9c

File tree

4 files changed

+39
-37
lines changed

4 files changed

+39
-37
lines changed

comps/cores/mega/orchestrator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ async def execute(
274274
headers={"Content-type": "application/json", "Authorization": f"Bearer {access_token}"},
275275
proxies={"http": None},
276276
stream=True,
277-
timeout=1000,
277+
timeout=2000,
278278
)
279279
else:
280280
response = requests.post(
@@ -285,7 +285,7 @@ async def execute(
285285
},
286286
proxies={"http": None},
287287
stream=True,
288-
timeout=1000,
288+
timeout=2000,
289289
)
290290

291291
downstream = runtime_graph.downstream(cur_node)
@@ -317,6 +317,7 @@ def generate():
317317
"Authorization": f"Bearer {access_token}",
318318
},
319319
proxies={"http": None},
320+
timeout=2000,
320321
)
321322
else:
322323
res = requests.post(
@@ -326,6 +327,7 @@ def generate():
326327
"Content-type": "application/json",
327328
},
328329
proxies={"http": None},
330+
timeout=2000,
329331
)
330332
res_json = res.json()
331333
if "text" in res_json:

comps/embeddings/src/integrations/ovms.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
# Copyright (C) 2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import json
54
import os
6-
from typing import List, Union
75

6+
import aiohttp
87
import requests
9-
from huggingface_hub import AsyncInferenceClient
108

119
from comps import CustomLogger, OpeaComponent, OpeaComponentRegistry, ServiceType
1210
from comps.cores.mega.utils import get_access_token
@@ -32,24 +30,11 @@ class OpeaOVMSEmbedding(OpeaComponent):
3230
def __init__(self, name: str, description: str, config: dict = None):
3331
super().__init__(name, ServiceType.EMBEDDING.name.lower(), description, config)
3432
self.base_url = os.getenv("OVMS_EMBEDDING_ENDPOINT", "http://localhost:8080")
35-
self.client = self._initialize_client()
3633

3734
health_status = self.check_health()
3835
if not health_status:
3936
logger.error("OpeaOVMSEmbedding health check failed.")
4037

41-
def _initialize_client(self) -> AsyncInferenceClient:
42-
"""Initializes the AsyncInferenceClient."""
43-
access_token = (
44-
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
45-
)
46-
headers = {"Authorization": f"Bearer {access_token}"} if access_token else {}
47-
return AsyncInferenceClient(
48-
model=MODEL_ID,
49-
token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
50-
headers=headers,
51-
)
52-
5338
async def invoke(self, input: EmbeddingRequest) -> EmbeddingResponse:
5439
"""Invokes the embedding service to generate embeddings for the provided input.
5540
@@ -69,17 +54,31 @@ async def invoke(self, input: EmbeddingRequest) -> EmbeddingResponse:
6954
raise ValueError("Invalid input format: Only string or list of strings are supported.")
7055
else:
7156
raise TypeError("Unsupported input type: input must be a string or list of strings.")
72-
response = await self.client.post(
73-
json={
74-
"input": texts,
75-
"encoding_format": input.encoding_format,
76-
"model": self.client.model,
77-
"user": input.user,
78-
},
79-
model=f"{self.base_url}/v3/embeddings",
80-
task="text-embedding",
57+
# Build headers
58+
headers = {"Content-Type": "application/json"}
59+
access_token = (
60+
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
8161
)
82-
embeddings = json.loads(response.decode())
62+
if access_token:
63+
headers["Authorization"] = f"Bearer {access_token}"
64+
65+
# Compose request
66+
payload = {
67+
"input": texts,
68+
"encoding_format": input.encoding_format,
69+
"model": MODEL_ID,
70+
"user": input.user,
71+
}
72+
73+
# Send async POST request using aiohttp
74+
url = f"{self.base_url}/v3/embeddings"
75+
async with aiohttp.ClientSession() as session:
76+
async with session.post(url, headers=headers, json=payload) as resp:
77+
if resp.status != 200:
78+
logger.error(f"Embedding service error: {resp.status} - {await resp.text()}")
79+
raise RuntimeError(f"Failed to fetch embeddings: HTTP {resp.status}")
80+
embeddings = await resp.json()
81+
8382
return EmbeddingResponse(**embeddings)
8483

8584
def check_health(self) -> bool:

comps/embeddings/src/integrations/tei.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from comps import CustomLogger, OpeaComponent, OpeaComponentRegistry, ServiceType
1212
from comps.cores.mega.utils import get_access_token
13-
from comps.cores.proto.api_protocol import EmbeddingRequest, EmbeddingResponse
13+
from comps.cores.proto.api_protocol import EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData
1414

1515
logger = CustomLogger("opea_tei_embedding")
1616
logflag = os.getenv("LOGFLAG", False)
@@ -44,7 +44,7 @@ def _initialize_client(self) -> AsyncInferenceClient:
4444
)
4545
headers = {"Authorization": f"Bearer {access_token}"} if access_token else {}
4646
return AsyncInferenceClient(
47-
model=f"{self.base_url}/v1/embeddings",
47+
model=f"{self.base_url}/embed",
4848
token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
4949
headers=headers,
5050
)
@@ -68,13 +68,13 @@ async def invoke(self, input: EmbeddingRequest) -> EmbeddingResponse:
6868
raise ValueError("Invalid input format: Only string or list of strings are supported.")
6969
else:
7070
raise TypeError("Unsupported input type: input must be a string or list of strings.")
71-
response = await self.client.post(
72-
json={"input": texts, "encoding_format": input.encoding_format, "model": input.model, "user": input.user},
73-
model=f"{self.base_url}/v1/embeddings",
74-
task="text-embedding",
75-
)
76-
embeddings = json.loads(response.decode())
77-
return EmbeddingResponse(**embeddings)
71+
# feature_extraction return np.ndarray
72+
response = await self.client.feature_extraction(text=texts, model=f"{self.base_url}/embed")
73+
# Convert np.ndarray to a list of lists (embedding)
74+
data = [EmbeddingResponseData(index=i, embedding=embedding.tolist()) for i, embedding in enumerate(response)]
75+
# Construct the EmbeddingResponse
76+
response = EmbeddingResponse(data=data)
77+
return response
7878

7979
def check_health(self) -> bool:
8080
"""Checks the health of the embedding service.

comps/rerankings/src/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ aiohttp
22
docarray[full]
33
fastapi
44
httpx
5+
huggingface-hub==0.30.2
56
opentelemetry-api
67
opentelemetry-exporter-otlp
78
opentelemetry-sdk

0 commit comments

Comments
 (0)