From c96c1a478debbf7c63c0b00a1cf39d9f713e922c Mon Sep 17 00:00:00 2001 From: akshaypardhanani Date: Wed, 13 Aug 2025 20:52:28 +0100 Subject: [PATCH 1/2] fix: parsing vlm output when it contains multiple json objects --- amadeusgpt/analysis_objects/llm.py | 30 ++++++++++++++++++++----- amadeusgpt/system_prompts/visual_llm.py | 9 ++++++++ pyproject.toml | 1 + tests/test_superanimal.py | 2 +- 4 files changed, 36 insertions(+), 6 deletions(-) diff --git a/amadeusgpt/analysis_objects/llm.py b/amadeusgpt/analysis_objects/llm.py index a7421aa..2d7d021 100644 --- a/amadeusgpt/analysis_objects/llm.py +++ b/amadeusgpt/analysis_objects/llm.py @@ -10,8 +10,10 @@ import numpy as np import openai from openai import OpenAI +from pydantic import ValidationError from amadeusgpt.programs.sandbox import Sandbox +from amadeusgpt.system_prompts.visual_llm import VlmInferenceOutput from amadeusgpt.utils import AmadeusLogger, QA_Message, create_qa_message from amadeusgpt.utils.openai_adapter import OpenAIAdapter @@ -267,13 +269,31 @@ def speak(self, sandbox: Sandbox, image: np.ndarray): print("description of the image frame provided") print(text) + thinking_pattern = r'.*?' + output_text = re.sub(thinking_pattern, '', text, flags=re.DOTALL) + + print(f"output text after removing thinking: {output_text}") + pattern = r"```json(.*?)```" - if len(re.findall(pattern, text, re.DOTALL)) == 0: - raise ValueError("can't parse the json string correctly", text) + if len(re.findall(pattern, output_text, re.DOTALL)) == 0: + raise ValueError("can't parse the json string correctly", output_text) else: - json_string = re.findall(pattern, text, re.DOTALL)[0] - json_obj = json.loads(json_string) - return json_obj + results = [] + for response_json in re.findall(pattern, output_text, re.DOTALL): + try: + json_obj = json.loads(response_json) + VlmInferenceOutput.model_validate(json_obj) + results.append(json_obj) + except ValidationError as val_err: + print(f"Couldn't validate the json string correctly for {response_json}", val_err) + except Exception as e: + print(f"Couldn't parse the json string correctly for {response_json}", e) + raise e + if len(results) == 0: + raise ValueError("can't parse the json string correctly", output_text) + elif len(results) > 1: + print("WARNING!! Found multiple json strings. Returning only the first", results) + return results[0] class CodeGenerationLLM(LLM): diff --git a/amadeusgpt/system_prompts/visual_llm.py b/amadeusgpt/system_prompts/visual_llm.py index 7637a7c..a1945b7 100644 --- a/amadeusgpt/system_prompts/visual_llm.py +++ b/amadeusgpt/system_prompts/visual_llm.py @@ -1,3 +1,12 @@ +from pydantic import BaseModel +from typing import List, Literal + +class VlmInferenceOutput(BaseModel): + description: str + individuals: int + species: Literal["topview_mouse", "sideview_quadruped", "others"] + background_objects: List[str] + def _get_system_prompt(): system_prompt = """ Describe what you see in the image and fill in the following json string: diff --git a/pyproject.toml b/pyproject.toml index e9f44dc..fb6ac70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "matplotlib<3.9", "openai>=1.0", "opencv-python-headless>=4.11.0.86", + "pydantic>=2.11.7", "pyyaml>=6.0.2", "sentence-transformers>=5.1.0", "streamlit>=1.26.0", diff --git a/tests/test_superanimal.py b/tests/test_superanimal.py index 03a738d..d05a4e6 100644 --- a/tests/test_superanimal.py +++ b/tests/test_superanimal.py @@ -7,7 +7,7 @@ def test_superanimal(): # the dummy video only contains 6 frames. kwargs = { 'video_info.scene_frame_number': 1, - 'llm_info.gpt_model': "gpt-4o" + 'llm_info.gpt_model': "moonshotai/kimi-vl-a3b-thinking:free" } data_folder = "examples/DummyVideo" result_folder = "temp_result_folder" From 7f6ac93945acc7877d3bf227cac955d83a5d86ef Mon Sep 17 00:00:00 2001 From: akshaypardhanani Date: Wed, 13 Aug 2025 21:24:05 +0100 Subject: [PATCH 2/2] fix: change model for superanimal since kimi activates too few parameters to detect the mouse --- tests/test_superanimal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_superanimal.py b/tests/test_superanimal.py index d05a4e6..c2cc9ff 100644 --- a/tests/test_superanimal.py +++ b/tests/test_superanimal.py @@ -7,7 +7,7 @@ def test_superanimal(): # the dummy video only contains 6 frames. kwargs = { 'video_info.scene_frame_number': 1, - 'llm_info.gpt_model': "moonshotai/kimi-vl-a3b-thinking:free" + 'llm_info.gpt_model': "qwen/qwen2.5-vl-72b-instruct:free" } data_folder = "examples/DummyVideo" result_folder = "temp_result_folder"