Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions comps/dataprep/src/integrations/arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
from fastapi import Body, File, Form, HTTPException, UploadFile
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_arangodb import ArangoGraph
from langchain_community.embeddings import HuggingFaceHubEmbeddings
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpointEmbeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import HTMLHeaderTextSplitter

Expand Down Expand Up @@ -200,8 +199,9 @@ def _initialize_embeddings(self):
"""Initialize the embeddings model."""

if TEI_EMBEDDING_ENDPOINT and HUGGINGFACEHUB_API_TOKEN:
self.embeddings = HuggingFaceHubEmbeddings(
self.embeddings = HuggingFaceEndpointEmbeddings(
model=TEI_EMBEDDING_ENDPOINT,
task="feature-extraction",
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
)
elif TEI_EMBED_MODEL:
Expand Down
9 changes: 6 additions & 3 deletions comps/retrievers/src/integrations/arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from arango.database import StandardDatabase
from fastapi import HTTPException
from langchain_arangodb import ArangoVector
from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceHubEmbeddings
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_huggingface import HuggingFaceEndpointEmbeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

from comps import CustomLogger, EmbedDoc, OpeaComponent, OpeaComponentRegistry, ServiceType
Expand Down Expand Up @@ -421,8 +422,10 @@ async def invoke(
if OPENAI_API_KEY and OPENAI_EMBED_MODEL and OPENAI_EMBED_ENABLED:
embeddings = OpenAIEmbeddings(model=OPENAI_EMBED_MODEL, dimensions=dimension)
elif TEI_EMBEDDING_ENDPOINT and HUGGINGFACEHUB_API_TOKEN:
embeddings = HuggingFaceHubEmbeddings(
model=TEI_EMBEDDING_ENDPOINT, huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN
embeddings = HuggingFaceEndpointEmbeddings(
model=TEI_EMBEDDING_ENDPOINT,
task="feature-extraction",
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
)
else:
embeddings = HuggingFaceBgeEmbeddings(model_name=TEI_EMBED_MODEL)
Expand Down
Loading