-
Notifications
You must be signed in to change notification settings - Fork 225
Expand file tree
/
Copy pathbilstm-tagger.jl
More file actions
179 lines (156 loc) · 6.44 KB
/
bilstm-tagger.jl
File metadata and controls
179 lines (156 loc) · 6.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
julia bilstm-tagger.jl # to use with default options on CPU
julia bilstm-tagger.jl --atype KnetArray{Float32} # to use with default options on GPU
julia bilstm-tagger.jl -h # to see all options with default values
This example implements a named entity tagger built on top of a BiLSTM
neural network similar to the model defined in 'Bidirectional LSTM-CRF Models
for Sequence Tagging', Zhiheng Huang, Wei Xu, Kai Yu, arXiv technical report
1508.01991, 2015. Originally, this model implemented for dynet-benchmarks.
* Paper url: https://arxiv.org/pdf/1508.01991.pdf
* DyNet report: https://arxiv.org/abs/1701.03980
* Benchmark repo: https://github.com/neulab/dynet-benchmark
"""
module Tagger
using Knet, CUDA, ArgParse, Dates, Random, Printf
using Knet: rnninit, rnnforw # old interface, use RNN in new code
include(Knet.dir("data","wikiner.jl"))
t00 = now()
function main(args)
s = ArgParseSettings()
s.description = "Bidirectional LSTM Tagger in Knet"
@add_arg_table s begin
("--atype"; default="$(Knet.array_type[])"; help="array type to use")
("--embed"; arg_type=Int; default=128; help="word embedding size")
("--hidden"; arg_type=Int; default=50; help="LSTM hidden size")
("--mlp"; arg_type=Int; default=32; help="MLP size")
("--timeout"; arg_type=Int; default=600; help="max timeout (in seconds)")
("--epochs"; arg_type=Int; default=100; help="number of training epochs")
("--minoccur"; arg_type=Int; default=6; help="word min occurence limit")
("--report"; arg_type=Int; default=500; help="report period in iters")
("--valid"; arg_type=Int; default=10000; help="valid period in iters")
("--seed"; arg_type=Int; default=-1; help="random seed")
end
isa(args, AbstractString) && (args=split(args))
o = parse_args(args, s; as_symbols=true)
o[:seed] > 0 && Knet.setseed(o[:seed])
atype = eval(Meta.parse(o[:atype]))
datadir = abspath(joinpath(@__DIR__, "../data/tags"))
datadir = isdir(datadir) ? datadir : WIKINER_DIR
# load WikiNER data
data = WikiNERData(datadir, o[:minoccur])
# build model
nwords = length(data.w2i); ntags = data.ntags
w, srnn = initweights(
atype, o[:hidden], nwords, ntags, o[:mlp], o[:embed])
opt = optimizers(w, Adam)
# train bilstm tagger
nwords = data.nwords; ntags = data.ntags
println("nwords=$nwords, ntags=$ntags"); flush(stdout)
println("startup time: ", Int((now()-t00).value)*0.001); flush(stdout)
t0 = now()
all_time = dev_time = all_tagged = this_tagged = this_loss = 0
o[:timeout] = o[:timeout] <= 0 ? Inf : o[:timeout]
for epoch = 1:o[:epochs]
shuffle!(data.trn)
for k = 1:length(data.trn)
iter = (epoch-1)*length(data.trn) + k
if o[:report] > 0 && iter % o[:report] == 0
@printf("%f\n", this_loss/this_tagged); flush(stdout)
all_tagged += this_tagged
this_loss = this_tagged = 0
all_time = Int((now()-t0).value)*0.001
end
if all_time > o[:timeout] || o[:valid] > 0 && iter % o[:valid] == 0
dev_start = now()
good_sent = bad_sent = good = bad = 0.0
for sent in data.dev
seq = make_input(sent, data.w2i)
nwords = length(sent)
ypred,_ = predict(w, seq, srnn)
ypred = map(
x->data.i2t[x], mapslices(argmax,Array(ypred),dims=1))
ygold = map(x -> x[2], sent)
same = true
for (y1,y2) in zip(ypred, ygold)
if y1 == y2
good += 1
else
bad += 1
same = false
end
end
if same
good_sent += 1
else
bad_sent += 1
end
end
dev_time += Int((now()-dev_start).value)*0.001
train_time = Int((now()-t0).value)*0.001-dev_time
@printf(
"tag_acc=%.4f, sent_acc=%.4f, time=%.4f, word_per_sec=%.4f\n",
good/(good+bad), good_sent/(good_sent+bad_sent), train_time,
all_tagged/train_time); flush(stdout)
all_time > o[:timeout] && return
end
# train on minibatch
x = make_input(data.trn[k], data.w2i)
y = make_output(data.trn[k], data.t2i)
batch_loss = train!(w,x,y,srnn,opt)
this_loss += batch_loss
this_tagged += length(data.trn[k])
end
@printf("epoch %d finished\n", epoch-1); flush(stdout)
end
end
function make_input(sample, w2i)
nwords = length(sample)
x = map(i->get(w2i, sample[i][1], w2i[UNK]), [1:nwords...])
x = reshape(x,1,length(x))
x = convert(Array{Int64}, x)
end
function make_output(sample, t2i)
nwords = length(sample)
y = map(i->t2i[sample[i][2]], [1:nwords...])
y = reshape(y,1,length(y))
y = convert(Array{Int64}, y)
end
# w[1] => weight/bias params for forward LSTM network
# w[2:5] => weight/bias params for MLP+softmax network
# w[6] => word embeddings
function initweights(atype, hidden, words, tags, embed, mlp, winit=0.01)
w = Array{Any}(undef,6)
input = embed
srnn, wrnn = rnninit(input, hidden; bidirectional=true, atype=atype)
w[1] = wrnn
w[2] = convert(atype, winit*randn(mlp, 2*hidden))
w[3] = convert(atype, zeros(mlp, 1))
w[4] = convert(atype, winit*randn(tags, mlp))
w[5] = convert(atype, winit*randn(tags, 1))
w[6] = convert(atype, winit*randn(embed, words))
return w, srnn
end
# loss function
function loss(w, x, ygold, srnn, h=nothing, c=nothing)
py, _ = predict(w,x,srnn,h,c)
return nll(py,ygold)
end
function predict(ws,xs,srnn,hx=nothing,cx=nothing)
wx = ws[6]
r = srnn; wr = ws[1]
wmlp = ws[2]; bmlp = ws[3];
wy = ws[4]; by = ws[5]
x = wx[:,xs]
y, hy, cy = rnnforw(r,wr,x)
y2 = reshape(y,size(y,1),size(y,2)*size(y,3))
y3 = wmlp * y2 .+ bmlp
return wy*y3.+by, hy, cy
end
lossgradient = gradloss(loss)
function train!(w,x,y,srnn,opt,h=nothing,c=nothing)
gloss, lossval = lossgradient(w, x, y, srnn)
update!(w,gloss,opt)
return lossval*size(x,2)
end
splitdir(PROGRAM_FILE)[end] == "bilstm-tagger.jl" && main(ARGS)
end # module