|
| 1 | +"""This module calculates image similarities using various methods.""" |
| 2 | + |
| 3 | +from typing import Callable |
| 4 | +import time |
| 5 | + |
| 6 | +from matplotlib.offsetbox import OffsetImage, AnnotationBbox |
| 7 | +from PIL import Image, ImageOps |
| 8 | +from skimage.metrics import structural_similarity as ssim |
| 9 | +from sklearn.manifold import MDS |
| 10 | +import imagehash |
| 11 | +import matplotlib.gridspec as gridspec |
| 12 | +import matplotlib.pyplot as plt |
| 13 | +import numpy as np |
| 14 | + |
| 15 | +from openadapt.db import crud |
| 16 | + |
| 17 | + |
| 18 | +SHOW_SSIM = False |
| 19 | + |
| 20 | + |
| 21 | +def calculate_ssim(im1: Image.Image, im2: Image.Image) -> float: |
| 22 | + """Calculate the Structural Similarity Index (SSIM) between two images. |
| 23 | +
|
| 24 | + Args: |
| 25 | + im1 (Image.Image): The first image. |
| 26 | + im2 (Image.Image): The second image. |
| 27 | +
|
| 28 | + Returns: |
| 29 | + float: The SSIM index between the two images. |
| 30 | + """ |
| 31 | + # Calculate aspect ratios |
| 32 | + aspect_ratio1 = im1.width / im1.height |
| 33 | + aspect_ratio2 = im2.width / im2.height |
| 34 | + # Use the smaller image as the base for resizing to maintain the aspect ratio |
| 35 | + if aspect_ratio1 < aspect_ratio2: |
| 36 | + base_width = min(im1.width, im2.width) |
| 37 | + base_height = int(base_width / aspect_ratio1) |
| 38 | + else: |
| 39 | + base_height = min(im1.height, im2.height) |
| 40 | + base_width = int(base_height * aspect_ratio2) |
| 41 | + |
| 42 | + # Resize images to a common base while maintaining aspect ratio |
| 43 | + im1 = im1.resize((base_width, base_height), Image.LANCZOS) |
| 44 | + im2 = im2.resize((base_width, base_height), Image.LANCZOS) |
| 45 | + |
| 46 | + # Convert images to grayscale |
| 47 | + im1_gray = np.array(im1.convert("L")) |
| 48 | + im2_gray = np.array(im2.convert("L")) |
| 49 | + |
| 50 | + mssim, grad, S = ssim( |
| 51 | + im1_gray, |
| 52 | + im2_gray, |
| 53 | + data_range=im2_gray.max() - im2_gray.min(), |
| 54 | + gradient=True, |
| 55 | + full=True, |
| 56 | + ) |
| 57 | + |
| 58 | + if SHOW_SSIM: |
| 59 | + # Normalize the gradient for visualization |
| 60 | + grad_normalized = (grad - grad.min()) / (grad.max() - grad.min()) |
| 61 | + im_grad = Image.fromarray((grad_normalized * 255).astype(np.uint8)) |
| 62 | + |
| 63 | + # Convert full SSIM image to uint8 |
| 64 | + im_S = Image.fromarray((S * 255).astype(np.uint8)) |
| 65 | + |
| 66 | + # Create a figure to display the images |
| 67 | + fig, axs = plt.subplots(1, 4, figsize=(20, 5)) # 1 row, 4 columns |
| 68 | + |
| 69 | + # Display each image in the subplot |
| 70 | + axs[0].imshow(im1, cmap="gray") |
| 71 | + axs[0].set_title("Image 1") |
| 72 | + axs[0].axis("off") |
| 73 | + |
| 74 | + axs[1].imshow(im2, cmap="gray") |
| 75 | + axs[1].set_title("Image 2") |
| 76 | + axs[1].axis("off") |
| 77 | + |
| 78 | + axs[2].imshow(im_grad, cmap="gray") |
| 79 | + axs[2].set_title("Gradient of SSIM") |
| 80 | + axs[2].axis("off") |
| 81 | + |
| 82 | + axs[3].imshow(im_S, cmap="gray") |
| 83 | + axs[3].set_title("SSIM Image") |
| 84 | + axs[3].axis("off") |
| 85 | + |
| 86 | + plt.show(block=False) |
| 87 | + |
| 88 | + return 1 - mssim |
| 89 | + |
| 90 | + |
| 91 | +def calculate_dynamic_threshold( |
| 92 | + im1: Image.Image, |
| 93 | + im2: Image.Image, |
| 94 | + k: float = 1.0, |
| 95 | +) -> float: |
| 96 | + """Calculate a dynamic threshold for image difference. |
| 97 | +
|
| 98 | + Based on the standard deviation of the pixel differences. |
| 99 | +
|
| 100 | + Args: |
| 101 | + im1 (Image.Image): The first image. |
| 102 | + im2 (Image.Image): The second image. |
| 103 | + k (float): The multiplier for the standard deviation to set the threshold. |
| 104 | +
|
| 105 | + Returns: |
| 106 | + float: The dynamically calculated threshold. |
| 107 | + """ |
| 108 | + # Convert images to numpy arrays |
| 109 | + arr1 = np.array(im1) |
| 110 | + arr2 = np.array(im2) |
| 111 | + |
| 112 | + # Calculate the absolute difference between the images |
| 113 | + diff = np.abs(arr1 - arr2) |
| 114 | + |
| 115 | + # Calculate mean and standard deviation of the differences |
| 116 | + mean_diff = np.mean(diff) |
| 117 | + std_diff = np.std(diff) |
| 118 | + |
| 119 | + # Calculate the threshold as mean plus k times the standard deviation |
| 120 | + threshold = mean_diff + k * std_diff |
| 121 | + |
| 122 | + return threshold |
| 123 | + |
| 124 | + |
| 125 | +def thresholded_difference(im1: Image.Image, im2: Image.Image, k: float = 1.0) -> int: |
| 126 | + """Return number of pixels differing by at least a dynamically calculated threshold. |
| 127 | +
|
| 128 | + Args: |
| 129 | + im1 (Image.Image): The first image. |
| 130 | + im2 (Image.Image): The second image. |
| 131 | + k (float): Multiplier for the standard deviation to set the dynamic threshold. |
| 132 | +
|
| 133 | + Returns: |
| 134 | + int: The number of pixels differing by at least the dynamically calculated |
| 135 | + threshold. |
| 136 | + """ |
| 137 | + common_size = (min(im1.width, im2.width), min(im1.height, im2.height)) |
| 138 | + im1 = im1.resize(common_size) |
| 139 | + im2 = im2.resize(common_size) |
| 140 | + |
| 141 | + # Calculate the dynamic threshold |
| 142 | + difference_threshold = calculate_dynamic_threshold(im1, im2, k) |
| 143 | + |
| 144 | + # Convert images to numpy arrays |
| 145 | + arr1 = np.array(im1) |
| 146 | + arr2 = np.array(im2) |
| 147 | + |
| 148 | + # Calculate the absolute difference between the images |
| 149 | + diff = np.abs(arr1 - arr2) |
| 150 | + |
| 151 | + # Count pixels with a difference above the dynamically calculated threshold |
| 152 | + count = np.sum(diff >= difference_threshold) |
| 153 | + |
| 154 | + return count |
| 155 | + |
| 156 | + |
| 157 | +def prepare_image( |
| 158 | + img: Image.Image, |
| 159 | + size: tuple[int, int] = (128, 128), |
| 160 | + border: int = 2, |
| 161 | + color: str = "red", |
| 162 | +) -> Image.Image: |
| 163 | + """Resize an image to a common size, add a border to it. |
| 164 | +
|
| 165 | + Args: |
| 166 | + img (Image.Image): The original image to prepare. |
| 167 | + size (tuple[int, int]): The size to which the images should be resized. |
| 168 | + border (int): The width of the border around the image. |
| 169 | + color (str): The color of the border. |
| 170 | +
|
| 171 | + Returns: |
| 172 | + Image.Image: The processed image. |
| 173 | + """ |
| 174 | + # Resize image |
| 175 | + img = img.resize(size, Image.ANTIALIAS) |
| 176 | + |
| 177 | + # Add border to the image |
| 178 | + img_with_border = ImageOps.expand(img, border=border, fill=color) |
| 179 | + |
| 180 | + return img_with_border |
| 181 | + |
| 182 | + |
| 183 | +def plot_images_with_mds( |
| 184 | + images: list[Image.Image], |
| 185 | + distance_matrix: np.ndarray, |
| 186 | + title: str, |
| 187 | + hash_func: Callable, |
| 188 | +) -> None: |
| 189 | + """Plot images on a scatter plot based on the provided distance matrix. |
| 190 | +
|
| 191 | + Args: |
| 192 | + images (list[Image.Image]): list of images to plot. |
| 193 | + distance_matrix (np.ndarray): A distance matrix of image differences. |
| 194 | + title (str): Title of the plot. |
| 195 | + hash_func (Callable): The hashing function to compute hash values. |
| 196 | +
|
| 197 | + Returns: |
| 198 | + None |
| 199 | + """ |
| 200 | + # Prepare images by resizing and adding a border |
| 201 | + prepared_images = [prepare_image(img) for img in images] |
| 202 | + |
| 203 | + # Compute hash values for each image |
| 204 | + hash_values = [str(hash_func(img)) if hash_func else "" for img in images] |
| 205 | + |
| 206 | + # Initialize MDS and fit the distance matrix to get the 2D embedding |
| 207 | + mds = MDS(n_components=2, dissimilarity="precomputed", random_state=0) |
| 208 | + positions = mds.fit_transform(distance_matrix) |
| 209 | + |
| 210 | + # Create a scatter plot with the MDS results |
| 211 | + fig, ax = plt.subplots(figsize=(15, 10)) |
| 212 | + ax.scatter(positions[:, 0], positions[:, 1], alpha=0) |
| 213 | + |
| 214 | + # Define properties for the bounding box |
| 215 | + bbox_props = dict(boxstyle="round,pad=0.3", ec="b", lw=2, fc="white", alpha=0.7) |
| 216 | + |
| 217 | + # Loop through images, positions, and hash values to create annotations |
| 218 | + for img, hash_val, (x, y) in zip(prepared_images, hash_values, positions): |
| 219 | + im = OffsetImage(np.array(img), zoom=0.5) |
| 220 | + ab = AnnotationBbox( |
| 221 | + im, |
| 222 | + (x, y), |
| 223 | + xycoords="data", |
| 224 | + frameon=True, |
| 225 | + bboxprops=bbox_props, |
| 226 | + ) |
| 227 | + ax.add_artist(ab) |
| 228 | + # Display the hash value beside the image |
| 229 | + ax.text(x, y - 0.05, hash_val, fontsize=9, ha="center") |
| 230 | + |
| 231 | + # Remove the x and y ticks |
| 232 | + ax.set_xticks([]) |
| 233 | + ax.set_yticks([]) |
| 234 | + |
| 235 | + plt.title(title) |
| 236 | + plt.show() |
| 237 | + |
| 238 | + |
| 239 | +def display_distance_matrix_with_images( |
| 240 | + distance_matrix: np.ndarray, |
| 241 | + images: list[Image.Image], |
| 242 | + func_name: str, |
| 243 | + thumbnail_size: tuple[int, int] = (32, 32), |
| 244 | +) -> None: |
| 245 | + """Display the distance matrix as an image with thumbnails along the top and left. |
| 246 | +
|
| 247 | + Args: |
| 248 | + distance_matrix (np.ndarray): A square matrix with distance values. |
| 249 | + images (list[Image.Image]): list of images corresponding to matrix rows/cols. |
| 250 | + thumbnail_size (tuple[int, int]): Size to which thumbnails will be resized. |
| 251 | +
|
| 252 | + Returns: |
| 253 | + None |
| 254 | + """ |
| 255 | + # Number of images |
| 256 | + n = len(images) |
| 257 | + # Create a figure with subplots |
| 258 | + fig = plt.figure(figsize=(10, 10)) |
| 259 | + # GridSpec layout for the thumbnails and the distance matrix |
| 260 | + gs = gridspec.GridSpec(n + 1, n + 1, figure=fig) |
| 261 | + |
| 262 | + # Place the distance matrix |
| 263 | + ax_matrix = fig.add_subplot(gs[1:, 1:]) |
| 264 | + ax_matrix.imshow(distance_matrix, cmap="viridis") |
| 265 | + ax_matrix.set_xticks([]) |
| 266 | + ax_matrix.set_yticks([]) |
| 267 | + |
| 268 | + # Annotate each cell with the distance value |
| 269 | + for (i, j), val in np.ndenumerate(distance_matrix): |
| 270 | + ax_matrix.text(j, i, f"{val:.4f}", ha="center", va="center", color="white") |
| 271 | + |
| 272 | + # Resize images to thumbnails |
| 273 | + thumbnails = [img.resize(thumbnail_size, Image.ANTIALIAS) for img in images] |
| 274 | + |
| 275 | + # Plot images on the top row |
| 276 | + for i, img in enumerate(thumbnails): |
| 277 | + ax_img_top = fig.add_subplot(gs[0, i + 1]) |
| 278 | + ax_img_top.imshow(np.array(img)) |
| 279 | + ax_img_top.axis("off") # Hide axes |
| 280 | + |
| 281 | + # Plot images on the left column |
| 282 | + for i, img in enumerate(thumbnails): |
| 283 | + ax_img_left = fig.add_subplot(gs[i + 1, 0]) |
| 284 | + ax_img_left.imshow(np.array(img)) |
| 285 | + ax_img_left.axis("off") # Hide axes |
| 286 | + |
| 287 | + plt.suptitle(func_name) |
| 288 | + plt.show() |
| 289 | + |
| 290 | + |
| 291 | +def main() -> None: |
| 292 | + """Main function to process images and display similarity metrics.""" |
| 293 | + recording = crud.get_latest_recording() |
| 294 | + action_events = recording.processed_action_events |
| 295 | + images = [action_event.screenshot.cropped_image for action_event in action_events] |
| 296 | + |
| 297 | + similarity_funcs = { |
| 298 | + "ssim": calculate_ssim, |
| 299 | + "thresholded_difference": thresholded_difference, |
| 300 | + "average_hash": lambda im1, im2: ( |
| 301 | + imagehash.average_hash(im1) - imagehash.average_hash(im2) |
| 302 | + ), |
| 303 | + "dhash": lambda im1, im2: (imagehash.dhash(im1) - imagehash.dhash(im2)), |
| 304 | + "phash": lambda im1, im2: (imagehash.phash(im1) - imagehash.phash(im2)), |
| 305 | + "crop_resistant_hash": lambda im1, im2: ( |
| 306 | + imagehash.crop_resistant_hash(im1) - imagehash.crop_resistant_hash(im2) |
| 307 | + ), |
| 308 | + "colorhash": lambda im1, im2: ( |
| 309 | + imagehash.colorhash(im1) - imagehash.colorhash(im2) |
| 310 | + ), |
| 311 | + "whash": lambda im1, im2: imagehash.whash(im1) - imagehash.whash(im2), |
| 312 | + } |
| 313 | + |
| 314 | + # Process each similarity function |
| 315 | + for func_name, func in similarity_funcs.items(): |
| 316 | + hash_func = { |
| 317 | + "average_hash": imagehash.average_hash, |
| 318 | + "dhash": imagehash.dhash, |
| 319 | + "phash": imagehash.phash, |
| 320 | + "crop_resistant_hash": imagehash.crop_resistant_hash, |
| 321 | + "colorhash": imagehash.colorhash, |
| 322 | + "whash": imagehash.whash, |
| 323 | + }.get(func_name, None) |
| 324 | + |
| 325 | + # Create a matrix to store all pairwise distances |
| 326 | + n = len(images) |
| 327 | + distance_matrix = np.zeros((n, n)) |
| 328 | + durations = [] |
| 329 | + for i in range(n): |
| 330 | + for j in range(i + 1, n): |
| 331 | + start_time = time.time() |
| 332 | + distance = abs(func(images[i], images[j])) |
| 333 | + duration = time.time() - start_time |
| 334 | + durations.append(duration) |
| 335 | + distance_matrix[i, j] = distance |
| 336 | + distance_matrix[j, i] = distance |
| 337 | + mean_duration = sum(durations) / len(durations) |
| 338 | + print(f"{func_name=}") |
| 339 | + print(f"distance_matrix=\n{distance_matrix}") |
| 340 | + print(f"{mean_duration=}") |
| 341 | + display_distance_matrix_with_images(distance_matrix, images, func_name) |
| 342 | + plot_images_with_mds( |
| 343 | + images, |
| 344 | + distance_matrix, |
| 345 | + f"Image layout based on {func_name} ({mean_duration=:.4f}s)", |
| 346 | + hash_func, |
| 347 | + ) |
| 348 | + |
| 349 | + |
| 350 | +if __name__ == "__main__": |
| 351 | + main() |
0 commit comments