Skip to content

Commit 838be2a

Browse files
committed
horovod for squad
1 parent 1d374a2 commit 838be2a

3 files changed

Lines changed: 89 additions & 61 deletions

File tree

scripts/pretraining/run_electra.py

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from sklearn import metrics
1717
from pretraining_utils import ElectraMasker, get_pretrain_data_npz, get_pretrain_data_text
18-
from gluonnlp.utils.misc import grouper, repeat, set_seed, naming_convention, logging_config
18+
from gluonnlp.utils.misc import grouper, repeat, set_seed, naming_convention, logging_config, init_comm
1919
from gluonnlp.initializer import TruncNorm
2020
from gluonnlp.models.electra import ElectraModel, ElectraForPretrain, get_pretrained_electra
2121
from gluonnlp.utils.parameter import clip_grad_global_norm
@@ -170,39 +170,6 @@ def get_pretraining_model(model_name, ctx_l,
170170
'corrupted_tokens'])
171171

172172

173-
def init_comm(backend, gpus):
174-
"""Init communication backend"""
175-
# backend specific implementation
176-
if backend == 'horovod':
177-
try:
178-
import horovod.mxnet as hvd # pylint: disable=import-outside-toplevel
179-
except ImportError:
180-
logging.info('horovod must be installed.')
181-
sys.exit(1)
182-
hvd.init()
183-
store = None
184-
num_workers = hvd.size()
185-
rank = hvd.rank()
186-
local_rank = hvd.local_rank()
187-
is_master_node = rank == local_rank
188-
ctx_l = [mx.gpu(local_rank)]
189-
logging.info('GPU communication supported by horovod')
190-
else:
191-
store = mx.kv.create(backend)
192-
num_workers = store.num_workers
193-
rank = store.rank
194-
local_rank = 0
195-
is_master_node = rank == local_rank
196-
if gpus == '-1' or gpus == '':
197-
ctx_l = [mx.cpu()]
198-
logging.info('Runing on CPU')
199-
else:
200-
ctx_l = [mx.gpu(int(x)) for x in gpus.split(',')]
201-
logging.info('GPU communication supported by KVStore')
202-
203-
return store, num_workers, rank, local_rank, is_master_node, ctx_l
204-
205-
206173
def final_save(model, save_dir, tokenizer):
207174
if not os.path.exists(save_dir):
208175
os.makedirs(save_dir)
@@ -261,6 +228,9 @@ def states_option(step_num, trainer, ckpt_dir, local_rank=0, option='Saving'):
261228
def train(args):
262229
store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
263230
args.comm_backend, args.gpus)
231+
logging.info('Training info: num_buckets: {}, '
232+
'num_workers: {}, rank: {}'.format(
233+
args.num_buckets, num_workers, rank))
264234
cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l,
265235
args.max_seq_length,
266236
args.hidden_dropout_prob,
@@ -269,9 +239,6 @@ def train(args):
269239
args.generator_layers_scale)
270240
data_masker = ElectraMasker(
271241
tokenizer, args.max_seq_length, args.mask_prob)
272-
logging.info('Training info: num_buckets: {}, '
273-
'num_workers: {}, rank: {}'.format(
274-
args.num_buckets, num_workers, rank))
275242
if args.from_raw_text:
276243
if args.cached_file_path and not os.path.exists(args.cached_file_path):
277244
os.mkdir(args.cached_file_path)
@@ -342,8 +309,6 @@ def train(args):
342309
'epsilon': 1e-6,
343310
'correct_bias': False,
344311
})
345-
# TODO(zheyuye), absentance of layer-wise decay, although the decay power
346-
# is 1.0 in electra model
347312
if args.comm_backend == 'horovod':
348313
trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params)
349314
else:
@@ -448,9 +413,9 @@ def train(args):
448413
# We need to change the ratio to be
449414
# \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom
450415
total_norm, ratio, is_finite = clip_grad_global_norm(
451-
params, args.max_grad_norm * num_samples_per_update / loss_denom)
416+
params, args.max_grad_norm * num_samples_per_update / loss_denom)
452417
total_norm = total_norm / (num_samples_per_update / loss_denom)
453-
trainer.update(num_samples_per_update / loss_denom, ignore_stale_grad=True)
418+
trainer.update(num_samples_per_update / loss_denom)
454419
step_num += 1
455420
if args.num_accumulated != 1:
456421
# set grad to zero for gradient accumulation

scripts/question_answering/run_squad.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,16 @@
2121
from eval_utils import squad_eval
2222
from squad_utils import SquadFeature, get_squad_examples, convert_squad_example_to_feature
2323
from gluonnlp.models import get_backbone
24-
from gluonnlp.utils.misc import grouper, repeat, set_seed, parse_ctx, logging_config, count_parameters
24+
from gluonnlp.utils.misc import repeat, grouper, set_seed, init_comm, \
25+
parse_ctx, logging_config, count_parameters
2526
from gluonnlp.initializer import TruncNorm
26-
from gluonnlp.utils.parameter import clip_grad_global_norm, grad_global_norm
27+
from gluonnlp.data.sampler import SplitSampler
28+
from gluonnlp.utils.parameter import grad_global_norm, clip_grad_global_norm
29+
30+
try:
31+
import horovod.mxnet as hvd
32+
except ImportError:
33+
pass
2734

2835
mx.npx.set_np()
2936

@@ -48,6 +55,10 @@ def parse_args():
4855
parser.add_argument('--output_dir', type=str, default='squad_out',
4956
help='The output directory where the model params will be written.'
5057
' default is squad_out')
58+
# Communication
59+
parser.add_argument('--comm_backend', type=str, default='device',
60+
choices=['horovod', 'dist_sync_device', 'device'],
61+
help='Communication backend.')
5162
parser.add_argument('--gpus', type=str, default='0',
5263
help='list of gpus to run, e.g. 0 or 0,2,5. -1 means using cpu.')
5364
# Training hyperparameters
@@ -384,8 +395,11 @@ def untune_params(model, untunable_depth, not_included=[]):
384395
continue
385396
value.grad_req = 'null'
386397

398+
387399
def train(args):
388-
ctx_l = parse_ctx(args.gpus)
400+
store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
401+
args.comm_backend, args.gpus)
402+
389403
cfg, tokenizer, qa_net, use_segmentation = \
390404
get_network(args.model_name, ctx_l,
391405
args.classifier_dropout,
@@ -439,12 +453,15 @@ def train(args):
439453
sum([ele.is_impossible for ele in train_features])))
440454
logging.info('After Chunking, #Train Sample/Is Impossible = {}/{}'
441455
.format(len(train_dataset), num_impossible))
456+
sampler = SplitSampler(len(train_dataset), num_parts=num_workers,
457+
part_index=rank, even_size=True)
442458
train_dataloader = mx.gluon.data.DataLoader(
443459
train_dataset,
444460
batchify_fn=dataset_processor.BatchifyFunction,
445461
batch_size=args.batch_size,
446-
num_workers=0,
447-
shuffle=True)
462+
num_workers=4,
463+
shuffle=True
464+
sampler=sampler)
448465
# Froze parameters
449466
if 'electra' in args.model_name:
450467
# does not work for albert model since parameters in all layers are shared
@@ -453,17 +470,24 @@ def train(args):
453470
if args.layerwise_decay > 0:
454471
qa_net.backbone.apply_layerwise_decay(args.layerwise_decay)
455472

473+
logging.info('Creating distributed trainer...')
474+
# Collect differentiable parameters
475+
param_dict = qa_net.collect_params()
456476
# Do not apply weight decay to all the LayerNorm and bias
457477
for _, v in qa_net.collect_params('.*beta|.*gamma|.*bias').items():
458478
v.wd_mult = 0.0
459-
# Collect differentiable parameters
460-
params = [p for p in qa_net.collect_params().values() if p.grad_req != 'null']
479+
params = [p for p in param_dict.values() if p.grad_req != 'null']
461480
# Set grad_req if gradient accumulation is required
462481
if args.num_accumulated > 1:
463482
logging.info('Using gradient accumulation. Effective global batch size = {}'
464-
.format(args.num_accumulated * args.batch_size * len(ctx_l)))
483+
.format(args.num_accumulated * args.batch_size * len(ctx_l) * num_workers))
465484
for p in params:
466485
p.grad_req = 'add'
486+
# backend specific implementation
487+
if args.comm_backend == 'horovod':
488+
# Horovod: fetch and broadcast parameters
489+
hvd.broadcast_parameters(param_dict, root_rank=0)
490+
467491
epoch_size = (len(train_dataloader) + len(ctx_l) - 1) // len(ctx_l)
468492
if args.num_train_steps is not None:
469493
num_train_steps = args.num_train_steps
@@ -504,9 +528,12 @@ def train(args):
504528
'beta2': adam_betas[1],
505529
'epsilon': args.adam_epsilon,
506530
})
507-
trainer = mx.gluon.Trainer(qa_net.collect_params(),
508-
args.optimizer, optimizer_params,
509-
update_on_kvstore=False)
531+
if args.comm_backend == 'horovod':
532+
trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params)
533+
else:
534+
trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params,
535+
update_on_kvstore=False)
536+
510537
num_samples_per_update = 0
511538
loss_denom = float(len(ctx_l) * args.num_accumulated)
512539

@@ -516,7 +543,7 @@ def train(args):
516543
log_sample_num = 0
517544
if args.num_accumulated != 1:
518545
# set grad to zero for gradient accumulation
519-
qa_net.collect_params().zero_grad()
546+
param_dict.zero_grad()
520547

521548
# start training
522549
global_tic = time.time()
@@ -575,17 +602,18 @@ def train(args):
575602
# \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom
576603
total_norm, ratio, is_finite = clip_grad_global_norm(
577604
params, args.max_grad_norm * num_samples_per_update / loss_denom)
578-
total_norm = total_norm / (num_samples_per_update / loss_denom)
579605
else:
580-
total_norm = grad_global_norm(parameters)
606+
total_norm = grad_global_norm(params)
581607

608+
total_norm = total_norm / (num_samples_per_update / loss_denom)
582609
trainer.update(num_samples_per_update / loss_denom)
583610
if args.num_accumulated != 1:
584611
# set grad to zero for gradient accumulation
585-
qa_net.collect_params().zero_grad()
612+
param_dict.zero_grad()
586613

587614
# saving
588-
if (step_num + 1) % save_interval == 0 or (step_num + 1) >= num_train_steps:
615+
if local_rank == 0 and is_master_node and (
616+
step_num + 1) % save_interval == 0 or (step_num + 1) >= num_train_steps:
589617
version_prefix = 'squad' + args.version
590618
ckpt_name = '{}_{}_{}.params'.format(args.model_name,
591619
version_prefix,
@@ -602,7 +630,7 @@ def train(args):
602630
logging.info('Params saved in: {}'.format(params_saved))
603631

604632
# logging
605-
if (step_num + 1) % log_interval == 0:
633+
if local_rank == 0 and (step_num + 1) % log_interval == 0:
606634
log_span_loss /= log_sample_num
607635
log_answerable_loss /= log_sample_num
608636
log_total_loss /= log_sample_num
@@ -611,8 +639,8 @@ def train(args):
611639
'Step: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},'
612640
' LR={:.8f}, grad_norm={:.4f}. Time cost={:.2f}, Throughput={:.2f} samples/s'
613641
' ETA={:.2f}h'.format((step_num + 1), num_train_steps, log_span_loss,
614-
log_answerable_loss, log_total_loss, trainer.learning_rate, total_norm,
615-
toc - tic, log_sample_num / (toc - tic),
642+
log_answerable_loss, log_total_loss, trainer.learning_rate,
643+
total_norm, toc - tic, log_sample_num / (toc - tic),
616644
(num_train_steps - (step_num + 1)) / ((step_num + 1) / (toc - global_tic)) / 3600))
617645
tic = time.time()
618646
log_span_loss = 0
@@ -622,7 +650,9 @@ def train(args):
622650
num_samples_per_update = 0
623651

624652
if (step_num + 1) >= num_train_steps:
625-
logging.info('Finish training step: %d', (step_num + 1))
653+
logging.info(
654+
'Finish training step: {} within {} hours'.format(
655+
step_num + 1, toc - global_tic))
626656
break
627657

628658
return params_saved

src/gluonnlp/utils/misc.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,3 +545,36 @@ def check_version(min_version: str,
545545
warnings.warn(msg)
546546
else:
547547
raise AssertionError(msg)
548+
549+
def init_comm(backend, gpus):
550+
"""Init communication backend"""
551+
# backend specific implementation
552+
import mxnet as mx
553+
if backend == 'horovod':
554+
try:
555+
import horovod.mxnet as hvd # pylint: disable=import-outside-toplevel
556+
except ImportError:
557+
logging.info('horovod must be installed.')
558+
sys.exit(1)
559+
hvd.init()
560+
store = None
561+
num_workers = hvd.size()
562+
rank = hvd.rank()
563+
local_rank = hvd.local_rank()
564+
is_master_node = rank == local_rank
565+
ctx_l = [mx.gpu(local_rank)]
566+
logging.info('GPU communication supported by horovod')
567+
else:
568+
store = mx.kv.create(backend)
569+
num_workers = store.num_workers
570+
rank = store.rank
571+
local_rank = 0
572+
is_master_node = rank == local_rank
573+
if gpus == '-1' or gpus == '':
574+
ctx_l = [mx.cpu()]
575+
logging.info('Runing on CPU')
576+
else:
577+
ctx_l = [mx.gpu(int(x)) for x in gpus.split(',')]
578+
logging.info('GPU communication supported by KVStore')
579+
580+
return store, num_workers, rank, local_rank, is_master_node, ctx_l

0 commit comments

Comments
 (0)