Skip to content
Merged
15 changes: 14 additions & 1 deletion DocIndexRetriever/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,22 @@ Example usage:
```python
url = "http://{host_ip}:{port}/v1/retrievaltool".format(host_ip=host_ip, port=port)
payload = {
"messages": query,
"messages": query, # must be a string, this is a required field
"k": 5, # retriever top k
"top_n": 2, # reranker top n
}
response = requests.post(url, json=payload)
```

**Note**: `messages` is the required field. You can also pass in parameters for the retriever and reranker in the request. The parameters that can changed are listed below.

1. retriever
* search_type: str = "similarity"
* k: int = 4
* distance_threshold: Optional[float] = None
* fetch_k: int = 20
* lambda_mult: float = 0.5
* score_threshold: float = 0.2

2. reranker
* top_n: int = 1
5 changes: 1 addition & 4 deletions DocIndexRetriever/docker_compose/intel/cpu/xeon/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ Retrieval from KnowledgeBase
curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
"messages": "Explain the OPEA project?"
}'

# expected output
{"id":"354e62c703caac8c547b3061433ec5e8","reranked_docs":[{"id":"06d5a5cefc06cf9a9e0b5fa74a9f233c","text":"Close SearchsearchMenu WikiNewsCommunity Daysx-twitter linkedin github searchStreamlining implementation of enterprise-grade Generative AIEfficiently integrate secure, performant, and cost-effective Generative AI workflows into business value.TODAYOPEA..."}],"initial_query":"Explain the OPEA project?"}
```

**Note**: `messages` is the required field. You can also pass in parameters for the retriever and reranker in the request. The parameters that can changed are listed below.
Expand Down Expand Up @@ -128,7 +125,7 @@ curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: applicati
# embedding microservice
curl http://${host_ip}:6000/v1/embeddings \
-X POST \
-d '{"text":"Explain the OPEA project"}' \
-d '{"messages":"Explain the OPEA project"}' \
-H 'Content-Type: application/json' > query
docker container logs embedding-server

Expand Down
13 changes: 5 additions & 8 deletions DocIndexRetriever/docker_compose/intel/cpu/xeon/compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ services:
dataprep-redis-service:
image: ${REGISTRY:-opea}/dataprep:${TAG:-latest}
container_name: dataprep-redis-server
# volumes:
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/user/comps
depends_on:
- redis-vector-db
redis-vector-db:
condition: service_started
tei-embedding-service:
condition: service_healthy
ports:
- "6007:5000"
- "6008:6008"
Expand All @@ -28,7 +29,7 @@ services:
REDIS_URL: ${REDIS_URL}
REDIS_HOST: ${REDIS_HOST}
INDEX_NAME: ${INDEX_NAME}
TEI_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN}
LOGFLAG: ${LOGFLAG}
tei-embedding-service:
Expand All @@ -54,8 +55,6 @@ services:
embedding:
image: ${REGISTRY:-opea}/embedding:${TAG:-latest}
container_name: embedding-server
# volumes:
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/comps
ports:
- "6000:6000"
ipc: host
Expand Down Expand Up @@ -114,8 +113,6 @@ services:
reranking:
image: ${REGISTRY:-opea}/reranking:${TAG:-latest}
container_name: reranking-tei-xeon-server
# volumes:
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/user/comps
depends_on:
tei-reranking-service:
condition: service_healthy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ Retrieval from KnowledgeBase
curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
"messages": "Explain the OPEA project?"
}'

# expected output
{"id":"354e62c703caac8c547b3061433ec5e8","reranked_docs":[{"id":"06d5a5cefc06cf9a9e0b5fa74a9f233c","text":"Close SearchsearchMenu WikiNewsCommunity Daysx-twitter linkedin github searchStreamlining implementation of enterprise-grade Generative AIEfficiently integrate secure, performant, and cost-effective Generative AI workflows into business value.TODAYOPEA..."}],"initial_query":"Explain the OPEA project?"}
```

**Note**: `messages` is the required field. You can also pass in parameters for the retriever and reranker in the request. The parameters that can changed are listed below.
Expand Down Expand Up @@ -118,7 +115,7 @@ curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: applicati
# embedding microservice
curl http://${host_ip}:6000/v1/embeddings \
-X POST \
-d '{"text":"Explain the OPEA project"}' \
-d '{"messages":"Explain the OPEA project"}' \
-H 'Content-Type: application/json' > query
docker container logs embedding-server

Expand Down
10 changes: 7 additions & 3 deletions DocIndexRetriever/docker_compose/intel/hpu/gaudi/compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ services:
image: ${REGISTRY:-opea}/dataprep:${TAG:-latest}
container_name: dataprep-redis-server
depends_on:
- redis-vector-db
- tei-embedding-service
redis-vector-db:
condition: service_started
tei-embedding-service:
condition: service_healthy
ports:
- "6007:5000"
environment:
Expand All @@ -25,7 +27,7 @@ services:
https_proxy: ${https_proxy}
REDIS_URL: ${REDIS_URL}
INDEX_NAME: ${INDEX_NAME}
TEI_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN}
tei-embedding-service:
image: ghcr.io/huggingface/tei-gaudi:1.5.0
Expand Down Expand Up @@ -87,6 +89,8 @@ services:
INDEX_NAME: ${INDEX_NAME}
LOGFLAG: ${LOGFLAG}
RETRIEVER_COMPONENT_NAME: "OPEA_RETRIEVER_REDIS"
TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN}
restart: unless-stopped
tei-reranking-service:
image: ghcr.io/huggingface/text-embeddings-inference:cpu-1.6
Expand Down
168 changes: 94 additions & 74 deletions DocIndexRetriever/retrieval_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@

from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType
from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingRequest
from comps.cores.proto.docarray import LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc
from comps.cores.proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc
from fastapi import Request
from fastapi.responses import StreamingResponse

MEGA_SERVICE_PORT = os.getenv("MEGA_SERVICE_PORT", 8889)
EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0")
Expand All @@ -22,41 +21,75 @@


def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs):
print(f"Inputs to {cur_node}: {inputs}")
print(f"*** Inputs to {cur_node}:\n{inputs}")
print("--" * 50)
for key, value in kwargs.items():
print(f"{key}: {value}")
if self.services[cur_node].service_type == ServiceType.EMBEDDING:
inputs["input"] = inputs["text"]
del inputs["text"]
elif self.services[cur_node].service_type == ServiceType.RETRIEVER:
# input is EmbedDoc
"""Class EmbedDoc(BaseDoc):

text: Union[str, List[str]]
embedding: Union[conlist(float, min_length=0), List[conlist(float, min_length=0)]]
search_type: str = "similarity"
k: int = 4
distance_threshold: Optional[float] = None
fetch_k: int = 20
lambda_mult: float = 0.5
score_threshold: float = 0.2
constraints: Optional[Union[Dict[str, Any], List[Dict[str, Any]], None]] = None
index_name: Optional[str] = None
"""
# prepare the retriever params
retriever_parameters = kwargs.get("retriever_parameters", None)
if retriever_parameters:
inputs.update(retriever_parameters.dict())
elif self.services[cur_node].service_type == ServiceType.RERANK:
# input is SearchedDoc
"""Class SearchedDoc(BaseDoc):

retrieved_docs: DocList[TextDoc]
initial_query: str
top_n: int = 1
"""
# prepare the reranker params
reranker_parameters = kwargs.get("reranker_parameters", None)
if reranker_parameters:
inputs.update(reranker_parameters.dict())
print(f"*** Formatted Inputs to {cur_node}:\n{inputs}")
print("--" * 50)
return inputs


def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs):
next_data = {}
if self.services[cur_node].service_type == ServiceType.EMBEDDING:
# turn into chat completion request
# next_data = {"text": inputs["input"], "embedding": [item["embedding"] for item in data["data"]]}
print("Assembing output from Embedding for next node...")
print("Inputs to Embedding: ", inputs)
print("Keyword arguments: ")
for key, value in kwargs.items():
print(f"{key}: {value}")

next_data = {
"input": inputs["input"],
"messages": inputs["input"],
"embedding": [item["embedding"] for item in data["data"]],
"k": kwargs["k"] if "k" in kwargs else 4,
"search_type": kwargs["search_type"] if "search_type" in kwargs else "similarity",
"distance_threshold": kwargs["distance_threshold"] if "distance_threshold" in kwargs else None,
"fetch_k": kwargs["fetch_k"] if "fetch_k" in kwargs else 20,
"lambda_mult": kwargs["lambda_mult"] if "lambda_mult" in kwargs else 0.5,
"score_threshold": kwargs["score_threshold"] if "score_threshold" in kwargs else 0.2,
"top_n": kwargs["top_n"] if "top_n" in kwargs else 1,
}

print("Output from Embedding for next node:\n", next_data)
print(f"*** Direct Outputs from {cur_node}:\n{data}")
print("--" * 50)

if self.services[cur_node].service_type == ServiceType.EMBEDDING:
# direct output from Embedding microservice is EmbeddingResponse
"""
class EmbeddingResponse(BaseModel):
object: str = "list"
model: Optional[str] = None
data: List[EmbeddingResponseData]
usage: Optional[UsageInfo] = None

class EmbeddingResponseData(BaseModel):
index: int
object: str = "embedding"
embedding: Union[List[float], str]
"""
# turn it into EmbedDoc
assert isinstance(data["data"], list)
next_data = {"text": inputs["input"], "embedding": data["data"][0]["embedding"]} # EmbedDoc
else:
next_data = data

print(f"*** Formatted Output from {cur_node} for next node:\n", next_data)
print("--" * 50)
return next_data


Expand Down Expand Up @@ -100,54 +133,41 @@ def add_remote_service(self):
self.megaservice.flow_to(retriever, rerank)

async def handle_request(self, request: Request):
def parser_input(data, TypeClass, key):
chat_request = None
try:
chat_request = TypeClass.parse_obj(data)
query = getattr(chat_request, key)
except:
query = None
return query, chat_request

data = await request.json()
query = None
for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]):
query, chat_request = parser_input(data, TypeClass, key)
if query is not None:
break
if query is None:
raise ValueError(f"Unknown request type: {data}")
if chat_request is None:
raise ValueError(f"Unknown request type: {data}")

if isinstance(chat_request, ChatCompletionRequest):
initial_inputs = {
"messages": query,
"input": query, # has to be input due to embedding expects either input or text
"search_type": chat_request.search_type if chat_request.search_type else "similarity",
"k": chat_request.k if chat_request.k else 4,
"distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None,
"fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20,
"lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
"score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2,
"top_n": chat_request.top_n if chat_request.top_n else 1,
}

kwargs = {
"search_type": chat_request.search_type if chat_request.search_type else "similarity",
"k": chat_request.k if chat_request.k else 4,
"distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None,
"fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20,
"lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
"score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2,
"top_n": chat_request.top_n if chat_request.top_n else 1,
}
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs=initial_inputs,
**kwargs,
)
else:
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"input": query})
chat_request = ChatCompletionRequest.parse_obj(data)

prompt = chat_request.messages

# dummy llm params
parameters = LLMParams(
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
model=chat_request.model if chat_request.model else None,
)

retriever_parameters = RetrieverParms(
search_type=chat_request.search_type if chat_request.search_type else "similarity",
k=chat_request.k if chat_request.k else 4,
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
)
reranker_parameters = RerankerParms(
top_n=chat_request.top_n if chat_request.top_n else 1,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": prompt},
llm_parameters=parameters,
retriever_parameters=retriever_parameters,
reranker_parameters=reranker_parameters,
)

last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]
Expand Down
21 changes: 16 additions & 5 deletions DocIndexRetriever/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any

import requests


def search_knowledge_base(query: str) -> str:
def search_knowledge_base(query: str, args: Any) -> str:
"""Search the knowledge base for a specific query."""
url = os.environ.get("RETRIEVAL_TOOL_URL")
url = os.environ.get("RETRIEVAL_TOOL_URL", "http://localhost:8889/v1/retrievaltool")
print(url)
proxies = {"http": ""}
payload = {"messages": query, "k": 5, "top_n": 2}
payload = {"messages": query, "k": args.k, "top_n": args.top_n}
response = requests.post(url, json=payload, proxies=proxies)
print(response)
if "documents" in response.json():
Expand All @@ -33,6 +34,16 @@ def search_knowledge_base(query: str) -> str:


if __name__ == "__main__":
resp = search_knowledge_base("What is OPEA?")
# resp = search_knowledge_base("Thriller")
import argparse

parser = argparse.ArgumentParser(description="Test the knowledge base search.")
parser.add_argument("--k", type=int, default=5, help="retriever top k")
parser.add_argument("--top_n", type=int, default=2, help="reranker top n")
args = parser.parse_args()

resp = search_knowledge_base("What is OPEA?", args)

print(resp)

if not resp.startswith("Error"):
print("Test successful!")
4 changes: 2 additions & 2 deletions DocIndexRetriever/tests/test_compose_milvus_on_gaudi.sh
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ function validate_megaservice() {
fi

# Curl the Mega Service
echo "================Testing retriever service: Text Request ================"
echo "================Testing retriever service ================"
cd $WORKPATH/tests
local CONTENT=$(http_proxy="" curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
"text": "Explain the OPEA project?"
"messages": "Explain the OPEA project?"
}')

local EXIT_CODE=$(validate "$CONTENT" "OPEA" "doc-index-retriever-service-gaudi")
Expand Down
Loading