-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel_baseline.py
More file actions
50 lines (37 loc) · 1.59 KB
/
model_baseline.py
File metadata and controls
50 lines (37 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import torch
import torchtext
from torch.utils import data
import ipdb
import torch.nn as nn
# use_cuda = torch.cuda.is_available()
# device = torch.device("cuda:0" if use_cuda else "cpu")
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x, device):
# Set initial hidden and cell states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
self.lstm.flatten_parameters()
out, _ = self.rnn(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size)
# Decode the hidden state of the last time step
out = self.fc(out[:, -1, :])
print("\t\t Device:", device, "In Model: input size", x.size(), "output size", out.size())
return out
def cal_loss(outputs, labels):
criterion = nn.CrossEntropyLoss()
loss = criterion(outputs, labels)
return loss
def optimizer():
return torch.optim.Adam
def init_model(config_dic, num_classes):
return RNN(config_dic['model_params']['input_size'],
config_dic['model_params']['hidden_size'],
config_dic['model_params']['num_layers'],
num_classes)