Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 83 additions & 28 deletions comps/retrievers/src/integrations/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
# SPDX-License-Identifier: Apache-2.0


import asyncio
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Union

from langchain_community.vectorstores import Redis
from fastapi import HTTPException
from langchain.vectorstores import Redis
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings

from comps import (
CustomLogger,
Expand All @@ -18,10 +23,24 @@
)
from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingResponse, RetrievalRequest, RetrievalResponse

from .config import BRIDGE_TOWER_EMBEDDING, EMBED_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL, TEI_EMBEDDING_ENDPOINT
from .config import (
BRIDGE_TOWER_EMBEDDING,
EMBED_MODEL,
HUGGINGFACEHUB_API_TOKEN,
INDEX_NAME,
INDEX_SCHEMA,
REDIS_URL,
TEI_EMBEDDING_ENDPOINT,
)

logger = CustomLogger("redis_retrievers")
logflag = os.getenv("LOGFLAG", False)
executor = ThreadPoolExecutor()


async def run_in_thread(func, *args, **kwargs):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, lambda: func(*args, **kwargs))


@OpeaComponentRegistry.register("OPEA_RETRIEVER_REDIS")
Expand All @@ -34,29 +53,46 @@ class OpeaRedisRetriever(OpeaComponent):

def __init__(self, name: str, description: str, config: dict = None):
super().__init__(name, ServiceType.RETRIEVER.name.lower(), description, config)
self.embeddings = asyncio.run(self._initialize_embedder())
self.client = asyncio.run(self._initialize_client())
health_status = self.check_health()
if not health_status:
logger.error("OpeaRedisRetriever health check failed.")

# Create embeddings
async def _initialize_embedder(self):
if TEI_EMBEDDING_ENDPOINT:
# create embeddings using TEI endpoint service
from langchain_huggingface import HuggingFaceEndpointEmbeddings
logger.info("use tei embedding")
if not HUGGINGFACEHUB_API_TOKEN:
raise HTTPException(
status_code=400,
detail="You MUST offer the `HUGGINGFACEHUB_API_TOKEN` when using `TEI_EMBEDDING_ENDPOINT`.",
)

self.embeddings = HuggingFaceEndpointEmbeddings(model=TEI_EMBEDDING_ENDPOINT)
import httpx

async with httpx.AsyncClient() as client:
response = await client.get(TEI_EMBEDDING_ENDPOINT + "/info")
if response.status_code != 200:
raise HTTPException(
status_code=400, detail=f"TEI embedding endpoint {TEI_EMBEDDING_ENDPOINT} is not available."
)
model_id = response.json()["model_id"]
# create embeddings using TEI endpoint service
embedder = HuggingFaceInferenceAPIEmbeddings(
api_key=HUGGINGFACEHUB_API_TOKEN, model_name=model_id, api_url=TEI_EMBEDDING_ENDPOINT
)
elif BRIDGE_TOWER_EMBEDDING:
logger.info("use bridge tower embedding")
from comps.third_parties.bridgetower.src.bridgetower_embedding import BridgeTowerEmbedding

self.embeddings = BridgeTowerEmbedding()
embedder = BridgeTowerEmbedding()
else:
logger.info("use local embedding")
# create embeddings using local embedding model
from langchain_community.embeddings import HuggingFaceEmbeddings

self.embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
self.client = self._initialize_client()
health_status = self.check_health()
if not health_status:
logger.error("OpeaRedisRetriever health check failed.")
embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
return embedder

def _initialize_client(self) -> Redis:
async def _initialize_client(self) -> Redis:
"""Initializes the redis client."""
try:
if BRIDGE_TOWER_EMBEDDING:
Expand All @@ -80,7 +116,8 @@ def check_health(self) -> bool:
if logflag:
logger.info("[ health check ] start to check health of redis")
try:
if self.client.client.ping():
if self.client:
self.client.client.ping()
if logflag:
logger.info("[ health check ] Successfully connected to Redis!")
return True
Expand All @@ -102,40 +139,58 @@ async def invoke(
logger.info(input)

# check if the Redis index has data
if self.client.client.keys() == []:
try:
keys_exist = self.client.client.keys()
except Exception as e:
logger.error(f"Redis key check failed: {e}")
keys_exist = []

if not keys_exist:
if logflag:
logger.info("No data in Redis index, return []")
search_res = []
else:
if isinstance(input, EmbedDoc) or isinstance(input, EmbedMultimodalDoc):
embedding_data_input = input.embedding
else:
# for RetrievalRequest, ChatCompletionRequest
if isinstance(input.embedding, EmbeddingResponse):
embeddings = input.embedding.data
embedding_data_input = []
for emb in embeddings:
embedding_data_input.append(emb.embedding)
embedding_data_input = [emb.embedding for emb in input.embedding.data]

else:
embedding_data_input = input.embedding

# if the Redis index has data, perform the search
if input.search_type == "similarity":
search_res = await self.client.asimilarity_search_by_vector(embedding=embedding_data_input, k=input.k)
search_res = await run_in_thread(
self.client.similarity_search_by_vector, embedding=embedding_data_input, k=input.k
)
elif input.search_type == "similarity_distance_threshold":
if input.distance_threshold is None:
raise ValueError(
"distance_threshold must be provided for " + "similarity_distance_threshold retriever"
)
search_res = await self.client.asimilarity_search_by_vector(
embedding=input.embedding, k=input.k, distance_threshold=input.distance_threshold
search_res = await run_in_thread(
self.client.similarity_search_by_vector,
embedding=input.embedding,
k=input.k,
distance_threshold=input.distance_threshold,
)
elif input.search_type == "similarity_score_threshold":
docs_and_similarities = await self.client.asimilarity_search_with_relevance_scores(
query=input.text, k=input.k, score_threshold=input.score_threshold
docs_and_similarities = await run_in_thread(
self.client.similarity_search_with_relevance_scores,
query=input.text,
k=input.k,
score_threshold=input.score_threshold,
)
search_res = [doc for doc, _ in docs_and_similarities]
elif input.search_type == "mmr":
search_res = await self.client.amax_marginal_relevance_search(
query=input.text, k=input.k, fetch_k=input.fetch_k, lambda_mult=input.lambda_mult
search_res = await run_in_thread(
self.client.max_marginal_relevance_search,
query=input.text,
k=input.k,
fetch_k=input.fetch_k,
lambda_mult=input.lambda_mult,
)
else:
raise ValueError(f"{input.search_type} not valid")
Expand Down