Skip to content

Commit d5b7bb7

Browse files
committed
Refine the script.
1 parent ba2ea17 commit d5b7bb7

File tree

1 file changed

+32
-36
lines changed

1 file changed

+32
-36
lines changed

fluid/machine_translation.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
import paddle.v2.fluid as fluid
1313
import paddle.v2.fluid.core as core
1414
import paddle.v2.fluid.framework as framework
15-
from paddle.v2.fluid.param_attr import ParamAttr
1615
from paddle.v2.fluid.executor import Executor
1716

1817
parser = argparse.ArgumentParser(description=__doc__)
1918
parser.add_argument(
20-
"--word_vector_dim",
19+
"--embedding_dim",
2120
type=int,
2221
default=512,
2322
help="The dimension of embedding table. (default: %(default)d)")
@@ -35,15 +34,15 @@
3534
"--batch_size",
3635
type=int,
3736
default=16,
38-
help="The sequence number of a batch data. (default: %(default)d)")
37+
help="The sequence number of a mini-batch data. (default: %(default)d)")
3938
parser.add_argument(
4039
"--dict_size",
4140
type=int,
4241
default=30000,
4342
help="The dictionary capacity. Dictionaries of source sequence and "
4443
"target dictionary have same capacity. (default: %(default)d)")
4544
parser.add_argument(
46-
"--pass_number",
45+
"--pass_num",
4746
type=int,
4847
default=2,
4948
help="The pass number to train. (default: %(default)d)")
@@ -53,11 +52,7 @@
5352
default=0.0002,
5453
help="Learning rate used to train the model. (default: %(default)f)")
5554
parser.add_argument(
56-
"--mode",
57-
type=str,
58-
default='train',
59-
choices=['train', 'infer'],
60-
help="Do training or inference. (default: %(default)s)")
55+
"--infer_only", action='store_true', help="If set, run forward only.")
6156
parser.add_argument(
6257
"--beam_size",
6358
type=int,
@@ -67,12 +62,12 @@
6762
"--use_gpu",
6863
type=distutils.util.strtobool,
6964
default=True,
70-
help="Whether use gpu. (default: %(default)d)")
65+
help="Whether to use gpu. (default: %(default)d)")
7166
parser.add_argument(
7267
"--max_length",
7368
type=int,
7469
default=250,
75-
help="The max length of sequence when doing generation. "
70+
help="The maximum length of sequence when doing generation. "
7671
"(default: %(default)d)")
7772

7873

@@ -97,40 +92,39 @@ def linear(inputs):
9792
return hidden_t, cell_t
9893

9994

100-
def seq_to_seq_net(word_vector_dim,
101-
encoder_size,
102-
decoder_size,
103-
source_dict_dim,
104-
target_dict_dim,
105-
is_generating=False,
106-
beam_size=3,
107-
max_length=250):
95+
def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim,
96+
target_dict_dim, is_generating, beam_size, max_length):
10897
"""Construct a seq2seq network."""
10998
feeding_list = ["source_sequence", "target_sequence", "label_sequence"]
11099

111-
def bi_lstm_encoder(input_seq, size):
100+
def bi_lstm_encoder(input_seq, gate_size):
101+
# Linear transformation part for input gate, output gate, forget gate
102+
# and cell activation vectors need be done outside of dynamic_lstm.
103+
# So the output size is 4 times of gate_size.
112104
input_forward_proj = fluid.layers.fc(input=input_seq,
113-
size=size * 4,
114-
act='tanh')
105+
size=gate_size * 4,
106+
act='tanh',
107+
bias_attr=True)
115108
forward, _ = fluid.layers.dynamic_lstm(
116-
input=input_forward_proj, size=size * 4)
109+
input=input_forward_proj, size=gate_size * 4)
117110
input_reversed_proj = fluid.layers.fc(input=input_seq,
118-
size=size * 4,
119-
act='tanh')
111+
size=gate_size * 4,
112+
act='tanh',
113+
bias_attr=True)
120114
reversed, _ = fluid.layers.dynamic_lstm(
121-
input=input_reversed_proj, size=size * 4, is_reverse=True)
115+
input=input_reversed_proj, size=gate_size * 4, is_reverse=True)
122116
return forward, reversed
123117

124118
src_word_idx = fluid.layers.data(
125119
name=feeding_list[0], shape=[1], dtype='int64', lod_level=1)
126120

127121
src_embedding = fluid.layers.embedding(
128122
input=src_word_idx,
129-
size=[source_dict_dim, word_vector_dim],
123+
size=[source_dict_dim, embedding_dim],
130124
dtype='float32')
131125

132126
src_forward, src_reversed = bi_lstm_encoder(
133-
input_seq=src_embedding, size=encoder_size)
127+
input_seq=src_embedding, gate_size=encoder_size)
134128

135129
encoded_vector = fluid.layers.concat(
136130
input=[src_forward, src_reversed], axis=1)
@@ -151,13 +145,15 @@ def lstm_decoder_with_attention(target_embedding, encoder_vec, encoder_proj,
151145
decoder_boot, decoder_size):
152146
def simple_attention(encoder_vec, encoder_proj, decoder_state):
153147
decoder_state_proj = fluid.layers.fc(input=decoder_state,
154-
size=decoder_size)
148+
size=decoder_size,
149+
bias_attr=False)
155150
decoder_state_expand = fluid.layers.sequence_expand(
156151
x=decoder_state_proj, y=encoder_proj)
157152
concated = fluid.layers.concat(
158153
input=[decoder_state_expand, encoder_proj], axis=1)
159154
attention_weights = fluid.layers.fc(input=concated,
160155
size=1,
156+
act='tanh',
161157
bias_attr=False)
162158
attention_weights = fluid.layers.sequence_softmax(
163159
x=attention_weights)
@@ -191,7 +187,7 @@ def simple_attention(encoder_vec, encoder_proj, decoder_state):
191187
rnn.update_memory(cell_mem, c)
192188
out = fluid.layers.fc(input=h,
193189
size=target_dict_dim,
194-
bias_attr=ParamAttr(),
190+
bias_attr=True,
195191
act='softmax')
196192
rnn.output(out)
197193
return rnn()
@@ -202,7 +198,7 @@ def simple_attention(encoder_vec, encoder_proj, decoder_state):
202198

203199
trg_embedding = fluid.layers.embedding(
204200
input=trg_word_idx,
205-
size=[target_dict_dim, word_vector_dim],
201+
size=[target_dict_dim, embedding_dim],
206202
dtype='float32')
207203

208204
prediction = lstm_decoder_with_attention(trg_embedding, encoded_vector,
@@ -242,7 +238,7 @@ def lodtensor_to_ndarray(lod_tensor):
242238

243239
def train():
244240
avg_cost, feeding_list = seq_to_seq_net(
245-
args.word_vector_dim,
241+
args.embedding_dim,
246242
args.encoder_size,
247243
args.decoder_size,
248244
args.dict_size,
@@ -290,7 +286,7 @@ def do_validation():
290286

291287
return total_loss / count
292288

293-
for pass_id in xrange(args.pass_number):
289+
for pass_id in xrange(args.pass_num):
294290
pass_start_time = time.time()
295291
words_seen = 0
296292
for batch_id, data in enumerate(train_batch_generator()):
@@ -323,7 +319,7 @@ def infer():
323319

324320
if __name__ == '__main__':
325321
args = parser.parse_args()
326-
if args.mode == 'train':
327-
train()
328-
else:
322+
if args.infer_only:
329323
infer()
324+
else:
325+
train()

0 commit comments

Comments
 (0)