11# Copyright (C) 2024 Intel Corporation
22# SPDX-License-Identifier: Apache-2.0
33
4- import json
54import os
6- from typing import List , Union
75
6+ import aiohttp
87import requests
9- from huggingface_hub import AsyncInferenceClient
108
119from comps import CustomLogger , OpeaComponent , OpeaComponentRegistry , ServiceType
1210from comps .cores .mega .utils import get_access_token
@@ -32,24 +30,11 @@ class OpeaOVMSEmbedding(OpeaComponent):
3230 def __init__ (self , name : str , description : str , config : dict = None ):
3331 super ().__init__ (name , ServiceType .EMBEDDING .name .lower (), description , config )
3432 self .base_url = os .getenv ("OVMS_EMBEDDING_ENDPOINT" , "http://localhost:8080" )
35- self .client = self ._initialize_client ()
3633
3734 health_status = self .check_health ()
3835 if not health_status :
3936 logger .error ("OpeaOVMSEmbedding health check failed." )
4037
41- def _initialize_client (self ) -> AsyncInferenceClient :
42- """Initializes the AsyncInferenceClient."""
43- access_token = (
44- get_access_token (TOKEN_URL , CLIENTID , CLIENT_SECRET ) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
45- )
46- headers = {"Authorization" : f"Bearer { access_token } " } if access_token else {}
47- return AsyncInferenceClient (
48- model = MODEL_ID ,
49- token = os .getenv ("HUGGINGFACEHUB_API_TOKEN" ),
50- headers = headers ,
51- )
52-
5338 async def invoke (self , input : EmbeddingRequest ) -> EmbeddingResponse :
5439 """Invokes the embedding service to generate embeddings for the provided input.
5540
@@ -69,17 +54,31 @@ async def invoke(self, input: EmbeddingRequest) -> EmbeddingResponse:
6954 raise ValueError ("Invalid input format: Only string or list of strings are supported." )
7055 else :
7156 raise TypeError ("Unsupported input type: input must be a string or list of strings." )
72- response = await self .client .post (
73- json = {
74- "input" : texts ,
75- "encoding_format" : input .encoding_format ,
76- "model" : self .client .model ,
77- "user" : input .user ,
78- },
79- model = f"{ self .base_url } /v3/embeddings" ,
80- task = "text-embedding" ,
57+ # Build headers
58+ headers = {"Content-Type" : "application/json" }
59+ access_token = (
60+ get_access_token (TOKEN_URL , CLIENTID , CLIENT_SECRET ) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
8161 )
82- embeddings = json .loads (response .decode ())
62+ if access_token :
63+ headers ["Authorization" ] = f"Bearer { access_token } "
64+
65+ # Compose request
66+ payload = {
67+ "input" : texts ,
68+ "encoding_format" : input .encoding_format ,
69+ "model" : MODEL_ID ,
70+ "user" : input .user ,
71+ }
72+
73+ # Send async POST request using aiohttp
74+ url = f"{ self .base_url } /v3/embeddings"
75+ async with aiohttp .ClientSession () as session :
76+ async with session .post (url , headers = headers , json = payload ) as resp :
77+ if resp .status != 200 :
78+ logger .error (f"Embedding service error: { resp .status } - { await resp .text ()} " )
79+ raise RuntimeError (f"Failed to fetch embeddings: HTTP { resp .status } " )
80+ embeddings = await resp .json ()
81+
8382 return EmbeddingResponse (** embeddings )
8483
8584 def check_health (self ) -> bool :
0 commit comments