Skip to content

Commit 3c185e8

Browse files
authored
feat(VisualReplayStrategy): compute image similarity to avoid unnecessary segmentation
* remove sct_image from Screenshot; fix typo * add Image.cropped_image * add experiments/imagesimilarity.py * bugfix: sct_image -> image * find_similar_image_segmentation * fix test_crop_active_window
1 parent 31ae83f commit 3c185e8

File tree

9 files changed

+572
-59
lines changed

9 files changed

+572
-59
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ poetry install
9797
poetry shell
9898
alembic upgrade head
9999
poetry run install-dashbaord
100-
101100
pytest
102101
```
103102

experiments/imagesimilarity.py

Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
"""This module calculates image similarities using various methods."""
2+
3+
from typing import Callable
4+
import time
5+
6+
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
7+
from PIL import Image, ImageOps
8+
from skimage.metrics import structural_similarity as ssim
9+
from sklearn.manifold import MDS
10+
import imagehash
11+
import matplotlib.gridspec as gridspec
12+
import matplotlib.pyplot as plt
13+
import numpy as np
14+
15+
from openadapt.db import crud
16+
17+
18+
SHOW_SSIM = False
19+
20+
21+
def calculate_ssim(im1: Image.Image, im2: Image.Image) -> float:
22+
"""Calculate the Structural Similarity Index (SSIM) between two images.
23+
24+
Args:
25+
im1 (Image.Image): The first image.
26+
im2 (Image.Image): The second image.
27+
28+
Returns:
29+
float: The SSIM index between the two images.
30+
"""
31+
# Calculate aspect ratios
32+
aspect_ratio1 = im1.width / im1.height
33+
aspect_ratio2 = im2.width / im2.height
34+
# Use the smaller image as the base for resizing to maintain the aspect ratio
35+
if aspect_ratio1 < aspect_ratio2:
36+
base_width = min(im1.width, im2.width)
37+
base_height = int(base_width / aspect_ratio1)
38+
else:
39+
base_height = min(im1.height, im2.height)
40+
base_width = int(base_height * aspect_ratio2)
41+
42+
# Resize images to a common base while maintaining aspect ratio
43+
im1 = im1.resize((base_width, base_height), Image.LANCZOS)
44+
im2 = im2.resize((base_width, base_height), Image.LANCZOS)
45+
46+
# Convert images to grayscale
47+
im1_gray = np.array(im1.convert("L"))
48+
im2_gray = np.array(im2.convert("L"))
49+
50+
mssim, grad, S = ssim(
51+
im1_gray,
52+
im2_gray,
53+
data_range=im2_gray.max() - im2_gray.min(),
54+
gradient=True,
55+
full=True,
56+
)
57+
58+
if SHOW_SSIM:
59+
# Normalize the gradient for visualization
60+
grad_normalized = (grad - grad.min()) / (grad.max() - grad.min())
61+
im_grad = Image.fromarray((grad_normalized * 255).astype(np.uint8))
62+
63+
# Convert full SSIM image to uint8
64+
im_S = Image.fromarray((S * 255).astype(np.uint8))
65+
66+
# Create a figure to display the images
67+
fig, axs = plt.subplots(1, 4, figsize=(20, 5)) # 1 row, 4 columns
68+
69+
# Display each image in the subplot
70+
axs[0].imshow(im1, cmap="gray")
71+
axs[0].set_title("Image 1")
72+
axs[0].axis("off")
73+
74+
axs[1].imshow(im2, cmap="gray")
75+
axs[1].set_title("Image 2")
76+
axs[1].axis("off")
77+
78+
axs[2].imshow(im_grad, cmap="gray")
79+
axs[2].set_title("Gradient of SSIM")
80+
axs[2].axis("off")
81+
82+
axs[3].imshow(im_S, cmap="gray")
83+
axs[3].set_title("SSIM Image")
84+
axs[3].axis("off")
85+
86+
plt.show(block=False)
87+
88+
return 1 - mssim
89+
90+
91+
def calculate_dynamic_threshold(
92+
im1: Image.Image,
93+
im2: Image.Image,
94+
k: float = 1.0,
95+
) -> float:
96+
"""Calculate a dynamic threshold for image difference.
97+
98+
Based on the standard deviation of the pixel differences.
99+
100+
Args:
101+
im1 (Image.Image): The first image.
102+
im2 (Image.Image): The second image.
103+
k (float): The multiplier for the standard deviation to set the threshold.
104+
105+
Returns:
106+
float: The dynamically calculated threshold.
107+
"""
108+
# Convert images to numpy arrays
109+
arr1 = np.array(im1)
110+
arr2 = np.array(im2)
111+
112+
# Calculate the absolute difference between the images
113+
diff = np.abs(arr1 - arr2)
114+
115+
# Calculate mean and standard deviation of the differences
116+
mean_diff = np.mean(diff)
117+
std_diff = np.std(diff)
118+
119+
# Calculate the threshold as mean plus k times the standard deviation
120+
threshold = mean_diff + k * std_diff
121+
122+
return threshold
123+
124+
125+
def thresholded_difference(im1: Image.Image, im2: Image.Image, k: float = 1.0) -> int:
126+
"""Return number of pixels differing by at least a dynamically calculated threshold.
127+
128+
Args:
129+
im1 (Image.Image): The first image.
130+
im2 (Image.Image): The second image.
131+
k (float): Multiplier for the standard deviation to set the dynamic threshold.
132+
133+
Returns:
134+
int: The number of pixels differing by at least the dynamically calculated
135+
threshold.
136+
"""
137+
common_size = (min(im1.width, im2.width), min(im1.height, im2.height))
138+
im1 = im1.resize(common_size)
139+
im2 = im2.resize(common_size)
140+
141+
# Calculate the dynamic threshold
142+
difference_threshold = calculate_dynamic_threshold(im1, im2, k)
143+
144+
# Convert images to numpy arrays
145+
arr1 = np.array(im1)
146+
arr2 = np.array(im2)
147+
148+
# Calculate the absolute difference between the images
149+
diff = np.abs(arr1 - arr2)
150+
151+
# Count pixels with a difference above the dynamically calculated threshold
152+
count = np.sum(diff >= difference_threshold)
153+
154+
return count
155+
156+
157+
def prepare_image(
158+
img: Image.Image,
159+
size: tuple[int, int] = (128, 128),
160+
border: int = 2,
161+
color: str = "red",
162+
) -> Image.Image:
163+
"""Resize an image to a common size, add a border to it.
164+
165+
Args:
166+
img (Image.Image): The original image to prepare.
167+
size (tuple[int, int]): The size to which the images should be resized.
168+
border (int): The width of the border around the image.
169+
color (str): The color of the border.
170+
171+
Returns:
172+
Image.Image: The processed image.
173+
"""
174+
# Resize image
175+
img = img.resize(size, Image.ANTIALIAS)
176+
177+
# Add border to the image
178+
img_with_border = ImageOps.expand(img, border=border, fill=color)
179+
180+
return img_with_border
181+
182+
183+
def plot_images_with_mds(
184+
images: list[Image.Image],
185+
distance_matrix: np.ndarray,
186+
title: str,
187+
hash_func: Callable,
188+
) -> None:
189+
"""Plot images on a scatter plot based on the provided distance matrix.
190+
191+
Args:
192+
images (list[Image.Image]): list of images to plot.
193+
distance_matrix (np.ndarray): A distance matrix of image differences.
194+
title (str): Title of the plot.
195+
hash_func (Callable): The hashing function to compute hash values.
196+
197+
Returns:
198+
None
199+
"""
200+
# Prepare images by resizing and adding a border
201+
prepared_images = [prepare_image(img) for img in images]
202+
203+
# Compute hash values for each image
204+
hash_values = [str(hash_func(img)) if hash_func else "" for img in images]
205+
206+
# Initialize MDS and fit the distance matrix to get the 2D embedding
207+
mds = MDS(n_components=2, dissimilarity="precomputed", random_state=0)
208+
positions = mds.fit_transform(distance_matrix)
209+
210+
# Create a scatter plot with the MDS results
211+
fig, ax = plt.subplots(figsize=(15, 10))
212+
ax.scatter(positions[:, 0], positions[:, 1], alpha=0)
213+
214+
# Define properties for the bounding box
215+
bbox_props = dict(boxstyle="round,pad=0.3", ec="b", lw=2, fc="white", alpha=0.7)
216+
217+
# Loop through images, positions, and hash values to create annotations
218+
for img, hash_val, (x, y) in zip(prepared_images, hash_values, positions):
219+
im = OffsetImage(np.array(img), zoom=0.5)
220+
ab = AnnotationBbox(
221+
im,
222+
(x, y),
223+
xycoords="data",
224+
frameon=True,
225+
bboxprops=bbox_props,
226+
)
227+
ax.add_artist(ab)
228+
# Display the hash value beside the image
229+
ax.text(x, y - 0.05, hash_val, fontsize=9, ha="center")
230+
231+
# Remove the x and y ticks
232+
ax.set_xticks([])
233+
ax.set_yticks([])
234+
235+
plt.title(title)
236+
plt.show()
237+
238+
239+
def display_distance_matrix_with_images(
240+
distance_matrix: np.ndarray,
241+
images: list[Image.Image],
242+
func_name: str,
243+
thumbnail_size: tuple[int, int] = (32, 32),
244+
) -> None:
245+
"""Display the distance matrix as an image with thumbnails along the top and left.
246+
247+
Args:
248+
distance_matrix (np.ndarray): A square matrix with distance values.
249+
images (list[Image.Image]): list of images corresponding to matrix rows/cols.
250+
thumbnail_size (tuple[int, int]): Size to which thumbnails will be resized.
251+
252+
Returns:
253+
None
254+
"""
255+
# Number of images
256+
n = len(images)
257+
# Create a figure with subplots
258+
fig = plt.figure(figsize=(10, 10))
259+
# GridSpec layout for the thumbnails and the distance matrix
260+
gs = gridspec.GridSpec(n + 1, n + 1, figure=fig)
261+
262+
# Place the distance matrix
263+
ax_matrix = fig.add_subplot(gs[1:, 1:])
264+
ax_matrix.imshow(distance_matrix, cmap="viridis")
265+
ax_matrix.set_xticks([])
266+
ax_matrix.set_yticks([])
267+
268+
# Annotate each cell with the distance value
269+
for (i, j), val in np.ndenumerate(distance_matrix):
270+
ax_matrix.text(j, i, f"{val:.4f}", ha="center", va="center", color="white")
271+
272+
# Resize images to thumbnails
273+
thumbnails = [img.resize(thumbnail_size, Image.ANTIALIAS) for img in images]
274+
275+
# Plot images on the top row
276+
for i, img in enumerate(thumbnails):
277+
ax_img_top = fig.add_subplot(gs[0, i + 1])
278+
ax_img_top.imshow(np.array(img))
279+
ax_img_top.axis("off") # Hide axes
280+
281+
# Plot images on the left column
282+
for i, img in enumerate(thumbnails):
283+
ax_img_left = fig.add_subplot(gs[i + 1, 0])
284+
ax_img_left.imshow(np.array(img))
285+
ax_img_left.axis("off") # Hide axes
286+
287+
plt.suptitle(func_name)
288+
plt.show()
289+
290+
291+
def main() -> None:
292+
"""Main function to process images and display similarity metrics."""
293+
recording = crud.get_latest_recording()
294+
action_events = recording.processed_action_events
295+
images = [action_event.screenshot.cropped_image for action_event in action_events]
296+
297+
similarity_funcs = {
298+
"ssim": calculate_ssim,
299+
"thresholded_difference": thresholded_difference,
300+
"average_hash": lambda im1, im2: (
301+
imagehash.average_hash(im1) - imagehash.average_hash(im2)
302+
),
303+
"dhash": lambda im1, im2: (imagehash.dhash(im1) - imagehash.dhash(im2)),
304+
"phash": lambda im1, im2: (imagehash.phash(im1) - imagehash.phash(im2)),
305+
"crop_resistant_hash": lambda im1, im2: (
306+
imagehash.crop_resistant_hash(im1) - imagehash.crop_resistant_hash(im2)
307+
),
308+
"colorhash": lambda im1, im2: (
309+
imagehash.colorhash(im1) - imagehash.colorhash(im2)
310+
),
311+
"whash": lambda im1, im2: imagehash.whash(im1) - imagehash.whash(im2),
312+
}
313+
314+
# Process each similarity function
315+
for func_name, func in similarity_funcs.items():
316+
hash_func = {
317+
"average_hash": imagehash.average_hash,
318+
"dhash": imagehash.dhash,
319+
"phash": imagehash.phash,
320+
"crop_resistant_hash": imagehash.crop_resistant_hash,
321+
"colorhash": imagehash.colorhash,
322+
"whash": imagehash.whash,
323+
}.get(func_name, None)
324+
325+
# Create a matrix to store all pairwise distances
326+
n = len(images)
327+
distance_matrix = np.zeros((n, n))
328+
durations = []
329+
for i in range(n):
330+
for j in range(i + 1, n):
331+
start_time = time.time()
332+
distance = abs(func(images[i], images[j]))
333+
duration = time.time() - start_time
334+
durations.append(duration)
335+
distance_matrix[i, j] = distance
336+
distance_matrix[j, i] = distance
337+
mean_duration = sum(durations) / len(durations)
338+
print(f"{func_name=}")
339+
print(f"distance_matrix=\n{distance_matrix}")
340+
print(f"{mean_duration=}")
341+
display_distance_matrix_with_images(distance_matrix, images, func_name)
342+
plot_images_with_mds(
343+
images,
344+
distance_matrix,
345+
f"Image layout based on {func_name} ({mean_duration=:.4f}s)",
346+
hash_func,
347+
)
348+
349+
350+
if __name__ == "__main__":
351+
main()

openadapt/events.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,13 +357,13 @@ def get_timestamp_mappings(
357357
"double_click_distance_pixels",
358358
utils.get_double_click_distance_pixels,
359359
)
360-
logger.info(f"{double_click_distance=}")
360+
logger.debug(f"{double_click_distance=}")
361361
double_click_interval = get_recording_attr(
362362
to_merge[0],
363363
"double_click_interval_seconds",
364364
utils.get_double_click_interval_seconds,
365365
)
366-
logger.info(f"{double_click_interval=}")
366+
logger.debug(f"{double_click_interval=}")
367367
press_to_press_t = {}
368368
press_to_release_t = {}
369369
prev_pressed_event = None
@@ -770,7 +770,7 @@ def discard_unused_events(
770770
]
771771
num_referred_events_after = len(referred_events)
772772
num_referred_events_removed = num_referred_events_before - num_referred_events_after
773-
logger.info(f"{referred_timestamp_key=} {num_referred_events_removed=}")
773+
logger.debug(f"{referred_timestamp_key=} {num_referred_events_removed=}")
774774
return referred_events
775775

776776

0 commit comments

Comments
 (0)