2121from eval_utils import squad_eval
2222from squad_utils import SquadFeature , get_squad_examples , convert_squad_example_to_feature
2323from gluonnlp .models import get_backbone
24- from gluonnlp .utils .misc import grouper , repeat , set_seed , parse_ctx , logging_config , count_parameters
24+ from gluonnlp .utils .misc import repeat , grouper , set_seed , init_comm , \
25+ parse_ctx , logging_config , count_parameters
2526from gluonnlp .initializer import TruncNorm
26- from gluonnlp .utils .parameter import clip_grad_global_norm , grad_global_norm
27+ from gluonnlp .data .sampler import SplitSampler
28+ from gluonnlp .utils .parameter import grad_global_norm , clip_grad_global_norm
29+
30+ try :
31+ import horovod .mxnet as hvd
32+ except ImportError :
33+ pass
2734
2835mx .npx .set_np ()
2936
@@ -48,6 +55,10 @@ def parse_args():
4855 parser .add_argument ('--output_dir' , type = str , default = 'squad_out' ,
4956 help = 'The output directory where the model params will be written.'
5057 ' default is squad_out' )
58+ # Communication
59+ parser .add_argument ('--comm_backend' , type = str , default = 'device' ,
60+ choices = ['horovod' , 'dist_sync_device' , 'device' ],
61+ help = 'Communication backend.' )
5162 parser .add_argument ('--gpus' , type = str , default = '0' ,
5263 help = 'list of gpus to run, e.g. 0 or 0,2,5. -1 means using cpu.' )
5364 # Training hyperparameters
@@ -384,8 +395,11 @@ def untune_params(model, untunable_depth, not_included=[]):
384395 continue
385396 value .grad_req = 'null'
386397
398+
387399def train (args ):
388- ctx_l = parse_ctx (args .gpus )
400+ store , num_workers , rank , local_rank , is_master_node , ctx_l = init_comm (
401+ args .comm_backend , args .gpus )
402+
389403 cfg , tokenizer , qa_net , use_segmentation = \
390404 get_network (args .model_name , ctx_l ,
391405 args .classifier_dropout ,
@@ -439,12 +453,15 @@ def train(args):
439453 sum ([ele .is_impossible for ele in train_features ])))
440454 logging .info ('After Chunking, #Train Sample/Is Impossible = {}/{}'
441455 .format (len (train_dataset ), num_impossible ))
456+ sampler = SplitSampler (len (train_dataset ), num_parts = num_workers ,
457+ part_index = rank , even_size = True )
442458 train_dataloader = mx .gluon .data .DataLoader (
443459 train_dataset ,
444460 batchify_fn = dataset_processor .BatchifyFunction ,
445461 batch_size = args .batch_size ,
446- num_workers = 0 ,
447- shuffle = True )
462+ num_workers = 4 ,
463+ shuffle = True
464+ sampler = sampler )
448465 # Froze parameters
449466 if 'electra' in args .model_name :
450467 # does not work for albert model since parameters in all layers are shared
@@ -453,17 +470,24 @@ def train(args):
453470 if args .layerwise_decay > 0 :
454471 qa_net .backbone .apply_layerwise_decay (args .layerwise_decay )
455472
473+ logging .info ('Creating distributed trainer...' )
474+ # Collect differentiable parameters
475+ param_dict = qa_net .collect_params ()
456476 # Do not apply weight decay to all the LayerNorm and bias
457477 for _ , v in qa_net .collect_params ('.*beta|.*gamma|.*bias' ).items ():
458478 v .wd_mult = 0.0
459- # Collect differentiable parameters
460- params = [p for p in qa_net .collect_params ().values () if p .grad_req != 'null' ]
479+ params = [p for p in param_dict .values () if p .grad_req != 'null' ]
461480 # Set grad_req if gradient accumulation is required
462481 if args .num_accumulated > 1 :
463482 logging .info ('Using gradient accumulation. Effective global batch size = {}'
464- .format (args .num_accumulated * args .batch_size * len (ctx_l )))
483+ .format (args .num_accumulated * args .batch_size * len (ctx_l ) * num_workers ))
465484 for p in params :
466485 p .grad_req = 'add'
486+ # backend specific implementation
487+ if args .comm_backend == 'horovod' :
488+ # Horovod: fetch and broadcast parameters
489+ hvd .broadcast_parameters (param_dict , root_rank = 0 )
490+
467491 epoch_size = (len (train_dataloader ) + len (ctx_l ) - 1 ) // len (ctx_l )
468492 if args .num_train_steps is not None :
469493 num_train_steps = args .num_train_steps
@@ -504,9 +528,12 @@ def train(args):
504528 'beta2' : adam_betas [1 ],
505529 'epsilon' : args .adam_epsilon ,
506530 })
507- trainer = mx .gluon .Trainer (qa_net .collect_params (),
508- args .optimizer , optimizer_params ,
509- update_on_kvstore = False )
531+ if args .comm_backend == 'horovod' :
532+ trainer = hvd .DistributedTrainer (param_dict , args .optimizer , optimizer_params )
533+ else :
534+ trainer = mx .gluon .Trainer (param_dict , args .optimizer , optimizer_params ,
535+ update_on_kvstore = False )
536+
510537 num_samples_per_update = 0
511538 loss_denom = float (len (ctx_l ) * args .num_accumulated )
512539
@@ -516,7 +543,7 @@ def train(args):
516543 log_sample_num = 0
517544 if args .num_accumulated != 1 :
518545 # set grad to zero for gradient accumulation
519- qa_net . collect_params () .zero_grad ()
546+ param_dict .zero_grad ()
520547
521548 # start training
522549 global_tic = time .time ()
@@ -575,17 +602,18 @@ def train(args):
575602 # \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom
576603 total_norm , ratio , is_finite = clip_grad_global_norm (
577604 params , args .max_grad_norm * num_samples_per_update / loss_denom )
578- total_norm = total_norm / (num_samples_per_update / loss_denom )
579605 else :
580- total_norm = grad_global_norm (parameters )
606+ total_norm = grad_global_norm (params )
581607
608+ total_norm = total_norm / (num_samples_per_update / loss_denom )
582609 trainer .update (num_samples_per_update / loss_denom )
583610 if args .num_accumulated != 1 :
584611 # set grad to zero for gradient accumulation
585- qa_net . collect_params () .zero_grad ()
612+ param_dict .zero_grad ()
586613
587614 # saving
588- if (step_num + 1 ) % save_interval == 0 or (step_num + 1 ) >= num_train_steps :
615+ if local_rank == 0 and is_master_node and (
616+ step_num + 1 ) % save_interval == 0 or (step_num + 1 ) >= num_train_steps :
589617 version_prefix = 'squad' + args .version
590618 ckpt_name = '{}_{}_{}.params' .format (args .model_name ,
591619 version_prefix ,
@@ -602,7 +630,7 @@ def train(args):
602630 logging .info ('Params saved in: {}' .format (params_saved ))
603631
604632 # logging
605- if (step_num + 1 ) % log_interval == 0 :
633+ if local_rank == 0 and (step_num + 1 ) % log_interval == 0 :
606634 log_span_loss /= log_sample_num
607635 log_answerable_loss /= log_sample_num
608636 log_total_loss /= log_sample_num
@@ -611,8 +639,8 @@ def train(args):
611639 'Step: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},'
612640 ' LR={:.8f}, grad_norm={:.4f}. Time cost={:.2f}, Throughput={:.2f} samples/s'
613641 ' ETA={:.2f}h' .format ((step_num + 1 ), num_train_steps , log_span_loss ,
614- log_answerable_loss , log_total_loss , trainer .learning_rate , total_norm ,
615- toc - tic , log_sample_num / (toc - tic ),
642+ log_answerable_loss , log_total_loss , trainer .learning_rate ,
643+ total_norm , toc - tic , log_sample_num / (toc - tic ),
616644 (num_train_steps - (step_num + 1 )) / ((step_num + 1 ) / (toc - global_tic )) / 3600 ))
617645 tic = time .time ()
618646 log_span_loss = 0
@@ -622,7 +650,9 @@ def train(args):
622650 num_samples_per_update = 0
623651
624652 if (step_num + 1 ) >= num_train_steps :
625- logging .info ('Finish training step: %d' , (step_num + 1 ))
653+ logging .info (
654+ 'Finish training step: {} within {} hours' .format (
655+ step_num + 1 , toc - global_tic ))
626656 break
627657
628658 return params_saved
0 commit comments