Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TranslationGateway,
SearchQnAGateway,
AudioQnAGateway,
RetrievalToolGateway,
FaqGenGateway,
VisualQnAGateway,
)
Expand Down
1 change: 1 addition & 0 deletions comps/cores/mega/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class MegaServiceEndpoint(Enum):
DOC_SUMMARY = "/v1/docsum"
SEARCH_QNA = "/v1/searchqna"
TRANSLATION = "/v1/translation"
RETRIEVALTOOL = "/v1/retrievaltool"
FAQ_GEN = "/v1/faqgen"
# Follow OPENAI
EMBEDDINGS = "/v1/embeddings"
Expand Down
41 changes: 40 additions & 1 deletion comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import base64
import os
from io import BytesIO
from typing import Union

import requests
from fastapi import Request
Expand All @@ -16,9 +17,10 @@
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
EmbeddingRequest,
UsageInfo,
)
from ..proto.docarray import LLMParams
from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, TextDoc
from .constants import MegaServiceEndpoint, ServiceRoleType, ServiceType
from .micro_service import MicroService

Expand Down Expand Up @@ -529,3 +531,40 @@ async def handle_request(self, request: Request):
)
)
return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage)


class RetrievalToolGateway(Gateway):
"""embed+retrieve+rerank."""

def __init__(self, megaservice, host="0.0.0.0", port=8889):
super().__init__(
megaservice,
host,
port,
str(MegaServiceEndpoint.RETRIEVALTOOL),
Union[TextDoc, EmbeddingRequest, ChatCompletionRequest], # ChatCompletionRequest,
Union[RerankedDoc, LLMParamsDoc], # ChatCompletionResponse
)

async def handle_request(self, request: Request):
def parser_input(data, TypeClass, key):
try:
chat_request = TypeClass.parse_obj(data)
query = getattr(chat_request, key)
except:
query = None
return query

data = await request.json()
query = None
for key, TypeClass in zip(["text", "input", "input"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]):
query = parser_input(data, TypeClass, key)
if query is not None:
break
if query is None:
raise ValueError(f"Unknown request type: {data}")
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query})
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]
print("response is ", response)
return response