Skip to content

Commit 4fc564c

Browse files
committed
update hyper-parameters of adamw
1 parent 59cffbf commit 4fc564c

1 file changed

Lines changed: 14 additions & 3 deletions

File tree

scripts/question_answering/run_squad.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def parse_args():
7070
help='Max gradient norm.')
7171
parser.add_argument('--optimizer', type=str, default='adamw',
7272
help='optimization algorithm. default is adamw')
73+
parser.add_argument('--adam_epsilon', type=float, default=1e-6,
74+
help='epsilon of AdamW optimizer')
75+
parser.add_argument('--adam_betas', default='(0.9, 0.999)', metavar='B',
76+
help='betas for Adam optimizer')
7377
parser.add_argument('--num_accumulated', type=int, default=1,
7478
help='The number of batches for gradients accumulation to '
7579
'simulate large batch size.')
@@ -476,6 +480,7 @@ def train(args):
476480
logging.info('#Total Training Steps={}, Warmup={}, Save Interval={}'
477481
.format(num_train_steps, warmup_steps, save_interval))
478482

483+
# set up optimization
479484
lr_scheduler = PolyScheduler(max_update=num_train_steps,
480485
base_lr=args.lr,
481486
warmup_begin_lr=0,
@@ -487,12 +492,18 @@ def train(args):
487492
'wd': args.wd,
488493
'lr_scheduler': lr_scheduler,
489494
}
495+
adam_betas = eval(args.adam_betas)
490496
if args.optimizer == 'adamw':
491-
optimizer_params.update({'beta1': 0.9,
492-
'beta2': 0.999,
493-
'epsilon': 1e-6,
497+
optimizer_params.update({'beta1': adam_betas[0],
498+
'beta2': adam_betas[1],
499+
'epsilon': args.adam_epsilon,
494500
'correct_bias': False,
495501
})
502+
elif args.optimizer == 'adam':
503+
optimizer_params.update({'beta1': adam_betas[0],
504+
'beta2': adam_betas[1],
505+
'epsilon': args.adam_epsilon,
506+
})
496507
trainer = mx.gluon.Trainer(qa_net.collect_params(),
497508
args.optimizer, optimizer_params,
498509
update_on_kvstore=False)

0 commit comments

Comments
 (0)