22import numpy as np
33import mxnet as mx
44import tempfile
5- from gluonnlp .models .roberta import RobertaModel ,\
5+ from gluonnlp .models .roberta import RobertaModel , RobertaForMLM , \
66 list_pretrained_roberta , get_pretrained_roberta
77from gluonnlp .loss import LabelSmoothCrossEntropyLoss
88
@@ -19,12 +19,12 @@ def test_roberta(model_name):
1919 # test from pretrained
2020 assert len (list_pretrained_roberta ()) > 0
2121 with tempfile .TemporaryDirectory () as root :
22- cfg , tokenizer , params_path = \
23- get_pretrained_roberta (model_name , root = root )
22+ cfg , tokenizer , params_path , mlm_params_path = \
23+ get_pretrained_roberta (model_name , load_backbone = True , load_mlm = True , root = root )
2424 assert cfg .MODEL .vocab_size == len (tokenizer .vocab )
2525 roberta_model = RobertaModel .from_cfg (cfg )
2626 roberta_model .load_parameters (params_path )
27-
27+
2828 # test forward
2929 batch_size = 3
3030 seq_length = 32
@@ -45,12 +45,19 @@ def test_roberta(model_name):
4545 ),
4646 dtype = np .int32
4747 )
48- x = roberta_model (input_ids , valid_length )
48+ contextual_embeddings , pooled_out = roberta_model (input_ids , valid_length )
4949 mx .npx .waitall ()
5050 # test backward
5151 label_smooth_loss = LabelSmoothCrossEntropyLoss (num_labels = vocab_size )
5252 with mx .autograd .record ():
53- x = roberta_model (input_ids , valid_length )
54- loss = label_smooth_loss (x , input_ids )
53+ contextual_embeddings , pooled_out = roberta_model (input_ids , valid_length )
54+ loss = label_smooth_loss (contextual_embeddings , input_ids )
5555 loss .backward ()
5656 mx .npx .waitall ()
57+
58+ # test for mlm model
59+ roberta_mlm_model = RobertaForMLM (cfg )
60+ if mlm_params_path is not None :
61+ roberta_mlm_model .load_parameters (mlm_params_path )
62+ roberta_mlm_model = RobertaForMLM (cfg )
63+ roberta_mlm_model .backbone_model .load_parameters (params_path )
0 commit comments