@@ -27,7 +27,17 @@ def parse_args():
2727 parser .add_argument (
2828 '--batch_size' , type = int , default = 32 , help = 'The minibatch size.' )
2929 parser .add_argument (
30- '--iterations' , type = int , default = 35 , help = 'The number of minibatches.' )
30+ '--use_fake_data' ,
31+ action = 'store_true' ,
32+ help = 'use real data or fake data' )
33+ parser .add_argument (
34+ '--skip_batch_num' ,
35+ type = int ,
36+ default = 5 ,
37+ help = 'The first num of minibatch num to skip, for better performance test'
38+ )
39+ parser .add_argument (
40+ '--iterations' , type = int , default = 80 , help = 'The number of minibatches.' )
3141 parser .add_argument (
3242 '--pass_num' , type = int , default = 100 , help = 'The number of passes.' )
3343 parser .add_argument (
@@ -392,7 +402,6 @@ def resnet(depth, class_dim, data_format):
392402
393403def run_benchmark (args , data_format = 'channels_last' , device = '/cpu:0' ):
394404 """Our model_fn for ResNet to be used with our Estimator."""
395- start_time = time .time ()
396405
397406 class_dim = 102
398407 dshape = (None , 224 , 224 , 3 )
@@ -431,17 +440,25 @@ def run_benchmark(args, data_format='channels_last', device='/cpu:0'):
431440 sess .run (init_g )
432441 sess .run (init_l )
433442
443+ if args .use_fake_data :
444+ data = train_reader ().next ()
445+ images_data = np .array (
446+ map (lambda x : np .transpose (x [0 ].reshape ([3 , 224 , 224 ]), axes = [1 , 2 , 0 ]), data )).astype ("float32" )
447+ labels_data = np .array (map (lambda x : x [1 ], data )).astype ('int64' )
434448 iter = 0
435449 for pass_id in range (args .pass_num ):
436450 if iter == args .iterations :
437451 break
438452 for batch_id , data in enumerate (train_reader ()):
453+ if iter == args .skip_batch_num :
454+ start_time = time .time ()
439455 if iter == args .iterations :
440456 break
441- images_data = np .array (
442- map (lambda x : np .transpose (x [0 ].reshape ([3 , 224 , 224 ]), axes = [1 , 2 , 0 ]), data )).astype ("float32" )
443- labels_data = np .array (map (lambda x : x [1 ], data )).astype (
444- 'int64' )
457+ if not args .use_fake_data :
458+ images_data = np .array (
459+ map (lambda x : np .transpose (x [0 ].reshape ([3 , 224 , 224 ]), axes = [1 , 2 , 0 ]), data )).astype ("float32" )
460+ labels_data = np .array (map (lambda x : x [1 ], data )).astype (
461+ 'int64' )
445462 _ , loss , acc , g_acc = sess .run (
446463 [train_op , avg_cost , accuracy , g_accuracy ],
447464 feed_dict = {images : images_data ,
@@ -451,12 +468,12 @@ def run_benchmark(args, data_format='channels_last', device='/cpu:0'):
451468 iter += 1
452469
453470 duration = time .time () - start_time
454- examples_per_sec = args .iterations * args .batch_size / duration
455- sec_per_batch = duration / args .batch_size
471+ img_num = (iter - args .skip_batch_num ) * args .batch_size
472+ examples_per_sec = img_num / duration
473+ sec_per_batch = duration / (iter - args .skip_batch_num )
456474
457- print ('\n Total examples: %d, total time: %.5f' %
458- (iter * args .batch_size , duration ))
459- print ('%.5f examples/sec, %.5f sec/batch \n ' %
475+ print ('Total examples: %d, total time: %.5f' % (img_num , duration ))
476+ print ('%.5f examples/sec, %.5f sec/batch' %
460477 (examples_per_sec , sec_per_batch ))
461478
462479
0 commit comments