|
| 1 | +import sys |
| 2 | +from typing import List |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +from sklearn.cluster import KMeans, AffinityPropagation |
| 6 | +from sklearn.decomposition import PCA |
| 7 | +from sklearn.metrics.pairwise import cosine_similarity |
| 8 | + |
| 9 | +from harmony.matching.default_matcher import convert_texts_to_vector |
| 10 | + |
| 11 | +# import matplotlib, tkinter and networkx for the GUI |
| 12 | +try: |
| 13 | + import matplotlib.pyplot as plt |
| 14 | + from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg |
| 15 | + from matplotlib.axes import Axes |
| 16 | + import tkinter as tk |
| 17 | + import tkinter.simpledialog |
| 18 | + from tkinter import ttk |
| 19 | + import networkx as nx |
| 20 | + from networkx.algorithms import community |
| 21 | +except ImportError as e: |
| 22 | + print("Make sure matplotlib, tkinter and networkx are installed.") |
| 23 | + print(e.msg) |
| 24 | + sys.exit(1) |
| 25 | + |
| 26 | + |
| 27 | +def draw_cosine_similarity_matrix(questions: List[str], ax: Axes, canvas: FigureCanvasTkAgg): |
| 28 | + """ |
| 29 | + Draws a heatmap of the cosine similarity matrix based on the given questions. |
| 30 | +
|
| 31 | + Args: |
| 32 | + questions: List of question strings to visualize |
| 33 | + ax: Matplotlib Axes object to draw on |
| 34 | + canvas: Tkinter canvas for displaying the plot |
| 35 | + """ |
| 36 | + embedding_matrix = convert_texts_to_vector(questions) |
| 37 | + similarity_matrix = cosine_similarity(embedding_matrix) |
| 38 | + |
| 39 | + ax.clear() |
| 40 | + ax.axis("on") |
| 41 | + ax.tick_params( |
| 42 | + axis="both", |
| 43 | + which="both", |
| 44 | + bottom=True, |
| 45 | + left=True, |
| 46 | + labelbottom=True, |
| 47 | + labelleft=True |
| 48 | + ) |
| 49 | + ax.set_title("Cosine Similarity Matrix") |
| 50 | + |
| 51 | + ax.imshow(similarity_matrix, cmap="Blues", interpolation="nearest") |
| 52 | + ax.invert_yaxis() |
| 53 | + canvas.draw() |
| 54 | + |
| 55 | + |
| 56 | +def draw_clusters_scatter_plot(questions: List[str], ax: Axes, canvas: FigureCanvasTkAgg): |
| 57 | + """ |
| 58 | + Draws a scatter plot based on the given questions. |
| 59 | + Uses K-Means clustering for small datasets (<30 questions) and Affinity Propagation clustering for larger ones. |
| 60 | +
|
| 61 | + Args: |
| 62 | + questions: List of question strings to visualize |
| 63 | + ax: Matplotlib Axes object to draw on |
| 64 | + canvas: Tkinter canvas for displaying the plot |
| 65 | + """ |
| 66 | + embedding_matrix = convert_texts_to_vector(questions) |
| 67 | + |
| 68 | + if len(questions) < 30: |
| 69 | + clustering = KMeans(n_clusters=5) |
| 70 | + labels = clustering.fit_predict(embedding_matrix) |
| 71 | + |
| 72 | + title = "K-Means Clustering" |
| 73 | + else: |
| 74 | + item_to_item_similarity_matrix = np.array(cosine_similarity(embedding_matrix)).astype(np.float64) |
| 75 | + |
| 76 | + clustering = AffinityPropagation(affinity="precomputed", damping=0.7, random_state=1, max_iter=200, |
| 77 | + convergence_iter=15) |
| 78 | + clustering.fit(np.abs(item_to_item_similarity_matrix)) |
| 79 | + labels = clustering.labels_ |
| 80 | + |
| 81 | + title = "Affinity Propagation Clustering" |
| 82 | + |
| 83 | + ax.clear() |
| 84 | + ax.axis("on") |
| 85 | + ax.tick_params( |
| 86 | + axis="both", |
| 87 | + which="both", |
| 88 | + bottom=True, |
| 89 | + left=True, |
| 90 | + labelbottom=True, |
| 91 | + labelleft=True |
| 92 | + ) |
| 93 | + ax.set_aspect("auto") |
| 94 | + ax.set_title(title) |
| 95 | + |
| 96 | + pca = PCA(n_components=2) |
| 97 | + reduced_embeddings = pca.fit_transform(embedding_matrix) |
| 98 | + |
| 99 | + ax.scatter( |
| 100 | + reduced_embeddings[:, 0], |
| 101 | + reduced_embeddings[:, 1], |
| 102 | + c=labels, |
| 103 | + cmap="viridis", |
| 104 | + s=100 |
| 105 | + ) |
| 106 | + |
| 107 | + for i, point in enumerate(reduced_embeddings): |
| 108 | + ax.annotate( |
| 109 | + str(i), |
| 110 | + xy=(point[0], point[1]), |
| 111 | + xytext=(8, -10), |
| 112 | + textcoords="offset points", |
| 113 | + fontsize=8, |
| 114 | + color="black", |
| 115 | + ha="center" |
| 116 | + ) |
| 117 | + |
| 118 | + canvas.draw() |
| 119 | + |
| 120 | + |
| 121 | +def draw_network_graph(questions: List[str], ax: Axes, canvas: FigureCanvasTkAgg): |
| 122 | + """ |
| 123 | + Draws a network graph based on the given questions, where edges represent high similarity (>0.5). |
| 124 | + Communities are detected using greedy modularity optimization. |
| 125 | +
|
| 126 | + Args: |
| 127 | + questions: List of question strings to visualize |
| 128 | + ax: Matplotlib Axes object to draw on |
| 129 | + canvas: Tkinter canvas for displaying the plot |
| 130 | + """ |
| 131 | + embedding_matrix = convert_texts_to_vector(questions) |
| 132 | + similarity_matrix = cosine_similarity(embedding_matrix) |
| 133 | + |
| 134 | + ax.clear() |
| 135 | + ax.axis("off") |
| 136 | + ax.set_aspect("auto") |
| 137 | + ax.set_title("Network Cluster Graph") |
| 138 | + |
| 139 | + G = nx.Graph() |
| 140 | + n = similarity_matrix.shape[0] |
| 141 | + |
| 142 | + i = 0 |
| 143 | + for i in range(n): |
| 144 | + for j in range(i + 1, n): |
| 145 | + if similarity_matrix[i, j] > 0.5: |
| 146 | + G.add_edge(i, j, weight=similarity_matrix[i, j]) |
| 147 | + |
| 148 | + communities = list(community.greedy_modularity_communities(G)) |
| 149 | + |
| 150 | + # assign colors to nodes based on communities |
| 151 | + node_color = [] |
| 152 | + for comm_idx, comm in enumerate(communities): |
| 153 | + for _ in comm: |
| 154 | + node_color.append(comm_idx) |
| 155 | + |
| 156 | + # improve node positions using existing layouts |
| 157 | + pos = nx.kamada_kawai_layout(G, weight="weight") |
| 158 | + pos = nx.spring_layout( |
| 159 | + G, |
| 160 | + pos=pos, |
| 161 | + k=2, |
| 162 | + scale=2.0, |
| 163 | + iterations=200 |
| 164 | + ) |
| 165 | + |
| 166 | + nx.draw_networkx_nodes( |
| 167 | + G, pos, |
| 168 | + ax=ax, |
| 169 | + node_size=300, |
| 170 | + node_color=node_color, |
| 171 | + ) |
| 172 | + |
| 173 | + nx.draw_networkx_edges( |
| 174 | + G, pos, |
| 175 | + ax=ax, |
| 176 | + width=1.0, |
| 177 | + alpha=0.7 |
| 178 | + ) |
| 179 | + |
| 180 | + nx.draw_networkx_labels( |
| 181 | + G, pos, |
| 182 | + ax=ax, |
| 183 | + font_size=12 |
| 184 | + ) |
| 185 | + |
| 186 | + canvas.draw() |
| 187 | + |
| 188 | + |
| 189 | +def setup_gui(questions: List[str]): |
| 190 | + """ |
| 191 | + Sets up the Tkinter GUI. |
| 192 | +
|
| 193 | + Args: |
| 194 | + questions: List of question strings to visualize. |
| 195 | + """ |
| 196 | + |
| 197 | + def add_question(questions: List[str], ax: Axes, canvas: FigureCanvasTkAgg): |
| 198 | + """Handles adding new questions through a simple dialog and updates the canvas""" |
| 199 | + question = tkinter.simpledialog.askstring("Add a New Question", "New Question:") |
| 200 | + if question: |
| 201 | + questions.append(question) |
| 202 | + # redraw cosine similarity matrix including newly added question |
| 203 | + draw_cosine_similarity_matrix(questions, ax, canvas) |
| 204 | + |
| 205 | + def display_questions(): |
| 206 | + """Displays all questions in a scrollable dialog window""" |
| 207 | + dialog = tk.Toplevel(root) |
| 208 | + dialog.title("All Questions") |
| 209 | + dialog.geometry("400x600") |
| 210 | + |
| 211 | + # make the dialog window modal |
| 212 | + dialog.grab_set() |
| 213 | + dialog.focus_set() |
| 214 | + root.attributes("-disabled", True) |
| 215 | + |
| 216 | + scrollbar = ttk.Scrollbar(dialog) |
| 217 | + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) |
| 218 | + |
| 219 | + questions_text = tk.Text(dialog, height=8) |
| 220 | + questions_text.pack(side=tk.LEFT, expand=True, fill=tk.BOTH, ) |
| 221 | + |
| 222 | + questions_text["yscrollcommand"] = scrollbar.set |
| 223 | + scrollbar.config(command=questions_text.yview) |
| 224 | + |
| 225 | + for i, question in enumerate(questions): |
| 226 | + questions_text.insert(tk.END, f"Q{i}: {question}\n") |
| 227 | + |
| 228 | + def close_dialog(): |
| 229 | + """Cleanup when closing the dialog""" |
| 230 | + root.attributes("-disabled", False) |
| 231 | + dialog.destroy() |
| 232 | + |
| 233 | + dialog.protocol("WM_DELETE_WINDOW", close_dialog) |
| 234 | + |
| 235 | + dialog.transient(root) |
| 236 | + dialog.wait_window() |
| 237 | + |
| 238 | + # main window |
| 239 | + root = tk.Tk() |
| 240 | + root.title("Harmony Visualizer") |
| 241 | + root.geometry("800x450") |
| 242 | + |
| 243 | + # main frame |
| 244 | + main_frame = tk.Frame(root) |
| 245 | + main_frame.pack(fill=tk.BOTH, expand=True) |
| 246 | + |
| 247 | + # left frame for graphs |
| 248 | + graph_frame = tk.Frame(main_frame, width=350, height=350, bg="white") |
| 249 | + graph_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) |
| 250 | + graph_frame.pack_propagate(False) |
| 251 | + |
| 252 | + # upper right frame for graph buttons |
| 253 | + button_frame = tk.Frame(main_frame, width=200, bg="lightgray") |
| 254 | + button_frame.pack(side=tk.RIGHT, fill=tk.Y) |
| 255 | + # lower right frame with buttons for displaying and adding questions |
| 256 | + bottom_button_frame = tk.Frame(button_frame, bg="lightgray") |
| 257 | + bottom_button_frame.pack(side=tk.BOTTOM, fill=tk.X, pady=10) |
| 258 | + |
| 259 | + fig, ax = plt.subplots() |
| 260 | + ax.axis("off") # hide placeholder chart until a button is pressed |
| 261 | + canvas = FigureCanvasTkAgg(fig, master=graph_frame) |
| 262 | + canvas_widget = canvas.get_tk_widget() |
| 263 | + canvas_widget.pack(fill=tk.BOTH, expand=True) |
| 264 | + |
| 265 | + # the graph buttons and their corresponding draw functions |
| 266 | + button_texts = ["Cosine Similarity Matrix", "Cluster Scatter Plot", "Network Graph"] |
| 267 | + button_functions = [draw_cosine_similarity_matrix, draw_clusters_scatter_plot, draw_network_graph] |
| 268 | + |
| 269 | + for button_text, function in zip(button_texts, button_functions): |
| 270 | + new_button = tk.Button(button_frame, text=button_text, |
| 271 | + command=lambda func=function: func(questions, ax, canvas)) |
| 272 | + new_button.pack(pady=8, padx=10, fill=tk.X) |
| 273 | + |
| 274 | + # buttons for adding and displaying questions |
| 275 | + add_question_button = tk.Button(bottom_button_frame, text="Add Question", |
| 276 | + command=lambda func=add_question: func(questions, ax, canvas)) |
| 277 | + display_questions_button = tk.Button(bottom_button_frame, text="See Questions", command=display_questions) |
| 278 | + add_question_button.pack(pady=8, padx=10, fill=tk.X) |
| 279 | + display_questions_button.pack(pady=8, padx=10, fill=tk.X) |
| 280 | + |
| 281 | + root.protocol("WM_DELETE_WINDOW", lambda: (plt.close("all"), root.destroy())) |
| 282 | + root.mainloop() |
| 283 | + |
| 284 | + |
| 285 | +def visualize_questions(questions: List[str]): |
| 286 | + """ |
| 287 | + Entry point for the GUI. |
| 288 | +
|
| 289 | + Args: |
| 290 | + questions: List of question strings to visualize |
| 291 | + """ |
| 292 | + if not questions: |
| 293 | + print("No questions provided. Exiting...") |
| 294 | + sys.exit(1) |
| 295 | + setup_gui(questions) |
0 commit comments