11import os
22import sys
3- import argparse
3+ import json
44import shutil
55import logging
6- import json
7- from numpy . testing import assert_allclose
6+ import argparse
7+
88import mxnet as mx
99import numpy as np
10+ from numpy .testing import assert_allclose
11+
1012import torch
13+ from gluonnlp .data .vocab import Vocab as gluon_Vocab
1114from gluonnlp .utils .misc import sha1sum , logging_config
15+ from fairseq .models .roberta import RobertaModel as fairseq_RobertaModel
1216from gluonnlp .models .roberta import RobertaModel , RobertaForMLM
1317from gluonnlp .data .tokenizers import HuggingFaceByteBPETokenizer
14- from gluonnlp .data .vocab import Vocab as gluon_Vocab
15- from fairseq .models .roberta import RobertaModel as fairseq_RobertaModel
1618
1719mx .npx .set_np ()
1820
@@ -164,16 +166,18 @@ def convert_config(fairseq_cfg, vocab_size, cfg):
164166def convert_params (fairseq_model ,
165167 gluon_cfg ,
166168 ctx ,
167- is_mlm = True
169+ is_mlm = True ,
168170 gluon_prefix = 'robert_' ):
169171 print ('converting params' )
170172 fairseq_params = fairseq_model .state_dict ()
171173 fairseq_prefix = 'model.decoder.'
172174 if is_mlm :
173175 gluon_model = RobertaForMLM (backbone_cfg = gluon_cfg , prefix = gluon_prefix )
176+ # output all hidden states for testing
174177 gluon_model .backbone_model ._output_all_encodings = True
178+ gluon_model .backbone_model .encoder ._output_all_encodings = True
175179 else :
176- gluon_model = RobertaForMLM .from_cfg (
180+ gluon_model = RobertaModel .from_cfg (
177181 gluon_cfg ,
178182 use_pooler = True ,
179183 output_all_encodings = True ,
@@ -223,7 +227,7 @@ def convert_params(fairseq_model,
223227 fairseq_params [fs_name ].cpu ().numpy ())
224228
225229 for k , v in [
226- ('sentence_encoder.embed_tokens.weight' , 'tokens_embed_weight ' ),
230+ ('sentence_encoder.embed_tokens.weight' , 'word_embed_weight ' ),
227231 ('sentence_encoder.emb_layer_norm.weight' , 'embed_ln_gamma' ),
228232 ('sentence_encoder.emb_layer_norm.bias' , 'embed_ln_beta' ),
229233 ]:
@@ -245,7 +249,7 @@ def convert_params(fairseq_model,
245249 ('lm_head.dense.bias' , 'mlm_proj_bias' ),
246250 ('lm_head.layer_norm.weight' , 'mlm_ln_gamma' ),
247251 ('lm_head.layer_norm.bias' , 'mlm_ln_beta' ),
248- ('lm_head.bias' , 'tokens_embed_bias ' )
252+ ('lm_head.bias' , 'word_embed_bias ' )
249253 ]:
250254 fs_name = fairseq_prefix + k
251255 gl_name = gluon_prefix + v
@@ -264,6 +268,7 @@ def test_model(fairseq_model, gluon_model, gpu):
264268 ctx = mx .gpu (gpu ) if gpu is not None else mx .cpu ()
265269 batch_size = 3
266270 seq_length = 32
271+ num_mask = 5
267272 vocab_size = len (fairseq_model .task .dictionary )
268273 padding_id = fairseq_model .model .decoder .sentence_encoder .padding_idx
269274 input_ids = np .random .randint ( # skip padding_id
@@ -276,28 +281,37 @@ def test_model(fairseq_model, gluon_model, gpu):
276281 seq_length ,
277282 (batch_size ,)
278283 )
284+ mlm_positions = np .random .randint (
285+ 0 ,
286+ seq_length // 2 ,
287+ (batch_size , num_mask )
288+ )
279289 for i in range (batch_size ): # add padding, for fairseq padding mask
280290 input_ids [i ,valid_length [i ]:] = padding_id
281291
282292 gl_input_ids = mx .np .array (input_ids , dtype = np .int32 , ctx = ctx )
283293 gl_valid_length = mx .np .array (valid_length , dtype = np .int32 , ctx = ctx )
294+ gl_masked_positions = mx .np .array (mlm_positions , dtype = np .int32 , ctx = ctx )
284295
285296 fs_input_ids = torch .from_numpy (input_ids ).cuda (gpu )
297+ fs_masked_positions = torch .from_numpy (mlm_positions ).cuda (gpu )
286298 if gpu is not None :
287299 fs_input_ids = fs_input_ids .cuda (gpu )
288300
289301 fairseq_model .model .eval ()
290302
291- gluon_all_hiddens , gluon_pooled , gluon_mlm_scores = \
292- gluon_model (gl_input_ids , gl_valid_length )
293-
294- fairseq_mlm_scores , fs_extra = \
295- fairseq_model .model .cuda (gpu )(fs_input_ids , return_all_hiddens = True )
303+ gl_all_hiddens , gl_pooled , gl_mlm_scores = \
304+ gluon_model (gl_input_ids , gl_valid_length , gl_masked_positions )
305+ fs_mlm_scores , fs_extra = \
306+ fairseq_model .model .cuda (gpu )(
307+ fs_input_ids ,
308+ return_all_hiddens = True ,
309+ masked_tokens = fs_masked_positions )
296310 fs_all_hiddens = fs_extra ['inner_states' ]
297311
298312 num_layers = fairseq_model .args .encoder_layers
299313 for i in range (num_layers + 1 ):
300- gl_hidden = gluon_all_hiddens [i ].asnumpy ()
314+ gl_hidden = gl_all_hiddens [i ].asnumpy ()
301315 fs_hidden = fs_all_hiddens [i ]
302316 fs_hidden = fs_hidden .transpose (0 , 1 )
303317 fs_hidden = fs_hidden .detach ().cpu ().numpy ()
@@ -309,13 +323,13 @@ def test_model(fairseq_model, gluon_model, gpu):
309323 1E-3
310324 )
311325
312- gluon_mlm_scores = gluon_mlm_scores .asnumpy ()
313- fairseq_mlm_scores = fairseq_mlm_scores .transpose (0 , 1 )
314- fairseq_mlm_scores = fairseq_mlm_scores .detach ().cpu ().numpy ()
326+ gl_mlm_scores = gl_mlm_scores .asnumpy ()
327+ fs_mlm_scores = fs_mlm_scores .transpose (0 , 1 )
328+ fs_mlm_scores = fs_mlm_scores .detach ().cpu ().numpy ()
315329 for j in range (batch_size ):
316330 assert_allclose (
317- gluon_mlm_scores [j , :valid_length [j ], :],
318- fairseq_mlm_scores [j , :valid_length [j ], :],
331+ gl_mlm_scores [j , :valid_length [j ], :],
332+ fs_mlm_scores [j , :valid_length [j ], :],
319333 1E-3 ,
320334 1E-3
321335 )
@@ -359,19 +373,19 @@ def convert_fairseq_model(args):
359373 is_mlm = is_mlm ,
360374 gluon_prefix = 'roberta_' )
361375
362- if is_mlm :
363- if args .test :
364- test_model (fairseq_roberta , gluon_roberta , args .gpu )
365-
366- gluon_roberta .save_parameters (os .path .join (args .save_dir , 'model_mlm.params' ), deduplicate = True )
367- logging .info ('Convert the RoBERTa MLM model in {} to {}' .
368- format (os .path .join (args .fairseq_model_path , 'model.pt' ), \
369- os .path .join (args .save_dir , 'model_mlm.params' )))
370- else :
371- gluon_roberta .save_parameters (os .path .join (args .save_dir , 'model.params' ), deduplicate = True )
372- logging .info ('Convert the RoBERTa backbone model in {} to {}' .
373- format (os .path .join (args .fairseq_model_path , 'model.pt' ), \
374- os .path .join (args .save_dir , 'model.params' )))
376+ if is_mlm :
377+ if args .test :
378+ test_model (fairseq_roberta , gluon_roberta , args .gpu )
379+
380+ gluon_roberta .save_parameters (os .path .join (args .save_dir , 'model_mlm.params' ), deduplicate = True )
381+ logging .info ('Convert the RoBERTa MLM model in {} to {}' .
382+ format (os .path .join (args .fairseq_model_path , 'model.pt' ), \
383+ os .path .join (args .save_dir , 'model_mlm.params' )))
384+ else :
385+ gluon_roberta .save_parameters (os .path .join (args .save_dir , 'model.params' ), deduplicate = True )
386+ logging .info ('Convert the RoBERTa backbone model in {} to {}' .
387+ format (os .path .join (args .fairseq_model_path , 'model.pt' ), \
388+ os .path .join (args .save_dir , 'model.params' )))
375389
376390 logging .info ('Conversion finished!' )
377391 logging .info ('Statistics:' )
0 commit comments