Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
51651aa
Backend enhancements for image query capabilities for MultimodalQnA
dmsuehir Nov 22, 2024
f83e2e1
Fix model name var
dmsuehir Nov 22, 2024
1a61cb5
Merge branch 'mmqna-phase2' of github.com:mhbuehler/GenAIComps into d…
dmsuehir Nov 25, 2024
1f0dfcd
Remove space at end of prompt
dmsuehir Nov 26, 2024
107680d
Merge branch 'mmqna-phase2' of github.com:mhbuehler/GenAIComps into d…
dmsuehir Dec 2, 2024
5b51771
Add env var for the max number of images sent to the LVM
dmsuehir Dec 2, 2024
242ee6f
README update for the MAX_IMAGES env var
dmsuehir Dec 2, 2024
8b21819
Merge branch 'dina/image_query' of github.com:mhbuehler/GenAIComps in…
dmsuehir Dec 2, 2024
5b41724
Remove prints
dmsuehir Dec 2, 2024
ae5437a
Audio query functionality to multimodal backend (#8)
okhleif-10 Dec 2, 2024
f4a7199
Merge branch 'mmqna-phase2' of github.com:mhbuehler/GenAIComps into d…
dmsuehir Dec 3, 2024
e1e5fde
Merge branch 'main' into mmqna-audio-query
mhbuehler Dec 4, 2024
70c54e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2024
6a71843
fixed role bug where i never was > 0
okhleif-10 Dec 4, 2024
411bfdf
Fix after merge
dmsuehir Dec 4, 2024
615459b
removed whitespace
okhleif-10 Dec 4, 2024
1753473
Merge pull request #13 from mhbuehler/omar/role-debug
mhbuehler Dec 4, 2024
dcafe8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2024
e32bef4
Fix call to get role labels
dmsuehir Dec 4, 2024
63c08fe
Merge branch 'mmqna-audio-query' of github.com:mhbuehler/GenAIComps i…
dmsuehir Dec 4, 2024
db22c47
Gateway test updates images within the conversation
dmsuehir Dec 5, 2024
fa47959
Adds unit test coverage for audio query
mhbuehler Dec 5, 2024
02efc8a
Update test to check the returned b64 types
dmsuehir Dec 5, 2024
d74bb32
Update test since we don't expect images from the assistant
dmsuehir Dec 5, 2024
37826be
Port number fix
mhbuehler Dec 6, 2024
40d34db
Formatting
mhbuehler Dec 6, 2024
6f2a753
Merge pull request #14 from mhbuehler/melanie/add_test_coverage
mhbuehler Dec 6, 2024
a665c3c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2024
4a5c8ea
Merge branch 'main' into mmqna-audio-query
ashahba Dec 6, 2024
d9ab567
Fixed place where port number is set
mhbuehler Dec 6, 2024
75b135f
Merge pull request #15 from mhbuehler/melanie/port_placement
mhbuehler Dec 6, 2024
9a077c5
Remove old comment and added more accurate description
dmsuehir Dec 9, 2024
b21e575
add comment in code about MAX_IMAGES
dmsuehir Dec 9, 2024
a3abd8a
Add Gaudi support for image query
dmsuehir Dec 10, 2024
b8dbabf
Merge branch 'mmqna-audio-query' of github.com:mhbuehler/GenAIComps i…
dmsuehir Dec 10, 2024
c87504c
Merge branch 'mmqna-image-query' of github.com:mhbuehler/GenAIComps i…
dmsuehir Dec 12, 2024
723f0c3
Fix to pass the retrieved image last
dmsuehir Dec 12, 2024
b1205f4
Revert out gateway and gateway test code, due to its move to GenAIExa…
dmsuehir Dec 12, 2024
bac117a
Fix retriever test for checking for b64_img_str in the result
dmsuehir Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 150 additions & 28 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

import base64
import json
import os
from io import BytesIO
from typing import List, Union

import requests
from fastapi import File, Request, UploadFile
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from PIL import Image

from ..proto.api_protocol import (
Expand Down Expand Up @@ -837,8 +838,12 @@ def parser_input(data, TypeClass, key):


class MultimodalQnAGateway(Gateway):
asr_port = int(os.getenv("ASR_SERVICE_PORT", 3001))
asr_endpoint = os.getenv("ASR_SERVICE_ENDPOINT", "http://0.0.0.0:{}/v1/audio/transcriptions".format(asr_port))

def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", port=9999):
self.lvm_megaservice = lvm_megaservice
self._role_labels = self._get_role_labels()
super().__init__(
multimodal_rag_megaservice,
host,
Expand All @@ -848,33 +853,73 @@ def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0",
ChatCompletionResponse,
)

def _get_role_labels(self):
"""
Returns a dictionary of role labels that are used in the chat prompt based on the LVM_MODEL_ID
environment variable. The function defines the role labels used by the llava-1.5, llava-v1.6-vicuna,
llava-v1.6-mistral, and llava-interleave models, and then defaults to use "USER:" and "ASSISTANT:" if the
LVM_MODEL_ID is not one of those.
"""
lvm_model = os.getenv("LVM_MODEL_ID", "")

# Default to labels used by llava-1.5 and llava-v1.6-vicuna models
role_labels = {
"user": "USER:",
"assistant": "ASSISTANT:"
}

if "llava-interleave" in lvm_model:
role_labels["user"] = "<|im_start|>user"
role_labels["assistant"] = "<|im_end|><|im_start|>assistant"
elif "llava-v1.6-mistral" in lvm_model:
role_labels["user"] = "[INST]"
role_labels["assistant"] = " [/INST]"
elif "llava-1.5" not in lvm_model and "llava-v1.6-vicuna" not in lvm_model:
print(f"[ MultimodalQnAGateway ] Using default role labels for prompt formatting: {role_labels}")

return role_labels

# this overrides _handle_message method of Gateway
def _handle_message(self, messages):
images = []
audios = []
b64_types = {}
messages_dicts = []
decoded_audio_input = ""
if isinstance(messages, str):
prompt = messages
else:
messages_dict = {}
system_prompt = ""
prompt = ""
role_label_dict = self._role_labels
for message in messages:
msg_role = message["role"]
messages_dict = {}
if msg_role == "system":
system_prompt = message["content"]
elif msg_role == "user":
if type(message["content"]) == list:
# separate each media type and store accordingly
text = ""
text_list = [item["text"] for item in message["content"] if item["type"] == "text"]
text += "\n".join(text_list)
image_list = [
item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url"
]
if image_list:
messages_dict[msg_role] = (text, image_list)
else:
audios = [item["audio"] for item in message["content"] if item["type"] == "audio"]
if audios:
# translate audio to text. From this point forward, audio is treated like text
decoded_audio_input = self.convert_audio_to_text(audios)
b64_types["audio"] = decoded_audio_input

if text and not audios and not image_list:
messages_dict[msg_role] = text
elif audios and not text and not image_list:
messages_dict[msg_role] = decoded_audio_input
else:
messages_dict[msg_role] = (text, decoded_audio_input, image_list)

else:
messages_dict[msg_role] = message["content"]
messages_dicts.append(messages_dict)
Expand All @@ -883,23 +928,30 @@ def _handle_message(self, messages):
messages_dicts.append(messages_dict)
else:
raise ValueError(f"Unknown role: {msg_role}")

if system_prompt:
prompt = system_prompt + "\n"
for messages_dict in messages_dicts:
for i, (role, message) in enumerate(messages_dict.items()):
for i, messages_dict in enumerate(messages_dicts):
for role, message in messages_dict.items():
if isinstance(message, tuple):
text, image_list = message
text, decoded_audio_input, image_list = message
# Remove empty items from the image list
image_list = [x for x in image_list if x]
# Add image indicators within the conversation
image_tags = "<image>\n" * len(image_list)
if i == 0:
# do not add role for the very first message.
# this will be added by llava_server
if text:
prompt += text + "\n"
prompt += image_tags + text + "\n"
elif decoded_audio_input:
prompt += image_tags + decoded_audio_input + "\n"
else:
if text:
prompt += role.upper() + ": " + text + "\n"
prompt += role_label_dict[role] + image_tags + " " + text + "\n"
elif decoded_audio_input:
prompt += role_label_dict[role] + image_tags + " " + decoded_audio_input + "\n"
else:
prompt += role.upper() + ":"
prompt += role_label_dict[role] + image_tags
for img in image_list:
# URL
if img.startswith("http://") or img.startswith("https://"):
Expand All @@ -918,42 +970,106 @@ def _handle_message(self, messages):
else:
img_b64_str = img

images.append(img_b64_str)
else:
if image_list:
for img in image_list:
# URL
if img.startswith("http://") or img.startswith("https://"):
response = requests.get(img)
image = Image.open(BytesIO(response.content)).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Local Path
elif os.path.exists(img):
image = Image.open(img).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Bytes
else:
img_b64_str = img

images.append(img_b64_str)

elif isinstance(message, str):
if i == 0:
# do not add role for the very first message.
# this will be added by llava_server
if message:
prompt += role.upper() + ": " + message + "\n"
prompt += message + "\n"
else:
if message:
prompt += role.upper() + ": " + message + "\n"
prompt += role_label_dict[role] + " " + message + "\n"
else:
prompt += role.upper() + ":"
prompt += role_label_dict[role]

if images:
return prompt, images
b64_types["image"] = images

# If the query has multiple media types, return all types
if prompt and b64_types:
return prompt, b64_types
else:
return prompt

def convert_audio_to_text(self, audio):
# translate audio to text by passing in dictionary to ASR
if isinstance(audio, dict):
input_dict = {"byte_str": audio["audio"][0]}
else:
input_dict = {"byte_str": audio[0]}

response = requests.post(self.asr_endpoint, data=json.dumps(input_dict), proxies={"http": None})

if response.status_code != 200:
return JSONResponse(
status_code=503, content={"message": "Unable to convert audio to text. {}".format(response.text)}
)

response = response.json()
return response["query"]

async def handle_request(self, request: Request):
data = await request.json()
stream_opt = bool(data.get("stream", False))
if stream_opt:
print("[ MultimodalQnAGateway ] stream=True not used, this has not support streaming yet!")
stream_opt = False
chat_request = ChatCompletionRequest.model_validate(data)
num_messages = len(data["messages"]) if isinstance(data["messages"], list) else 1

# Multimodal RAG QnA With Videos has not yet accepts image as input during QnA.
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove this comment or update it for accuracy?

prompt_and_image = self._handle_message(chat_request.messages)
if isinstance(prompt_and_image, tuple):
# print(f"This request include image, thus it is a follow-up query. Using lvm megaservice")
prompt, images = prompt_and_image
messages = self._handle_message(chat_request.messages)
decoded_audio_input = ""

if num_messages > 1:
# This is a follow up query, go to LVM
cur_megaservice = self.lvm_megaservice
initial_inputs = {"prompt": prompt, "image": images[0]}
if isinstance(messages, tuple):
prompt, b64_types = messages
if "audio" in b64_types:
# for metadata storage purposes
decoded_audio_input = b64_types["audio"]
if "image" in b64_types:
initial_inputs = {"prompt": prompt, "image": b64_types["image"]}
else:
initial_inputs = {"prompt": prompt, "image": ""}
else:
prompt = messages
initial_inputs = {"prompt": prompt, "image": ""}
else:
# print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice")
prompt = prompt_and_image
# This is the first query. Ignore image input
cur_megaservice = self.megaservice
initial_inputs = {"text": prompt}
if isinstance(messages, tuple):
prompt, b64_types = messages
initial_inputs = {"text": prompt}
if "audio" in b64_types:
# for metadata storage purposes
decoded_audio_input = b64_types["audio"]
if "image" in b64_types and len(b64_types["image"]) > 0:
initial_inputs["image"] = {"base64_image": b64_types["image"][0]}
else:
initial_inputs = {"text": messages}

parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
Expand Down Expand Up @@ -985,18 +1101,24 @@ async def handle_request(self, request: Request):
if "text" in result_dict[last_node].keys():
response = result_dict[last_node]["text"]
else:
# text in not response message
# text is not in response message
# something wrong, for example due to empty retrieval results
if "detail" in result_dict[last_node].keys():
response = result_dict[last_node]["detail"]
else:
response = "The server fail to generate answer to your query!"
response = "The server failed to generate an answer to your query!"
if "metadata" in result_dict[last_node].keys():
# from retrieval results
metadata = result_dict[last_node]["metadata"]
if decoded_audio_input:
metadata["audio"] = decoded_audio_input
else:
# follow-up question, no retrieval
metadata = None
if decoded_audio_input:
metadata = {"audio": decoded_audio_input}
else:
metadata = None

choices = []
usage = UsageInfo()
choices.append(
Expand Down
2 changes: 1 addition & 1 deletion comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ class GraphDoc(BaseDoc):


class LVMDoc(BaseDoc):
image: str
image: Union[str, List[str]]
prompt: str
max_new_tokens: conint(ge=0, le=1024) = 512
top_k: int = 10
Expand Down
7 changes: 7 additions & 0 deletions comps/embeddings/multimodal/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,18 @@ docker compose -f docker_compose_multimodal_embedding.yaml up -d

**Compute a joint embedding of an image-text pair**

The image can be passed as a URL:
```bash
curl -X POST http://0.0.0.0:6600/v1/embeddings \
-H "Content-Type: application/json" \
-d '{"text": {"text" : "This is some sample text."}, "image" : {"url": "https://github.com/docarray/docarray/blob/main/tests/toydata/image-data/apple.png?raw=true"}}'
```
Or as a base64 encoded string:
```bash
curl -X POST http://0.0.0.0:6600/v1/embeddings \
-H "Content-Type: application/json" \
-d '{"text": {"text" : "This is some sample text."}, "image" : {"base64_image": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC"}}'
```

**Compute an embedding of a text**

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import requests
from fastapi.responses import JSONResponse
from typing import Union

from comps import (
CustomLogger,
Expand Down Expand Up @@ -38,7 +39,7 @@
output_datatype=EmbedMultimodalDoc,
)
@register_statistics(names=["opea_service@multimodal_embedding_mmei_langchain"])
def embedding(input: MultimodalDoc) -> EmbedDoc:
def embedding(input: MultimodalDoc) -> Union[EmbedDoc, EmbedMultimodalDoc]:
start = time.time()
if logflag:
logger.info(input)
Expand All @@ -48,9 +49,15 @@ def embedding(input: MultimodalDoc) -> EmbedDoc:
json["text"] = input.text
elif isinstance(input, TextImageDoc):
json["text"] = input.text.text
img_bytes = input.image.url.load_bytes()
base64_img = base64.b64encode(img_bytes).decode("utf-8")
json["img_b64_str"] = base64_img
base64_img = ""
if input.image.url:
img_bytes = input.image.url.load_bytes()
base64_img = base64.b64encode(img_bytes).decode("utf-8")
elif input.image.base64_image:
base64_img = input.image.base64_image

if base64_img:
json["img_b64_str"] = base64_img
else:
return JSONResponse(status_code=400, content={"message": "Bad request!"})

Expand All @@ -66,6 +73,9 @@ def embedding(input: MultimodalDoc) -> EmbedDoc:
res = EmbedDoc(text=input.text, embedding=embed_vector)
elif isinstance(input, TextImageDoc):
res = EmbedMultimodalDoc(text=input.text.text, url=input.image.url, embedding=embed_vector)

if base64_img:
res.base64_image = base64_img
except requests.exceptions.ConnectionError:
res = JSONResponse(status_code=503, content={"message": "Multimodal embedding endpoint not started!"})
statistics_dict["opea_service@multimodal_embedding_mmei_langchain"].append_latency(time.time() - start, None)
Expand Down
Loading