Skip to content

Commit c87a86b

Browse files
Yada Pruksachatkunphu-pmhHaokunLiupruksmhcDeepLearning VM
authored
Adding Masked Language Modelling (#1030)
* misc run scripts * sbatch * sweep scripts * update * qa * update * update * update * update * update * sb file * moving update_metrics to outside scope of dataparallel * fixing micro_avg calculation * undo debugging * Fixing tests, moving update_metrics out of other tasks * remove extraneous change * MLM task * Added MLM task * update * fix multiple choice dataparallel forward * update * add _mask_id to transformers * Update * MLM update * adding update_metrics abstraction * delete update_metrics_ notation * fixed wrong index problem * removed unrelated files * removed unrelated files * removed unrelated files * fix PEP8 * Fixed get_pretained_lm_head for BERT and ALBERT * spelling check * black formatting * fixing tests * bug fix * Adding batch_size constraints to multi-GPU setting * adding documentation * adding batch size test * black correct version * Fixing batch size assertion * generalize batch size assertion for more than 2 GPU setting * reducing label loops in code * fixing span forward * Fixing span prediction forward for multi-GPU * fix commonsenseQA forward * MLM * adding function documentation * resolving nits, fixing seq_gen forward * remove nit * fixing batch_size assert and SpanPrediction task * Remove debugging * Fix batch size mismatch multi-GPU test * Fix order of assert checking for batch size mismatch * mlm training * update * sbatch * update * data parallel * update data parallel stuffs * using sequencelabel, using 1 paragraph per example * update label mapping * adding exmaples-porportion-mixing * changing dataloader to work with wikitext103 * weight sampling * add early stopping only onb one task * commit * Cleaning up code * Removing unecessarily tracked git folders * Removing unnecesary changes * revert README * revert README.md again * Making more general for Transformer-based embedders * torch.uint8 -> torch.bool * Fixing indexing issues * get rid of unecessary changes * black cleanup * update * Prevent updating update_metrics twice in one step * update * update * add base_roberta * update * reverting CCG edit added for debugging * refactor defaults.conf * black formatting * merge * removed SOP task and mlm_manual_scaling * Fixing label namespace vocabulary creation, mergeing from master * Deleting MLM weight * black formatting * Adding early_stopping_method to defaults.conf * Fixing MLM with preprocessed wikitext103 * Deleting intermediate class hierarchy for MLM * Correcting black * LanguageModelingTask -> AutoregressiveModelingTask * code style * fixing MaskedLanguageModelTask * Fixing typo * Fixing label namespace * extracting out masking portion * Revert "extracting out masking portion" This reverts commit f21165c. * Code cleanup * Adding tests for early_stpping_method * Adding pretrain_stop_metric * Reverting get_data_iter * Reverting to get_data_iter * Fixing get_pretrained_lm_head for all embedder types * Extracting out MLM probability masking * Move dynamic masking function to Task for easier testing * Adding unit tests for MLM * Adding change to MLM forward function to expose more intermediate steps for testing * Fixing code style * Adding more detailed instructions of how to generate Wikipedia data * Adding rest of MLM data generation code * Black style and remove comment * black style * updating repro code for MLM data Co-authored-by: phu-pmh <[email protected]> Co-authored-by: Haokun Liu <[email protected]> Co-authored-by: pruksmhc <[email protected]> Co-authored-by: DeepLearning VM <[email protected]>
1 parent c975afa commit c87a86b

15 files changed

Lines changed: 545 additions & 39 deletions

File tree

jiant/__main__.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,33 @@ def load_model_for_target_train_run(args, ckpt_path, model, strict, task, cuda_d
536536
return to_train
537537

538538

539+
def get_pretrain_stop_metric(early_stopping_method, pretrain_tasks):
540+
"""
541+
Get stop_metric, which is used for early stopping.
542+
543+
Parameters
544+
-------------------
545+
early_stopping_method: str,
546+
pretrain_tasks: List[Task]
547+
548+
Returns
549+
-------------------
550+
stop_metric: str
551+
552+
"""
553+
if early_stopping_method != "auto":
554+
pretrain_names = [task.name for task in pretrain_tasks]
555+
if early_stopping_method in pretrain_names:
556+
index = pretrain_names.index(early_stopping_method)
557+
stop_metric = pretrain_tasks[index].val_metric
558+
else:
559+
raise ValueError("args.early_stopping_method must be either 'auto' or a task name")
560+
561+
else:
562+
stop_metric = pretrain_tasks[0].val_metric if len(pretrain_tasks) == 1 else "macro_avg"
563+
return stop_metric
564+
565+
539566
def main(cl_arguments):
540567
""" Train a model for multitask-training."""
541568
cl_args = handle_arguments(cl_arguments)
@@ -551,7 +578,6 @@ def main(cl_arguments):
551578
tasks = sorted(set(pretrain_tasks + target_tasks), key=lambda x: x.name)
552579
log.info("\tFinished loading tasks in %.3fs", time.time() - start_time)
553580
log.info("\t Tasks: {}".format([task.name for task in tasks]))
554-
555581
# Build model
556582
log.info("Building model...")
557583
start_time = time.time()
@@ -567,7 +593,7 @@ def main(cl_arguments):
567593
if args.do_pretrain:
568594
# Train on pretrain tasks
569595
log.info("Training...")
570-
stop_metric = pretrain_tasks[0].val_metric if len(pretrain_tasks) == 1 else "macro_avg"
596+
stop_metric = get_pretrain_stop_metric(args.early_stopping_method, pretrain_tasks)
571597
should_decrease = (
572598
pretrain_tasks[0].val_metric_decreases if len(pretrain_tasks) == 1 else False
573599
)

jiant/config/base_mlm_roberta.conf

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Base config file for mlm experiments wit roberta
2+
include "defaults.conf"
3+
4+
early_stopping_method=auto // Early stopping method. Options: task_name to only do early stopping based
5+
// on a specific task, 'auto': use the macro_avg
6+
7+
// Multi-task Training
8+
weighting_method = proportional // Weighting method for task sampling, relative to the number of
9+
// training examples in each task:
10+
// Options: uniform, power_<power>, softmax_<temp>
11+
// proportional, proportional_log_batch, and
12+
// proportional_log_example (plus the less-useful inverse,
13+
// inverse_log_example, and inverse_log_batch).
14+
// Additionally, we include the T5 method of examples-proportional-mixing.
15+
// See relevant source code for details.
16+
scaling_method = uniform // Method for scaling loss:
17+
// Options: uniform, max_power_<power>, max_proportional,
18+
// max_proportional_log, max_inverse, max_inverse_log
19+
// max_epoch_<E1_E2_..._En>

jiant/config/defaults.conf

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ max_epochs = -1 // If positive, maximum number of epochs (full pass over a task'
178178
// especially if it's higher than one epoch's worth of steps, it's possible to
179179
// significantly overshoot the intended number of epochs.
180180

181+
early_stopping_method=auto // Early stopping method. Options: task_name to only do early stopping based
182+
// on a specific task, 'auto': use the macro_avg
181183
patience = 5 // Patience in early stopping. Training will stop if performance does not improve at
182184
// all in patience + 1 validations.
183185
keep_all_checkpoints = 0 // If set, keep checkpoint files from every validation. Otherwise, keep
@@ -196,6 +198,8 @@ weighting_method = proportional // Weighting method for task sampling, relative
196198
// proportional, proportional_log_batch, and
197199
// proportional_log_example (plus the less-useful inverse,
198200
// inverse_log_example, and inverse_log_batch).
201+
// Additionally, we include the T5 method of examples_proportional_mixing.
202+
// To use this, set weighting_method=examples_proportional_mixingK=104857
199203
// See relevant source code for details.
200204
scaling_method = uniform // Method for scaling loss:
201205
// Options: uniform, max_power_<power>, max_proportional,

jiant/huggingface_transformers_interface/modules.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(self, args):
3939
self._sep_id = None
4040
self._pad_id = None
4141
self._unk_id = None
42+
self._mask_id = None
4243

4344
# If set, treat these special tokens as part of input segments other than A/B.
4445
self._SEG_ID_CLS = None
@@ -270,6 +271,7 @@ def __init__(self, args):
270271
self._cls_id = self.tokenizer.convert_tokens_to_ids("[CLS]")
271272
self._pad_id = self.tokenizer.convert_tokens_to_ids("[PAD]")
272273
self._unk_id = self.tokenizer.convert_tokens_to_ids("[UNK]")
274+
self._mask_id = self.tokenizer.convert_tokens_to_ids("[MASK]")
273275

274276
self.parameter_setup(args)
275277

@@ -305,7 +307,7 @@ def get_pretrained_lm_head(self):
305307
)
306308
lm_head = model_with_lm_head.cls
307309
lm_head.predictions.decoder.weight = self.model.embeddings.word_embeddings.weight
308-
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
310+
return lm_head
309311

310312

311313
class RobertaEmbedderModule(HuggingfaceTransformersEmbedderModule):
@@ -327,6 +329,7 @@ def __init__(self, args):
327329
self._cls_id = self.tokenizer.convert_tokens_to_ids("<s>")
328330
self._pad_id = self.tokenizer.convert_tokens_to_ids("<pad>")
329331
self._unk_id = self.tokenizer.convert_tokens_to_ids("<unk>")
332+
self._mask_id = self.tokenizer.convert_tokens_to_ids("<mask>")
330333

331334
self.parameter_setup(args)
332335

@@ -358,8 +361,8 @@ def get_pretrained_lm_head(self):
358361
self.input_module, cache_dir=self.cache_dir
359362
)
360363
lm_head = model_with_lm_head.lm_head
361-
lm_head.predictions.decoder.weight = self.model.embeddings.word_embeddings.weight
362-
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
364+
lm_head.decoder.weight = self.model.embeddings.word_embeddings.weight
365+
return lm_head
363366

364367

365368
class AlbertEmbedderModule(HuggingfaceTransformersEmbedderModule):
@@ -381,6 +384,7 @@ def __init__(self, args):
381384
self._cls_id = self.tokenizer.convert_tokens_to_ids("[CLS]")
382385
self._pad_id = self.tokenizer.convert_tokens_to_ids("<pad>")
383386
self._unk_id = self.tokenizer.convert_tokens_to_ids("<unk>")
387+
self._mask_id = self.tokenizer.convert_tokens_to_ids("[MASK]")
384388

385389
self.parameter_setup(args)
386390

@@ -416,7 +420,7 @@ def get_pretrained_lm_head(self):
416420
)
417421
lm_head = model_with_lm_head.predictions
418422
lm_head.decoder.weight = self.model.embeddings.word_embeddings.weight
419-
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
423+
return lm_head
420424

421425

422426
class XLNetEmbedderModule(HuggingfaceTransformersEmbedderModule):
@@ -437,6 +441,7 @@ def __init__(self, args):
437441
self._cls_id = self.tokenizer.convert_tokens_to_ids("<cls>")
438442
self._pad_id = self.tokenizer.convert_tokens_to_ids("<pad>")
439443
self._unk_id = self.tokenizer.convert_tokens_to_ids("<unk>")
444+
self._mask_id = self.tokenizer.convert_tokens_to_ids("<mask>")
440445

441446
self.parameter_setup(args)
442447

@@ -478,7 +483,7 @@ def get_pretrained_lm_head(self, args):
478483
)
479484
lm_head = model_with_lm_head.lm_loss
480485
lm_head.weight = self.model.word_embedding.weight
481-
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
486+
return lm_head
482487

483488

484489
class OpenAIGPTEmbedderModule(HuggingfaceTransformersEmbedderModule):
@@ -541,7 +546,7 @@ def get_pretrained_lm_head(self, args):
541546
)
542547
lm_head = model_with_lm_head.lm_head
543548
lm_head.weight = self.model.tokens_embed.weight[: lm_head.weight.size()[0]]
544-
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
549+
return lm_head
545550

546551

547552
class GPT2EmbedderModule(HuggingfaceTransformersEmbedderModule):
@@ -603,7 +608,7 @@ def get_pretrained_lm_head(self):
603608
)
604609
lm_head = model_with_lm_head.lm_head
605610
lm_head.weight = self.model.wte.weight[: lm_head.weight.size()[0]]
606-
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
611+
return lm_head
607612

608613

609614
class TransfoXLEmbedderModule(HuggingfaceTransformersEmbedderModule):
@@ -724,4 +729,4 @@ def get_pretrained_lm_head(self):
724729
)
725730
lm_head = model_with_lm_head.pred_layer
726731
lm_head.proj.weight = self.model.embeddings.weight
727-
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
732+
return lm_head

jiant/models.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from jiant.modules.span_modules import SpanClassifierModule
4646
from jiant.huggingface_transformers_interface import input_module_uses_transformers
4747
from jiant.tasks.edge_probing import EdgeProbingTask
48-
from jiant.tasks.lm import LanguageModelingTask
48+
from jiant.tasks.lm import AutoregressiveLanguageModelingTask, MaskedLanguageModelingTask
4949
from jiant.tasks.lm_parsing import LanguageModelingParsingTask
5050
from jiant.tasks.qa import MultiRCTask, ReCoRDTask
5151
from jiant.tasks.seq2seq import Seq2SeqTask
@@ -76,6 +76,7 @@
7676
format_output,
7777
uses_cuda,
7878
)
79+
from jiant.utils.data_loaders import get_tokenizer
7980

8081
# Elmo stuff
8182
# Look in $ELMO_SRC_DIR (e.g. /usr/share/jsalt/elmo) or download from web
@@ -158,18 +159,22 @@ def build_sent_encoder(args, vocab, d_emb, tasks, embedder, cove_layer):
158159
)
159160
d_sent = args.d_word
160161
log.info("Using PRPN sentence encoder!")
161-
elif any(isinstance(task, LanguageModelingTask) for task in tasks) or args.sent_enc == "bilm":
162+
elif (
163+
any(isinstance(task, AutoregressiveLanguageModelingTask) for task in tasks)
164+
or args.sent_enc == "bilm"
165+
):
162166
assert_for_log(args.sent_enc in ["rnn", "bilm"], "Only RNNLM supported!")
163-
assert_for_log(
164-
not (
165-
args.input_module == "elmo"
166-
or args.input_module.startswith("bert")
167-
or args.input_module.startswith("xlnet")
168-
),
169-
f"Using input_module = {args.input_module} for language modeling is probably not a "
170-
"good idea, since it allows the language model to use information from the right-hand "
171-
"context.",
172-
)
167+
if any(isinstance(task, AutoregressiveLanguageModelingTask) for task in tasks):
168+
assert_for_log(
169+
not (
170+
args.input_module == "elmo"
171+
or args.input_module.startswith("bert")
172+
or args.input_module.startswith("xlnet")
173+
),
174+
f"Using input_module = {args.input_module} for language modeling is probably not a "
175+
"good idea, since it allows the language model to use information from the right-hand "
176+
"context.",
177+
)
173178
bilm = BiLMEncoder(d_emb, args.d_hid, args.d_hid, args.n_layers_enc)
174179
sent_encoder = SentenceEncoder(
175180
vocab,
@@ -549,7 +554,10 @@ def build_task_specific_modules(task, model, d_sent, d_emb, vocab, embedder, arg
549554
hid2voc = build_lm(task, d_sent, args)
550555
setattr(model, "%s_hid2voc" % task.name, hid2voc)
551556
setattr(model, "%s_mdl" % task.name, hid2voc)
552-
elif isinstance(task, LanguageModelingTask):
557+
elif isinstance(task, MaskedLanguageModelingTask):
558+
module = build_mlm(model.sent_encoder._text_field_embedder)
559+
setattr(model, "%s_mdl" % task.name, module)
560+
elif isinstance(task, AutoregressiveLanguageModelingTask):
553561
assert not input_module_uses_transformers(args.input_module), (
554562
"our LM Task does not support transformers, if you need them, try to update",
555563
"corresponding parts of the code. You may find get_pretrained_lm_head and",
@@ -746,6 +754,12 @@ def build_lm(task, d_inp, args):
746754
return hid2voc
747755

748756

757+
def build_mlm(embedder):
758+
" Build MLM components "
759+
lm_head = embedder.get_pretrained_lm_head()
760+
return lm_head
761+
762+
749763
def build_span_classifier(task, d_sent, task_params):
750764
module = SpanClassifierModule(task, d_sent, task_params, num_spans=task.num_spans)
751765
return module
@@ -853,7 +867,9 @@ def forward(self, task, batch, predict=False):
853867
task, (PairClassificationTask, PairRegressionTask, PairOrdinalRegressionTask)
854868
):
855869
out = self._pair_sentence_forward(batch, task, predict)
856-
elif isinstance(task, LanguageModelingTask):
870+
elif isinstance(task, MaskedLanguageModelingTask):
871+
out = self._masked_lm_forward(batch, task, predict)
872+
elif isinstance(task, AutoregressiveLanguageModelingTask):
857873
if isinstance(self.sent_encoder._phrase_layer, ONLSTMStack) or isinstance(
858874
self.sent_encoder._phrase_layer, PRPN
859875
):
@@ -1160,6 +1176,32 @@ def _lm_forward(self, batch, task, predict):
11601176
pass
11611177
return out
11621178

1179+
def _masked_lm_forward(self, batch, task, predict):
1180+
"""
1181+
We currently only support RoBERTa-style dynamic masking, with the exact
1182+
setup and parameters as RoBERTa.
1183+
"""
1184+
out = {}
1185+
tokenizer_name = self.sent_encoder._text_field_embedder.input_module
1186+
text_embedder = self.sent_encoder._text_field_embedder
1187+
vocab_size = text_embedder.model.embeddings.word_embeddings.num_embeddings
1188+
input_key = text_embedder.tokenizer_required
1189+
mask_idx = text_embedder._mask_id
1190+
b_size, seq_len = batch["targs"].size()
1191+
inputs = batch["input"][input_key]
1192+
labels = batch["targs"]
1193+
inputs, labels, _, _, _, _ = task.mlm_dynamic_masking(
1194+
inputs, labels, mask_idx, tokenizer_name, self.sent_encoder
1195+
)
1196+
batch["input"][input_key] = inputs
1197+
sent_embs, sent_mask = self.sent_encoder(batch["input"], task)
1198+
module = getattr(self, "%s_mdl" % task.name)
1199+
logits = module.forward(sent_embs)
1200+
out["logits"] = logits
1201+
out["loss"] = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1))
1202+
out["n_exs"] = format_output(b_size, self._cuda_device)
1203+
return out
1204+
11631205
def _mc_forward(self, batch, task, predict):
11641206
""" Forward for a multiple choice question answering task """
11651207
out = {}

jiant/preprocess.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from jiant.tasks import REGISTRY as TASKS_REGISTRY
5555
from jiant.tasks.seq2seq import Seq2SeqTask
5656
from jiant.tasks.tasks import SequenceGenerationTask, Task
57+
from jiant.tasks.lm import MaskedLanguageModelingTask
5758
from jiant.utils import config, serialize, utils, options
5859
from jiant.utils.options import parse_task_list_arg
5960

@@ -261,6 +262,7 @@ def _build_vocab(args: config.Params, tasks: List[Task], vocab_path: str):
261262
for task in tasks: # add custom label namespaces
262263
# TODO: surface more docs for add_task_label_vocab:
263264
add_task_label_vocab(vocab, task)
265+
264266
if args.force_include_wsj_vocabulary:
265267
# Add WSJ full vocabulary for PTB F1 parsing tasks.
266268
add_wsj_vocab(vocab, args.data_dir)
@@ -661,10 +663,6 @@ def add_task_label_vocab(vocab, task):
661663
return
662664
log.info("\tTask '%s': adding vocab namespace '%s'", task.name, namespace)
663665

664-
if isinstance(task, SequenceGenerationTask):
665-
for special in SPECIALS:
666-
vocab.add_token_to_namespace(special, namespace)
667-
668666
for label in task.get_all_labels():
669667
vocab.add_token_to_namespace(label, namespace)
670668

0 commit comments

Comments
 (0)