Skip to content

Commit 5811d40

Browse files
committed
update
1 parent 44a09a3 commit 5811d40

3 files changed

Lines changed: 120 additions & 93 deletions

File tree

scripts/conversion_toolkits/convert_fairseq_roberta.py

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
import os
22
import sys
3-
import argparse
3+
import json
44
import shutil
55
import logging
6-
import json
7-
from numpy.testing import assert_allclose
6+
import argparse
7+
88
import mxnet as mx
99
import numpy as np
10+
from numpy.testing import assert_allclose
11+
1012
import torch
13+
from gluonnlp.data.vocab import Vocab as gluon_Vocab
1114
from gluonnlp.utils.misc import sha1sum, logging_config
15+
from fairseq.models.roberta import RobertaModel as fairseq_RobertaModel
1216
from gluonnlp.models.roberta import RobertaModel, RobertaForMLM
1317
from gluonnlp.data.tokenizers import HuggingFaceByteBPETokenizer
14-
from gluonnlp.data.vocab import Vocab as gluon_Vocab
15-
from fairseq.models.roberta import RobertaModel as fairseq_RobertaModel
1618

1719
mx.npx.set_np()
1820

@@ -164,16 +166,18 @@ def convert_config(fairseq_cfg, vocab_size, cfg):
164166
def convert_params(fairseq_model,
165167
gluon_cfg,
166168
ctx,
167-
is_mlm=True
169+
is_mlm=True,
168170
gluon_prefix='robert_'):
169171
print('converting params')
170172
fairseq_params = fairseq_model.state_dict()
171173
fairseq_prefix = 'model.decoder.'
172174
if is_mlm:
173175
gluon_model = RobertaForMLM(backbone_cfg=gluon_cfg, prefix=gluon_prefix)
176+
# output all hidden states for testing
174177
gluon_model.backbone_model._output_all_encodings = True
178+
gluon_model.backbone_model.encoder._output_all_encodings = True
175179
else:
176-
gluon_model = RobertaForMLM.from_cfg(
180+
gluon_model = RobertaModel.from_cfg(
177181
gluon_cfg,
178182
use_pooler=True,
179183
output_all_encodings=True,
@@ -223,7 +227,7 @@ def convert_params(fairseq_model,
223227
fairseq_params[fs_name].cpu().numpy())
224228

225229
for k, v in [
226-
('sentence_encoder.embed_tokens.weight', 'tokens_embed_weight'),
230+
('sentence_encoder.embed_tokens.weight', 'word_embed_weight'),
227231
('sentence_encoder.emb_layer_norm.weight', 'embed_ln_gamma'),
228232
('sentence_encoder.emb_layer_norm.bias', 'embed_ln_beta'),
229233
]:
@@ -245,7 +249,7 @@ def convert_params(fairseq_model,
245249
('lm_head.dense.bias', 'mlm_proj_bias'),
246250
('lm_head.layer_norm.weight', 'mlm_ln_gamma'),
247251
('lm_head.layer_norm.bias', 'mlm_ln_beta'),
248-
('lm_head.bias', 'tokens_embed_bias')
252+
('lm_head.bias', 'word_embed_bias')
249253
]:
250254
fs_name = fairseq_prefix + k
251255
gl_name = gluon_prefix + v
@@ -264,6 +268,7 @@ def test_model(fairseq_model, gluon_model, gpu):
264268
ctx = mx.gpu(gpu) if gpu is not None else mx.cpu()
265269
batch_size = 3
266270
seq_length = 32
271+
num_mask = 5
267272
vocab_size = len(fairseq_model.task.dictionary)
268273
padding_id = fairseq_model.model.decoder.sentence_encoder.padding_idx
269274
input_ids = np.random.randint( # skip padding_id
@@ -276,28 +281,37 @@ def test_model(fairseq_model, gluon_model, gpu):
276281
seq_length,
277282
(batch_size,)
278283
)
284+
mlm_positions = np.random.randint(
285+
0,
286+
seq_length // 2,
287+
(batch_size, num_mask)
288+
)
279289
for i in range(batch_size): # add padding, for fairseq padding mask
280290
input_ids[i,valid_length[i]:] = padding_id
281291

282292
gl_input_ids = mx.np.array(input_ids, dtype=np.int32, ctx=ctx)
283293
gl_valid_length = mx.np.array(valid_length, dtype=np.int32, ctx=ctx)
294+
gl_masked_positions = mx.np.array(mlm_positions, dtype=np.int32, ctx=ctx)
284295

285296
fs_input_ids = torch.from_numpy(input_ids).cuda(gpu)
297+
fs_masked_positions = torch.from_numpy(mlm_positions).cuda(gpu)
286298
if gpu is not None:
287299
fs_input_ids = fs_input_ids.cuda(gpu)
288300

289301
fairseq_model.model.eval()
290302

291-
gluon_all_hiddens, gluon_pooled, gluon_mlm_scores = \
292-
gluon_model(gl_input_ids, gl_valid_length)
293-
294-
fairseq_mlm_scores, fs_extra = \
295-
fairseq_model.model.cuda(gpu)(fs_input_ids, return_all_hiddens=True)
303+
gl_all_hiddens, gl_pooled, gl_mlm_scores = \
304+
gluon_model(gl_input_ids, gl_valid_length, gl_masked_positions)
305+
fs_mlm_scores, fs_extra = \
306+
fairseq_model.model.cuda(gpu)(
307+
fs_input_ids,
308+
return_all_hiddens=True,
309+
masked_tokens=fs_masked_positions)
296310
fs_all_hiddens = fs_extra['inner_states']
297311

298312
num_layers = fairseq_model.args.encoder_layers
299313
for i in range(num_layers + 1):
300-
gl_hidden = gluon_all_hiddens[i].asnumpy()
314+
gl_hidden = gl_all_hiddens[i].asnumpy()
301315
fs_hidden = fs_all_hiddens[i]
302316
fs_hidden = fs_hidden.transpose(0, 1)
303317
fs_hidden = fs_hidden.detach().cpu().numpy()
@@ -309,13 +323,13 @@ def test_model(fairseq_model, gluon_model, gpu):
309323
1E-3
310324
)
311325

312-
gluon_mlm_scores = gluon_mlm_scores.asnumpy()
313-
fairseq_mlm_scores = fairseq_mlm_scores.transpose(0, 1)
314-
fairseq_mlm_scores = fairseq_mlm_scores.detach().cpu().numpy()
326+
gl_mlm_scores = gl_mlm_scores.asnumpy()
327+
fs_mlm_scores = fs_mlm_scores.transpose(0, 1)
328+
fs_mlm_scores = fs_mlm_scores.detach().cpu().numpy()
315329
for j in range(batch_size):
316330
assert_allclose(
317-
gluon_mlm_scores[j, :valid_length[j], :],
318-
fairseq_mlm_scores[j, :valid_length[j], :],
331+
gl_mlm_scores[j, :valid_length[j], :],
332+
fs_mlm_scores[j, :valid_length[j], :],
319333
1E-3,
320334
1E-3
321335
)
@@ -359,19 +373,19 @@ def convert_fairseq_model(args):
359373
is_mlm=is_mlm,
360374
gluon_prefix='roberta_')
361375

362-
if is_mlm:
363-
if args.test:
364-
test_model(fairseq_roberta, gluon_roberta, args.gpu)
365-
366-
gluon_roberta.save_parameters(os.path.join(args.save_dir, 'model_mlm.params'), deduplicate=True)
367-
logging.info('Convert the RoBERTa MLM model in {} to {}'.
368-
format(os.path.join(args.fairseq_model_path, 'model.pt'), \
369-
os.path.join(args.save_dir, 'model_mlm.params')))
370-
else:
371-
gluon_roberta.save_parameters(os.path.join(args.save_dir, 'model.params'), deduplicate=True)
372-
logging.info('Convert the RoBERTa backbone model in {} to {}'.
373-
format(os.path.join(args.fairseq_model_path, 'model.pt'), \
374-
os.path.join(args.save_dir, 'model.params')))
376+
if is_mlm:
377+
if args.test:
378+
test_model(fairseq_roberta, gluon_roberta, args.gpu)
379+
380+
gluon_roberta.save_parameters(os.path.join(args.save_dir, 'model_mlm.params'), deduplicate=True)
381+
logging.info('Convert the RoBERTa MLM model in {} to {}'.
382+
format(os.path.join(args.fairseq_model_path, 'model.pt'), \
383+
os.path.join(args.save_dir, 'model_mlm.params')))
384+
else:
385+
gluon_roberta.save_parameters(os.path.join(args.save_dir, 'model.params'), deduplicate=True)
386+
logging.info('Convert the RoBERTa backbone model in {} to {}'.
387+
format(os.path.join(args.fairseq_model_path, 'model.pt'), \
388+
os.path.join(args.save_dir, 'model.params')))
375389

376390
logging.info('Conversion finished!')
377391
logging.info('Statistics:')

scripts/conversion_toolkits/convert_fairseq_xlmr.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
import os
2-
import argparse
3-
import logging
42
import copy
3+
import logging
4+
import argparse
5+
56
import mxnet as mx
6-
from gluonnlp.third_party import sentencepiece_model_pb2
7+
78
from gluonnlp.utils.misc import logging_config
89
from gluonnlp.models.xlmr import XLMRModel as gluon_XLMRModel
9-
from gluonnlp.data.tokenizers import SentencepieceTokenizer
10+
from gluonnlp.third_party import sentencepiece_model_pb2
1011
from fairseq.models.roberta import XLMRModel as fairseq_XLMRModel
11-
from convert_fairseq_roberta import (
12-
convert_config,
13-
convert_params,
14-
test_model,
15-
test_vocab,
16-
rename
17-
)
12+
from convert_fairseq_roberta import rename, test_model, test_vocab, convert_config, convert_params
13+
from gluonnlp.data.tokenizers import SentencepieceTokenizer
14+
1815

1916
def parse_args():
2017
parser = argparse.ArgumentParser(description='Convert the fairseq XLM-R Model to Gluon.')

0 commit comments

Comments
 (0)