Skip to content

Commit 588494c

Browse files
authored
feat(contrib.som, postinstall): add contrib.som; postinstall.py; other fixes (#679)
* add vision.get_similar_image_idxs and test_vision.py * refactor fixtures * add get_size_similarity and tests; update default params in adapters.ultralytics * get_image_similarity rgb * add caching; win_size * short_circuit_ssim; plot_similar_image_groups * refactor: adding plotting.py * fix test; typo * bug fixes; title_params * add tests/openadapt/adapters * bugfix: imports * modify prompt adapter API to accept images instead of base64 images * add experiments/handle_similar_segments.py * add highlight_masks; increase_contrast * add experiments/gpt4o_seg.py; normalize_positions; clean_data; filter_keys; modify get_scale_ratios to accept no arguments; modify crop_active_window to accept window event; fix bugs * ultralytics sam * add experiments/nms.py * add fastsamsom.py, visualizer.py * working fastsamsom.py * cleanup * black; flake8; postinstall.py; test_vision.py * restore utils.get_performance_plot_file_path/get_performance_plot_file_path * fix ActionEvent.from_dict * handle HTTPError in test_openai * utils.split_by_separators * other fixes * remove list python files step * consolidate install-dashboard and postinstall * install_dashboard from dashboard_dir * postinstall error logging
1 parent 4d357f0 commit 588494c

31 files changed

+3400
-612
lines changed

.flake8

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
[flake8]
2-
exclude =
3-
alembic/versions,
4-
.venv
2+
exclude = alembic,.venv,venv,contrib,.cache,.git
53
docstring-convention = google
64
max-line-length = 88
75
extend-ignore = ANN101, E203

.github/workflows/main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
if: steps.cache-deps.outputs.cache-hit == 'true'
6161

6262
- name: Check formatting with Black
63-
run: poetry run black --preview --check . --exclude '/(alembic|\.venv)/'
63+
run: poetry run black --preview --check . --exclude '/(alembic|\.cache|\.venv|venv|contrib|__pycache__)/'
6464

6565
- name: Run Flake8
66-
run: poetry run flake8 --exclude=alembic,.venv,*/.cache
66+
run: poetry run flake8 --exclude=alembic,.venv,venv,contrib,.cache,.git

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ pip3 install poetry
9999
poetry install
100100
poetry shell
101101
poetry run install-dashboard
102+
poetry run postinstall
102103
cd openadapt && alembic upgrade head && cd ..
103104
pytest
104105
```

experiments/fastsamsom.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""SoM with Ultralytics FastSAM."""
2+
3+
from pprint import pformat
4+
5+
from loguru import logger
6+
from PIL import Image
7+
import numpy as np
8+
9+
from openadapt import adapters, config, contrib, utils, vision
10+
11+
12+
CONTRAST_FACTOR = 10000
13+
DEBUG = False
14+
15+
16+
def main() -> None:
17+
"""Main."""
18+
image_file_path = config.ROOT_DIR_PATH / "../tests/assets/excel.png"
19+
image = Image.open(image_file_path)
20+
if DEBUG:
21+
image.show()
22+
23+
image_contrasted = utils.increase_contrast(image, CONTRAST_FACTOR)
24+
if DEBUG:
25+
image_contrasted.show()
26+
27+
segmentation_adapter = adapters.get_default_segmentation_adapter()
28+
segmented_image = segmentation_adapter.fetch_segmented_image(image)
29+
if DEBUG:
30+
segmented_image.show()
31+
32+
masks = vision.get_masks_from_segmented_image(segmented_image, sort_by_area=True)
33+
# refined_masks = vision.refine_masks(masks)
34+
35+
image_arr = np.asarray(image)
36+
37+
# https://github.com/microsoft/SoM/blob/main/task_adapter/sam/tasks/inference_sam_m2m_auto.py
38+
# metadata = MetadataCatalog.get('coco_2017_train_panoptic')
39+
metadata = None
40+
visual = contrib.som.visualizer.Visualizer(image_arr, metadata=metadata)
41+
mask_map = np.zeros(image_arr.shape, dtype=np.uint8)
42+
label_mode = "1"
43+
alpha = 0.1
44+
anno_mode = [
45+
"Mask",
46+
# 'Mark',
47+
]
48+
for i, mask in enumerate(masks):
49+
label = i + 1
50+
demo = visual.draw_binary_mask_with_number(
51+
mask,
52+
text=str(label),
53+
label_mode=label_mode,
54+
alpha=alpha,
55+
anno_mode=anno_mode,
56+
)
57+
mask_map[mask == 1] = label
58+
59+
im = demo.get_image()
60+
image_som = Image.fromarray(im)
61+
image_som.show()
62+
63+
results = []
64+
65+
prompt_adapter = adapters.get_default_prompt_adapter()
66+
text = (
67+
"What are the values of the dates in the leftmost column? What about the"
68+
" horizontal column headings?"
69+
)
70+
output = prompt_adapter.prompt(
71+
text,
72+
images=[
73+
# no marks seem to perform just as well as with marks on spreadsheets
74+
# image_som,
75+
image,
76+
],
77+
)
78+
logger.info(output)
79+
results.append((text, output))
80+
81+
text = "\n".join(
82+
[
83+
(
84+
"Consider the dates along the leftmost column and the horizontal"
85+
" column headings:"
86+
),
87+
output,
88+
"What are the values in the corresponding cells?",
89+
]
90+
)
91+
output = prompt_adapter.prompt(text, images=[image_som])
92+
logger.info(output)
93+
results.append((text, output))
94+
95+
text = "What are the contents of cells A2, B2, and C2?"
96+
output = prompt_adapter.prompt(text, images=[image_som])
97+
logger.info(output)
98+
results.append((text, output))
99+
100+
logger.info(f"results=\n{pformat(results)}")
101+
102+
103+
if __name__ == "__main__":
104+
main()

experiments/gpt4o_seg.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""Generate segmentations directly with LLM."""
2+
3+
from pprint import pformat
4+
import os
5+
import sys
6+
import time
7+
8+
from loguru import logger
9+
from PIL import Image
10+
11+
from openadapt import cache, config, models, plotting, utils
12+
from openadapt.adapters import openai
13+
14+
15+
@cache.cache(force_refresh=False)
16+
def get_window_image(window_search_str: str) -> tuple:
17+
"""Get window image."""
18+
logger.info(f"Waiting for window with title containing {window_search_str=}...")
19+
while True:
20+
window_event = models.WindowEvent.get_active_window_event()
21+
window_title = window_event.title
22+
if window_search_str.lower() in window_title.lower():
23+
logger.info(f"found {window_title=}")
24+
break
25+
time.sleep(0.1)
26+
27+
screenshot = models.Screenshot.take_screenshot()
28+
image = screenshot.crop_active_window(window_event=window_event)
29+
return window_event, image
30+
31+
32+
def main(window_search_str: str | None) -> None:
33+
"""Main."""
34+
if window_search_str:
35+
window_event, image = get_window_image(window_search_str)
36+
window_dict = window_event.to_prompt_dict()
37+
window_dict = utils.normalize_positions(
38+
window_dict, -window_event.left, -window_event.top
39+
)
40+
else:
41+
image_file_path = os.path.join(
42+
config.ROOT_DIR_PATH, "../tests/assets/calculator.png"
43+
)
44+
image = Image.open(image_file_path)
45+
window_dict = None
46+
47+
system_prompt = utils.render_template_from_file(
48+
"prompts/system.j2",
49+
)
50+
51+
if window_dict:
52+
window_prompt = (
53+
f"Consider the corresponding window state:\n```{pformat(window_dict)}```"
54+
)
55+
else:
56+
window_prompt = ""
57+
58+
prompt = f"""You are a master GUI understander.
59+
Your task is to locate all interactable elements in the supplied screenshot.
60+
{window_prompt}
61+
Return JSON containing an array of segments with the following properties:
62+
- "name": a unique identifier
63+
- "description": enough context to be able to differentiate between similar segments
64+
- "top": top coordinate of bounding box
65+
- "left": left coordinate of bounding box
66+
- "width": width of bouding box
67+
- "height": height of bounding box
68+
Provide as much detail as possible. My career depends on this. Lives are at stake.
69+
Respond with JSON ONLY AND NOTHING ELSE.
70+
"""
71+
72+
result = openai.prompt(
73+
prompt,
74+
system_prompt,
75+
[image],
76+
)
77+
segment_dict = utils.parse_code_snippet(result)
78+
plotting.plot_segments(image, segment_dict)
79+
80+
window_dict = window_event.to_prompt_dict()
81+
import ipdb
82+
83+
ipdb.set_trace()
84+
85+
86+
if __name__ == "__main__":
87+
main(sys.argv[1])
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""Handle similar segments."""
2+
3+
import os
4+
5+
from PIL import Image
6+
from loguru import logger
7+
8+
from openadapt import adapters, cache, config, plotting, utils, vision
9+
10+
11+
DEBUG = True
12+
MIN_SEGMENT_SSIM = 0.95 # threshold for considering segments structurally similar
13+
MIN_SEGMENT_SIZE_SIM = 0.95 # threshold for considering segment sizes similar
14+
15+
16+
# TODO: consolidate with strategies.visual.get_window_segmentation
17+
@cache.cache(enabled=not DEBUG)
18+
def get_similar_segment_groups(
19+
image_file_path: str,
20+
min_segment_ssim: float = MIN_SEGMENT_SSIM,
21+
min_segment_size_sim: float = MIN_SEGMENT_SIZE_SIM,
22+
show_images: bool = DEBUG,
23+
contrast_factor: int = 10000,
24+
) -> tuple:
25+
"""Get similar segment groups."""
26+
image = Image.open(image_file_path)
27+
image.show()
28+
29+
if contrast_factor:
30+
image = utils.increase_contrast(image, contrast_factor)
31+
image.show()
32+
33+
segmentation_adapter = adapters.get_default_segmentation_adapter()
34+
segmented_image = segmentation_adapter.fetch_segmented_image(image)
35+
if show_images:
36+
segmented_image.show()
37+
38+
import ipdb
39+
40+
ipdb.set_trace()
41+
42+
masks = vision.get_masks_from_segmented_image(segmented_image)
43+
logger.info(f"{len(masks)=}")
44+
if show_images:
45+
plotting.display_binary_images_grid(masks)
46+
47+
refined_masks = vision.refine_masks(masks)
48+
logger.info(f"{len(refined_masks)=}")
49+
if show_images:
50+
plotting.display_binary_images_grid(refined_masks)
51+
52+
masked_images = vision.extract_masked_images(image, refined_masks)
53+
descriptions = ["" for _ in masked_images]
54+
if show_images:
55+
plotting.display_images_table_with_titles(masked_images, descriptions)
56+
57+
similar_idx_groups, ungrouped_idxs, ssim_matrix, _ = vision.get_similar_image_idxs(
58+
masked_images,
59+
min_segment_ssim,
60+
min_segment_size_sim,
61+
)
62+
logger.info(f"{len(similar_idx_groups)=}")
63+
64+
return (
65+
image,
66+
masked_images,
67+
refined_masks,
68+
similar_idx_groups,
69+
ungrouped_idxs,
70+
ssim_matrix,
71+
)
72+
73+
74+
def main() -> None:
75+
"""Main."""
76+
image_file_path = os.path.join(config.ROOT_DIR_PATH, "../tests/assets/excel.png")
77+
78+
MAX_GROUPS = 2
79+
80+
for min_segment_ssim in (MIN_SEGMENT_SSIM, MIN_SEGMENT_SSIM // 3):
81+
logger.info(f"{min_segment_ssim=}")
82+
image, masked_images, masks, similar_idx_groups, ungrouped_idxs, ssim_matrix = (
83+
get_similar_segment_groups(image_file_path)
84+
)
85+
similar_idx_groups = sorted(
86+
similar_idx_groups,
87+
key=lambda group: len(group),
88+
reverse=True,
89+
)
90+
if MAX_GROUPS:
91+
similar_idx_groups = similar_idx_groups[:MAX_GROUPS]
92+
plotting.plot_similar_image_groups(
93+
masked_images,
94+
similar_idx_groups,
95+
ssim_matrix,
96+
[
97+
f"min_ssim={MIN_SEGMENT_SSIM}",
98+
f"min_size_sim={MIN_SEGMENT_SIZE_SIM}",
99+
],
100+
)
101+
102+
"""
103+
- images:
104+
- original
105+
- one segment mask
106+
- multiple segment masks
107+
- original with one segment highlighted
108+
- original with multiple segments highlighted
109+
- original with one segment labelled
110+
- original with multiple segments labelled
111+
- original with one segment highlighted+labelled
112+
- original with multiple segments highlighted+labelled
113+
- individual segment
114+
- individual segment labelled
115+
- one or multiple segments per prompt
116+
"""
117+
for similar_idx_group in similar_idx_groups:
118+
similar_masks = [masks[idx] for idx in similar_idx_group]
119+
highlighted_image = plotting.highlight_masks(image, similar_masks)
120+
highlighted_image.show()
121+
122+
import ipdb
123+
124+
ipdb.set_trace()
125+
foo = 1 # noqa
126+
127+
128+
if __name__ == "__main__":
129+
main()

install/install_openadapt.ps1

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ RunAndCheck "git clone -q https://github.com/MLDSAI/OpenAdapt.git" "clone git re
360360
Set-Location .\OpenAdapt
361361
RunAndCheck "pip install poetry" "Run ``pip install poetry``"
362362
RunAndCheck "poetry install" "Run ``poetry install``"
363-
RunAndCheck "poetry run install-dashboard" "Install dashboard dependencies" -SkipCleanup:$true
363+
RunAndCheck "poetry run postinstall" "Install other dependencies" -SkipCleanup:$true
364364
RunAndCheck "cd openadapt"
365365
RunAndCheck "poetry run alembic upgrade head" "Run ``alembic upgrade head``" -SkipCleanup:$true
366366
RunAndCheck "cd .."

install/install_openadapt.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ RunAndCheck "git checkout $BRANCH" "Checkout branch $BRANCH"
157157

158158
RunAndCheck "pip3.10 install poetry" "Install Poetry"
159159
RunAndCheck "poetry install" "Install Python dependencies"
160-
RunAndCheck "poetry run install-dashboard" "Install dashboard dependencies"
160+
RunAndCheck "poetry run postinstall" "Install other dependencies"
161161
RunAndCheck "cd openadapt"
162162
RunAndCheck "poetry run alembic upgrade head" "Update database"
163163
RunAndCheck "cd .."

0 commit comments

Comments
 (0)