Skip to content

Commit 8f40794

Browse files
Uss, WojciechGitHub Enterprise
authored andcommitted
Merge pull request PaddlePaddle#31 from AIPG/wojtuss/bilstm
bilstm topology added to text_classification model (Senta)
2 parents 9a4a7ef + 14fee20 commit 8f40794

File tree

10 files changed

+43218
-12
lines changed

10 files changed

+43218
-12
lines changed

fluid/text_classification/data/test_data/corpus.test

Lines changed: 200 additions & 0 deletions
Large diffs are not rendered by default.

fluid/text_classification/data/train.vocab

Lines changed: 32896 additions & 0 deletions
Large diffs are not rendered by default.

fluid/text_classification/data/train_data/corpus.train

Lines changed: 10000 additions & 0 deletions
Large diffs are not rendered by default.

fluid/text_classification/infer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
import sys
21
import time
3-
import unittest
4-
import contextlib
52
import numpy as np
63
import argparse
74

@@ -17,7 +14,13 @@ def parse_args():
1714
'--batch_size',
1815
type=int,
1916
default=128,
20-
help='The size of a batch. (default: %(default)d, usually: 128 for "bow" and "gru", 4 for "cnn" and "lstm")')
17+
help='The size of a batch. (default: %(default)d, usually: 128 for "bow" and "gru", 4 for "cnn", "lstm" and "bilstm").')
18+
parser.add_argument(
19+
"--dataset",
20+
type=str,
21+
default='imdb',
22+
choices=['imdb', 'data'],
23+
help="Dataset to be used: 'imdb' or 'data' (from 'data' subdirectory).")
2124
parser.add_argument(
2225
'--device',
2326
type=str,
@@ -77,7 +80,7 @@ def infer(args):
7780
wpses = [0] * total_passes
7881
acces = [0] * total_passes
7982
word_dict, train_reader, test_reader = utils.prepare_data(
80-
"imdb", self_dict=False, batch_size=args.batch_size,
83+
args.dataset, self_dict=False, batch_size=args.batch_size,
8184
buf_size=50000)
8285
pass_acc = 0.0
8386
pass_data_len = 0
@@ -100,6 +103,7 @@ def infer(args):
100103
fetch_list=fetch_targets,
101104
return_numpy=True)
102105
batch_time = time.time() - start
106+
# TODO: add class accuracy measurement as in Senta
103107
word_count = len([w for d in data for w in d[0]])
104108
batch_times[pass_id] += batch_time
105109
word_counts[pass_id] += word_count

fluid/text_classification/nets.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,55 @@ def lstm_net(data,
9393
return avg_cost, acc, prediction
9494

9595

96+
def bilstm_net(data,
97+
label,
98+
dict_dim,
99+
emb_dim=128,
100+
hid_dim=128,
101+
hid_dim2=96,
102+
class_dim=2,
103+
emb_lr=30.0):
104+
"""
105+
Bi-Lstm net
106+
"""
107+
# embedding layer
108+
emb = fluid.layers.embedding(
109+
input=data,
110+
size=[dict_dim, emb_dim],
111+
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
112+
113+
# bi-lstm layer
114+
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
115+
116+
rfc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
117+
118+
lstm_h, c = fluid.layers.dynamic_lstm(
119+
input=fc0, size=hid_dim * 4, is_reverse=False)
120+
121+
rlstm_h, c = fluid.layers.dynamic_lstm(
122+
input=rfc0, size=hid_dim * 4, is_reverse=True)
123+
124+
# extract last layer
125+
lstm_last = fluid.layers.sequence_last_step(input=lstm_h)
126+
rlstm_last = fluid.layers.sequence_last_step(input=rlstm_h)
127+
128+
lstm_last_tanh = fluid.layers.tanh(lstm_last)
129+
rlstm_last_tanh = fluid.layers.tanh(rlstm_last)
130+
131+
# concat layer
132+
lstm_concat = fluid.layers.concat(input=[lstm_last, rlstm_last], axis=1)
133+
134+
# full connect layer
135+
fc1 = fluid.layers.fc(input=lstm_concat, size=hid_dim2, act='tanh')
136+
# softmax layer
137+
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
138+
cost = fluid.layers.cross_entropy(input=prediction, label=label)
139+
avg_cost = fluid.layers.mean(x=cost)
140+
acc = fluid.layers.accuracy(input=prediction, label=label)
141+
142+
return avg_cost, acc, prediction
143+
144+
96145
def gru_net(data,
97146
label,
98147
dict_dim,

fluid/text_classification/scripts/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ train_bow.sh
2424
train_cnn.sh
2525
train_gru.sh
2626
train_lstm.sh
27+
train_bilstm.sh
2728
```
2829

2930
## Inference
@@ -35,4 +36,5 @@ infer_bow.sh
3536
infer_cnn.sh
3637
infer_gru.sh
3738
infer_lstm.sh
39+
infer_bilstm.sh
3840
```
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/bash
2+
time python ../infer.py \
3+
--device CPU \
4+
--model_path bilstm_model/epoch0 \
5+
--num_passes 100 \
6+
--profile
7+
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/bash
2+
time python ../train.py \
3+
--device CPU \
4+
--model_save_dir bilstm_model \
5+
--num_passes 1 \
6+
bilstm
7+

fluid/text_classification/train.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,32 @@
1010
from nets import bow_net
1111
from nets import cnn_net
1212
from nets import lstm_net
13+
from nets import bilstm_net
1314
from nets import gru_net
1415

15-
nets = {'bow': bow_net, 'cnn': cnn_net, 'lstm': lstm_net, 'gru': gru_net}
16+
nets = {'bow': bow_net, 'cnn': cnn_net, 'lstm': lstm_net, 'bilstm': bilstm_net,
17+
'gru': gru_net}
1618
# learning rates
17-
lrs = {'bow': 0.002, 'cnn': 0.01, 'lstm': 0.05, 'gru': 0.05}
19+
lrs = {'bow': 0.002, 'cnn': 0.01, 'lstm': 0.05, 'bilstm':0.002, 'gru': 0.05}
1820

1921
def parse_args():
2022
parser = argparse.ArgumentParser("Run inference.")
2123
parser.add_argument(
2224
'topology',
2325
type=str,
24-
choices=['bow', 'cnn', 'lstm', 'gru'],
26+
choices=['bow', 'cnn', 'lstm', 'bilstm', 'gru'],
2527
help='Topology used for the model (bow/cnn/lstm/gru).')
28+
parser.add_argument(
29+
"--dataset",
30+
type=str,
31+
default='imdb',
32+
choices=['imdb', 'data'],
33+
help="Dataset to be used: 'imdb' or 'data' (from 'data' subdirectory).")
2634
parser.add_argument(
2735
'--batch_size',
2836
type=int,
2937
default=128,
30-
help='The size of a batch. (default: %(default)d, usually: 128 for "bow" and "gru", 4 for "cnn" and "lstm")')
38+
help='The size of a batch. (default: %(default)d, usually: 128 for "bow" and "gru", 4 for "cnn", "lstm" and "bilstm").')
3139
parser.add_argument(
3240
'--device',
3341
type=str,
@@ -122,7 +130,7 @@ def train(train_reader,
122130

123131
def train_net(args):
124132
word_dict, train_reader, test_reader = utils.prepare_data(
125-
"imdb", self_dict=False, batch_size=128, buf_size=50000)
133+
args.dataset, self_dict=False, batch_size=128, buf_size=50000)
126134

127135
train(
128136
train_reader,

fluid/text_classification/utils.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import sys
22
import time
33
import numpy as np
4+
import random
5+
import os
46

57
import paddle
68
import paddle.fluid as fluid
@@ -48,18 +50,43 @@ def data2tensor(data, place):
4850
return {"words": input_seq, "label": y_data}
4951

5052

53+
def data_reader(file_path, word_dict, is_shuffle=True):
54+
"""
55+
Convert word sequence into slot
56+
"""
57+
unk_id = len(word_dict)
58+
all_data = []
59+
with open(file_path, "r") as fin:
60+
for line in fin:
61+
cols = line.strip().split("\t")
62+
label = int(cols[0])
63+
wids = [word_dict[x] if x in word_dict else unk_id
64+
for x in cols[1].split(" ")]
65+
all_data.append((wids, label))
66+
if is_shuffle:
67+
random.shuffle(all_data)
68+
69+
def reader():
70+
for doc, label in all_data:
71+
yield doc, label
72+
return reader
73+
74+
5175
def prepare_data(data_type="imdb",
5276
self_dict=False,
5377
batch_size=128,
5478
buf_size=50000):
5579
"""
5680
prepare data
5781
"""
82+
script_path = os.path.dirname(__file__)
5883
if self_dict:
5984
word_dict = load_vocab(data_type + ".vocab")
6085
else:
6186
if data_type == "imdb":
6287
word_dict = paddle.dataset.imdb.word_dict()
88+
elif data_type == "data":
89+
word_dict = load_vocab(script_path + "/data/train.vocab")
6390
else:
6491
raise RuntimeError("No such dataset")
6592

@@ -68,12 +95,18 @@ def prepare_data(data_type="imdb",
6895
paddle.reader.shuffle(
6996
paddle.dataset.imdb.train(word_dict), buf_size=buf_size),
7097
batch_size=batch_size)
71-
7298
test_reader = paddle.batch(
7399
paddle.reader.shuffle(
74100
paddle.dataset.imdb.test(word_dict), buf_size=buf_size),
75101
batch_size=batch_size)
102+
elif data_type == "data":
103+
train_reader = paddle.batch(
104+
data_reader(script_path + "/data/train_data/corpus.train", word_dict, True),
105+
batch_size=batch_size)
106+
test_reader = paddle.batch(
107+
data_reader(script_path + "/data/test_data/corpus.test", word_dict, False),
108+
batch_size=batch_size)
76109
else:
77-
raise RuntimeError("no such dataset")
110+
raise RuntimeError("No such dataset")
78111

79112
return word_dict, train_reader, test_reader

0 commit comments

Comments
 (0)