99import numpy as np
1010import torch
1111from gluonnlp .utils .misc import sha1sum , logging_config
12- from gluonnlp .models .roberta import RobertaModel as gluon_RobertaModel
12+ from gluonnlp .models .roberta import RobertaModel , RobertaForMLM
1313from gluonnlp .data .tokenizers import HuggingFaceByteBPETokenizer
1414from gluonnlp .data .vocab import Vocab as gluon_Vocab
1515from fairseq .models .roberta import RobertaModel as fairseq_RobertaModel
@@ -163,19 +163,22 @@ def convert_config(fairseq_cfg, vocab_size, cfg):
163163
164164def convert_params (fairseq_model ,
165165 gluon_cfg ,
166- gluon_model_cls ,
167166 ctx ,
167+ is_mlm = True
168168 gluon_prefix = 'robert_' ):
169169 print ('converting params' )
170170 fairseq_params = fairseq_model .state_dict ()
171171 fairseq_prefix = 'model.decoder.'
172- gluon_model = gluon_model_cls .from_cfg (
173- gluon_cfg ,
174- use_mlm = True ,
175- use_pooler = False ,
176- output_all_encodings = True ,
177- prefix = gluon_prefix
178- )
172+ if is_mlm :
173+ gluon_model = RobertaForMLM (backbone_cfg = gluon_cfg , prefix = gluon_prefix )
174+ gluon_model .backbone_model ._output_all_encodings = True
175+ else :
176+ gluon_model = RobertaForMLM .from_cfg (
177+ gluon_cfg ,
178+ use_pooler = True ,
179+ output_all_encodings = True ,
180+ prefix = gluon_prefix
181+ )
179182 gluon_model .initialize (ctx = ctx )
180183 gluon_model .hybridize ()
181184 gluon_params = gluon_model .collect_params ()
@@ -223,11 +226,6 @@ def convert_params(fairseq_model,
223226 ('sentence_encoder.embed_tokens.weight' , 'tokens_embed_weight' ),
224227 ('sentence_encoder.emb_layer_norm.weight' , 'embed_ln_gamma' ),
225228 ('sentence_encoder.emb_layer_norm.bias' , 'embed_ln_beta' ),
226- ('lm_head.dense.weight' , 'lm_dense1_weight' ),
227- ('lm_head.dense.bias' , 'lm_dense1_bias' ),
228- ('lm_head.layer_norm.weight' , 'lm_ln_gamma' ),
229- ('lm_head.layer_norm.bias' , 'lm_ln_beta' ),
230- ('lm_head.bias' , 'tokens_embed_bias' )
231229 ]:
232230 fs_name = fairseq_prefix + k
233231 gl_name = gluon_prefix + v
@@ -241,14 +239,26 @@ def convert_params(fairseq_model,
241239 gluon_params [gl_pos_embed_name ].set_data (
242240 fairseq_params [fs_pos_embed_name ].cpu ().numpy ()[padding_idx + 1 :,:])
243241
244- # assert untie=False
245- assert np .array_equal (
246- fairseq_params [fairseq_prefix + 'sentence_encoder.embed_tokens.weight' ].cpu ().numpy (),
247- fairseq_params [fairseq_prefix + 'lm_head.weight' ].cpu ().numpy ()
248- )
249-
242+ if is_mlm :
243+ for k , v in [
244+ ('lm_head.dense.weight' , 'mlm_proj_weight' ),
245+ ('lm_head.dense.bias' , 'mlm_proj_bias' ),
246+ ('lm_head.layer_norm.weight' , 'mlm_ln_gamma' ),
247+ ('lm_head.layer_norm.bias' , 'mlm_ln_beta' ),
248+ ('lm_head.bias' , 'tokens_embed_bias' )
249+ ]:
250+ fs_name = fairseq_prefix + k
251+ gl_name = gluon_prefix + v
252+ gluon_params [gl_name ].set_data (
253+ fairseq_params [fs_name ].cpu ().numpy ())
254+ # assert untie=False
255+ assert np .array_equal (
256+ fairseq_params [fairseq_prefix + 'sentence_encoder.embed_tokens.weight' ].cpu ().numpy (),
257+ fairseq_params [fairseq_prefix + 'lm_head.weight' ].cpu ().numpy ()
258+ )
250259 return gluon_model
251260
261+
252262def test_model (fairseq_model , gluon_model , gpu ):
253263 print ('testing model' )
254264 ctx = mx .gpu (gpu ) if gpu is not None else mx .cpu ()
@@ -278,16 +288,16 @@ def test_model(fairseq_model, gluon_model, gpu):
278288
279289 fairseq_model .model .eval ()
280290
281- gl_all_hiddens , gl_x = \
291+ gluon_all_hiddens , gluon_pooled , gluon_mlm_scores = \
282292 gluon_model (gl_input_ids , gl_valid_length )
283293
284- fs_x , fs_extra = \
294+ fairseq_mlm_scores , fs_extra = \
285295 fairseq_model .model .cuda (gpu )(fs_input_ids , return_all_hiddens = True )
286296 fs_all_hiddens = fs_extra ['inner_states' ]
287297
288298 num_layers = fairseq_model .args .encoder_layers
289299 for i in range (num_layers + 1 ):
290- gl_hidden = gl_all_hiddens [i ].asnumpy ()
300+ gl_hidden = gluon_all_hiddens [i ].asnumpy ()
291301 fs_hidden = fs_all_hiddens [i ]
292302 fs_hidden = fs_hidden .transpose (0 , 1 )
293303 fs_hidden = fs_hidden .detach ().cpu ().numpy ()
@@ -299,13 +309,13 @@ def test_model(fairseq_model, gluon_model, gpu):
299309 1E-3
300310 )
301311
302- gl_x = gl_x .asnumpy ()
303- fs_x = fs_x .transpose (0 , 1 )
304- fs_x = fs_x .detach ().cpu ().numpy ()
312+ gluon_mlm_scores = gluon_mlm_scores .asnumpy ()
313+ fairseq_mlm_scores = fairseq_mlm_scores .transpose (0 , 1 )
314+ fairseq_mlm_scores = fairseq_mlm_scores .detach ().cpu ().numpy ()
305315 for j in range (batch_size ):
306316 assert_allclose (
307- gl_x [j , :valid_length [j ], :],
308- fs_x [j , :valid_length [j ], :],
317+ gluon_mlm_scores [j , :valid_length [j ], :],
318+ fairseq_mlm_scores [j , :valid_length [j ], :],
309319 1E-3 ,
310320 1E-3
311321 )
@@ -337,24 +347,32 @@ def convert_fairseq_model(args):
337347 vocab_size = convert_vocab (args , fairseq_roberta )
338348
339349 gluon_cfg = convert_config (fairseq_roberta .args , vocab_size ,
340- gluon_RobertaModel .get_cfg ().clone ())
350+ RobertaModel .get_cfg ().clone ())
341351 with open (os .path .join (args .save_dir , 'model.yml' ), 'w' ) as of :
342352 of .write (gluon_cfg .dump ())
343353
344354 ctx = mx .gpu (args .gpu ) if args .gpu is not None else mx .cpu ()
345- gluon_roberta = convert_params (fairseq_roberta ,
346- gluon_cfg ,
347- gluon_RobertaModel ,
348- ctx ,
349- gluon_prefix = 'roberta_' )
350-
351- if args .test :
352- test_model (fairseq_roberta , gluon_roberta , args .gpu )
355+ for is_mlm in [False , True ]:
356+ gluon_roberta = convert_params (fairseq_roberta ,
357+ gluon_cfg ,
358+ ctx ,
359+ is_mlm = is_mlm ,
360+ gluon_prefix = 'roberta_' )
361+
362+ if is_mlm :
363+ if args .test :
364+ test_model (fairseq_roberta , gluon_roberta , args .gpu )
365+
366+ gluon_roberta .save_parameters (os .path .join (args .save_dir , 'model_mlm.params' ), deduplicate = True )
367+ logging .info ('Convert the RoBERTa MLM model in {} to {}' .
368+ format (os .path .join (args .fairseq_model_path , 'model.pt' ), \
369+ os .path .join (args .save_dir , 'model_mlm.params' )))
370+ else :
371+ gluon_roberta .save_parameters (os .path .join (args .save_dir , 'model.params' ), deduplicate = True )
372+ logging .info ('Convert the RoBERTa backbone model in {} to {}' .
373+ format (os .path .join (args .fairseq_model_path , 'model.pt' ), \
374+ os .path .join (args .save_dir , 'model.params' )))
353375
354- gluon_roberta .save_parameters (os .path .join (args .save_dir , 'model.params' ), deduplicate = True )
355- logging .info ('Convert the RoBERTa model in {} to {}' .
356- format (os .path .join (args .fairseq_model_path , 'model.pt' ), \
357- os .path .join (args .save_dir , 'model.params' )))
358376 logging .info ('Conversion finished!' )
359377 logging .info ('Statistics:' )
360378 rename (args .save_dir )
0 commit comments