Skip to content

Commit 3eab57b

Browse files
authored
Merge pull request #113 from KonstantinosKorovesis/main
feat: add question visualization GUI
2 parents 9ba891c + 633d0bc commit 3eab57b

File tree

2 files changed

+344
-0
lines changed

2 files changed

+344
-0
lines changed
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
import sys
2+
from typing import List
3+
4+
import numpy as np
5+
from sklearn.cluster import KMeans, AffinityPropagation
6+
from sklearn.decomposition import PCA
7+
from sklearn.metrics.pairwise import cosine_similarity
8+
9+
from harmony.matching.default_matcher import convert_texts_to_vector
10+
11+
# import matplotlib, tkinter and networkx for the GUI
12+
try:
13+
import matplotlib.pyplot as plt
14+
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
15+
from matplotlib.axes import Axes
16+
import tkinter as tk
17+
import tkinter.simpledialog
18+
from tkinter import ttk
19+
import networkx as nx
20+
from networkx.algorithms import community
21+
except ImportError as e:
22+
print("Make sure matplotlib, tkinter and networkx are installed.")
23+
print(e.msg)
24+
sys.exit(1)
25+
26+
27+
def draw_cosine_similarity_matrix(questions: List[str], ax: Axes, canvas: FigureCanvasTkAgg):
28+
"""
29+
Draws a heatmap of the cosine similarity matrix based on the given questions.
30+
31+
Args:
32+
questions: List of question strings to visualize
33+
ax: Matplotlib Axes object to draw on
34+
canvas: Tkinter canvas for displaying the plot
35+
"""
36+
embedding_matrix = convert_texts_to_vector(questions)
37+
similarity_matrix = cosine_similarity(embedding_matrix)
38+
39+
ax.clear()
40+
ax.axis("on")
41+
ax.tick_params(
42+
axis="both",
43+
which="both",
44+
bottom=True,
45+
left=True,
46+
labelbottom=True,
47+
labelleft=True
48+
)
49+
ax.set_title("Cosine Similarity Matrix")
50+
51+
ax.imshow(similarity_matrix, cmap="Blues", interpolation="nearest")
52+
ax.invert_yaxis()
53+
canvas.draw()
54+
55+
56+
def draw_clusters_scatter_plot(questions: List[str], ax: Axes, canvas: FigureCanvasTkAgg):
57+
"""
58+
Draws a scatter plot based on the given questions.
59+
Uses K-Means clustering for small datasets (<30 questions) and Affinity Propagation clustering for larger ones.
60+
61+
Args:
62+
questions: List of question strings to visualize
63+
ax: Matplotlib Axes object to draw on
64+
canvas: Tkinter canvas for displaying the plot
65+
"""
66+
embedding_matrix = convert_texts_to_vector(questions)
67+
68+
if len(questions) < 30:
69+
clustering = KMeans(n_clusters=5)
70+
labels = clustering.fit_predict(embedding_matrix)
71+
72+
title = "K-Means Clustering"
73+
else:
74+
item_to_item_similarity_matrix = np.array(cosine_similarity(embedding_matrix)).astype(np.float64)
75+
76+
clustering = AffinityPropagation(affinity="precomputed", damping=0.7, random_state=1, max_iter=200,
77+
convergence_iter=15)
78+
clustering.fit(np.abs(item_to_item_similarity_matrix))
79+
labels = clustering.labels_
80+
81+
title = "Affinity Propagation Clustering"
82+
83+
ax.clear()
84+
ax.axis("on")
85+
ax.tick_params(
86+
axis="both",
87+
which="both",
88+
bottom=True,
89+
left=True,
90+
labelbottom=True,
91+
labelleft=True
92+
)
93+
ax.set_aspect("auto")
94+
ax.set_title(title)
95+
96+
pca = PCA(n_components=2)
97+
reduced_embeddings = pca.fit_transform(embedding_matrix)
98+
99+
ax.scatter(
100+
reduced_embeddings[:, 0],
101+
reduced_embeddings[:, 1],
102+
c=labels,
103+
cmap="viridis",
104+
s=100
105+
)
106+
107+
for i, point in enumerate(reduced_embeddings):
108+
ax.annotate(
109+
str(i),
110+
xy=(point[0], point[1]),
111+
xytext=(8, -10),
112+
textcoords="offset points",
113+
fontsize=8,
114+
color="black",
115+
ha="center"
116+
)
117+
118+
canvas.draw()
119+
120+
121+
def draw_network_graph(questions: List[str], ax: Axes, canvas: FigureCanvasTkAgg):
122+
"""
123+
Draws a network graph based on the given questions, where edges represent high similarity (>0.5).
124+
Communities are detected using greedy modularity optimization.
125+
126+
Args:
127+
questions: List of question strings to visualize
128+
ax: Matplotlib Axes object to draw on
129+
canvas: Tkinter canvas for displaying the plot
130+
"""
131+
embedding_matrix = convert_texts_to_vector(questions)
132+
similarity_matrix = cosine_similarity(embedding_matrix)
133+
134+
ax.clear()
135+
ax.axis("off")
136+
ax.set_aspect("auto")
137+
ax.set_title("Network Cluster Graph")
138+
139+
G = nx.Graph()
140+
n = similarity_matrix.shape[0]
141+
142+
i = 0
143+
for i in range(n):
144+
for j in range(i + 1, n):
145+
if similarity_matrix[i, j] > 0.5:
146+
G.add_edge(i, j, weight=similarity_matrix[i, j])
147+
148+
communities = list(community.greedy_modularity_communities(G))
149+
150+
# assign colors to nodes based on communities
151+
node_color = []
152+
for comm_idx, comm in enumerate(communities):
153+
for _ in comm:
154+
node_color.append(comm_idx)
155+
156+
# improve node positions using existing layouts
157+
pos = nx.kamada_kawai_layout(G, weight="weight")
158+
pos = nx.spring_layout(
159+
G,
160+
pos=pos,
161+
k=2,
162+
scale=2.0,
163+
iterations=200
164+
)
165+
166+
nx.draw_networkx_nodes(
167+
G, pos,
168+
ax=ax,
169+
node_size=300,
170+
node_color=node_color,
171+
)
172+
173+
nx.draw_networkx_edges(
174+
G, pos,
175+
ax=ax,
176+
width=1.0,
177+
alpha=0.7
178+
)
179+
180+
nx.draw_networkx_labels(
181+
G, pos,
182+
ax=ax,
183+
font_size=12
184+
)
185+
186+
canvas.draw()
187+
188+
189+
def setup_gui(questions: List[str]):
190+
"""
191+
Sets up the Tkinter GUI.
192+
193+
Args:
194+
questions: List of question strings to visualize.
195+
"""
196+
197+
def add_question(questions: List[str], ax: Axes, canvas: FigureCanvasTkAgg):
198+
"""Handles adding new questions through a simple dialog and updates the canvas"""
199+
question = tkinter.simpledialog.askstring("Add a New Question", "New Question:")
200+
if question:
201+
questions.append(question)
202+
# redraw cosine similarity matrix including newly added question
203+
draw_cosine_similarity_matrix(questions, ax, canvas)
204+
205+
def display_questions():
206+
"""Displays all questions in a scrollable dialog window"""
207+
dialog = tk.Toplevel(root)
208+
dialog.title("All Questions")
209+
dialog.geometry("400x600")
210+
211+
# make the dialog window modal
212+
dialog.grab_set()
213+
dialog.focus_set()
214+
root.attributes("-disabled", True)
215+
216+
scrollbar = ttk.Scrollbar(dialog)
217+
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
218+
219+
questions_text = tk.Text(dialog, height=8)
220+
questions_text.pack(side=tk.LEFT, expand=True, fill=tk.BOTH, )
221+
222+
questions_text["yscrollcommand"] = scrollbar.set
223+
scrollbar.config(command=questions_text.yview)
224+
225+
for i, question in enumerate(questions):
226+
questions_text.insert(tk.END, f"Q{i}: {question}\n")
227+
228+
def close_dialog():
229+
"""Cleanup when closing the dialog"""
230+
root.attributes("-disabled", False)
231+
dialog.destroy()
232+
233+
dialog.protocol("WM_DELETE_WINDOW", close_dialog)
234+
235+
dialog.transient(root)
236+
dialog.wait_window()
237+
238+
# main window
239+
root = tk.Tk()
240+
root.title("Harmony Visualizer")
241+
root.geometry("800x450")
242+
243+
# main frame
244+
main_frame = tk.Frame(root)
245+
main_frame.pack(fill=tk.BOTH, expand=True)
246+
247+
# left frame for graphs
248+
graph_frame = tk.Frame(main_frame, width=350, height=350, bg="white")
249+
graph_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
250+
graph_frame.pack_propagate(False)
251+
252+
# upper right frame for graph buttons
253+
button_frame = tk.Frame(main_frame, width=200, bg="lightgray")
254+
button_frame.pack(side=tk.RIGHT, fill=tk.Y)
255+
# lower right frame with buttons for displaying and adding questions
256+
bottom_button_frame = tk.Frame(button_frame, bg="lightgray")
257+
bottom_button_frame.pack(side=tk.BOTTOM, fill=tk.X, pady=10)
258+
259+
fig, ax = plt.subplots()
260+
ax.axis("off") # hide placeholder chart until a button is pressed
261+
canvas = FigureCanvasTkAgg(fig, master=graph_frame)
262+
canvas_widget = canvas.get_tk_widget()
263+
canvas_widget.pack(fill=tk.BOTH, expand=True)
264+
265+
# the graph buttons and their corresponding draw functions
266+
button_texts = ["Cosine Similarity Matrix", "Cluster Scatter Plot", "Network Graph"]
267+
button_functions = [draw_cosine_similarity_matrix, draw_clusters_scatter_plot, draw_network_graph]
268+
269+
for button_text, function in zip(button_texts, button_functions):
270+
new_button = tk.Button(button_frame, text=button_text,
271+
command=lambda func=function: func(questions, ax, canvas))
272+
new_button.pack(pady=8, padx=10, fill=tk.X)
273+
274+
# buttons for adding and displaying questions
275+
add_question_button = tk.Button(bottom_button_frame, text="Add Question",
276+
command=lambda func=add_question: func(questions, ax, canvas))
277+
display_questions_button = tk.Button(bottom_button_frame, text="See Questions", command=display_questions)
278+
add_question_button.pack(pady=8, padx=10, fill=tk.X)
279+
display_questions_button.pack(pady=8, padx=10, fill=tk.X)
280+
281+
root.protocol("WM_DELETE_WINDOW", lambda: (plt.close("all"), root.destroy()))
282+
root.mainloop()
283+
284+
285+
def visualize_questions(questions: List[str]):
286+
"""
287+
Entry point for the GUI.
288+
289+
Args:
290+
questions: List of question strings to visualize
291+
"""
292+
if not questions:
293+
print("No questions provided. Exiting...")
294+
sys.exit(1)
295+
setup_gui(questions)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import unittest
2+
from unittest.mock import patch, MagicMock
3+
from harmony.matching.visualize_questions_gui import (
4+
draw_cosine_similarity_matrix,
5+
draw_clusters_scatter_plot,
6+
draw_network_graph,
7+
visualize_questions
8+
)
9+
10+
11+
class TestHarmonyBasic(unittest.TestCase):
12+
def setUp(self):
13+
# mock the embedding function to return dummy data
14+
self.patcher = patch(
15+
'harmony.matching.default_matcher.convert_texts_to_vector',
16+
return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
17+
)
18+
self.mock_convert = self.patcher.start()
19+
20+
# simple mock objects for the Axes and Canvas objects
21+
self.mock_ax = MagicMock()
22+
self.mock_canvas = MagicMock()
23+
24+
def tearDown(self):
25+
self.patcher.stop()
26+
27+
def test_draw_cosine_similarity_matrix(self):
28+
"""Check if the draw_cosine_similarity_matrix function runs without error"""
29+
draw_cosine_similarity_matrix(["Q1", "Q2", "Q3", "Q4", "Q5"], self.mock_ax, self.mock_canvas)
30+
self.assertTrue(True)
31+
32+
def test_draw_clusters_scatter_plot(self):
33+
"""Just check if the draw_clusters_scatter_plot function runs without error"""
34+
draw_clusters_scatter_plot(["Q1", "Q2", "Q3", "Q4", "Q5"], self.mock_ax, self.mock_canvas)
35+
self.assertTrue(True)
36+
37+
def test_draw_network_graph(self):
38+
"""Just check if the draw_network_graph function runs without error"""
39+
draw_network_graph(["Q1", "Q2", "Q3", "Q4", "Q5"], self.mock_ax, self.mock_canvas)
40+
self.assertTrue(True)
41+
42+
def test_empty_questions(self):
43+
"""Check empty input exits correctly"""
44+
with self.assertRaises(SystemExit) as se:
45+
visualize_questions([])
46+
self.assertEqual(se.exception.code, 1)
47+
48+
if __name__ == '__main__':
49+
unittest.main()

0 commit comments

Comments
 (0)