Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 4defc7a

Browse files
committed
update testings
1 parent 2719e81 commit 4defc7a

9 files changed

Lines changed: 39 additions & 21 deletions

File tree

src/gluonnlp/models/albert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,8 @@ def list_pretrained_albert():
607607

608608
def get_pretrained_albert(model_name: str = 'google_albert_base_v2',
609609
root: str = get_model_zoo_home_dir(),
610-
load_backbone=True, load_mlm=False)\
610+
load_backbone: str = True,
611+
load_mlm: str = False)\
611612
-> Tuple[CN, SentencepieceTokenizer, str, str]:
612613
"""Get the pretrained Albert weights
613614

src/gluonnlp/models/bert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,8 @@ def list_pretrained_bert():
598598

599599
def get_pretrained_bert(model_name: str = 'google_en_cased_bert_base',
600600
root: str = get_model_zoo_home_dir(),
601-
load_backbone=True, load_mlm=False)\
601+
load_backbone: str = True,
602+
load_mlm: str = False)\
602603
-> Tuple[CN, HuggingFaceWordPieceTokenizer, str, str]:
603604
"""Get the pretrained bert weights
604605

src/gluonnlp/models/mobilebert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,8 @@ def list_pretrained_mobilebert():
909909

910910
def get_pretrained_mobilebert(model_name: str = 'google_uncased_mobilebert',
911911
root: str = get_model_zoo_home_dir(),
912-
load_backbone=True, load_mlm=True)\
912+
load_backbone: str = True,
913+
load_mlm: str = False)\
913914
-> Tuple[CN, HuggingFaceWordPieceTokenizer, str, str]:
914915
"""Get the pretrained mobile bert weights
915916

src/gluonnlp/models/roberta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
'merges': 'fairseq_roberta_base/gpt2-396d4d8e.merges',
5656
'vocab': 'fairseq_roberta_base/gpt2-f1335494.vocab',
5757
'params': 'fairseq_roberta_base/model-09a1520a.params',
58-
'mlm_params': 'google_uncased_mobilebert/model_mlm-29889e2b.params',
58+
'mlm_params': 'fairseq_roberta_base/model_mlm-29889e2b.params',
5959
},
6060
'fairseq_roberta_large': {
6161
'cfg': 'fairseq_roberta_large/model-6e66dc4a.yml',

tests/test_models_albert.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,13 @@ def test_albert_get_pretrained(model_name):
108108
assert len(list_pretrained_albert()) > 0
109109
with tempfile.TemporaryDirectory() as root:
110110
cfg, tokenizer, backbone_params_path, mlm_params_path =\
111-
get_pretrained_albert(model_name, root=root, load_mlm=True)
111+
get_pretrained_albert(model_name, load_backbone=True, load_mlm=True, root=root)
112112
assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
113113
albert_model = AlbertModel.from_cfg(cfg)
114114
albert_model.load_parameters(backbone_params_path)
115115
albert_mlm_model = AlbertForMLM(cfg)
116-
albert_mlm_model.load_parameters(mlm_params_path)
116+
if mlm_params_path is not None:
117+
albert_mlm_model.load_parameters(mlm_params_path)
117118
# Just load the backbone
118119
albert_mlm_model = AlbertForMLM(cfg)
119120
albert_mlm_model.backbone_model.load_parameters(backbone_params_path)

tests/test_models_bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_bert_get_pretrained(model_name):
1818
assert len(list_pretrained_bert()) > 0
1919
with tempfile.TemporaryDirectory() as root:
2020
cfg, tokenizer, backbone_params_path, mlm_params_path =\
21-
get_pretrained_bert(model_name, root=root)
21+
get_pretrained_bert(model_name, load_backbone=True, load_mlm=True, root=root)
2222
assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
2323
bert_model = BertModel.from_cfg(cfg)
2424
bert_model.load_parameters(backbone_params_path)

tests/test_models_mobilebert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_list_pretrained_mobilebert():
1717
def test_bert_get_pretrained(model_name):
1818
with tempfile.TemporaryDirectory() as root:
1919
cfg, tokenizer, backbone_params_path, mlm_params_path =\
20-
get_pretrained_mobilebert(model_name, root=root)
20+
get_pretrained_mobilebert(model_name, load_backbone=True, load_mlm=True, root=root)
2121
assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
2222
mobilebert_model = MobileBertModel.from_cfg(cfg)
2323
mobilebert_model.load_parameters(backbone_params_path)

tests/test_models_roberta.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import mxnet as mx
44
import tempfile
5-
from gluonnlp.models.roberta import RobertaModel,\
5+
from gluonnlp.models.roberta import RobertaModel, RobertaForMLM, \
66
list_pretrained_roberta, get_pretrained_roberta
77
from gluonnlp.loss import LabelSmoothCrossEntropyLoss
88

@@ -19,12 +19,12 @@ def test_roberta(model_name):
1919
# test from pretrained
2020
assert len(list_pretrained_roberta()) > 0
2121
with tempfile.TemporaryDirectory() as root:
22-
cfg, tokenizer, params_path =\
23-
get_pretrained_roberta(model_name, root=root)
22+
cfg, tokenizer, params_path, mlm_params_path =\
23+
get_pretrained_roberta(model_name, load_backbone=True, load_mlm=True, root=root)
2424
assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
2525
roberta_model = RobertaModel.from_cfg(cfg)
2626
roberta_model.load_parameters(params_path)
27-
27+
2828
# test forward
2929
batch_size = 3
3030
seq_length = 32
@@ -45,12 +45,19 @@ def test_roberta(model_name):
4545
),
4646
dtype=np.int32
4747
)
48-
x = roberta_model(input_ids, valid_length)
48+
contextual_embeddings, pooled_out = roberta_model(input_ids, valid_length)
4949
mx.npx.waitall()
5050
# test backward
5151
label_smooth_loss = LabelSmoothCrossEntropyLoss(num_labels=vocab_size)
5252
with mx.autograd.record():
53-
x = roberta_model(input_ids, valid_length)
54-
loss = label_smooth_loss(x, input_ids)
53+
contextual_embeddings, pooled_out = roberta_model(input_ids, valid_length)
54+
loss = label_smooth_loss(contextual_embeddings, input_ids)
5555
loss.backward()
5656
mx.npx.waitall()
57+
58+
# test for mlm model
59+
roberta_mlm_model = RobertaForMLM(cfg)
60+
if mlm_params_path is not None:
61+
roberta_mlm_model.load_parameters(mlm_params_path)
62+
roberta_mlm_model = RobertaForMLM(cfg)
63+
roberta_mlm_model.backbone_model.load_parameters(params_path)

tests/test_models_xlmr.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import mxnet as mx
44
import tempfile
5-
from gluonnlp.models.xlmr import XLMRModel,\
5+
from gluonnlp.models.xlmr import XLMRModel, XLMRForMLM, \
66
list_pretrained_xlmr, get_pretrained_xlmr
77
from gluonnlp.loss import LabelSmoothCrossEntropyLoss
88

@@ -19,8 +19,8 @@ def test_xlmr():
1919
assert len(list_pretrained_xlmr()) > 0
2020
for model_name in list_pretrained_xlmr():
2121
with tempfile.TemporaryDirectory() as root:
22-
cfg, tokenizer, params_path =\
23-
get_pretrained_xlmr(model_name, root=root)
22+
cfg, tokenizer, params_path, mlm_params_path =\
23+
get_pretrained_xlmr(model_name, load_backbone=True, load_mlm=True, root=root)
2424
assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
2525
xlmr_model = XLMRModel.from_cfg(cfg)
2626
xlmr_model.load_parameters(params_path)
@@ -44,12 +44,19 @@ def test_xlmr():
4444
),
4545
dtype=np.int32
4646
)
47-
x = xlmr_model(input_ids, valid_length)
47+
contextual_embeddings, pooled_out = xlmr_model(input_ids, valid_length)
4848
mx.npx.waitall()
4949
# test backward
5050
label_smooth_loss = LabelSmoothCrossEntropyLoss(num_labels=vocab_size)
5151
with mx.autograd.record():
52-
x = xlmr_model(input_ids, valid_length)
53-
loss = label_smooth_loss(x, input_ids)
52+
contextual_embeddings, pooled_out = xlmr_model(input_ids, valid_length)
53+
loss = label_smooth_loss(contextual_embeddings, input_ids)
5454
loss.backward()
5555
mx.npx.waitall()
56+
57+
# test for mlm model
58+
xlmr = XLMRForMLM(cfg)
59+
if mlm_params_path is not None:
60+
xlmr.load_parameters(mlm_params_path)
61+
xlmr = XLMRForMLM(cfg)
62+
xlmr.backbone_model.load_parameters(params_path)

0 commit comments

Comments
 (0)