-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathrun.py
More file actions
125 lines (109 loc) · 4.4 KB
/
run.py
File metadata and controls
125 lines (109 loc) · 4.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import time
import argparse
import pickle
from MSN import MSN
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
task_dic = {
'ubuntu':'./dataset/ubuntu_data/',
'douban':'./dataset/DoubanConversaionCorpus/',
'alime':'./dataset/E_commerce/'
}
data_batch_size = {
"ubuntu": 200,
"douban": 150,
"alime": 200
}
## Required parameters
parser = argparse.ArgumentParser()
parser.add_argument("--task",
default='ubuntu',
type=str,
help="The dataset used for training and test.")
parser.add_argument("--is_training",
default=False,
type=bool,
help="Training model or evaluating model?")
parser.add_argument("--max_utterances",
default=10,
type=int,
help="The maximum number of utterances.")
parser.add_argument("--max_words",
default=50,
type=int,
help="The maximum number of words for each utterance.")
parser.add_argument("--batch_size",
default=0,
type=int,
help="The batch size.")
parser.add_argument("--gru_hidden",
default=300,
type=int,
help="The hidden size of GRU in layer 1")
parser.add_argument("--learning_rate",
default=1e-3,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--l2_reg",
default=0.0,
type=float,
help="The l2 regularization.")
parser.add_argument("--epochs",
default=5,
type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--save_path",
default="./checkpoint/",
type=str,
help="The path to save model.")
parser.add_argument("--score_file_path",
default="score_file.txt",
type=str,
help="The path to save model.")
args = parser.parse_args()
args.batch_size = data_batch_size[args.task]
args.save_path += args.task + '.' + MSN.__name__ + ".pt"
args.score_file_path = task_dic[args.task] + args.score_file_path
print(args)
print("Task: ", args.task)
def train_model():
path = task_dic[args.task]
X_train_utterances, X_train_responses, y_train = pickle.load(file=open(path+"train.pkl", 'rb'))
X_dev_utterances, X_dev_responses, y_dev = pickle.load(file=open(path+"test.pkl", 'rb'))
vocab, word_embeddings = pickle.load(file=open(path + "vocab_and_embeddings.pkl", 'rb'))
model = MSN(word_embeddings, args=args)
model.fit(
X_train_utterances, X_train_responses, y_train,
X_dev_utterances, X_dev_responses, y_dev
)
def test_model():
path = task_dic[args.task]
X_test_utterances, X_test_responses, y_test = pickle.load(file=open(path+"test.pkl", 'rb'))
vocab, word_embeddings = pickle.load(file=open(path + "vocab_and_embeddings.pkl", 'rb'))
model = MSN(word_embeddings, args=args)
model.load_model(args.save_path)
model.evaluate(X_test_utterances, X_test_responses, y_test, is_test=True)
def test_adversarial():
path = task_dic[args.task]
vocab, word_embeddings = pickle.load(file=open(path + "vocab_and_embeddings.pkl", 'rb'))
model = MSN(word_embeddings, args=args)
model.load_model(args.save_path)
print("adversarial test set (k=1): ")
X_test_utterances, X_test_responses, y_test = pickle.load(file=open(path+"test_adversarial_k_1.pkl", 'rb'))
model.evaluate(X_test_utterances, X_test_responses, y_test, is_test=True)
print("adversarial test set (k=2): ")
X_test_utterances, X_test_responses, y_test = pickle.load(file=open(path+"test_adversarial_k_2.pkl", 'rb'))
model.evaluate(X_test_utterances, X_test_responses, y_test, is_test=True)
print("adversarial test set (k=3): ")
X_test_utterances, X_test_responses, y_test = pickle.load(file=open(path+"test_adversarial_k_3.pkl", 'rb'))
model.evaluate(X_test_utterances, X_test_responses, y_test, is_test=True)
if __name__ == '__main__':
start = time.time()
if args.is_training:
train_model()
test_model()
else:
test_model()
# test_adversarial()
end = time.time()
print("use time: ", (end-start)/60, " min")