@@ -70,6 +70,10 @@ def parse_args():
7070 help = 'Max gradient norm.' )
7171 parser .add_argument ('--optimizer' , type = str , default = 'adamw' ,
7272 help = 'optimization algorithm. default is adamw' )
73+ parser .add_argument ('--adam_epsilon' , type = float , default = 1e-6 ,
74+ help = 'epsilon of AdamW optimizer' )
75+ parser .add_argument ('--adam_betas' , default = '(0.9, 0.999)' , metavar = 'B' ,
76+ help = 'betas for Adam optimizer' )
7377 parser .add_argument ('--num_accumulated' , type = int , default = 1 ,
7478 help = 'The number of batches for gradients accumulation to '
7579 'simulate large batch size.' )
@@ -476,6 +480,7 @@ def train(args):
476480 logging .info ('#Total Training Steps={}, Warmup={}, Save Interval={}'
477481 .format (num_train_steps , warmup_steps , save_interval ))
478482
483+ # set up optimization
479484 lr_scheduler = PolyScheduler (max_update = num_train_steps ,
480485 base_lr = args .lr ,
481486 warmup_begin_lr = 0 ,
@@ -487,12 +492,18 @@ def train(args):
487492 'wd' : args .wd ,
488493 'lr_scheduler' : lr_scheduler ,
489494 }
495+ adam_betas = eval (args .adam_betas )
490496 if args .optimizer == 'adamw' :
491- optimizer_params .update ({'beta1' : 0.9 ,
492- 'beta2' : 0.999 ,
493- 'epsilon' : 1e-6 ,
497+ optimizer_params .update ({'beta1' : adam_betas [ 0 ] ,
498+ 'beta2' : adam_betas [ 1 ] ,
499+ 'epsilon' : args . adam_epsilon ,
494500 'correct_bias' : False ,
495501 })
502+ elif args .optimizer == 'adam' :
503+ optimizer_params .update ({'beta1' : adam_betas [0 ],
504+ 'beta2' : adam_betas [1 ],
505+ 'epsilon' : args .adam_epsilon ,
506+ })
496507 trainer = mx .gluon .Trainer (qa_net .collect_params (),
497508 args .optimizer , optimizer_params ,
498509 update_on_kvstore = False )
0 commit comments