|
45 | 45 | from jiant.modules.span_modules import SpanClassifierModule |
46 | 46 | from jiant.huggingface_transformers_interface import input_module_uses_transformers |
47 | 47 | from jiant.tasks.edge_probing import EdgeProbingTask |
48 | | -from jiant.tasks.lm import LanguageModelingTask |
| 48 | +from jiant.tasks.lm import AutoregressiveLanguageModelingTask, MaskedLanguageModelingTask |
49 | 49 | from jiant.tasks.lm_parsing import LanguageModelingParsingTask |
50 | 50 | from jiant.tasks.qa import MultiRCTask, ReCoRDTask |
51 | 51 | from jiant.tasks.seq2seq import Seq2SeqTask |
|
76 | 76 | format_output, |
77 | 77 | uses_cuda, |
78 | 78 | ) |
| 79 | +from jiant.utils.data_loaders import get_tokenizer |
79 | 80 |
|
80 | 81 | # Elmo stuff |
81 | 82 | # Look in $ELMO_SRC_DIR (e.g. /usr/share/jsalt/elmo) or download from web |
@@ -158,18 +159,22 @@ def build_sent_encoder(args, vocab, d_emb, tasks, embedder, cove_layer): |
158 | 159 | ) |
159 | 160 | d_sent = args.d_word |
160 | 161 | log.info("Using PRPN sentence encoder!") |
161 | | - elif any(isinstance(task, LanguageModelingTask) for task in tasks) or args.sent_enc == "bilm": |
| 162 | + elif ( |
| 163 | + any(isinstance(task, AutoregressiveLanguageModelingTask) for task in tasks) |
| 164 | + or args.sent_enc == "bilm" |
| 165 | + ): |
162 | 166 | assert_for_log(args.sent_enc in ["rnn", "bilm"], "Only RNNLM supported!") |
163 | | - assert_for_log( |
164 | | - not ( |
165 | | - args.input_module == "elmo" |
166 | | - or args.input_module.startswith("bert") |
167 | | - or args.input_module.startswith("xlnet") |
168 | | - ), |
169 | | - f"Using input_module = {args.input_module} for language modeling is probably not a " |
170 | | - "good idea, since it allows the language model to use information from the right-hand " |
171 | | - "context.", |
172 | | - ) |
| 167 | + if any(isinstance(task, AutoregressiveLanguageModelingTask) for task in tasks): |
| 168 | + assert_for_log( |
| 169 | + not ( |
| 170 | + args.input_module == "elmo" |
| 171 | + or args.input_module.startswith("bert") |
| 172 | + or args.input_module.startswith("xlnet") |
| 173 | + ), |
| 174 | + f"Using input_module = {args.input_module} for language modeling is probably not a " |
| 175 | + "good idea, since it allows the language model to use information from the right-hand " |
| 176 | + "context.", |
| 177 | + ) |
173 | 178 | bilm = BiLMEncoder(d_emb, args.d_hid, args.d_hid, args.n_layers_enc) |
174 | 179 | sent_encoder = SentenceEncoder( |
175 | 180 | vocab, |
@@ -549,7 +554,10 @@ def build_task_specific_modules(task, model, d_sent, d_emb, vocab, embedder, arg |
549 | 554 | hid2voc = build_lm(task, d_sent, args) |
550 | 555 | setattr(model, "%s_hid2voc" % task.name, hid2voc) |
551 | 556 | setattr(model, "%s_mdl" % task.name, hid2voc) |
552 | | - elif isinstance(task, LanguageModelingTask): |
| 557 | + elif isinstance(task, MaskedLanguageModelingTask): |
| 558 | + module = build_mlm(model.sent_encoder._text_field_embedder) |
| 559 | + setattr(model, "%s_mdl" % task.name, module) |
| 560 | + elif isinstance(task, AutoregressiveLanguageModelingTask): |
553 | 561 | assert not input_module_uses_transformers(args.input_module), ( |
554 | 562 | "our LM Task does not support transformers, if you need them, try to update", |
555 | 563 | "corresponding parts of the code. You may find get_pretrained_lm_head and", |
@@ -746,6 +754,12 @@ def build_lm(task, d_inp, args): |
746 | 754 | return hid2voc |
747 | 755 |
|
748 | 756 |
|
| 757 | +def build_mlm(embedder): |
| 758 | + " Build MLM components " |
| 759 | + lm_head = embedder.get_pretrained_lm_head() |
| 760 | + return lm_head |
| 761 | + |
| 762 | + |
749 | 763 | def build_span_classifier(task, d_sent, task_params): |
750 | 764 | module = SpanClassifierModule(task, d_sent, task_params, num_spans=task.num_spans) |
751 | 765 | return module |
@@ -853,7 +867,9 @@ def forward(self, task, batch, predict=False): |
853 | 867 | task, (PairClassificationTask, PairRegressionTask, PairOrdinalRegressionTask) |
854 | 868 | ): |
855 | 869 | out = self._pair_sentence_forward(batch, task, predict) |
856 | | - elif isinstance(task, LanguageModelingTask): |
| 870 | + elif isinstance(task, MaskedLanguageModelingTask): |
| 871 | + out = self._masked_lm_forward(batch, task, predict) |
| 872 | + elif isinstance(task, AutoregressiveLanguageModelingTask): |
857 | 873 | if isinstance(self.sent_encoder._phrase_layer, ONLSTMStack) or isinstance( |
858 | 874 | self.sent_encoder._phrase_layer, PRPN |
859 | 875 | ): |
@@ -1160,6 +1176,32 @@ def _lm_forward(self, batch, task, predict): |
1160 | 1176 | pass |
1161 | 1177 | return out |
1162 | 1178 |
|
| 1179 | + def _masked_lm_forward(self, batch, task, predict): |
| 1180 | + """ |
| 1181 | + We currently only support RoBERTa-style dynamic masking, with the exact |
| 1182 | + setup and parameters as RoBERTa. |
| 1183 | + """ |
| 1184 | + out = {} |
| 1185 | + tokenizer_name = self.sent_encoder._text_field_embedder.input_module |
| 1186 | + text_embedder = self.sent_encoder._text_field_embedder |
| 1187 | + vocab_size = text_embedder.model.embeddings.word_embeddings.num_embeddings |
| 1188 | + input_key = text_embedder.tokenizer_required |
| 1189 | + mask_idx = text_embedder._mask_id |
| 1190 | + b_size, seq_len = batch["targs"].size() |
| 1191 | + inputs = batch["input"][input_key] |
| 1192 | + labels = batch["targs"] |
| 1193 | + inputs, labels, _, _, _, _ = task.mlm_dynamic_masking( |
| 1194 | + inputs, labels, mask_idx, tokenizer_name, self.sent_encoder |
| 1195 | + ) |
| 1196 | + batch["input"][input_key] = inputs |
| 1197 | + sent_embs, sent_mask = self.sent_encoder(batch["input"], task) |
| 1198 | + module = getattr(self, "%s_mdl" % task.name) |
| 1199 | + logits = module.forward(sent_embs) |
| 1200 | + out["logits"] = logits |
| 1201 | + out["loss"] = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1)) |
| 1202 | + out["n_exs"] = format_output(b_size, self._cuda_device) |
| 1203 | + return out |
| 1204 | + |
1163 | 1205 | def _mc_forward(self, batch, task, predict): |
1164 | 1206 | """ Forward for a multiple choice question answering task """ |
1165 | 1207 | out = {} |
|
0 commit comments