-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDKVMNModel.py
More file actions
62 lines (49 loc) · 2.42 KB
/
DKVMNModel.py
File metadata and controls
62 lines (49 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn as nn
import torch.nn.functional as F
class DKVMNModel(nn.Module):
def __init__(self, input_dim, memory_size, memory_dim, output_dim, device):
super(DKVMNModel, self).__init__()
self.device = device
self.memory_size = memory_size
self.memory_dim = memory_dim # dimension of each memory slot
self.output_dim = output_dim
# Key memory (fixed)
self.memory_key = nn.Parameter(torch.Tensor(memory_size, memory_dim))
nn.init.xavier_uniform_(self.memory_key)
# Value memory (updated over time)
self.memory_value = nn.Parameter(torch.Tensor(memory_size, memory_dim))
nn.init.xavier_uniform_(self.memory_value)
# Attention, read, and write mechanisms
self.read_linear = nn.Linear(input_dim, memory_dim)
self.write_linear = nn.Linear(input_dim, memory_dim)
self.predict_layer = nn.Linear(memory_dim, output_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, x): # x: [batch_size, seq_len, input_dim]
batch_size, seq_len, input_dim = x.size()
predictions = []
# Initialize value memory for the batch
memory_value = self.memory_value.unsqueeze(0).repeat(
batch_size, 1, 1
) # [batch_size, memory_size, memory_dim]
for t in range(seq_len):
xt = x[:, t, :] # [batch_size, input_dim]
# Attention weights over memory slots
k_t = self.read_linear(xt) # [batch_size, memory_dim]
w_t = F.softmax(
torch.matmul(k_t.unsqueeze(1), self.memory_key.T),
dim=-1
) # [batch_size, 1, memory_size]
# Read: weighted sum of memory slots
r_t = torch.matmul(w_t, memory_value).squeeze(1) # [batch_size, memory_dim]
y_t = self.sigmoid(self.predict_layer(r_t)) # [batch_size, output_dim]
predictions.append(y_t.unsqueeze(1))
# Write: update value memory
erase = torch.sigmoid(self.write_linear(xt)).unsqueeze(1) # [batch_size, 1, memory_dim]
add = torch.tanh(self.write_linear(xt)).unsqueeze(1)
w_t_trans = w_t.transpose(1, 2) # [batch_size, memory_size, 1]
memory_value = (
memory_value * (1 - w_t_trans * erase)
+ w_t_trans * add
)
return torch.cat(predictions, dim=1) # [batch_size, seq_len, output_dim]