-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathnn_visualizer.py
More file actions
519 lines (431 loc) · 21 KB
/
nn_visualizer.py
File metadata and controls
519 lines (431 loc) · 21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
import pygame
from pygame.locals import *
import math
import random
from collections import deque
# Initialize Pygame
pygame.init()
# Screen dimensions
WIDTH, HEIGHT = 1400, 900
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Neural Network Architecture Visualizer")
# Color palette
BG_COLOR = (10, 15, 25)
INPUT_COLOR = (0, 200, 255) # Cyan for inputs
HIDDEN_COLOR = (255, 100, 150) # Pink for hidden
OUTPUT_COLOR = (150, 255, 100) # Green for outputs
MEMORY_COLOR = (255, 150, 50) # Orange for memory
SOCIAL_COLOR = (200, 100, 255) # Purple for social signals
CONNECTION_COLOR = (80, 100, 150)
TEXT_COLOR = (200, 220, 255)
HIGHLIGHT_COLOR = (255, 255, 100)
class NeuralNetworkVisualizer:
def __init__(self):
# Default settings
self.n_step_memory_enabled = False
self.n_step_memory_depth = 2
self.social_signals_enabled = False
# Calculate initial architecture
self.update_architecture()
# Animation parameters
self.time_counter = 0
self.auto_rotate = True
self.angle_x = 0.2
self.angle_y = 0.0
self.angle_z = 0.0
self.auto_rotate_y = 0.005
self.auto_rotate_x = 0.002
self.camera_distance = 800
# Node positions will be calculated dynamically
self.input_nodes = []
self.hidden_nodes = []
self.output_nodes = []
self.memory_nodes = [] # For N-step memory
self.calculate_positions()
def update_architecture(self):
"""Update the network architecture based on current settings"""
# Base inputs: 24 for standard architecture
self.n_inputs = 24
# Add social signal inputs if enabled
if self.social_signals_enabled:
self.n_inputs += 5 # 5 social signal inputs (one per sector)
# Add memory inputs if enabled
if self.n_step_memory_enabled:
self.n_inputs += self.n_step_memory_depth * 8 # 8 hidden units per time step
# Fixed architecture
self.n_hidden = 8
self.n_outputs = 6
# Add social signal output if social signals enabled
if self.social_signals_enabled:
self.n_outputs += 1 # 1 social signal output
self.calculate_positions()
def calculate_positions(self):
"""Calculate positions for all nodes based on current architecture"""
# Clear existing nodes
self.input_nodes = []
self.hidden_nodes = []
self.output_nodes = []
self.memory_nodes = []
# Calculate positions for input nodes
input_spacing_y = max(30, 600 / max(1, self.n_inputs - 1)) if self.n_inputs > 1 else 0
start_y_input = -((self.n_inputs - 1) * input_spacing_y) / 2
for i in range(self.n_inputs):
y_pos = start_y_input + i * input_spacing_y
self.input_nodes.append({
'x': -300, # Left side
'y': y_pos,
'z': 0,
'layer': 'input',
'index': i,
'activation': random.random(),
'pulse_offset': random.random() * math.pi * 2
})
# Calculate positions for hidden nodes
hidden_spacing_y = 600 / max(1, self.n_hidden - 1) if self.n_hidden > 1 else 0
start_y_hidden = -((self.n_hidden - 1) * hidden_spacing_y) / 2
for i in range(self.n_hidden):
y_pos = start_y_hidden + i * hidden_spacing_y
self.hidden_nodes.append({
'x': 0, # Center
'y': y_pos,
'z': 0,
'layer': 'hidden',
'index': i,
'activation': random.random(),
'pulse_offset': random.random() * math.pi * 2
})
# Calculate positions for output nodes
output_spacing_y = 600 / max(1, self.n_outputs - 1) if self.n_outputs > 1 else 0
start_y_output = -((self.n_outputs - 1) * output_spacing_y) / 2
for i in range(self.n_outputs):
y_pos = start_y_output + i * output_spacing_y
self.output_nodes.append({
'x': 300, # Right side
'y': y_pos,
'z': 0,
'layer': 'output',
'index': i,
'activation': random.random(),
'pulse_offset': random.random() * math.pi * 2
})
# Calculate positions for memory nodes if enabled
if self.n_step_memory_enabled:
memory_spacing_y = 400 / max(1, 8 - 1) # 8 nodes per time step
start_y_memory = -((8 - 1) * memory_spacing_y) / 2
for step in range(self.n_step_memory_depth):
for i in range(8): # 8 hidden units per time step
y_pos = start_y_memory + i * memory_spacing_y
self.memory_nodes.append({
'x': -150 + step * 80, # Spread along x-axis
'y': y_pos + 200, # Below main network
'z': 0,
'layer': 'memory',
'step': step,
'index': i,
'activation': random.random(),
'pulse_offset': random.random() * math.pi * 2
})
def rotate_3d(self, x, y, z, ax, ay, az):
"""Rotate point around origin using rotation matrices"""
# Rotate around X axis
y_new = y * math.cos(ax) - z * math.sin(ax)
z_new = y * math.sin(ax) + z * math.cos(ax)
y, z = y_new, z_new
# Rotate around Y axis
x_new = x * math.cos(ay) + z * math.sin(ay)
z_new = -x * math.sin(ay) + z * math.cos(ay)
x, z = x_new, z_new
# Rotate around Z axis
x_new = x * math.cos(az) - y * math.sin(az)
y_new = x * math.sin(az) + y * math.cos(az)
x, y = x_new, y_new
return x, y, z
def project_3d_to_2d(self, x, y, z):
"""Simple perspective projection"""
factor = self.camera_distance / (self.camera_distance + z + 200) # Offset to bring forward
x_2d = x * factor + WIDTH / 2
y_2d = y * factor + HEIGHT / 2
return int(x_2d), int(y_2d), factor
def draw_node(self, node_data, time):
"""Draw a single node with activation-based effects"""
x, y, z = self.rotate_3d(node_data['x'], node_data['y'], node_data['z'],
self.angle_x, self.angle_y, self.angle_z)
px, py, factor = self.project_3d_to_2d(x, y, z)
# Calculate size based on depth and activation
depth_factor = max(0.3, min(1.5, factor))
base_radius = 8
# Pulsing effect based on activation
pulse = math.sin(time * 3 + node_data['pulse_offset']) * 0.3 + 0.7
radius = int(base_radius * depth_factor * (0.8 + node_data['activation'] * 0.4 * pulse))
# Color based on layer
if node_data['layer'] == 'input':
color = INPUT_COLOR
elif node_data['layer'] == 'hidden':
color = HIDDEN_COLOR
elif node_data['layer'] == 'output':
color = OUTPUT_COLOR
elif node_data['layer'] == 'memory':
color = MEMORY_COLOR
else:
color = (200, 200, 200)
# Draw outer glow
for i in range(3):
glow_radius = radius + (3 - i) * 2
glow_alpha = int(60 * depth_factor * (node_data['activation'] + 0.3) / (i + 1))
glow_surface = pygame.Surface((glow_radius * 2, glow_radius * 2), pygame.SRCALPHA)
pygame.draw.circle(glow_surface, (*color, glow_alpha), (glow_radius, glow_radius), glow_radius)
screen.blit(glow_surface, (px - glow_radius, py - glow_radius))
# Draw main node
main_color = tuple(min(255, int(c * (0.8 + node_data['activation'] * 0.4))) for c in color)
pygame.draw.circle(screen, main_color, (px, py), radius)
# Draw border
border_color = tuple(min(255, c + 50) for c in main_color)
pygame.draw.circle(screen, border_color, (px, py), radius, 1)
def draw_connections(self):
"""Draw connections between layers"""
# Draw input to hidden connections
for input_node in self.input_nodes:
for hidden_node in self.hidden_nodes:
self.draw_connection(input_node, hidden_node, (100, 120, 180), 1)
# Draw hidden to output connections
for hidden_node in self.hidden_nodes:
for output_node in self.output_nodes:
self.draw_connection(hidden_node, output_node, (150, 180, 220), 1)
# Draw memory connections if enabled
if self.n_step_memory_enabled and self.memory_nodes:
# Connect memory nodes to inputs (representing past hidden states being fed as inputs)
for mem_node in self.memory_nodes:
# Find corresponding position in input layer for past hidden states
input_idx = 24 + mem_node['step'] * 8 + mem_node['index'] # Skip first 24 base inputs
if input_idx < len(self.input_nodes):
self.draw_connection(mem_node, self.input_nodes[input_idx], MEMORY_COLOR, 1, dashed=True)
# Draw temporal connections between memory steps
for step in range(self.n_step_memory_depth - 1):
for i in range(8):
curr_idx = step * 8 + i
next_idx = (step + 1) * 8 + i
if curr_idx < len(self.memory_nodes) and next_idx < len(self.memory_nodes):
self.draw_connection(
self.memory_nodes[curr_idx],
self.memory_nodes[next_idx],
(200, 150, 100), 1, dashed=True
)
def draw_connection(self, node1, node2, color, width, dashed=False):
"""Draw a connection between two nodes"""
x1, y1, z1 = self.rotate_3d(node1['x'], node1['y'], node1['z'],
self.angle_x, self.angle_y, self.angle_z)
x2, y2, z2 = self.rotate_3d(node2['x'], node2['y'], node2['z'],
self.angle_x, self.angle_y, self.angle_z)
px1, py1, f1 = self.project_3d_to_2d(x1, y1, z1)
px2, py2, f2 = self.project_3d_to_2d(x2, y2, z2)
if dashed:
# Draw dashed line
length = math.sqrt((px2 - px1)**2 + (py2 - py1)**2)
dash_length = 6
num_dashes = int(length / (dash_length * 2))
if num_dashes > 0:
for i in range(num_dashes):
start_t = i * 2 / num_dashes
end_t = (i * 2 + 1) / num_dashes
start_x = px1 + (px2 - px1) * start_t
start_y = py1 + (py2 - py1) * start_t
end_x = px1 + (px2 - px1) * end_t
end_y = py1 + (py2 - py1) * end_t
pygame.draw.line(screen, color, (start_x, start_y), (end_x, end_y), width)
else:
pygame.draw.line(screen, color, (px1, py1), (px2, py2), width)
def draw_labels(self):
"""Draw labels for the network"""
font = pygame.font.SysFont('Arial', 20, bold=True)
small_font = pygame.font.SysFont('Arial', 14)
# Draw layer labels
input_label = font.render(f"INPUTS ({self.n_inputs})", True, INPUT_COLOR)
hidden_label = font.render(f"HIDDEN ({self.n_hidden})", True, HIDDEN_COLOR)
output_label = font.render(f"OUTPUTS ({self.n_outputs})", True, OUTPUT_COLOR)
screen.blit(input_label, (WIDTH//2 - 400, 30))
screen.blit(hidden_label, (WIDTH//2 - 50, 30))
screen.blit(output_label, (WIDTH//2 + 250, 30))
# Draw input labels (first 24 are always the base inputs)
base_input_labels = [
"Food S0", "Food S1", "Food S2", "Food S3", "Food S4",
"Water S0", "Water S1", "Water S2", "Water S3", "Water S4",
"Agent S0", "Agent S1", "Agent S2", "Agent S3", "Agent S4",
"Energy", "Hydration", "Age Ratio", "Stress", "Health",
"Vel Forward", "Vel Lateral", "Own Size", "Own Speed"
]
for i, label_text in enumerate(base_input_labels):
if i < len(self.input_nodes):
node = self.input_nodes[i]
x, y, z = self.rotate_3d(node['x'], node['y'], node['z'],
self.angle_x, self.angle_y, self.angle_z)
px, py, factor = self.project_3d_to_2d(x, y, z)
label = small_font.render(label_text, True, INPUT_COLOR)
screen.blit(label, (px + 15, py - 8))
# Draw output labels
base_output_labels = ["Move Fwd", "Turn", "Avoid", "Attack", "Mate", "Effort"]
if self.social_signals_enabled:
base_output_labels.append("Social Sig")
for i, label_text in enumerate(base_output_labels):
if i < len(self.output_nodes):
node = self.output_nodes[i]
x, y, z = self.rotate_3d(node['x'], node['y'], node['z'],
self.angle_x, self.angle_y, self.angle_z)
px, py, factor = self.project_3d_to_2d(x, y, z)
label = small_font.render(label_text, True, OUTPUT_COLOR)
screen.blit(label, (px + 15, py - 8))
# Draw memory labels if enabled
if self.n_step_memory_enabled:
memory_label = font.render(f"MEMORY (Depth: {self.n_step_memory_depth})", True, MEMORY_COLOR)
screen.blit(memory_label, (WIDTH//2 - 200, HEIGHT - 80))
for step in range(self.n_step_memory_depth):
step_label = small_font.render(f"T-{step}", True, MEMORY_COLOR)
# Position roughly where the memory column would be
screen.blit(step_label, (WIDTH//2 - 250 + step * 80, HEIGHT - 50))
def draw_settings_panel(self):
"""Draw the settings panel to toggle features"""
font = pygame.font.SysFont('Arial', 16)
title_font = pygame.font.SysFont('Arial', 20, bold=True)
# Panel background
panel_rect = pygame.Rect(WIDTH - 300, 20, 280, 200)
pygame.draw.rect(screen, (30, 40, 60), panel_rect)
pygame.draw.rect(screen, (100, 150, 200), panel_rect, 2)
# Title
title = title_font.render("NETWORK SETTINGS", True, HIGHLIGHT_COLOR)
screen.blit(title, (WIDTH - 290, 30))
# N-Step Memory toggle
mem_text = f"N-Step Memory: {'ON' if self.n_step_memory_enabled else 'OFF'}"
mem_color = HIGHLIGHT_COLOR if self.n_step_memory_enabled else TEXT_COLOR
mem_label = font.render(mem_text, True, mem_color)
screen.blit(mem_label, (WIDTH - 290, 70))
# Memory depth
depth_text = f"Memory Depth: {self.n_step_memory_depth}"
depth_label = font.render(depth_text, True, TEXT_COLOR)
screen.blit(depth_label, (WIDTH - 290, 100))
# Social Signals toggle
social_text = f"Social Signals: {'ON' if self.social_signals_enabled else 'OFF'}"
social_color = HIGHLIGHT_COLOR if self.social_signals_enabled else TEXT_COLOR
social_label = font.render(social_text, True, social_color)
screen.blit(social_label, (WIDTH - 290, 130))
# Current architecture info
arch_text = f"Arc: {self.n_inputs}->{self.n_hidden}->{self.n_outputs}"
arch_label = font.render(arch_text, True, (200, 255, 200))
screen.blit(arch_label, (WIDTH - 290, 170))
def draw_instructions(self):
"""Draw usage instructions"""
font = pygame.font.SysFont('Arial', 14)
instructions = [
"CONTROLS:",
"M: Toggle N-Step Memory",
"S: Toggle Social Signals",
"UP/DOWN: Change Memory Depth",
"R: Reset Rotation",
"ARROW KEYS: Manual Rotation",
"CLICK + DRAG: Rotate View",
"SPACE: Toggle Auto-Rotate"
]
for i, text in enumerate(instructions):
label = font.render(text, True, TEXT_COLOR)
screen.blit(label, (20, 20 + i * 25))
def toggle_n_step_memory(self):
"""Toggle N-step memory on/off"""
self.n_step_memory_enabled = not self.n_step_memory_enabled
self.update_architecture()
def toggle_social_signals(self):
"""Toggle social signals on/off"""
self.social_signals_enabled = not self.social_signals_enabled
self.update_architecture()
def increase_memory_depth(self):
"""Increase memory depth"""
if self.n_step_memory_enabled:
self.n_step_memory_depth = min(5, self.n_step_memory_depth + 1)
self.update_architecture()
def decrease_memory_depth(self):
"""Decrease memory depth"""
if self.n_step_memory_enabled:
self.n_step_memory_depth = max(1, self.n_step_memory_depth - 1)
self.update_architecture()
def update(self, dt):
"""Update animation parameters"""
self.time_counter += dt
# Update node activations with wave pattern
for node_list in [self.input_nodes, self.hidden_nodes, self.output_nodes, self.memory_nodes]:
for node in node_list:
wave = math.sin(self.time_counter * 2 + node['pulse_offset'])
node['activation'] = 0.3 + 0.4 * ((wave + 1) / 2)
if self.auto_rotate:
self.angle_y += self.auto_rotate_y
self.angle_x += self.auto_rotate_x
def draw(self):
"""Draw the entire visualization"""
screen.fill(BG_COLOR)
# Draw connections first (so they appear behind nodes)
self.draw_connections()
# Draw nodes from back to front (based on z-coordinate after rotation)
all_nodes = self.input_nodes + self.hidden_nodes + self.output_nodes + self.memory_nodes
node_depths = []
for node in all_nodes:
x_rot, y_rot, z_rot = self.rotate_3d(node['x'], node['y'], node['z'],
self.angle_x, self.angle_y, self.angle_z)
node_depths.append((node, z_rot))
# Sort by depth (farthest first)
node_depths.sort(key=lambda x: x[1])
for node, _ in node_depths:
self.draw_node(node, self.time_counter)
# Draw labels and UI elements
self.draw_labels()
self.draw_settings_panel()
self.draw_instructions()
def main():
visualizer = NeuralNetworkVisualizer()
clock = pygame.time.Clock()
running = True
mouse_down = False
last_mouse_pos = None
while running:
dt = clock.tick(60) / 1000.0
for event in pygame.event.get():
if event.type == QUIT:
running = False
elif event.type == MOUSEBUTTONDOWN:
if event.button == 1: # Left click
mouse_down = True
last_mouse_pos = pygame.mouse.get_pos()
elif event.type == MOUSEBUTTONUP:
if event.button == 1: # Left release
mouse_down = False
elif event.type == MOUSEMOTION and mouse_down:
if last_mouse_pos:
dx = event.pos[0] - last_mouse_pos[0]
dy = event.pos[1] - last_mouse_pos[1]
visualizer.angle_y += dx * 0.005
visualizer.angle_x += dy * 0.005
last_mouse_pos = event.pos
elif event.type == KEYDOWN:
if event.key == K_m:
visualizer.toggle_n_step_memory()
elif event.key == K_s:
visualizer.toggle_social_signals()
elif event.key == K_UP:
visualizer.increase_memory_depth()
elif event.key == K_DOWN:
visualizer.decrease_memory_depth()
elif event.key == K_r:
visualizer.angle_x = 0.2
visualizer.angle_y = 0.0
visualizer.angle_z = 0.0
elif event.key == K_LEFT:
visualizer.angle_y -= 0.1
elif event.key == K_RIGHT:
visualizer.angle_y += 0.1
elif event.key == K_UP:
visualizer.angle_x -= 0.1
elif event.key == K_DOWN:
visualizer.angle_x += 0.1
elif event.key == K_SPACE:
visualizer.auto_rotate = not visualizer.auto_rotate
visualizer.update(dt)
visualizer.draw()
pygame.display.flip()
pygame.quit()
if __name__ == "__main__":
main()