@@ -54,51 +54,59 @@ def to_lodtensor(data, place):
5454 return res
5555
5656
57- def chop_data (data , chop_len = 80 , batch_len = 50 ):
57+ def chop_data (data , chop_len = 80 , batch_size = 50 ):
5858 data = [(x [0 ][:chop_len ], x [1 ]) for x in data if len (x [0 ]) >= chop_len ]
5959
60- return data [:batch_len ]
60+ return data [:batch_size ]
6161
6262
6363def prepare_feed_data (data , place ):
6464 tensor_words = to_lodtensor (map (lambda x : x [0 ], data ), place )
6565
6666 label = np .array (map (lambda x : x [1 ], data )).astype ("int64" )
67- label = label .reshape ([50 , 1 ])
67+ label = label .reshape ([len ( label ) , 1 ])
6868 tensor_label = core .LoDTensor ()
6969 tensor_label .set (label , place )
7070
7171 return tensor_words , tensor_label
7272
7373
7474def main ():
75- word_dict = paddle . dataset . imdb . word_dict ()
76- cost , acc = lstm_net ( dict_dim = len ( word_dict ), class_dim = 2 )
75+ BATCH_SIZE = 100
76+ PASS_NUM = 5
7777
78- batch_size = 100
79- train_data = paddle .batch (
80- paddle .reader .buffered (
81- paddle .dataset .imdb .train (word_dict ), size = batch_size * 10 ),
82- batch_size = batch_size )
78+ word_dict = paddle .dataset .imdb .word_dict ()
79+ print "load word dict successfully"
80+ dict_dim = len (word_dict )
81+ class_dim = 2
8382
84- data = chop_data ( next ( train_data ()) )
83+ cost , acc = lstm_net ( dict_dim = dict_dim , class_dim = class_dim )
8584
85+ train_data = paddle .batch (
86+ paddle .reader .shuffle (
87+ paddle .dataset .imdb .train (word_dict ), buf_size = BATCH_SIZE * 10 ),
88+ batch_size = BATCH_SIZE )
8689 place = core .CPUPlace ()
87- tensor_words , tensor_label = prepare_feed_data (data , place )
8890 exe = Executor (place )
91+
8992 exe .run (framework .default_startup_program ())
9093
91- while True :
92- outs = exe .run (framework .default_main_program (),
93- feed = {"words" : tensor_words ,
94- "label" : tensor_label },
95- fetch_list = [cost , acc ])
96- cost_val = np .array (outs [0 ])
97- acc_val = np .array (outs [1 ])
98-
99- print ("cost=" + str (cost_val ) + " acc=" + str (acc_val ))
100- if acc_val > 0.9 :
101- break
94+ for pass_id in xrange (PASS_NUM ):
95+ for data in train_data ():
96+ chopped_data = chop_data (data )
97+ tensor_words , tensor_label = prepare_feed_data (chopped_data , place )
98+
99+ outs = exe .run (framework .default_main_program (),
100+ feed = {"words" : tensor_words ,
101+ "label" : tensor_label },
102+ fetch_list = [cost , acc ])
103+ cost_val = np .array (outs [0 ])
104+ acc_val = np .array (outs [1 ])
105+
106+ print ("cost=" + str (cost_val ) + " acc=" + str (acc_val ))
107+ if acc_val > 0.7 :
108+ exit (0 )
109+ exit (1 )
102110
103111
104112if __name__ == '__main__' :
0 commit comments