1212import paddle .v2 .fluid as fluid
1313import paddle .v2 .fluid .core as core
1414import paddle .v2 .fluid .framework as framework
15- from paddle .v2 .fluid .param_attr import ParamAttr
1615from paddle .v2 .fluid .executor import Executor
1716
1817parser = argparse .ArgumentParser (description = __doc__ )
1918parser .add_argument (
20- "--word_vector_dim " ,
19+ "--embedding_dim " ,
2120 type = int ,
2221 default = 512 ,
2322 help = "The dimension of embedding table. (default: %(default)d)" )
3534 "--batch_size" ,
3635 type = int ,
3736 default = 16 ,
38- help = "The sequence number of a batch data. (default: %(default)d)" )
37+ help = "The sequence number of a mini- batch data. (default: %(default)d)" )
3938parser .add_argument (
4039 "--dict_size" ,
4140 type = int ,
4241 default = 30000 ,
4342 help = "The dictionary capacity. Dictionaries of source sequence and "
4443 "target dictionary have same capacity. (default: %(default)d)" )
4544parser .add_argument (
46- "--pass_number " ,
45+ "--pass_num " ,
4746 type = int ,
4847 default = 2 ,
4948 help = "The pass number to train. (default: %(default)d)" )
5352 default = 0.0002 ,
5453 help = "Learning rate used to train the model. (default: %(default)f)" )
5554parser .add_argument (
56- "--mode" ,
57- type = str ,
58- default = 'train' ,
59- choices = ['train' , 'infer' ],
60- help = "Do training or inference. (default: %(default)s)" )
55+ "--infer_only" , action = 'store_true' , help = "If set, run forward only." )
6156parser .add_argument (
6257 "--beam_size" ,
6358 type = int ,
6762 "--use_gpu" ,
6863 type = distutils .util .strtobool ,
6964 default = True ,
70- help = "Whether use gpu. (default: %(default)d)" )
65+ help = "Whether to use gpu. (default: %(default)d)" )
7166parser .add_argument (
7267 "--max_length" ,
7368 type = int ,
7469 default = 250 ,
75- help = "The max length of sequence when doing generation. "
70+ help = "The maximum length of sequence when doing generation. "
7671 "(default: %(default)d)" )
7772
7873
@@ -97,40 +92,39 @@ def linear(inputs):
9792 return hidden_t , cell_t
9893
9994
100- def seq_to_seq_net (word_vector_dim ,
101- encoder_size ,
102- decoder_size ,
103- source_dict_dim ,
104- target_dict_dim ,
105- is_generating = False ,
106- beam_size = 3 ,
107- max_length = 250 ):
95+ def seq_to_seq_net (embedding_dim , encoder_size , decoder_size , source_dict_dim ,
96+ target_dict_dim , is_generating , beam_size , max_length ):
10897 """Construct a seq2seq network."""
10998 feeding_list = ["source_sequence" , "target_sequence" , "label_sequence" ]
11099
111- def bi_lstm_encoder (input_seq , size ):
100+ def bi_lstm_encoder (input_seq , gate_size ):
101+ # Linear transformation part for input gate, output gate, forget gate
102+ # and cell activation vectors need be done outside of dynamic_lstm.
103+ # So the output size is 4 times of gate_size.
112104 input_forward_proj = fluid .layers .fc (input = input_seq ,
113- size = size * 4 ,
114- act = 'tanh' )
105+ size = gate_size * 4 ,
106+ act = 'tanh' ,
107+ bias_attr = True )
115108 forward , _ = fluid .layers .dynamic_lstm (
116- input = input_forward_proj , size = size * 4 )
109+ input = input_forward_proj , size = gate_size * 4 )
117110 input_reversed_proj = fluid .layers .fc (input = input_seq ,
118- size = size * 4 ,
119- act = 'tanh' )
111+ size = gate_size * 4 ,
112+ act = 'tanh' ,
113+ bias_attr = True )
120114 reversed , _ = fluid .layers .dynamic_lstm (
121- input = input_reversed_proj , size = size * 4 , is_reverse = True )
115+ input = input_reversed_proj , size = gate_size * 4 , is_reverse = True )
122116 return forward , reversed
123117
124118 src_word_idx = fluid .layers .data (
125119 name = feeding_list [0 ], shape = [1 ], dtype = 'int64' , lod_level = 1 )
126120
127121 src_embedding = fluid .layers .embedding (
128122 input = src_word_idx ,
129- size = [source_dict_dim , word_vector_dim ],
123+ size = [source_dict_dim , embedding_dim ],
130124 dtype = 'float32' )
131125
132126 src_forward , src_reversed = bi_lstm_encoder (
133- input_seq = src_embedding , size = encoder_size )
127+ input_seq = src_embedding , gate_size = encoder_size )
134128
135129 encoded_vector = fluid .layers .concat (
136130 input = [src_forward , src_reversed ], axis = 1 )
@@ -151,13 +145,15 @@ def lstm_decoder_with_attention(target_embedding, encoder_vec, encoder_proj,
151145 decoder_boot , decoder_size ):
152146 def simple_attention (encoder_vec , encoder_proj , decoder_state ):
153147 decoder_state_proj = fluid .layers .fc (input = decoder_state ,
154- size = decoder_size )
148+ size = decoder_size ,
149+ bias_attr = False )
155150 decoder_state_expand = fluid .layers .sequence_expand (
156151 x = decoder_state_proj , y = encoder_proj )
157152 concated = fluid .layers .concat (
158153 input = [decoder_state_expand , encoder_proj ], axis = 1 )
159154 attention_weights = fluid .layers .fc (input = concated ,
160155 size = 1 ,
156+ act = 'tanh' ,
161157 bias_attr = False )
162158 attention_weights = fluid .layers .sequence_softmax (
163159 x = attention_weights )
@@ -191,7 +187,7 @@ def simple_attention(encoder_vec, encoder_proj, decoder_state):
191187 rnn .update_memory (cell_mem , c )
192188 out = fluid .layers .fc (input = h ,
193189 size = target_dict_dim ,
194- bias_attr = ParamAttr () ,
190+ bias_attr = True ,
195191 act = 'softmax' )
196192 rnn .output (out )
197193 return rnn ()
@@ -202,7 +198,7 @@ def simple_attention(encoder_vec, encoder_proj, decoder_state):
202198
203199 trg_embedding = fluid .layers .embedding (
204200 input = trg_word_idx ,
205- size = [target_dict_dim , word_vector_dim ],
201+ size = [target_dict_dim , embedding_dim ],
206202 dtype = 'float32' )
207203
208204 prediction = lstm_decoder_with_attention (trg_embedding , encoded_vector ,
@@ -242,7 +238,7 @@ def lodtensor_to_ndarray(lod_tensor):
242238
243239def train ():
244240 avg_cost , feeding_list = seq_to_seq_net (
245- args .word_vector_dim ,
241+ args .embedding_dim ,
246242 args .encoder_size ,
247243 args .decoder_size ,
248244 args .dict_size ,
@@ -290,7 +286,7 @@ def do_validation():
290286
291287 return total_loss / count
292288
293- for pass_id in xrange (args .pass_number ):
289+ for pass_id in xrange (args .pass_num ):
294290 pass_start_time = time .time ()
295291 words_seen = 0
296292 for batch_id , data in enumerate (train_batch_generator ()):
@@ -323,7 +319,7 @@ def infer():
323319
324320if __name__ == '__main__' :
325321 args = parser .parse_args ()
326- if args .mode == 'train' :
327- train ()
328- else :
322+ if args .infer_only :
329323 infer ()
324+ else :
325+ train ()
0 commit comments