From 51651aab9a0fb08a2a0aec8e46b1d7e7c40ab7c0 Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Fri, 22 Nov 2024 15:00:28 -0800 Subject: [PATCH 01/27] Backend enhancements for image query capabilities for MultimodalQnA --- comps/cores/mega/gateway.py | 64 ++++++++++---- comps/cores/proto/docarray.py | 2 +- comps/embeddings/multimodal/README.md | 7 ++ .../multimodal_langchain/mm_embedding_mmei.py | 18 +++- comps/lvms/llava/README.md | 4 + comps/lvms/llava/dependency/llava_server.py | 83 +++++++++++++++---- .../redis/langchain/retriever_redis.py | 6 ++ .../embeddings/test_embeddings_multimodal.sh | 16 ++++ tests/lvms/test_lvms_llava.sh | 23 +++++ ...t_retrievers_multimodal_redis_langchain.sh | 26 ++++++ 10 files changed, 213 insertions(+), 36 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 29642eea55..b3f3c6e16c 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -848,6 +848,32 @@ def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", ChatCompletionResponse, ) + def _get_role_labels(self): + """ + Returns a dictionary of role labels that are used in the chat prompt based on the LVM_MODEL_ID + environment variable. The function defines the role labels used by the llava-1.5, llava-v1.6-vicuna, + llava-v1.6-mistral, and llava-interleave models, and then defaults to use "USER:" and "ASSISTANT:" if the + LVM_MODEL_ID is not one of those. + """ + lvm_model = os.getenv("LVM_MODEL_ID", "") + + # Default to labels used by llava-1.5 and llava-v1.6-vicuna models + role_labels = { + "user": "USER:", + "assistant": "ASSISTANT:" + } + + if "llava-interleave" in lvm_model: + role_labels["user"] = "<|im_start|>user" + role_labels["assistant"] = "<|im_end|><|im_start|>assistant" + elif "llava-v1.6-mistral" in model_name: + role_labels["user"] = "[INST]" + role_labels["assistant"] = " [/INST]" + elif "llava-1.5" not in model_name and "llava-v1.6-vicuna" not in model_name: + print(f"[ MultimodalQnAGateway ] Using default role labels for prompt formatting: {role_labels}") + + return role_labels + # this overrides _handle_message method of Gateway def _handle_message(self, messages): images = [] @@ -858,6 +884,7 @@ def _handle_message(self, messages): messages_dict = {} system_prompt = "" prompt = "" + role_label_dict = self._get_role_labels() for message in messages: msg_role = message["role"] messages_dict = {} @@ -890,16 +917,18 @@ def _handle_message(self, messages): for i, (role, message) in enumerate(messages_dict.items()): if isinstance(message, tuple): text, image_list = message + # Add image indicators within the conversation + image_tags = "\n" * len(image_list) if i == 0: # do not add role for the very first message. # this will be added by llava_server if text: - prompt += text + "\n" + prompt += image_tags + text + "\n" else: if text: - prompt += role.upper() + ": " + text + "\n" + prompt += role_label_dict[role] + image_tags + " " + text + "\n" else: - prompt += role.upper() + ":" + prompt += role_label_dict[role] + image_tags for img in image_list: # URL if img.startswith("http://") or img.startswith("https://"): @@ -924,16 +953,13 @@ def _handle_message(self, messages): # do not add role for the very first message. # this will be added by llava_server if message: - prompt += role.upper() + ": " + message + "\n" + prompt += message + "\n" else: if message: - prompt += role.upper() + ": " + message + "\n" + prompt += role_label_dict[role] + " " + message + "\n" else: - prompt += role.upper() + ":" - if images: - return prompt, images - else: - return prompt + prompt += role_label_dict[role] + return prompt, images async def handle_request(self, request: Request): data = await request.json() @@ -942,18 +968,24 @@ async def handle_request(self, request: Request): print("[ MultimodalQnAGateway ] stream=True not used, this has not support streaming yet!") stream_opt = False chat_request = ChatCompletionRequest.model_validate(data) + num_messages = len(data["messages"]) if isinstance(data["messages"], list) else 1 + # Multimodal RAG QnA With Videos has not yet accepts image as input during QnA. - prompt_and_image = self._handle_message(chat_request.messages) - if isinstance(prompt_and_image, tuple): + prompt, images = self._handle_message(chat_request.messages) + if num_messages > 1: # print(f"This request include image, thus it is a follow-up query. Using lvm megaservice") - prompt, images = prompt_and_image cur_megaservice = self.lvm_megaservice - initial_inputs = {"prompt": prompt, "image": images[0]} + initial_inputs = {"prompt": prompt, "image": images} else: # print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice") - prompt = prompt_and_image + if images and len(images) > 0: + # Formatting as a TextImageDoc + initial_inputs = {"text": {"text": prompt}, "image": {"base64_image": images[0]}} + else: + # Formatting as a TextDoc + initial_inputs = {"text": prompt} + cur_megaservice = self.megaservice - initial_inputs = {"text": prompt} parameters = LLMParams( max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 8c71086f58..56de4a8c60 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -278,7 +278,7 @@ class GraphDoc(BaseDoc): class LVMDoc(BaseDoc): - image: str + image: Union[str, List[str]] prompt: str max_new_tokens: conint(ge=0, le=1024) = 512 top_k: int = 10 diff --git a/comps/embeddings/multimodal/README.md b/comps/embeddings/multimodal/README.md index c75a60f12a..c839365bcd 100644 --- a/comps/embeddings/multimodal/README.md +++ b/comps/embeddings/multimodal/README.md @@ -170,11 +170,18 @@ docker compose -f docker_compose_multimodal_embedding.yaml up -d **Compute a joint embedding of an image-text pair** +The image can be passed as a URL: ```bash curl -X POST http://0.0.0.0:6600/v1/embeddings \ -H "Content-Type: application/json" \ -d '{"text": {"text" : "This is some sample text."}, "image" : {"url": "https://github.com/docarray/docarray/blob/main/tests/toydata/image-data/apple.png?raw=true"}}' ``` +Or as a base64 encoded string: +```bash +curl -X POST http://0.0.0.0:6600/v1/embeddings \ + -H "Content-Type: application/json" \ + -d '{"text": {"text" : "This is some sample text."}, "image" : {"base64_image": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC"}}' +``` **Compute an embedding of a text** diff --git a/comps/embeddings/multimodal/multimodal_langchain/mm_embedding_mmei.py b/comps/embeddings/multimodal/multimodal_langchain/mm_embedding_mmei.py index fbd972a202..cd052fc288 100644 --- a/comps/embeddings/multimodal/multimodal_langchain/mm_embedding_mmei.py +++ b/comps/embeddings/multimodal/multimodal_langchain/mm_embedding_mmei.py @@ -7,6 +7,7 @@ import requests from fastapi.responses import JSONResponse +from typing import Union from comps import ( CustomLogger, @@ -38,7 +39,7 @@ output_datatype=EmbedMultimodalDoc, ) @register_statistics(names=["opea_service@multimodal_embedding_mmei_langchain"]) -def embedding(input: MultimodalDoc) -> EmbedDoc: +def embedding(input: MultimodalDoc) -> Union[EmbedDoc, EmbedMultimodalDoc]: start = time.time() if logflag: logger.info(input) @@ -48,9 +49,15 @@ def embedding(input: MultimodalDoc) -> EmbedDoc: json["text"] = input.text elif isinstance(input, TextImageDoc): json["text"] = input.text.text - img_bytes = input.image.url.load_bytes() - base64_img = base64.b64encode(img_bytes).decode("utf-8") - json["img_b64_str"] = base64_img + base64_img = "" + if input.image.url: + img_bytes = input.image.url.load_bytes() + base64_img = base64.b64encode(img_bytes).decode("utf-8") + elif input.image.base64_image: + base64_img = input.image.base64_image + + if base64_img: + json["img_b64_str"] = base64_img else: return JSONResponse(status_code=400, content={"message": "Bad request!"}) @@ -66,6 +73,9 @@ def embedding(input: MultimodalDoc) -> EmbedDoc: res = EmbedDoc(text=input.text, embedding=embed_vector) elif isinstance(input, TextImageDoc): res = EmbedMultimodalDoc(text=input.text.text, url=input.image.url, embedding=embed_vector) + + if base64_img: + res.base64_image = base64_img except requests.exceptions.ConnectionError: res = JSONResponse(status_code=503, content={"message": "Multimodal embedding endpoint not started!"}) statistics_dict["opea_service@multimodal_embedding_mmei_langchain"].append_latency(time.time() - start, None) diff --git a/comps/lvms/llava/README.md b/comps/lvms/llava/README.md index 998eb4b664..ae19a221b8 100644 --- a/comps/lvms/llava/README.md +++ b/comps/lvms/llava/README.md @@ -106,6 +106,10 @@ docker run -p 9399:9399 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$htt # curl with an image and a prompt http_proxy="" curl http://localhost:9399/v1/lvm -XPOST -d '{"image": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC", "prompt":"What is this?"}' -H 'Content-Type: application/json' +# curl with multiple images and a prompt + +http_proxy="" curl http://localhost:9399/v1/lvm -XPOST -d '{"image": ["iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC", "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC"], "prompt":"What is in these images?"}' -H 'Content-Type: application/json' + # curl with a prompt only (no image) http_proxy="" curl http://localhost:9399/v1/lvm -XPOST -d '{"image": "", "prompt":"What is deep learning?"}' -H 'Content-Type: application/json' diff --git a/comps/lvms/llava/dependency/llava_server.py b/comps/lvms/llava/dependency/llava_server.py index 644e15a82e..b238185b5f 100644 --- a/comps/lvms/llava/dependency/llava_server.py +++ b/comps/lvms/llava/dependency/llava_server.py @@ -13,6 +13,7 @@ import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response +from transformers import AutoProcessor from transformers import pipeline from transformers.image_utils import load_image @@ -33,9 +34,16 @@ def pipeline_preprocess(self, image, prompt=None, timeout=None): The original transformers image-to-text pipeline preprocess function requires that an image is passed in, and will fail if the image parameter is null/empty. In order to support multimodal use cases with the same pipeline, this preprocess function handles the case where there is no image with the prompt. + Also, the image-to-text pipeline typically treats multiple images passed in as a list as a batch (where it iterates + over the image inputs for generation). For that reason, the original pipeline_preprocess code would only get a + single image at a time. To support multiple images, the pipeline call is updated to send a list of lists for the + images (so that when iterated, we still get multiple images) and this pipeline_preprocess function has been updated + to handle a list of images in addition to single images. """ - if image: + if isinstance(image, list): + image = [load_image(i, timeout=timeout) for i in image] + elif image: image = load_image(image, timeout=timeout) if prompt is not None: @@ -46,6 +54,7 @@ def pipeline_preprocess(self, image, prompt=None, timeout=None): ) model_type = self.model.config.model_type + print("Model type: " + model_type) if model_type == "git": if image: @@ -114,23 +123,51 @@ async def health() -> Response: @app.post("/generate") -async def generate(request: Request) -> Response: # FIXME batch_size=1 for now, only accept single image +async def generate(request: Request) -> Response: # FIXME batch_size=1 for now print("LLaVA generation begin.") request_dict = await request.json() prompt = request_dict.pop("prompt") - img_b64_str = request_dict.pop("img_b64_str") + img_b64_str = request_dict.pop("img_b64_str") # String or list of strings max_new_tokens = request_dict.pop("max_new_tokens", 100) + # Determine the format of the role labels based on the model name + model_name = generator.model.name_or_path + user_label = "USER:" + assistant_label = "ASSISTANT:" + image_tag = "\n" + + # This is the role label that we see in the results from the pipeline. This is used to split the output. + output_assistant_label = "ASSISTANT: " + + if "llava-interleave" in model_name: + user_label = "<|im_start|>user" + assistant_label = "<|im_end|><|im_start|>assistant" + output_assistant_label = "assistant " + elif "llava-v1.6-mistral" in model_name: + user_label = "[INST]" + assistant_label = " [/INST]" + output_assistant_label = "[/INST] " + if img_b64_str: - # Decode and Resize the image - image = PIL.Image.open(BytesIO(base64.b64decode(img_b64_str))) - image = process_image(image) - # format the prompt with an image - prompt = f"\nUSER: {prompt}\nASSISTANT:" + if isinstance(img_b64_str, str): + img_b64_str = [img_b64_str] + + # Decode and Resize the images + images = [] + for img_b64 in img_b64_str: + image = PIL.Image.open(BytesIO(base64.b64decode(img_b64))) + image = process_image(image) + images.append(image) + + # If the prompt provided does not have all the image tags, format the prompt with images + num_images = len(images) + num_image_tags = prompt.count(image_tag) + image_tags = image_tag * (num_images - num_image_tags) if num_images > num_image_tags else "" + prompt = f"{user_label}{image_tags} {prompt}{assistant_label} " else: - image = None + images = None # format the prompt with text only - prompt = f"USER: {prompt}\nASSISTANT:" + prompt = f"{user_label} {prompt}\n{assistant_label} " if args.device == "hpu": generate_kwargs = { @@ -149,12 +186,13 @@ async def generate(request: Request) -> Response: # FIXME batch_size=1 for now, # Override the pipeline preprocessing generator.preprocess = pipeline_preprocess.__get__(generator, type(generator)) - result = generator(image, prompt=prompt, batch_size=1, generate_kwargs=generate_kwargs) + result = generator([images], prompt=prompt, batch_size=1, generate_kwargs=generate_kwargs) end = time.time() - result = result[0]["generated_text"].split("ASSISTANT: ")[-1] + result = result[0][0]["generated_text"].split(output_assistant_label)[-1] print(f"LLaVA result = {result}, time = {(end-start) * 1000 }ms") - if image: - image.close() + if images: + for i in images: + i.close() ret = {"text": result} return JSONResponse(ret) @@ -191,6 +229,8 @@ async def generate(request: Request) -> Response: # FIXME batch_size=1 for now, device=args.device, ) + processor = AutoProcessor.from_pretrained(model_name_or_path) + # warmup print("LLaVA warmup...") if args.device == "hpu": @@ -214,10 +254,23 @@ async def generate(request: Request) -> Response: # FIXME batch_size=1 for now, images = [] for image_path in image_paths: images.append(PIL.Image.open(requests.get(image_path, stream=True, timeout=3000).raw)) + + # Generate a text prompt to use for warm up + conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What's the content of the image?"}, + ], + }, + ] + text_prompt = processor.apply_chat_template(conversation) + for i in range(args.warmup): generator( images, - prompt="\nUSER: What's the content of the image?\nASSISTANT:", + prompt=text_prompt, batch_size=1, generate_kwargs=generate_kwargs, ) diff --git a/comps/retrievers/multimodal/redis/langchain/retriever_redis.py b/comps/retrievers/multimodal/redis/langchain/retriever_redis.py index a01b3e20c4..363c54a516 100644 --- a/comps/retrievers/multimodal/redis/langchain/retriever_redis.py +++ b/comps/retrievers/multimodal/redis/langchain/retriever_redis.py @@ -69,6 +69,12 @@ async def retrieve( if isinstance(input, EmbedMultimodalDoc): metadata_list = [] for r in search_res: + # If the input had an image, pass that through in the metadata along with the search result image + if input.base64_image: + if r.metadata["b64_img_str"]: + r.metadata["b64_img_str"] = [r.metadata["b64_img_str"], input.base64_image] + else: + r.metadata["b64_img_str"] = input.base64_image metadata_list.append(r.metadata) retrieved_docs.append(TextDoc(text=r.page_content)) result = SearchedMultimodalDoc(retrieved_docs=retrieved_docs, initial_query=input.text, metadata=metadata_list) diff --git a/tests/embeddings/test_embeddings_multimodal.sh b/tests/embeddings/test_embeddings_multimodal.sh index bd2ca93b70..5bb2fd9f93 100644 --- a/tests/embeddings/test_embeddings_multimodal.sh +++ b/tests/embeddings/test_embeddings_multimodal.sh @@ -85,6 +85,22 @@ function validate_microservice_image_text_pair_embedding() { fi } +function validate_microservice_b64_image_text_pair_embedding() { + result=$(http_proxy="" curl http://${ip_address}:$MM_EMBEDDING_PORT_MICROSERVICE/v1/embeddings \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{"text": {"text" : "This is some sample text."}, "image" : {"base64_image": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC"}}') + + if [[ $result == *"embedding"* ]]; then + echo "Result correct." + else + echo "Result wrong. Received was $result" + docker logs embedding-multimodal-bridgetower + docker logs embedding-multimodal + exit 1 + fi +} + function validate_microservice() { validate_microservice_text_embedding validate_microservice_image_text_pair_embedding diff --git a/tests/lvms/test_lvms_llava.sh b/tests/lvms/test_lvms_llava.sh index 4627ec6ee7..bd2fc3950f 100644 --- a/tests/lvms/test_lvms_llava.sh +++ b/tests/lvms/test_lvms_llava.sh @@ -48,6 +48,29 @@ function validate_microservice() { exit 1 fi + # Test sending two images with a text prompt + result=$(http_proxy="" curl http://localhost:$lvm_port/v1/lvm -XPOST -d '{"image": ["iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC", "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC"], "prompt":"What are in these images?"}' -H 'Content-Type: application/json') + if [[ $result == *"blue"* ]] && [[ $result == *"green"* ]]; then + echo "Result correct." + else + echo "Result wrong." + docker logs test-comps-lvm-llava >> ${LOG_PATH}/llava-dependency.log + docker logs test-comps-lvm-llava-svc >> ${LOG_PATH}/llava-server.log + exit 1 + fi + + # Test sending two images with a text prompt where the prompt has only one image tag + # (the LVM microservice should add the second one) + result=$(http_proxy="" curl http://localhost:$lvm_port/v1/lvm -XPOST -d '{"image": ["iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC", "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC"], "prompt":"\nWhat are in these images?"}' -H 'Content-Type: application/json') + if [[ $result == *"blue"* ]] && [[ $result == *"green"* ]]; then + echo "Result correct." + else + echo "Result wrong." + docker logs test-comps-lvm-llava >> ${LOG_PATH}/llava-dependency.log + docker logs test-comps-lvm-llava-svc >> ${LOG_PATH}/llava-server.log + exit 1 + fi + result=$(http_proxy="" curl http://localhost:$lvm_port/v1/lvm -XPOST -d '{"retrieved_docs": [], "initial_query": "What is this?", "top_n": 1, "metadata": [{"b64_img_str": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC", "transcript_for_inference": "yellow image", "video_id": "8c7461df-b373-4a00-8696-9a2234359fe0", "time_of_frame_ms":"37000000", "source_video":"WeAreGoingOnBullrun_8c7461df-b373-4a00-8696-9a2234359fe0.mp4"}]}' -H 'Content-Type: application/json') if [[ $result == *"yellow"* ]]; then echo "Result correct." diff --git a/tests/retrievers/test_retrievers_multimodal_redis_langchain.sh b/tests/retrievers/test_retrievers_multimodal_redis_langchain.sh index 873516ddc5..bd256e6e05 100644 --- a/tests/retrievers/test_retrievers_multimodal_redis_langchain.sh +++ b/tests/retrievers/test_retrievers_multimodal_redis_langchain.sh @@ -58,6 +58,32 @@ function validate_microservice() { docker logs test-comps-retriever-multimodal-redis >> ${LOG_PATH}/retriever.log exit 1 fi + + # Test the retriever with a b64 image that should be passed through + HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X POST -d "{\"text\":\"test\",\"embedding\":${test_embedding},\"img_b64_str\":\"iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC\"}" -H 'Content-Type: application/json' "$URL") + if [ "$HTTP_STATUS" -eq 200 ]; then + echo "[ retriever ] HTTP status is 200. Checking content..." + local CONTENT=$(curl -s -X POST -d "{\"text\":\"test\",\"embedding\":${test_embedding}}" -H 'Content-Type: application/json' "$URL" | tee ${LOG_PATH}/retriever.log) + + if echo "$CONTENT" | grep -q "retrieved_docs"; then + echo "[ retriever ] Content has retrieved_docs as expected." + if echo "$CONTENT" | grep -q "retrieved_docs"; then + echo "[ retriever ] Content has img_b64_str as expected." + else + echo "[ retriever ] Content does not include the img_b64_str: $CONTENT" + docker logs test-comps-retriever-multimodal-redis >> ${LOG_PATH}/retriever.log + exit 1 + fi + else + echo "[ retriever ] Content does not match the expected result: $CONTENT" + docker logs test-comps-retriever-multimodal-redis >> ${LOG_PATH}/retriever.log + exit 1 + fi + else + echo "[ retriever ] HTTP status is not 200. Received status was $HTTP_STATUS" + docker logs test-comps-retriever-multimodal-redis >> ${LOG_PATH}/retriever.log + exit 1 + fi } function stop_docker() { From f83e2e1f4d1a9fe2b74b605d655ff5031b4f6538 Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Fri, 22 Nov 2024 15:47:10 -0800 Subject: [PATCH 02/27] Fix model name var Signed-off-by: dmsuehir --- comps/cores/mega/gateway.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index b3f3c6e16c..48b242c976 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -866,10 +866,10 @@ def _get_role_labels(self): if "llava-interleave" in lvm_model: role_labels["user"] = "<|im_start|>user" role_labels["assistant"] = "<|im_end|><|im_start|>assistant" - elif "llava-v1.6-mistral" in model_name: + elif "llava-v1.6-mistral" in lvm_model: role_labels["user"] = "[INST]" role_labels["assistant"] = " [/INST]" - elif "llava-1.5" not in model_name and "llava-v1.6-vicuna" not in model_name: + elif "llava-1.5" not in lvm_model and "llava-v1.6-vicuna" not in lvm_model: print(f"[ MultimodalQnAGateway ] Using default role labels for prompt formatting: {role_labels}") return role_labels From 1f0dfcd43df71b772224e415172a69b1a1087136 Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Mon, 25 Nov 2024 17:16:50 -0800 Subject: [PATCH 03/27] Remove space at end of prompt Signed-off-by: dmsuehir --- comps/lvms/llava/dependency/llava_server.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/comps/lvms/llava/dependency/llava_server.py b/comps/lvms/llava/dependency/llava_server.py index b238185b5f..c21561701f 100644 --- a/comps/lvms/llava/dependency/llava_server.py +++ b/comps/lvms/llava/dependency/llava_server.py @@ -155,19 +155,20 @@ async def generate(request: Request) -> Response: # FIXME batch_size=1 for now # Decode and Resize the images images = [] for img_b64 in img_b64_str: - image = PIL.Image.open(BytesIO(base64.b64decode(img_b64))) - image = process_image(image) - images.append(image) + if img_b64: + image = PIL.Image.open(BytesIO(base64.b64decode(img_b64))) + image = process_image(image) + images.append(image) # If the prompt provided does not have all the image tags, format the prompt with images num_images = len(images) num_image_tags = prompt.count(image_tag) image_tags = image_tag * (num_images - num_image_tags) if num_images > num_image_tags else "" - prompt = f"{user_label}{image_tags} {prompt}{assistant_label} " + prompt = f"{user_label}{image_tags} {prompt}{assistant_label}" else: images = None # format the prompt with text only - prompt = f"{user_label} {prompt}\n{assistant_label} " + prompt = f"{user_label} {prompt}\n{assistant_label}" if args.device == "hpu": generate_kwargs = { From 5b51771773813d89894699896ed8d7aa3ed644f1 Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Mon, 2 Dec 2024 14:07:59 -0800 Subject: [PATCH 04/27] Add env var for the max number of images sent to the LVM Signed-off-by: dmsuehir --- comps/cores/mega/gateway.py | 9 ++++---- comps/lvms/llava/dependency/llava_server.py | 4 +++- comps/lvms/llava/lvm.py | 15 +++++++++++++ tests/lvms/test_lvms_llava.sh | 25 ++++++++++++++++----- 4 files changed, 42 insertions(+), 11 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 48b242c976..28dca91702 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -910,13 +910,14 @@ def _handle_message(self, messages): messages_dicts.append(messages_dict) else: raise ValueError(f"Unknown role: {msg_role}") - if system_prompt: prompt = system_prompt + "\n" - for messages_dict in messages_dicts: - for i, (role, message) in enumerate(messages_dict.items()): + for i, messages_dict in enumerate(messages_dicts): + for role, message in messages_dict.items(): if isinstance(message, tuple): text, image_list = message + # Remove empty items from the image list + image_list = [x for x in image_list if x] # Add image indicators within the conversation image_tags = "\n" * len(image_list) if i == 0: @@ -973,7 +974,7 @@ async def handle_request(self, request: Request): # Multimodal RAG QnA With Videos has not yet accepts image as input during QnA. prompt, images = self._handle_message(chat_request.messages) if num_messages > 1: - # print(f"This request include image, thus it is a follow-up query. Using lvm megaservice") + # print(f"There is more than one message, thus it is a follow-up query. Using lvm megaservice") cur_megaservice = self.lvm_megaservice initial_inputs = {"prompt": prompt, "image": images} else: diff --git a/comps/lvms/llava/dependency/llava_server.py b/comps/lvms/llava/dependency/llava_server.py index c21561701f..3b8b3be7da 100644 --- a/comps/lvms/llava/dependency/llava_server.py +++ b/comps/lvms/llava/dependency/llava_server.py @@ -170,6 +170,8 @@ async def generate(request: Request) -> Response: # FIXME batch_size=1 for now # format the prompt with text only prompt = f"{user_label} {prompt}\n{assistant_label}" + print(repr(prompt)) + if args.device == "hpu": generate_kwargs = { "lazy_mode": True, @@ -189,7 +191,7 @@ async def generate(request: Request) -> Response: # FIXME batch_size=1 for now result = generator([images], prompt=prompt, batch_size=1, generate_kwargs=generate_kwargs) end = time.time() - result = result[0][0]["generated_text"].split(output_assistant_label)[-1] + result = result[0][0]["generated_text"].split(output_assistant_label.strip())[-1].strip() print(f"LLaVA result = {result}, time = {(end-start) * 1000 }ms") if images: for i in images: diff --git a/comps/lvms/llava/lvm.py b/comps/lvms/llava/lvm.py index 897f7cbbe4..425576debd 100644 --- a/comps/lvms/llava/lvm.py +++ b/comps/lvms/llava/lvm.py @@ -27,6 +27,7 @@ logger = CustomLogger("lvm") logflag = os.getenv("LOGFLAG", False) +max_images = int(os.getenv("MAX_IMAGES", 1)) @register_microservice( @@ -76,6 +77,17 @@ async def lvm(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> Union[TextDoc prompt = request.prompt max_new_tokens = request.max_new_tokens + # Limit the number of images being sent to the LVM + if isinstance(img_b64_str, list) and len(img_b64_str) > max_images: + img_b64_str=img_b64_str[-max_images:] + + # Adjust the number of images tags in the prompt + image_tag = "\n" + num_tags_in_prompt = prompt.count(image_tag) + + if len(img_b64_str) < num_tags_in_prompt: + prompt = prompt.replace(image_tag, "", num_tags_in_prompt - len(img_b64_str)) + inputs = {"img_b64_str": img_b64_str, "prompt": prompt, "max_new_tokens": max_new_tokens} # forward to the LLaVA server response = requests.post(url=f"{lvm_endpoint}/generate", data=json.dumps(inputs), proxies={"http": None}) @@ -99,5 +111,8 @@ async def lvm(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> Union[TextDoc if __name__ == "__main__": lvm_endpoint = os.getenv("LVM_ENDPOINT", "http://localhost:8399") + if logflag: + logger.info(f"MAX_IMAGES: {max_images}") + logger.info("[LVM] LVM initialized.") opea_microservices["opea_service@lvm"].start() diff --git a/tests/lvms/test_lvms_llava.sh b/tests/lvms/test_lvms_llava.sh index bd2fc3950f..8558fa5e3d 100644 --- a/tests/lvms/test_lvms_llava.sh +++ b/tests/lvms/test_lvms_llava.sh @@ -48,9 +48,22 @@ function validate_microservice() { exit 1 fi - # Test sending two images with a text prompt + # Test sending two images with a text prompt with one image tag in the prompt. + # The first image is green and the second image is blue. Since the default MAX_IMAGES is 1, only the blue image should be sent to the LVM. + result=$(http_proxy="" curl http://localhost:$lvm_port/v1/lvm -XPOST -d '{"image": ["iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC", "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC"], "prompt":"\nWhat are in these images?"}' -H 'Content-Type: application/json') + if [[ $result == *"blue"* ]]; then + echo "Result correct." + else + echo "Result wrong." + docker logs test-comps-lvm-llava >> ${LOG_PATH}/llava-dependency.log + docker logs test-comps-lvm-llava-svc >> ${LOG_PATH}/llava-server.log + exit 1 + fi + + # Test sending two images with a text prompt without any image tags. + # The first image is blue and the second image is green. Since the default MAX_IMAGES is 1, only the green image should be sent to the LVM. result=$(http_proxy="" curl http://localhost:$lvm_port/v1/lvm -XPOST -d '{"image": ["iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC", "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC"], "prompt":"What are in these images?"}' -H 'Content-Type: application/json') - if [[ $result == *"blue"* ]] && [[ $result == *"green"* ]]; then + if [[ $result == *"green"* ]]; then echo "Result correct." else echo "Result wrong." @@ -59,10 +72,10 @@ function validate_microservice() { exit 1 fi - # Test sending two images with a text prompt where the prompt has only one image tag - # (the LVM microservice should add the second one) - result=$(http_proxy="" curl http://localhost:$lvm_port/v1/lvm -XPOST -d '{"image": ["iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC", "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC"], "prompt":"\nWhat are in these images?"}' -H 'Content-Type: application/json') - if [[ $result == *"blue"* ]] && [[ $result == *"green"* ]]; then + # Same test as above, except including two image tags with the prompt to ensure the number of image tags is reconciled. + # The first image is blue and the second image is green. Since the default MAX_IMAGES is 1, only the green image should be sent to the LVM. + result=$(http_proxy="" curl http://localhost:$lvm_port/v1/lvm -XPOST -d '{"image": ["iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC", "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC"], "prompt":"\n\nWhat are in these images?"}' -H 'Content-Type: application/json') + if [[ $result == *"green"* ]]; then echo "Result correct." else echo "Result wrong." From 242ee6f70be2d5504e16d8a02335dcfa0b1939d9 Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Mon, 2 Dec 2024 14:21:04 -0800 Subject: [PATCH 05/27] README update for the MAX_IMAGES env var Signed-off-by: dmsuehir --- comps/lvms/llava/README.md | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/comps/lvms/llava/README.md b/comps/lvms/llava/README.md index ae19a221b8..74e1de706f 100644 --- a/comps/lvms/llava/README.md +++ b/comps/lvms/llava/README.md @@ -1,6 +1,6 @@ # LVM Microservice -Visual Question and Answering is one of the multimodal tasks empowered by LVMs (Large Visual Models). This microservice supports visual Q&A by using LLaVA as the base large visual model. It accepts two inputs: a prompt and an image. It outputs the answer to the prompt about the image. +Visual Question and Answering is one of the multimodal tasks empowered by LVMs (Large Visual Models). This microservice supports visual Q&A by using LLaVA as the base large visual model. It accepts two inputs: a prompt and images. It outputs the answer to the prompt about the images. ## 🚀1. Start Microservice with Python (Option 1) @@ -92,10 +92,15 @@ docker run -p 8399:8399 --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_M #### 2.2.2 Start LVM service +> Note: The `MAX_IMAGES` environment variable is used to specify the maximum number of images that will be sent from the LVM service to the LLaVA server. +> If an image list longer than `MAX_IMAGES` is sent to the LVM server, a shortened image list will be sent to the LLaVA service. If the image list +> needs to be shortened, the most recent images (the ones at the end of the list) are prioritized to send to the LLaVA service. Some LLaVA models have not +> been trained with multiple images and may lead to inaccurate results. If `MAX_IMAGES` is not set, it will default to `1`. + ```bash ip_address=$(hostname -I | awk '{print $1}') -docker run -p 9399:9399 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e LVM_ENDPOINT=http://$ip_address:8399 opea/lvm-llava-svc:latest +docker run -p 9399:9399 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e LVM_ENDPOINT=http://$ip_address:8399 -e MAX_IMAGES=1 opea/lvm-llava-svc:latest ``` #### 2.2.3 Test @@ -106,8 +111,7 @@ docker run -p 9399:9399 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$htt # curl with an image and a prompt http_proxy="" curl http://localhost:9399/v1/lvm -XPOST -d '{"image": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC", "prompt":"What is this?"}' -H 'Content-Type: application/json' -# curl with multiple images and a prompt - +# curl with multiple images and a prompt (Note that depending on your MAX_IMAGES value, both images may not be sent to the LLaVA model) http_proxy="" curl http://localhost:9399/v1/lvm -XPOST -d '{"image": ["iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC", "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC"], "prompt":"What is in these images?"}' -H 'Content-Type: application/json' # curl with a prompt only (no image) From 5b41724f834f30a4cc02ad6dc17f59b039c33e5f Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Mon, 2 Dec 2024 15:08:30 -0800 Subject: [PATCH 06/27] Remove prints Signed-off-by: dmsuehir --- comps/lvms/llava/dependency/llava_server.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/comps/lvms/llava/dependency/llava_server.py b/comps/lvms/llava/dependency/llava_server.py index 3b8b3be7da..4fc0043805 100644 --- a/comps/lvms/llava/dependency/llava_server.py +++ b/comps/lvms/llava/dependency/llava_server.py @@ -54,7 +54,6 @@ def pipeline_preprocess(self, image, prompt=None, timeout=None): ) model_type = self.model.config.model_type - print("Model type: " + model_type) if model_type == "git": if image: @@ -170,8 +169,6 @@ async def generate(request: Request) -> Response: # FIXME batch_size=1 for now # format the prompt with text only prompt = f"{user_label} {prompt}\n{assistant_label}" - print(repr(prompt)) - if args.device == "hpu": generate_kwargs = { "lazy_mode": True, From ae5437ad30bbb768b09c1c0105ef7ccc71efe3e7 Mon Sep 17 00:00:00 2001 From: Omar Khleif Date: Mon, 2 Dec 2024 15:39:59 -0800 Subject: [PATCH 07/27] Audio query functionality to multimodal backend (#8) Signed-off-by: okhleif-IL * added in audio dict creation Signed-off-by: okhleif-IL * separated audio from prompt Signed-off-by: okhleif-IL * added ASR endpoint Signed-off-by: okhleif-IL * removed ASR endpoints from mm embedding Signed-off-by: okhleif-IL * edited return logic, fixed function call Signed-off-by: okhleif-IL * added megaservice to elif Signed-off-by: okhleif-IL * reworked helper func Signed-off-by: okhleif-IL * Append audio to prompt Signed-off-by: okhleif-IL * Reworked handle messages, added metadata Signed-off-by: okhleif-IL * Moved dictionary logic to right place Signed-off-by: okhleif-IL * changed logic to rely on message len Signed-off-by: okhleif-IL * list --> empty str Signed-off-by: okhleif-IL --------- Signed-off-by: Melanie Buehler Signed-off-by: okhleif-IL Signed-off-by: dmsuehir --- comps/cores/mega/gateway.py | 147 ++++++++++++++++++++++++++---------- 1 file changed, 109 insertions(+), 38 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 29642eea55..30245be633 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -2,13 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 import base64 +import json import os from io import BytesIO from typing import List, Union import requests from fastapi import File, Request, UploadFile -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse from PIL import Image from ..proto.api_protocol import ( @@ -837,6 +838,8 @@ def parser_input(data, TypeClass, key): class MultimodalQnAGateway(Gateway): + asr_port = int(os.getenv("ASR_SERVICE_PORT", 3001)) + asr_endpoint = os.getenv("ASR_SERVICE_ENDPOINT", "http://0.0.0.0:{}/v1/audio/transcriptions".format(asr_port)) def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", port=9999): self.lvm_megaservice = lvm_megaservice super().__init__( @@ -847,11 +850,13 @@ def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", ChatCompletionRequest, ChatCompletionResponse, ) - # this overrides _handle_message method of Gateway def _handle_message(self, messages): images = [] + audios = [] + b64_types = {} messages_dicts = [] + decoded_audio_input = "" if isinstance(messages, str): prompt = messages else: @@ -865,16 +870,28 @@ def _handle_message(self, messages): system_prompt = message["content"] elif msg_role == "user": if type(message["content"]) == list: + # separate each media type and store accordingly text = "" text_list = [item["text"] for item in message["content"] if item["type"] == "text"] text += "\n".join(text_list) image_list = [ item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url" ] - if image_list: - messages_dict[msg_role] = (text, image_list) - else: + audios = [ + item["audio"] for item in message["content"] if item["type"] == "audio" + ] + if audios: + # translate audio to text. From this point forward, audio is treated like text + decoded_audio_input = self.convert_audio_to_text(audios) + b64_types["audio"] = decoded_audio_input + + if text and not audios and not image_list: messages_dict[msg_role] = text + elif audios and not text and not image_list: + messages_dict[msg_role] = decoded_audio_input + else: + messages_dict[msg_role] = (text, decoded_audio_input, image_list) + else: messages_dict[msg_role] = message["content"] messages_dicts.append(messages_dict) @@ -889,51 +906,80 @@ def _handle_message(self, messages): for messages_dict in messages_dicts: for i, (role, message) in enumerate(messages_dict.items()): if isinstance(message, tuple): - text, image_list = message + text, decoded_audio_input, image_list = message if i == 0: # do not add role for the very first message. # this will be added by llava_server if text: prompt += text + "\n" + elif decoded_audio_input: + prompt += decoded_audio_input + "\n" else: if text: prompt += role.upper() + ": " + text + "\n" + elif decoded_audio_input: + prompt += role.upper() + ": " + decoded_audio_input + "\n" else: prompt += role.upper() + ":" - for img in image_list: - # URL - if img.startswith("http://") or img.startswith("https://"): - response = requests.get(img) - image = Image.open(BytesIO(response.content)).convert("RGBA") - image_bytes = BytesIO() - image.save(image_bytes, format="PNG") - img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() - # Local Path - elif os.path.exists(img): - image = Image.open(img).convert("RGBA") - image_bytes = BytesIO() - image.save(image_bytes, format="PNG") - img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() - # Bytes - else: - img_b64_str = img - images.append(img_b64_str) - else: + if image_list: + for img in image_list: + # URL + if img.startswith("http://") or img.startswith("https://"): + response = requests.get(img) + image = Image.open(BytesIO(response.content)).convert("RGBA") + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() + # Local Path + elif os.path.exists(img): + image = Image.open(img).convert("RGBA") + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() + # Bytes + else: + img_b64_str = img + + images.append(img_b64_str) + + elif isinstance(message, str): if i == 0: # do not add role for the very first message. # this will be added by llava_server if message: - prompt += role.upper() + ": " + message + "\n" + prompt += message + "\n" else: if message: prompt += role.upper() + ": " + message + "\n" else: - prompt += role.upper() + ":" + prompt += role.upper() + ":" + + if images: - return prompt, images + b64_types["image"] = images + + # If the query has multiple media types, return all types + if prompt and b64_types: + return prompt, b64_types else: return prompt + + def convert_audio_to_text(self, audio): + # translate audio to text by passing in dictionary to ASR + if isinstance(audio, dict): + input_dict = {"byte_str": audio["audio"][0]} + else: + input_dict = {"byte_str": audio[0]} + + response = requests.post(self.asr_endpoint, data=json.dumps(input_dict), proxies={"http": None}) + + if response.status_code != 200: + return JSONResponse(status_code=503, content={"message": "Unable to convert audio to text. {}".format( + response.text)}) + + response = response.json() + return response["query"] async def handle_request(self, request: Request): data = await request.json() @@ -943,16 +989,35 @@ async def handle_request(self, request: Request): stream_opt = False chat_request = ChatCompletionRequest.model_validate(data) # Multimodal RAG QnA With Videos has not yet accepts image as input during QnA. - prompt_and_image = self._handle_message(chat_request.messages) - if isinstance(prompt_and_image, tuple): - # print(f"This request include image, thus it is a follow-up query. Using lvm megaservice") - prompt, images = prompt_and_image + num_messages = len(data["messages"]) if isinstance(data["messages"], list) else 1 + messages = self._handle_message(chat_request.messages) + decoded_audio_input = "" + + if num_messages > 1: + # This is a follow up query, go to LVM cur_megaservice = self.lvm_megaservice - initial_inputs = {"prompt": prompt, "image": images[0]} + if isinstance(messages, tuple): + prompt, b64_types = messages + if "audio" in b64_types: + # for metadata storage purposes + decoded_audio_input = b64_types["audio"] + if "image" in b64_types: + initial_inputs = {"prompt": prompt, "image": b64_types["image"][0]} + else: + initial_inputs = {"prompt": prompt, "image": ""} + else: + prompt = messages + initial_inputs = {"prompt": prompt, "image": ""} else: - # print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice") - prompt = prompt_and_image + # This is the first query. Ignore image input cur_megaservice = self.megaservice + if isinstance(messages, tuple): + prompt, b64_types = messages + if "audio" in b64_types: + # for metadata storage purposes + decoded_audio_input = b64_types["audio"] + else: + prompt = messages initial_inputs = {"text": prompt} parameters = LLMParams( @@ -985,18 +1050,24 @@ async def handle_request(self, request: Request): if "text" in result_dict[last_node].keys(): response = result_dict[last_node]["text"] else: - # text in not response message + # text is not in response message # something wrong, for example due to empty retrieval results if "detail" in result_dict[last_node].keys(): response = result_dict[last_node]["detail"] else: - response = "The server fail to generate answer to your query!" + response = "The server failed to generate an answer to your query!" if "metadata" in result_dict[last_node].keys(): # from retrieval results metadata = result_dict[last_node]["metadata"] + if decoded_audio_input: + metadata["audio"] = decoded_audio_input else: # follow-up question, no retrieval - metadata = None + if decoded_audio_input: + metadata = {"audio": decoded_audio_input} + else: + metadata = None + choices = [] usage = UsageInfo() choices.append( From 70c54e19d8b913a9d25bda514ac2091051c6c996 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Dec 2024 19:02:40 +0000 Subject: [PATCH 08/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- comps/cores/mega/gateway.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 30245be633..f4e7518852 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -9,7 +9,7 @@ import requests from fastapi import File, Request, UploadFile -from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse from PIL import Image from ..proto.api_protocol import ( @@ -840,6 +840,7 @@ def parser_input(data, TypeClass, key): class MultimodalQnAGateway(Gateway): asr_port = int(os.getenv("ASR_SERVICE_PORT", 3001)) asr_endpoint = os.getenv("ASR_SERVICE_ENDPOINT", "http://0.0.0.0:{}/v1/audio/transcriptions".format(asr_port)) + def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", port=9999): self.lvm_megaservice = lvm_megaservice super().__init__( @@ -850,6 +851,7 @@ def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", ChatCompletionRequest, ChatCompletionResponse, ) + # this overrides _handle_message method of Gateway def _handle_message(self, messages): images = [] @@ -877,14 +879,12 @@ def _handle_message(self, messages): image_list = [ item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url" ] - audios = [ - item["audio"] for item in message["content"] if item["type"] == "audio" - ] + audios = [item["audio"] for item in message["content"] if item["type"] == "audio"] if audios: # translate audio to text. From this point forward, audio is treated like text decoded_audio_input = self.convert_audio_to_text(audios) - b64_types["audio"] = decoded_audio_input - + b64_types["audio"] = decoded_audio_input + if text and not audios and not image_list: messages_dict[msg_role] = text elif audios and not text and not image_list: @@ -953,18 +953,17 @@ def _handle_message(self, messages): if message: prompt += role.upper() + ": " + message + "\n" else: - prompt += role.upper() + ":" - + prompt += role.upper() + ":" if images: - b64_types["image"] = images + b64_types["image"] = images - # If the query has multiple media types, return all types + # If the query has multiple media types, return all types if prompt and b64_types: return prompt, b64_types else: return prompt - + def convert_audio_to_text(self, audio): # translate audio to text by passing in dictionary to ASR if isinstance(audio, dict): @@ -973,10 +972,11 @@ def convert_audio_to_text(self, audio): input_dict = {"byte_str": audio[0]} response = requests.post(self.asr_endpoint, data=json.dumps(input_dict), proxies={"http": None}) - + if response.status_code != 200: - return JSONResponse(status_code=503, content={"message": "Unable to convert audio to text. {}".format( - response.text)}) + return JSONResponse( + status_code=503, content={"message": "Unable to convert audio to text. {}".format(response.text)} + ) response = response.json() return response["query"] @@ -992,7 +992,7 @@ async def handle_request(self, request: Request): num_messages = len(data["messages"]) if isinstance(data["messages"], list) else 1 messages = self._handle_message(chat_request.messages) decoded_audio_input = "" - + if num_messages > 1: # This is a follow up query, go to LVM cur_megaservice = self.lvm_megaservice @@ -1060,7 +1060,7 @@ async def handle_request(self, request: Request): # from retrieval results metadata = result_dict[last_node]["metadata"] if decoded_audio_input: - metadata["audio"] = decoded_audio_input + metadata["audio"] = decoded_audio_input else: # follow-up question, no retrieval if decoded_audio_input: From 6a718430ce0f5d46190f93fc6a096d25be0712af Mon Sep 17 00:00:00 2001 From: okhleif-IL Date: Wed, 4 Dec 2024 14:37:59 -0800 Subject: [PATCH 09/27] fixed role bug where i never was > 0 Signed-off-by: okhleif-IL --- comps/cores/mega/gateway.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index f4e7518852..04e145d209 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -903,8 +903,8 @@ def _handle_message(self, messages): if system_prompt: prompt = system_prompt + "\n" - for messages_dict in messages_dicts: - for i, (role, message) in enumerate(messages_dict.items()): + for i, messages_dict in enumerate(messages_dicts): + for (role, message) in messages_dict.items(): if isinstance(message, tuple): text, decoded_audio_input, image_list = message if i == 0: From 411bfdfa77f2d77ed24c06c832d1ada888797792 Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Wed, 4 Dec 2024 14:58:56 -0800 Subject: [PATCH 10/27] Fix after merge Signed-off-by: dmsuehir --- comps/cores/mega/gateway.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 5d2f9f7901..ee0f783e1a 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -850,6 +850,7 @@ def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", ChatCompletionRequest, ChatCompletionResponse, ) + self._role_labels = _get_role_labels() def _get_role_labels(self): """ @@ -890,7 +891,7 @@ def _handle_message(self, messages): messages_dict = {} system_prompt = "" prompt = "" - role_label_dict = self._get_role_labels() + role_label_dict = self._role_labels for message in messages: msg_role = message["role"] messages_dict = {} @@ -1038,7 +1039,7 @@ async def handle_request(self, request: Request): num_messages = len(data["messages"]) if isinstance(data["messages"], list) else 1 # Multimodal RAG QnA With Videos has not yet accepts image as input during QnA. - prompt, images = self._handle_message(chat_request.messages) + messages = self._handle_message(chat_request.messages) decoded_audio_input = "" if num_messages > 1: From 615459b00bdfef2bd1d7f81df523156eabb826de Mon Sep 17 00:00:00 2001 From: okhleif-IL Date: Wed, 4 Dec 2024 15:14:28 -0800 Subject: [PATCH 11/27] removed whitespace Signed-off-by: okhleif-IL --- comps/cores/mega/gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 04e145d209..6865a49ae1 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -903,7 +903,7 @@ def _handle_message(self, messages): if system_prompt: prompt = system_prompt + "\n" - for i, messages_dict in enumerate(messages_dicts): + for i, messages_dict in enumerate(messages_dicts): for (role, message) in messages_dict.items(): if isinstance(message, tuple): text, decoded_audio_input, image_list = message From dcafe8decc34c3857d9d872e195cd9e221db6754 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Dec 2024 23:20:57 +0000 Subject: [PATCH 12/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- comps/cores/mega/gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 6865a49ae1..cadc6d102d 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -904,7 +904,7 @@ def _handle_message(self, messages): if system_prompt: prompt = system_prompt + "\n" for i, messages_dict in enumerate(messages_dicts): - for (role, message) in messages_dict.items(): + for role, message in messages_dict.items(): if isinstance(message, tuple): text, decoded_audio_input, image_list = message if i == 0: From e32bef4858223e6e188e25d3d76777e4c2ad26fd Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Wed, 4 Dec 2024 15:49:37 -0800 Subject: [PATCH 13/27] Fix call to get role labels Signed-off-by: dmsuehir --- comps/cores/mega/gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index ee0f783e1a..3c52305247 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -842,6 +842,7 @@ class MultimodalQnAGateway(Gateway): asr_endpoint = os.getenv("ASR_SERVICE_ENDPOINT", "http://0.0.0.0:{}/v1/audio/transcriptions".format(asr_port)) def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", port=9999): self.lvm_megaservice = lvm_megaservice + self._role_labels = self._get_role_labels() super().__init__( multimodal_rag_megaservice, host, @@ -850,7 +851,6 @@ def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", ChatCompletionRequest, ChatCompletionResponse, ) - self._role_labels = _get_role_labels() def _get_role_labels(self): """ From db22c47d09a6faa9cc62e28ab7ff7db0caf5521e Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Thu, 5 Dec 2024 14:34:30 -0800 Subject: [PATCH 14/27] Gateway test updates images within the conversation Signed-off-by: dmsuehir --- .../cores/mega/test_multimodalqna_gateway.py | 95 +++++++++++++++++-- 1 file changed, 89 insertions(+), 6 deletions(-) diff --git a/tests/cores/mega/test_multimodalqna_gateway.py b/tests/cores/mega/test_multimodalqna_gateway.py index c05bf57bdd..035a508e82 100644 --- a/tests/cores/mega/test_multimodalqna_gateway.py +++ b/tests/cores/mega/test_multimodalqna_gateway.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +import os import unittest from typing import Union @@ -65,7 +66,23 @@ async def lvm_add(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc: else: print("request is from user.") text = req_dict["prompt"] - text = f"\nUSER: {text}\nASSISTANT:" + image_tag = "" + + # There may already be image tags interleaved within the prompt. The LVM service checks that and + # adds image tag(s) if they are needed. + if "image" in req_dict.keys(): + num_tags_in_prompt = text.count("\n") + if isinstance(req_dict["image"], list): + image_list = req_dict["image"] + else: + image_list = [req_dict["image"]] + num_images = len(image_list) + + # Add more image tags, if needed + if num_images > num_tags_in_prompt: + image_tag = "\n" * (num_images - num_tags_in_prompt) + + text = f"USER: {image_tag}{text}\nASSISTANT:" res = {} res["text"] = text @@ -111,7 +128,7 @@ async def test_follow_up_query_service_builder_schedule(self): initial_inputs={"prompt": "chao, ", "image": "some image"} ) # print(result_dict) - self.assertEqual(result_dict[self.lvm.name]["text"], "\nUSER: chao, \nASSISTANT:") + self.assertEqual(result_dict[self.lvm.name]["text"], "USER: \nchao, \nASSISTANT:") def test_MultimodalQnAGateway_gateway(self): json_data = {"messages": "hello, "} @@ -141,7 +158,7 @@ def test_follow_up_MultimodalQnAGateway_gateway(self): response = response.json() self.assertEqual( response["choices"][-1]["message"]["content"], - "\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:", + "USER: \nhello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:", ) def test_handle_message(self): @@ -160,7 +177,7 @@ def test_handle_message(self): {"role": "user", "content": "chao, "}, ] prompt, images = self.gateway._handle_message(messages) - self.assertEqual(prompt, "hello, \nASSISTANT: opea project! \nUSER: chao, \n") + self.assertEqual(prompt, "\nhello, \nASSISTANT: opea project! \nUSER: chao, \n") def test_handle_message_with_system_prompt(self): messages = [ @@ -179,7 +196,7 @@ def test_handle_message_with_system_prompt(self): {"role": "user", "content": "chao, "}, ] prompt, images = self.gateway._handle_message(messages) - self.assertEqual(prompt, "System Prompt\nhello, \nASSISTANT: opea project! \nUSER: chao, \n") + self.assertEqual(prompt, "System Prompt\n\nhello, \nASSISTANT: opea project! \nUSER: chao, \n") async def test_handle_request(self): json_data = { @@ -205,9 +222,75 @@ async def test_handle_request(self): res = json.loads(res.json()) self.assertEqual( res["choices"][-1]["message"]["content"], - "\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:", + "USER: \nhello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:", ) + def test_interleaved_image_handle_message(self): + """ + This tests a back and forth conversation with images interleaved with different models that have different prompt + formats than the default LLaVA 1.5 model. + """ + + # Models to test and their expected prompts + model_names = ["llava-hf/llava-interleave-qwen-7b-hf", + "llava-hf/llava-v1.6-mistral-7b-hf", + "llava-hf/llava-v1.6-vicuna-7b-hf"] + expected_prompts = ["\nDescribe the image.\n<|im_end|><|im_start|>assistant \nIt is an image of a red apple with a green leaf\n<|im_start|>user \nIs this the same type of fruit?\n", + "\nDescribe the image.\n [/INST] \nIt is an image of a red apple with a green leaf\n[INST] \nIs this the same type of fruit?\n", + "\nDescribe the image.\nASSISTANT: \nIt is an image of a red apple with a green leaf\nUSER: \nIs this the same type of fruit?\n"] + gateway_port = 9988 + + for model_name, expected_prompt in zip(model_names, expected_prompts): + # Simulate running gateway with the specified model + lvm_model = os.environ["LVM_MODEL_ID"] = model_name + test_gateway = MultimodalQnAGateway(self.service_builder, self.follow_up_query_service_builder, port=gateway_port) + gateway_port += 1 + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", "text": "Describe the image." + }, + { + "type": "image_url", + "image_url": {"url": "https://raw.githubusercontent.com/docarray/docarray/refs/heads/main/tests/toydata/image-data/apple.png"}, + }, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", "text": "It is an image of a red apple with a green leaf" + }, + { + "type": "image_url", + "image_url": {"url": "https://raw.githubusercontent.com/docarray/docarray/refs/heads/main/tests/toydata/image-data/apple.png"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", "text": "Is this the same type of fruit?" + }, + { + "type": "image_url", + "image_url": {"url": "http://images.cocodataset.org/test-stuff2017/000000004248.jpg"}, + }, + ], + }, + ] + try: + prompt, images = test_gateway._handle_message(messages) + self.assertEqual(prompt, expected_prompt, + "The generated prompt does not match the expected prompt for {} \nActual:\n{}\nExpected:\n{}".format(model_name, repr(prompt), repr(expected_prompt))) + finally: + test_gateway.stop() + if __name__ == "__main__": unittest.main() From fa4795986291159996089003bcb7e5aff617bf55 Mon Sep 17 00:00:00 2001 From: Melanie Buehler Date: Thu, 5 Dec 2024 14:45:35 -0800 Subject: [PATCH 15/27] Adds unit test coverage for audio query Signed-off-by: Melanie Buehler --- .../cores/mega/test_multimodalqna_gateway.py | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/tests/cores/mega/test_multimodalqna_gateway.py b/tests/cores/mega/test_multimodalqna_gateway.py index c05bf57bdd..56b7dfee39 100644 --- a/tests/cores/mega/test_multimodalqna_gateway.py +++ b/tests/cores/mega/test_multimodalqna_gateway.py @@ -9,8 +9,10 @@ from fastapi import Request from comps import ( + Base64ByteStrDoc, EmbedDoc, EmbedMultimodalDoc, + LLMParamsDoc, LVMDoc, LVMSearchedMultimodalDoc, MultimodalDoc, @@ -72,21 +74,30 @@ async def lvm_add(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc: return res +@register_microservice(name="asr", host="0.0.0.0", port=3001, endpoint="/v1/audio/transcriptions") +async def asr_add(request: Base64ByteStrDoc) -> LLMParamsDoc: + req = request.model_dump_json() + res = {} + res['query'] = 'you' + return res + + class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase): @classmethod def setUpClass(cls): cls.mm_embedding = opea_microservices["mm_embedding"] cls.mm_retriever = opea_microservices["mm_retriever"] cls.lvm = opea_microservices["lvm"] + cls.asr = opea_microservices["asr"] cls.mm_embedding.start() cls.mm_retriever.start() cls.lvm.start() + cls.asr.start() cls.service_builder = ServiceOrchestrator() cls.service_builder.add(opea_microservices["mm_embedding"]).add(opea_microservices["mm_retriever"]).add( - opea_microservices["lvm"] - ) + opea_microservices["lvm"]).add(opea_microservices["asr"]) cls.service_builder.flow_to(cls.mm_embedding, cls.mm_retriever) cls.service_builder.flow_to(cls.mm_retriever, cls.lvm) @@ -100,6 +111,7 @@ def tearDownClass(cls): cls.mm_embedding.stop() cls.mm_retriever.stop() cls.lvm.stop() + cls.asr.stop() cls.gateway.stop() async def test_service_builder_schedule(self): @@ -181,6 +193,31 @@ def test_handle_message_with_system_prompt(self): prompt, images = self.gateway._handle_message(messages) self.assertEqual(prompt, "System Prompt\nhello, \nASSISTANT: opea project! \nUSER: chao, \n") + def test_handle_message_with_audio(self): + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "hello, " + } + ] + }, + {"role": "assistant", "content": "opea project! "}, + { + "role": "user", + "content": [ + { + "type": "audio", + "audio": "UklGRigAAABXQVZFZm10IBIAAAABAAEARKwAAIhYAQACABAAAABkYXRhAgAAAAEA" + } + ] + } + ] + prompt, images = self.gateway._handle_message(messages) + self.assertEqual(prompt, "hello, \nASSISTANT: opea project! \nUSER: you\n") + async def test_handle_request(self): json_data = { "messages": [ From 02efc8a5fe5ae3bdf2f4cbff101894751649a71e Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Thu, 5 Dec 2024 15:14:41 -0800 Subject: [PATCH 16/27] Update test to check the returned b64 types Signed-off-by: dmsuehir --- tests/cores/mega/test_multimodalqna_gateway.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/cores/mega/test_multimodalqna_gateway.py b/tests/cores/mega/test_multimodalqna_gateway.py index 035a508e82..e88be450de 100644 --- a/tests/cores/mega/test_multimodalqna_gateway.py +++ b/tests/cores/mega/test_multimodalqna_gateway.py @@ -267,7 +267,7 @@ def test_interleaved_image_handle_message(self): }, { "type": "image_url", - "image_url": {"url": "https://raw.githubusercontent.com/docarray/docarray/refs/heads/main/tests/toydata/image-data/apple.png"}, + "image_url": {"url": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC"}, }, ], }, @@ -285,9 +285,12 @@ def test_interleaved_image_handle_message(self): }, ] try: - prompt, images = test_gateway._handle_message(messages) + prompt, b64_types = test_gateway._handle_message(messages) self.assertEqual(prompt, expected_prompt, "The generated prompt does not match the expected prompt for {} \nActual:\n{}\nExpected:\n{}".format(model_name, repr(prompt), repr(expected_prompt))) + self.assertTrue("image" in b64_types.keys()) + self.assertFalse("audio" in b64_types.keys()) + self.assertEqual(len(b64_types["image"]), 3) finally: test_gateway.stop() From d74bb320a201102fe6cd381244af3e7136d27bab Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Thu, 5 Dec 2024 15:28:51 -0800 Subject: [PATCH 17/27] Update test since we don't expect images from the assistant Signed-off-by: dmsuehir --- comps/cores/mega/gateway.py | 6 +++--- tests/cores/mega/test_multimodalqna_gateway.py | 18 +++++------------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 7a22799e19..61fd3f1b10 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -947,11 +947,11 @@ def _handle_message(self, messages): prompt += image_tags + decoded_audio_input + "\n" else: if text: - prompt += role_label_dict[role] + image_tags + " " + text + "\n" + prompt += role_label_dict[role] + " " + image_tags + text + "\n" elif decoded_audio_input: - prompt += role_label_dict[role] + image_tags + " " + decoded_audio_input + "\n" + prompt += role_label_dict[role] + " " + image_tags + decoded_audio_input + "\n" else: - prompt += role_label_dict[role] + image_tags + prompt += role_label_dict[role] + " " + image_tags for img in image_list: # URL if img.startswith("http://") or img.startswith("https://"): diff --git a/tests/cores/mega/test_multimodalqna_gateway.py b/tests/cores/mega/test_multimodalqna_gateway.py index e88be450de..4f604c52a6 100644 --- a/tests/cores/mega/test_multimodalqna_gateway.py +++ b/tests/cores/mega/test_multimodalqna_gateway.py @@ -235,9 +235,9 @@ def test_interleaved_image_handle_message(self): model_names = ["llava-hf/llava-interleave-qwen-7b-hf", "llava-hf/llava-v1.6-mistral-7b-hf", "llava-hf/llava-v1.6-vicuna-7b-hf"] - expected_prompts = ["\nDescribe the image.\n<|im_end|><|im_start|>assistant \nIt is an image of a red apple with a green leaf\n<|im_start|>user \nIs this the same type of fruit?\n", - "\nDescribe the image.\n [/INST] \nIt is an image of a red apple with a green leaf\n[INST] \nIs this the same type of fruit?\n", - "\nDescribe the image.\nASSISTANT: \nIt is an image of a red apple with a green leaf\nUSER: \nIs this the same type of fruit?\n"] + expected_prompts = ["\nDescribe the image.\n<|im_end|><|im_start|>assistant It is an image of a red apple with a green leaf\n<|im_start|>user \nIs this the same type of fruit?\n", + "\nDescribe the image.\n [/INST] It is an image of a red apple with a green leaf\n[INST] \nIs this the same type of fruit?\n", + "\nDescribe the image.\nASSISTANT: It is an image of a red apple with a green leaf\nUSER: \nIs this the same type of fruit?\n"] gateway_port = 9988 for model_name, expected_prompt in zip(model_names, expected_prompts): @@ -261,15 +261,7 @@ def test_interleaved_image_handle_message(self): }, { "role": "assistant", - "content": [ - { - "type": "text", "text": "It is an image of a red apple with a green leaf" - }, - { - "type": "image_url", - "image_url": {"url": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC"}, - }, - ], + "content": "It is an image of a red apple with a green leaf" }, { "role": "user", @@ -290,7 +282,7 @@ def test_interleaved_image_handle_message(self): "The generated prompt does not match the expected prompt for {} \nActual:\n{}\nExpected:\n{}".format(model_name, repr(prompt), repr(expected_prompt))) self.assertTrue("image" in b64_types.keys()) self.assertFalse("audio" in b64_types.keys()) - self.assertEqual(len(b64_types["image"]), 3) + self.assertEqual(len(b64_types["image"]), 2) finally: test_gateway.stop() From 37826bea96107d3dde4d8f5254875757b8aef0fa Mon Sep 17 00:00:00 2001 From: Melanie Buehler Date: Fri, 6 Dec 2024 09:12:45 -0800 Subject: [PATCH 18/27] Port number fix Signed-off-by: Melanie Buehler --- tests/cores/mega/test_multimodalqna_gateway.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/cores/mega/test_multimodalqna_gateway.py b/tests/cores/mega/test_multimodalqna_gateway.py index 56b7dfee39..935c5c3683 100644 --- a/tests/cores/mega/test_multimodalqna_gateway.py +++ b/tests/cores/mega/test_multimodalqna_gateway.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +import os import unittest from typing import Union @@ -74,7 +75,7 @@ async def lvm_add(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc: return res -@register_microservice(name="asr", host="0.0.0.0", port=3001, endpoint="/v1/audio/transcriptions") +@register_microservice(name="asr", host="0.0.0.0", port=8086, endpoint="/v1/audio/transcriptions") async def asr_add(request: Base64ByteStrDoc) -> LLMParamsDoc: req = request.model_dump_json() res = {} @@ -97,13 +98,14 @@ def setUpClass(cls): cls.service_builder = ServiceOrchestrator() cls.service_builder.add(opea_microservices["mm_embedding"]).add(opea_microservices["mm_retriever"]).add( - opea_microservices["lvm"]).add(opea_microservices["asr"]) + opea_microservices["lvm"]) cls.service_builder.flow_to(cls.mm_embedding, cls.mm_retriever) cls.service_builder.flow_to(cls.mm_retriever, cls.lvm) cls.follow_up_query_service_builder = ServiceOrchestrator() cls.follow_up_query_service_builder.add(cls.lvm) + os.environ["ASR_SERVICE_PORT"] = "8086" cls.gateway = MultimodalQnAGateway(cls.service_builder, cls.follow_up_query_service_builder, port=9898) @classmethod @@ -215,8 +217,9 @@ def test_handle_message_with_audio(self): ] } ] - prompt, images = self.gateway._handle_message(messages) + prompt, b64_types = self.gateway._handle_message(messages) self.assertEqual(prompt, "hello, \nASSISTANT: opea project! \nUSER: you\n") + self.assertEqual(b64_types, {"audio": "you"}) async def test_handle_request(self): json_data = { From 40d34db000ea936cdcb9a4b16e05a51095fab5b9 Mon Sep 17 00:00:00 2001 From: Melanie Buehler Date: Fri, 6 Dec 2024 09:31:20 -0800 Subject: [PATCH 19/27] Formatting Signed-off-by: Melanie Buehler --- tests/cores/mega/test_multimodalqna_gateway.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/cores/mega/test_multimodalqna_gateway.py b/tests/cores/mega/test_multimodalqna_gateway.py index 935c5c3683..23918f2540 100644 --- a/tests/cores/mega/test_multimodalqna_gateway.py +++ b/tests/cores/mega/test_multimodalqna_gateway.py @@ -98,7 +98,8 @@ def setUpClass(cls): cls.service_builder = ServiceOrchestrator() cls.service_builder.add(opea_microservices["mm_embedding"]).add(opea_microservices["mm_retriever"]).add( - opea_microservices["lvm"]) + opea_microservices["lvm"] + ) cls.service_builder.flow_to(cls.mm_embedding, cls.mm_retriever) cls.service_builder.flow_to(cls.mm_retriever, cls.lvm) From a665c3c32f96ee04ef5463fc677d9281816ed8db Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:34:37 +0000 Subject: [PATCH 20/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../cores/mega/test_multimodalqna_gateway.py | 21 +++++-------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/tests/cores/mega/test_multimodalqna_gateway.py b/tests/cores/mega/test_multimodalqna_gateway.py index 23918f2540..42dd1973b0 100644 --- a/tests/cores/mega/test_multimodalqna_gateway.py +++ b/tests/cores/mega/test_multimodalqna_gateway.py @@ -79,7 +79,7 @@ async def lvm_add(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc: async def asr_add(request: Base64ByteStrDoc) -> LLMParamsDoc: req = request.model_dump_json() res = {} - res['query'] = 'you' + res["query"] = "you" return res @@ -198,25 +198,14 @@ def test_handle_message_with_system_prompt(self): def test_handle_message_with_audio(self): messages = [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "hello, " - } - ] - }, + {"role": "user", "content": [{"type": "text", "text": "hello, "}]}, {"role": "assistant", "content": "opea project! "}, { "role": "user", "content": [ - { - "type": "audio", - "audio": "UklGRigAAABXQVZFZm10IBIAAAABAAEARKwAAIhYAQACABAAAABkYXRhAgAAAAEA" - } - ] - } + {"type": "audio", "audio": "UklGRigAAABXQVZFZm10IBIAAAABAAEARKwAAIhYAQACABAAAABkYXRhAgAAAAEA"} + ], + }, ] prompt, b64_types = self.gateway._handle_message(messages) self.assertEqual(prompt, "hello, \nASSISTANT: opea project! \nUSER: you\n") From d9ab567ef2a71dd2ad457a2e26aec65537216196 Mon Sep 17 00:00:00 2001 From: Melanie Buehler Date: Fri, 6 Dec 2024 10:09:48 -0800 Subject: [PATCH 21/27] Fixed place where port number is set Signed-off-by: Melanie Buehler --- tests/cores/mega/test_multimodalqna_gateway.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/cores/mega/test_multimodalqna_gateway.py b/tests/cores/mega/test_multimodalqna_gateway.py index 42dd1973b0..9329a7ac9d 100644 --- a/tests/cores/mega/test_multimodalqna_gateway.py +++ b/tests/cores/mega/test_multimodalqna_gateway.py @@ -9,6 +9,8 @@ import requests from fastapi import Request +os.environ["ASR_SERVICE_PORT"] = "8086" + from comps import ( Base64ByteStrDoc, EmbedDoc, @@ -106,7 +108,6 @@ def setUpClass(cls): cls.follow_up_query_service_builder = ServiceOrchestrator() cls.follow_up_query_service_builder.add(cls.lvm) - os.environ["ASR_SERVICE_PORT"] = "8086" cls.gateway = MultimodalQnAGateway(cls.service_builder, cls.follow_up_query_service_builder, port=9898) @classmethod From 9a077c5ebd0cc8fe29e0b31d85ec34af8c4a7056 Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Mon, 9 Dec 2024 11:41:23 -0800 Subject: [PATCH 22/27] Remove old comment and added more accurate description Signed-off-by: dmsuehir --- comps/cores/mega/gateway.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 61fd3f1b10..68a7c0e014 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -1030,6 +1030,17 @@ def convert_audio_to_text(self, audio): return response["query"] async def handle_request(self, request: Request): + """ + MultimodalQnA accepts input queries as text, images, and/or audio. The messages in the request can be a single + message (which would be assumed to be a first query from the user) or back and forth conversation between the + user and the assistant. + Audio queries are converted to text before being sent to the megaservice and the translated text is returned + as part of the metadata in the response. + First queries are sent to the full Multimodal megaserivce, which includes using the embedding microservice and + retriever, in order to get relevant information from the vector store to send to the LVM along with the user's + query. Follow up queries are sent directly to the LVM without searching for more similar information from the + vector store. + """ data = await request.json() stream_opt = bool(data.get("stream", False)) if stream_opt: @@ -1037,8 +1048,6 @@ async def handle_request(self, request: Request): stream_opt = False chat_request = ChatCompletionRequest.model_validate(data) num_messages = len(data["messages"]) if isinstance(data["messages"], list) else 1 - - # Multimodal RAG QnA With Videos has not yet accepts image as input during QnA. messages = self._handle_message(chat_request.messages) decoded_audio_input = "" From b21e57572de1ee40f05ee33a02c4c191e839e39c Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Mon, 9 Dec 2024 22:17:39 +0000 Subject: [PATCH 23/27] add comment in code about MAX_IMAGES Signed-off-by: dmsuehir --- comps/lvms/llava/lvm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comps/lvms/llava/lvm.py b/comps/lvms/llava/lvm.py index 425576debd..9d7bde0f90 100644 --- a/comps/lvms/llava/lvm.py +++ b/comps/lvms/llava/lvm.py @@ -27,6 +27,8 @@ logger = CustomLogger("lvm") logflag = os.getenv("LOGFLAG", False) + +# The maximum number of images that should be sent to the LVM max_images = int(os.getenv("MAX_IMAGES", 1)) From a3abd8a743bfe89f3c8c07afe6d866bfefd43b15 Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Tue, 10 Dec 2024 23:42:40 +0000 Subject: [PATCH 24/27] Add Gaudi support for image query Signed-off-by: dmsuehir --- comps/lvms/tgi-llava/lvm_tgi.py | 47 +++++++++++++++---- .../lvms/test_lvms_tgi-llava_on_intel_hpu.sh | 35 ++++++++++++++ 2 files changed, 74 insertions(+), 8 deletions(-) diff --git a/comps/lvms/tgi-llava/lvm_tgi.py b/comps/lvms/tgi-llava/lvm_tgi.py index 38b492c395..04ceee400c 100644 --- a/comps/lvms/tgi-llava/lvm_tgi.py +++ b/comps/lvms/tgi-llava/lvm_tgi.py @@ -27,6 +27,9 @@ logger = CustomLogger("lvm_tgi") logflag = os.getenv("LOGFLAG", False) +# The maximum number of images that should be sent to the LVM +max_images = int(os.getenv("MAX_IMAGES", 1)) + @register_microservice( name="opea_service@lvm_tgi", @@ -88,15 +91,41 @@ async def lvm(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> Union[TextDoc top_k = request.top_k top_p = request.top_p - if not img_b64_str: - # Work around an issue where LLaVA-NeXT is not providing good responses when prompted without an image. - # Provide an image and then instruct the model to ignore the image. The base64 string below is the encoded png: - # https://raw.githubusercontent.com/opea-project/GenAIExamples/refs/tags/v1.0/AudioQnA/ui/svelte/src/lib/assets/icons/png/audio1.png - img_b64_str = "iVBORw0KGgoAAAANSUhEUgAAADUAAAAlCAYAAADiMKHrAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAKPSURBVHgB7Zl/btowFMefnUTqf+MAHYMTjN4gvcGOABpM+8E0doLSE4xpsE3rKuAG3KC5Ad0J6MYOkP07YnvvhR9y0lVzupTIVT5SwDjB9fd97WfsMkCef1rUXM8dY9HHK4hWUevzi/oVWAqnF8fzLmAtiPA3Aq0lFsVA1fRKxlgNLIbDPaQUZQuu6YO98aIipHOiFGtIqaYfn1UnUCDds6WPyeANlTFbv9WztbFTK+HNUVAPiz7nbPzq7HsPCoKWIBREGfsJXZit5xT07X0jp6iRdIbEHOnjyyD97OvzH00lVS2K5OS2ax11cBXxJgYxlEIE6XZclzdTX6n8XjkkcEIfbj2nMO0/SNd1vy4vsCNjYPyEovfyy88GZIQCSKOCMf6ORgStoboLJuSWKDYCfK2q4jjrMZ+GOh7Pib/gek5DHxVUJtcgA7mJ4kwZRbN7viQXFzQn0Nl52gXG4Fo7DKAYp0yI3VHQ16oaWV0wYa+iGE8nG+wAdx5DzpS/KGyhFGULpShbKEXZQinqLlBK/IKc2asoh4sZvoXJWhlAzuxV1KBVD3HrfYTFAK8ZHgu0hu36DHLG+Izinw250WUkXHJht02QUnxLP7fZxR7f1I6S7Ir2GgmYvIQM5OYUuYBdainATq2ZjTqPBlnbGXYeBrg9Od18DKmc1U0jpw4OIIwEJFxQSl2b4MN2lf74fw8nFNbHt/5N9xWKTZvJ2S6YZk6RC3j2cKpVhSIShZ0mea6caCOCAjyNHd5gPPxGncMBTvI6hunYdaJ6kf8VoSCP2odxX6RkR6NOtanfj13EswKVqEQrPzzFL1lK+YvCFraiEqs8TrwQLGYraqpX4kr/Hixml+63Z+CoM9DTo438AUmP+KyMWT+tAAAAAElFTkSuQmCC" - prompt = f"Please disregard the image and answer the question. {prompt}" + # Make img_b64_str into a list of strings (if it's not already a list) + if not isinstance(img_b64_str, list): + if img_b64_str: + img_b64_str = [img_b64_str] + else: + # If img_b64_str was an empty string, which means we have just have a text prompt. + # Work around an issue where LLaVA-NeXT is not providing good responses when prompted without an image. + # Provide an image and then instruct the model to ignore the image. The base64 string below is the encoded png: + # https://raw.githubusercontent.com/opea-project/GenAIExamples/refs/tags/v1.0/AudioQnA/ui/svelte/src/lib/assets/icons/png/audio1.png + img_b64_str = ["iVBORw0KGgoAAAANSUhEUgAAADUAAAAlCAYAAADiMKHrAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAKPSURBVHgB7Zl/btowFMefnUTqf+MAHYMTjN4gvcGOABpM+8E0doLSE4xpsE3rKuAG3KC5Ad0J6MYOkP07YnvvhR9y0lVzupTIVT5SwDjB9fd97WfsMkCef1rUXM8dY9HHK4hWUevzi/oVWAqnF8fzLmAtiPA3Aq0lFsVA1fRKxlgNLIbDPaQUZQuu6YO98aIipHOiFGtIqaYfn1UnUCDds6WPyeANlTFbv9WztbFTK+HNUVAPiz7nbPzq7HsPCoKWIBREGfsJXZit5xT07X0jp6iRdIbEHOnjyyD97OvzH00lVS2K5OS2ax11cBXxJgYxlEIE6XZclzdTX6n8XjkkcEIfbj2nMO0/SNd1vy4vsCNjYPyEovfyy88GZIQCSKOCMf6ORgStoboLJuSWKDYCfK2q4jjrMZ+GOh7Pib/gek5DHxVUJtcgA7mJ4kwZRbN7viQXFzQn0Nl52gXG4Fo7DKAYp0yI3VHQ16oaWV0wYa+iGE8nG+wAdx5DzpS/KGyhFGULpShbKEXZQinqLlBK/IKc2asoh4sZvoXJWhlAzuxV1KBVD3HrfYTFAK8ZHgu0hu36DHLG+Izinw250WUkXHJht02QUnxLP7fZxR7f1I6S7Ir2GgmYvIQM5OYUuYBdainATq2ZjTqPBlnbGXYeBrg9Od18DKmc1U0jpw4OIIwEJFxQSl2b4MN2lf74fw8nFNbHt/5N9xWKTZvJ2S6YZk6RC3j2cKpVhSIShZ0mea6caCOCAjyNHd5gPPxGncMBTvI6hunYdaJ6kf8VoSCP2odxX6RkR6NOtanfj13EswKVqEQrPzzFL1lK+YvCFraiEqs8TrwQLGYraqpX4kr/Hixml+63Z+CoM9DTo438AUmP+KyMWT+tAAAAAElFTkSuQmCC"] + prompt = f"Please disregard the image and answer the question. {prompt}" + + # Truncate the list of images if we have too many, only sending the most recent ones at the end of the list + if len(img_b64_str) > max_images: + img_b64_str=img_b64_str[-max_images:] - image = f"data:image/png;base64,{img_b64_str}" - image_prompt = f"![]({image})\n{prompt}\nASSISTANT:" + # Check the number of image tags in the prompt and adjust them to match the number of images that we have + image_tag = "\n" + num_tags_in_prompt = prompt.count(image_tag) + + # We have too many image tags in the prompt replace the first x instance of the tag with an empty string + if len(img_b64_str) < num_tags_in_prompt: + prompt = prompt.replace(image_tag, "", num_tags_in_prompt - len(img_b64_str)) + + # We don't have enough image tags in the prompt, add them + if len(img_b64_str) > num_tags_in_prompt: + num_tags_to_add = len(img_b64_str) - num_tags_in_prompt + tags_to_add = image_tag * num_tags_to_add + prompt = f"{tags_to_add}{prompt}" + + # Replace image tags with the data + for i in img_b64_str: + formatted_image_str = f"![](data:image/png;base64,{i})\n" + prompt = prompt.replace(image_tag, formatted_image_str, 1) + image_prompt = f"{prompt}\nASSISTANT:" if streaming: @@ -152,4 +181,6 @@ async def stream_generator(): lvm_endpoint = os.getenv("LVM_ENDPOINT", "http://localhost:8399") lvm_client = AsyncInferenceClient(lvm_endpoint) logger.info("[LVM] LVM initialized.") + if logflag: + logger.info(f"MAX_IMAGES: {max_images}") opea_microservices["opea_service@lvm_tgi"].start() diff --git a/tests/lvms/test_lvms_tgi-llava_on_intel_hpu.sh b/tests/lvms/test_lvms_tgi-llava_on_intel_hpu.sh index 1fa0155266..9d1a69a7ae 100644 --- a/tests/lvms/test_lvms_tgi-llava_on_intel_hpu.sh +++ b/tests/lvms/test_lvms_tgi-llava_on_intel_hpu.sh @@ -54,6 +54,41 @@ function validate_microservice() { echo "LVM prompt without image - HTTP status (successful)" fi + # Test sending two images with a text prompt with one image tag in the prompt. + # The first image is green and the second image is blue. Since the default MAX_IMAGES is 1, only the blue image should be sent to the LVM. + result=$(http_proxy="" curl http://localhost:$lvm_port/v1/lvm -XPOST -d '{"image": ["iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC", "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC"], "prompt":"\nWhat are in these images?"}' -H 'Content-Type: application/json') + if [[ $result == *"blue"* ]]; then + echo "Result correct." + else + echo "Result wrong." + docker logs test-comps-lvm-llava >> ${LOG_PATH}/llava-dependency.log + docker logs test-comps-lvm-llava-svc >> ${LOG_PATH}/llava-server.log + exit 1 + fi + + # Test sending two images with a text prompt without any image tags. + # The first image is blue and the second image is green. Since the default MAX_IMAGES is 1, only the green image should be sent to the LVM. + result=$(http_proxy="" curl http://localhost:$lvm_port/v1/lvm -XPOST -d '{"image": ["iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC", "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC"], "prompt":"What are in these images?"}' -H 'Content-Type: application/json') + if [[ $result == *"green"* ]]; then + echo "Result correct." + else + echo "Result wrong." + docker logs test-comps-lvm-llava >> ${LOG_PATH}/llava-dependency.log + docker logs test-comps-lvm-llava-svc >> ${LOG_PATH}/llava-server.log + exit 1 + fi + + # Same test as above, except including two image tags with the prompt to ensure the number of image tags is reconciled. + # The first image is blue and the second image is green. Since the default MAX_IMAGES is 1, only the green image should be sent to the LVM. + result=$(http_proxy="" curl http://localhost:$lvm_port/v1/lvm -XPOST -d '{"image": ["iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC", "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC"], "prompt":"\n\nWhat are in these images?"}' -H 'Content-Type: application/json') + if [[ $result == *"green"* ]]; then + echo "Result correct." + else + echo "Result wrong." + docker logs test-comps-lvm-llava >> ${LOG_PATH}/llava-dependency.log + docker logs test-comps-lvm-llava-svc >> ${LOG_PATH}/llava-server.log + exit 1 + fi } function stop_docker() { From 723f0c3136e020d196615c5a4500a7de72b5bdf1 Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Thu, 12 Dec 2024 14:55:53 -0800 Subject: [PATCH 25/27] Fix to pass the retrieved image last Signed-off-by: dmsuehir --- comps/retrievers/multimodal/redis/langchain/retriever_redis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comps/retrievers/multimodal/redis/langchain/retriever_redis.py b/comps/retrievers/multimodal/redis/langchain/retriever_redis.py index 363c54a516..a92d59aba2 100644 --- a/comps/retrievers/multimodal/redis/langchain/retriever_redis.py +++ b/comps/retrievers/multimodal/redis/langchain/retriever_redis.py @@ -72,7 +72,7 @@ async def retrieve( # If the input had an image, pass that through in the metadata along with the search result image if input.base64_image: if r.metadata["b64_img_str"]: - r.metadata["b64_img_str"] = [r.metadata["b64_img_str"], input.base64_image] + r.metadata["b64_img_str"] = [input.base64_image, r.metadata["b64_img_str"]] else: r.metadata["b64_img_str"] = input.base64_image metadata_list.append(r.metadata) From b1205f432d56cd50589bd9742b84c6f0a87589d0 Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Thu, 12 Dec 2024 15:01:40 -0800 Subject: [PATCH 26/27] Revert out gateway and gateway test code, due to its move to GenAIExamples Signed-off-by: dmsuehir --- comps/cores/mega/gateway.py | 189 +++--------------- .../cores/mega/test_multimodalqna_gateway.py | 120 +---------- 2 files changed, 35 insertions(+), 274 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 68a7c0e014..29642eea55 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -2,14 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import base64 -import json import os from io import BytesIO from typing import List, Union import requests from fastapi import File, Request, UploadFile -from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.responses import StreamingResponse from PIL import Image from ..proto.api_protocol import ( @@ -838,12 +837,8 @@ def parser_input(data, TypeClass, key): class MultimodalQnAGateway(Gateway): - asr_port = int(os.getenv("ASR_SERVICE_PORT", 3001)) - asr_endpoint = os.getenv("ASR_SERVICE_ENDPOINT", "http://0.0.0.0:{}/v1/audio/transcriptions".format(asr_port)) - def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", port=9999): self.lvm_megaservice = lvm_megaservice - self._role_labels = self._get_role_labels() super().__init__( multimodal_rag_megaservice, host, @@ -853,46 +848,16 @@ def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", ChatCompletionResponse, ) - def _get_role_labels(self): - """ - Returns a dictionary of role labels that are used in the chat prompt based on the LVM_MODEL_ID - environment variable. The function defines the role labels used by the llava-1.5, llava-v1.6-vicuna, - llava-v1.6-mistral, and llava-interleave models, and then defaults to use "USER:" and "ASSISTANT:" if the - LVM_MODEL_ID is not one of those. - """ - lvm_model = os.getenv("LVM_MODEL_ID", "") - - # Default to labels used by llava-1.5 and llava-v1.6-vicuna models - role_labels = { - "user": "USER:", - "assistant": "ASSISTANT:" - } - - if "llava-interleave" in lvm_model: - role_labels["user"] = "<|im_start|>user" - role_labels["assistant"] = "<|im_end|><|im_start|>assistant" - elif "llava-v1.6-mistral" in lvm_model: - role_labels["user"] = "[INST]" - role_labels["assistant"] = " [/INST]" - elif "llava-1.5" not in lvm_model and "llava-v1.6-vicuna" not in lvm_model: - print(f"[ MultimodalQnAGateway ] Using default role labels for prompt formatting: {role_labels}") - - return role_labels - # this overrides _handle_message method of Gateway def _handle_message(self, messages): images = [] - audios = [] - b64_types = {} messages_dicts = [] - decoded_audio_input = "" if isinstance(messages, str): prompt = messages else: messages_dict = {} system_prompt = "" prompt = "" - role_label_dict = self._role_labels for message in messages: msg_role = message["role"] messages_dict = {} @@ -900,26 +865,16 @@ def _handle_message(self, messages): system_prompt = message["content"] elif msg_role == "user": if type(message["content"]) == list: - # separate each media type and store accordingly text = "" text_list = [item["text"] for item in message["content"] if item["type"] == "text"] text += "\n".join(text_list) image_list = [ item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url" ] - audios = [item["audio"] for item in message["content"] if item["type"] == "audio"] - if audios: - # translate audio to text. From this point forward, audio is treated like text - decoded_audio_input = self.convert_audio_to_text(audios) - b64_types["audio"] = decoded_audio_input - - if text and not audios and not image_list: - messages_dict[msg_role] = text - elif audios and not text and not image_list: - messages_dict[msg_role] = decoded_audio_input + if image_list: + messages_dict[msg_role] = (text, image_list) else: - messages_dict[msg_role] = (text, decoded_audio_input, image_list) - + messages_dict[msg_role] = text else: messages_dict[msg_role] = message["content"] messages_dicts.append(messages_dict) @@ -928,30 +883,23 @@ def _handle_message(self, messages): messages_dicts.append(messages_dict) else: raise ValueError(f"Unknown role: {msg_role}") + if system_prompt: prompt = system_prompt + "\n" - for i, messages_dict in enumerate(messages_dicts): - for role, message in messages_dict.items(): + for messages_dict in messages_dicts: + for i, (role, message) in enumerate(messages_dict.items()): if isinstance(message, tuple): - text, decoded_audio_input, image_list = message - # Remove empty items from the image list - image_list = [x for x in image_list if x] - # Add image indicators within the conversation - image_tags = "\n" * len(image_list) + text, image_list = message if i == 0: # do not add role for the very first message. # this will be added by llava_server if text: - prompt += image_tags + text + "\n" - elif decoded_audio_input: - prompt += image_tags + decoded_audio_input + "\n" + prompt += text + "\n" else: if text: - prompt += role_label_dict[role] + " " + image_tags + text + "\n" - elif decoded_audio_input: - prompt += role_label_dict[role] + " " + image_tags + decoded_audio_input + "\n" + prompt += role.upper() + ": " + text + "\n" else: - prompt += role_label_dict[role] + " " + image_tags + prompt += role.upper() + ":" for img in image_list: # URL if img.startswith("http://") or img.startswith("https://"): @@ -970,115 +918,42 @@ def _handle_message(self, messages): else: img_b64_str = img - if image_list: - for img in image_list: - # URL - if img.startswith("http://") or img.startswith("https://"): - response = requests.get(img) - image = Image.open(BytesIO(response.content)).convert("RGBA") - image_bytes = BytesIO() - image.save(image_bytes, format="PNG") - img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() - # Local Path - elif os.path.exists(img): - image = Image.open(img).convert("RGBA") - image_bytes = BytesIO() - image.save(image_bytes, format="PNG") - img_b64_str = base64.b64encode(image_bytes.getvalue()).decode() - # Bytes - else: - img_b64_str = img - - images.append(img_b64_str) - - elif isinstance(message, str): + images.append(img_b64_str) + else: if i == 0: # do not add role for the very first message. # this will be added by llava_server if message: - prompt += message + "\n" + prompt += role.upper() + ": " + message + "\n" else: if message: - prompt += role_label_dict[role] + " " + message + "\n" + prompt += role.upper() + ": " + message + "\n" else: - prompt += role_label_dict[role] - + prompt += role.upper() + ":" if images: - b64_types["image"] = images - - # If the query has multiple media types, return all types - if prompt and b64_types: - return prompt, b64_types + return prompt, images else: return prompt - def convert_audio_to_text(self, audio): - # translate audio to text by passing in dictionary to ASR - if isinstance(audio, dict): - input_dict = {"byte_str": audio["audio"][0]} - else: - input_dict = {"byte_str": audio[0]} - - response = requests.post(self.asr_endpoint, data=json.dumps(input_dict), proxies={"http": None}) - - if response.status_code != 200: - return JSONResponse( - status_code=503, content={"message": "Unable to convert audio to text. {}".format(response.text)} - ) - - response = response.json() - return response["query"] - async def handle_request(self, request: Request): - """ - MultimodalQnA accepts input queries as text, images, and/or audio. The messages in the request can be a single - message (which would be assumed to be a first query from the user) or back and forth conversation between the - user and the assistant. - Audio queries are converted to text before being sent to the megaservice and the translated text is returned - as part of the metadata in the response. - First queries are sent to the full Multimodal megaserivce, which includes using the embedding microservice and - retriever, in order to get relevant information from the vector store to send to the LVM along with the user's - query. Follow up queries are sent directly to the LVM without searching for more similar information from the - vector store. - """ data = await request.json() stream_opt = bool(data.get("stream", False)) if stream_opt: print("[ MultimodalQnAGateway ] stream=True not used, this has not support streaming yet!") stream_opt = False chat_request = ChatCompletionRequest.model_validate(data) - num_messages = len(data["messages"]) if isinstance(data["messages"], list) else 1 - messages = self._handle_message(chat_request.messages) - decoded_audio_input = "" - - if num_messages > 1: - # This is a follow up query, go to LVM + # Multimodal RAG QnA With Videos has not yet accepts image as input during QnA. + prompt_and_image = self._handle_message(chat_request.messages) + if isinstance(prompt_and_image, tuple): + # print(f"This request include image, thus it is a follow-up query. Using lvm megaservice") + prompt, images = prompt_and_image cur_megaservice = self.lvm_megaservice - if isinstance(messages, tuple): - prompt, b64_types = messages - if "audio" in b64_types: - # for metadata storage purposes - decoded_audio_input = b64_types["audio"] - if "image" in b64_types: - initial_inputs = {"prompt": prompt, "image": b64_types["image"]} - else: - initial_inputs = {"prompt": prompt, "image": ""} - else: - prompt = messages - initial_inputs = {"prompt": prompt, "image": ""} + initial_inputs = {"prompt": prompt, "image": images[0]} else: - # This is the first query. Ignore image input + # print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice") + prompt = prompt_and_image cur_megaservice = self.megaservice - if isinstance(messages, tuple): - prompt, b64_types = messages - initial_inputs = {"text": prompt} - if "audio" in b64_types: - # for metadata storage purposes - decoded_audio_input = b64_types["audio"] - if "image" in b64_types and len(b64_types["image"]) > 0: - initial_inputs["image"] = {"base64_image": b64_types["image"][0]} - else: - initial_inputs = {"text": messages} + initial_inputs = {"text": prompt} parameters = LLMParams( max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, @@ -1110,24 +985,18 @@ async def handle_request(self, request: Request): if "text" in result_dict[last_node].keys(): response = result_dict[last_node]["text"] else: - # text is not in response message + # text in not response message # something wrong, for example due to empty retrieval results if "detail" in result_dict[last_node].keys(): response = result_dict[last_node]["detail"] else: - response = "The server failed to generate an answer to your query!" + response = "The server fail to generate answer to your query!" if "metadata" in result_dict[last_node].keys(): # from retrieval results metadata = result_dict[last_node]["metadata"] - if decoded_audio_input: - metadata["audio"] = decoded_audio_input else: # follow-up question, no retrieval - if decoded_audio_input: - metadata = {"audio": decoded_audio_input} - else: - metadata = None - + metadata = None choices = [] usage = UsageInfo() choices.append( diff --git a/tests/cores/mega/test_multimodalqna_gateway.py b/tests/cores/mega/test_multimodalqna_gateway.py index f3714d0d23..c05bf57bdd 100644 --- a/tests/cores/mega/test_multimodalqna_gateway.py +++ b/tests/cores/mega/test_multimodalqna_gateway.py @@ -2,20 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import json -import os import unittest from typing import Union import requests from fastapi import Request -os.environ["ASR_SERVICE_PORT"] = "8086" - from comps import ( - Base64ByteStrDoc, EmbedDoc, EmbedMultimodalDoc, - LLMParamsDoc, LVMDoc, LVMSearchedMultimodalDoc, MultimodalDoc, @@ -70,48 +65,22 @@ async def lvm_add(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc: else: print("request is from user.") text = req_dict["prompt"] - image_tag = "" - - # There may already be image tags interleaved within the prompt. The LVM service checks that and - # adds image tag(s) if they are needed. - if "image" in req_dict.keys(): - num_tags_in_prompt = text.count("\n") - if isinstance(req_dict["image"], list): - image_list = req_dict["image"] - else: - image_list = [req_dict["image"]] - num_images = len(image_list) - - # Add more image tags, if needed - if num_images > num_tags_in_prompt: - image_tag = "\n" * (num_images - num_tags_in_prompt) - - text = f"USER: {image_tag}{text}\nASSISTANT:" + text = f"\nUSER: {text}\nASSISTANT:" res = {} res["text"] = text return res -@register_microservice(name="asr", host="0.0.0.0", port=8086, endpoint="/v1/audio/transcriptions") -async def asr_add(request: Base64ByteStrDoc) -> LLMParamsDoc: - req = request.model_dump_json() - res = {} - res["query"] = "you" - return res - - class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase): @classmethod def setUpClass(cls): cls.mm_embedding = opea_microservices["mm_embedding"] cls.mm_retriever = opea_microservices["mm_retriever"] cls.lvm = opea_microservices["lvm"] - cls.asr = opea_microservices["asr"] cls.mm_embedding.start() cls.mm_retriever.start() cls.lvm.start() - cls.asr.start() cls.service_builder = ServiceOrchestrator() @@ -131,7 +100,6 @@ def tearDownClass(cls): cls.mm_embedding.stop() cls.mm_retriever.stop() cls.lvm.stop() - cls.asr.stop() cls.gateway.stop() async def test_service_builder_schedule(self): @@ -143,7 +111,7 @@ async def test_follow_up_query_service_builder_schedule(self): initial_inputs={"prompt": "chao, ", "image": "some image"} ) # print(result_dict) - self.assertEqual(result_dict[self.lvm.name]["text"], "USER: \nchao, \nASSISTANT:") + self.assertEqual(result_dict[self.lvm.name]["text"], "\nUSER: chao, \nASSISTANT:") def test_MultimodalQnAGateway_gateway(self): json_data = {"messages": "hello, "} @@ -173,7 +141,7 @@ def test_follow_up_MultimodalQnAGateway_gateway(self): response = response.json() self.assertEqual( response["choices"][-1]["message"]["content"], - "USER: \nhello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:", + "\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:", ) def test_handle_message(self): @@ -192,7 +160,7 @@ def test_handle_message(self): {"role": "user", "content": "chao, "}, ] prompt, images = self.gateway._handle_message(messages) - self.assertEqual(prompt, "\nhello, \nASSISTANT: opea project! \nUSER: chao, \n") + self.assertEqual(prompt, "hello, \nASSISTANT: opea project! \nUSER: chao, \n") def test_handle_message_with_system_prompt(self): messages = [ @@ -211,22 +179,7 @@ def test_handle_message_with_system_prompt(self): {"role": "user", "content": "chao, "}, ] prompt, images = self.gateway._handle_message(messages) - self.assertEqual(prompt, "System Prompt\n\nhello, \nASSISTANT: opea project! \nUSER: chao, \n") - - def test_handle_message_with_audio(self): - messages = [ - {"role": "user", "content": [{"type": "text", "text": "hello, "}]}, - {"role": "assistant", "content": "opea project! "}, - { - "role": "user", - "content": [ - {"type": "audio", "audio": "UklGRigAAABXQVZFZm10IBIAAAABAAEARKwAAIhYAQACABAAAABkYXRhAgAAAAEA"} - ], - }, - ] - prompt, b64_types = self.gateway._handle_message(messages) - self.assertEqual(prompt, "hello, \nASSISTANT: opea project! \nUSER: you\n") - self.assertEqual(b64_types, {"audio": "you"}) + self.assertEqual(prompt, "System Prompt\nhello, \nASSISTANT: opea project! \nUSER: chao, \n") async def test_handle_request(self): json_data = { @@ -252,70 +205,9 @@ async def test_handle_request(self): res = json.loads(res.json()) self.assertEqual( res["choices"][-1]["message"]["content"], - "USER: \nhello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:", + "\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:", ) - def test_interleaved_image_handle_message(self): - """ - This tests a back and forth conversation with images interleaved with different models that have different prompt - formats than the default LLaVA 1.5 model. - """ - - # Models to test and their expected prompts - model_names = ["llava-hf/llava-interleave-qwen-7b-hf", - "llava-hf/llava-v1.6-mistral-7b-hf", - "llava-hf/llava-v1.6-vicuna-7b-hf"] - expected_prompts = ["\nDescribe the image.\n<|im_end|><|im_start|>assistant It is an image of a red apple with a green leaf\n<|im_start|>user \nIs this the same type of fruit?\n", - "\nDescribe the image.\n [/INST] It is an image of a red apple with a green leaf\n[INST] \nIs this the same type of fruit?\n", - "\nDescribe the image.\nASSISTANT: It is an image of a red apple with a green leaf\nUSER: \nIs this the same type of fruit?\n"] - gateway_port = 9988 - - for model_name, expected_prompt in zip(model_names, expected_prompts): - # Simulate running gateway with the specified model - lvm_model = os.environ["LVM_MODEL_ID"] = model_name - test_gateway = MultimodalQnAGateway(self.service_builder, self.follow_up_query_service_builder, port=gateway_port) - gateway_port += 1 - - messages = [ - { - "role": "user", - "content": [ - { - "type": "text", "text": "Describe the image." - }, - { - "type": "image_url", - "image_url": {"url": "https://raw.githubusercontent.com/docarray/docarray/refs/heads/main/tests/toydata/image-data/apple.png"}, - }, - ], - }, - { - "role": "assistant", - "content": "It is an image of a red apple with a green leaf" - }, - { - "role": "user", - "content": [ - { - "type": "text", "text": "Is this the same type of fruit?" - }, - { - "type": "image_url", - "image_url": {"url": "http://images.cocodataset.org/test-stuff2017/000000004248.jpg"}, - }, - ], - }, - ] - try: - prompt, b64_types = test_gateway._handle_message(messages) - self.assertEqual(prompt, expected_prompt, - "The generated prompt does not match the expected prompt for {} \nActual:\n{}\nExpected:\n{}".format(model_name, repr(prompt), repr(expected_prompt))) - self.assertTrue("image" in b64_types.keys()) - self.assertFalse("audio" in b64_types.keys()) - self.assertEqual(len(b64_types["image"]), 2) - finally: - test_gateway.stop() - if __name__ == "__main__": unittest.main() From bac117a01cf7a964dea86d9553182c848969185e Mon Sep 17 00:00:00 2001 From: dmsuehir Date: Fri, 13 Dec 2024 11:28:04 -0800 Subject: [PATCH 27/27] Fix retriever test for checking for b64_img_str in the result Signed-off-by: dmsuehir --- .../test_retrievers_multimodal_redis_langchain.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/retrievers/test_retrievers_multimodal_redis_langchain.sh b/tests/retrievers/test_retrievers_multimodal_redis_langchain.sh index bd256e6e05..06fecec69d 100644 --- a/tests/retrievers/test_retrievers_multimodal_redis_langchain.sh +++ b/tests/retrievers/test_retrievers_multimodal_redis_langchain.sh @@ -67,10 +67,10 @@ function validate_microservice() { if echo "$CONTENT" | grep -q "retrieved_docs"; then echo "[ retriever ] Content has retrieved_docs as expected." - if echo "$CONTENT" | grep -q "retrieved_docs"; then - echo "[ retriever ] Content has img_b64_str as expected." + if echo "$CONTENT" | grep -q "b64_img_str"; then + echo "[ retriever ] Content has b64_img_str as expected." else - echo "[ retriever ] Content does not include the img_b64_str: $CONTENT" + echo "[ retriever ] Content does not include the b64_img_str: $CONTENT" docker logs test-comps-retriever-multimodal-redis >> ${LOG_PATH}/retriever.log exit 1 fi