-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathattn_visualize.py
More file actions
30 lines (24 loc) · 848 Bytes
/
attn_visualize.py
File metadata and controls
30 lines (24 loc) · 848 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
atten_data = []
key_map = {0:'p', 1:'q', 2:'v', 3:'atten'}
with open('atten_data.pickle', 'rb') as f:
while True:
try:
atten_data.append(pickle.load(f))
except EOFError:
break
def convert_to_numpy(x):
return x.detach().cpu().numpy()
def cross_bs_atten_map(key=3):
for i in range(32):
data_np = convert_to_numpy(atten_data[i][key].mean(dim=0).mean(dim=0))
max_val = data_np.max()
sns.heatmap(data_np / max_val * 1250 + 50, cbar=False, cmap='Blues',vmin=0, vmax=250, center=110)
plt.xticks([])
plt.yticks([])
plt.savefig('./atten_visualize/layer_%s_head_mean_bs_mean_re.jpg' % i, bbox_inches='tight', pad_inches=0.)
plt.close()
cross_bs_atten_map()