Skip to content

Commit 0ec9262

Browse files
authored
Merge pull request #13 from jacquesqiao/add-more-test-flags
Add more test flags
2 parents 8122e4e + 77dbb14 commit 0ec9262

File tree

2 files changed

+55
-17
lines changed

2 files changed

+55
-17
lines changed

paddle/resnet.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ def parse_args():
2424
help='The model architecture.')
2525
parser.add_argument(
2626
'--batch_size', type=int, default=32, help='The minibatch size.')
27+
parser.add_argument(
28+
'--use_fake_data',
29+
action='store_true',
30+
help='use real data or fake data')
31+
parser.add_argument(
32+
'--skip_batch_num',
33+
type=int,
34+
default=5,
35+
help='The first num of minibatch num to skip, for better performance test'
36+
)
2737
parser.add_argument(
2838
'--iterations', type=int, default=80, help='The number of minibatches.')
2939
parser.add_argument(
@@ -148,19 +158,29 @@ def run_benchmark(model, args):
148158
exe = fluid.Executor(place)
149159
exe.run(fluid.default_startup_program())
150160

161+
if args.use_fake_data:
162+
data = train_reader().next()
163+
image = np.array(map(lambda x: x[0].reshape(dshape), data)).astype(
164+
'float32')
165+
label = np.array(map(lambda x: x[1], data)).astype('int64')
166+
label = label.reshape([-1, 1])
167+
151168
iter = 0
152169
im_num = 0
153170
for pass_id in range(args.pass_num):
154171
accuracy.reset(exe)
155172
if iter == args.iterations:
156173
break
157-
for data in train_reader():
174+
for batch_id, data in enumerate(train_reader()):
175+
if iter == args.skip_batch_num:
176+
start_time = time.time()
158177
if iter == args.iterations:
159178
break
160-
image = np.array(map(lambda x: x[0].reshape(dshape), data)).astype(
161-
'float32')
162-
label = np.array(map(lambda x: x[1], data)).astype('int64')
163-
label = label.reshape([-1, 1])
179+
if not args.use_fake_data:
180+
image = np.array(map(lambda x: x[0].reshape(dshape),
181+
data)).astype('float32')
182+
label = np.array(map(lambda x: x[1], data)).astype('int64')
183+
label = label.reshape([-1, 1])
164184
loss, acc = exe.run(fluid.default_main_program(),
165185
feed={'data': image,
166186
'label': label},
@@ -172,8 +192,9 @@ def run_benchmark(model, args):
172192
im_num += label.shape[0]
173193

174194
duration = time.time() - start_time
195+
im_num = im_num - args.skip_batch_num * args.batch_size
175196
examples_per_sec = im_num / duration
176-
sec_per_batch = duration / iter
197+
sec_per_batch = duration / (iter - args.skip_batch_num)
177198

178199
print('\nTotal examples: %d, total time: %.5f' % (im_num, duration))
179200
print('%.5f examples/sec, %.5f sec/batch \n' %

tensorflow/resnet.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,17 @@ def parse_args():
2727
parser.add_argument(
2828
'--batch_size', type=int, default=32, help='The minibatch size.')
2929
parser.add_argument(
30-
'--iterations', type=int, default=35, help='The number of minibatches.')
30+
'--use_fake_data',
31+
action='store_true',
32+
help='use real data or fake data')
33+
parser.add_argument(
34+
'--skip_batch_num',
35+
type=int,
36+
default=5,
37+
help='The first num of minibatch num to skip, for better performance test'
38+
)
39+
parser.add_argument(
40+
'--iterations', type=int, default=80, help='The number of minibatches.')
3141
parser.add_argument(
3242
'--pass_num', type=int, default=100, help='The number of passes.')
3343
parser.add_argument(
@@ -392,7 +402,6 @@ def resnet(depth, class_dim, data_format):
392402

393403
def run_benchmark(args, data_format='channels_last', device='/cpu:0'):
394404
"""Our model_fn for ResNet to be used with our Estimator."""
395-
start_time = time.time()
396405

397406
class_dim = 102
398407
dshape = (None, 224, 224, 3)
@@ -431,17 +440,25 @@ def run_benchmark(args, data_format='channels_last', device='/cpu:0'):
431440
sess.run(init_g)
432441
sess.run(init_l)
433442

443+
if args.use_fake_data:
444+
data = train_reader().next()
445+
images_data = np.array(
446+
map(lambda x: np.transpose(x[0].reshape([3, 224, 224]), axes=[1, 2, 0]), data)).astype("float32")
447+
labels_data = np.array(map(lambda x: x[1], data)).astype('int64')
434448
iter = 0
435449
for pass_id in range(args.pass_num):
436450
if iter == args.iterations:
437451
break
438452
for batch_id, data in enumerate(train_reader()):
453+
if iter == args.skip_batch_num:
454+
start_time = time.time()
439455
if iter == args.iterations:
440456
break
441-
images_data = np.array(
442-
map(lambda x: np.transpose(x[0].reshape([3, 224, 224]), axes=[1, 2, 0]), data)).astype("float32")
443-
labels_data = np.array(map(lambda x: x[1], data)).astype(
444-
'int64')
457+
if not args.use_fake_data:
458+
images_data = np.array(
459+
map(lambda x: np.transpose(x[0].reshape([3, 224, 224]), axes=[1, 2, 0]), data)).astype("float32")
460+
labels_data = np.array(map(lambda x: x[1], data)).astype(
461+
'int64')
445462
_, loss, acc, g_acc = sess.run(
446463
[train_op, avg_cost, accuracy, g_accuracy],
447464
feed_dict={images: images_data,
@@ -451,12 +468,12 @@ def run_benchmark(args, data_format='channels_last', device='/cpu:0'):
451468
iter += 1
452469

453470
duration = time.time() - start_time
454-
examples_per_sec = args.iterations * args.batch_size / duration
455-
sec_per_batch = duration / args.batch_size
471+
img_num = (iter - args.skip_batch_num) * args.batch_size
472+
examples_per_sec = img_num / duration
473+
sec_per_batch = duration / (iter - args.skip_batch_num)
456474

457-
print('\nTotal examples: %d, total time: %.5f' %
458-
(iter * args.batch_size, duration))
459-
print('%.5f examples/sec, %.5f sec/batch \n' %
475+
print('Total examples: %d, total time: %.5f' % (img_num, duration))
476+
print('%.5f examples/sec, %.5f sec/batch' %
460477
(examples_per_sec, sec_per_batch))
461478

462479

0 commit comments

Comments
 (0)