Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 2b7f7a3

Browse files
committed
fix roberta
1 parent 86702fe commit 2b7f7a3

2 files changed

Lines changed: 84 additions & 44 deletions

File tree

scripts/conversion_toolkits/convert_fairseq_roberta.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def convert_vocab(args, fairseq_model):
3737
fairseq_vocab = fairseq_model.task.dictionary
3838
# bos_word attr missing in fairseq_vocab
3939
fairseq_vocab.bos_word = fairseq_vocab[fairseq_vocab.bos_index]
40-
40+
4141
assert os.path.exists(fairseq_dict_path), \
4242
'{} not found'.format(fairseq_dict_path)
4343
from mxnet.gluon.utils import download
@@ -63,7 +63,7 @@ def convert_vocab(args, fairseq_model):
6363
inter_vocab = list(inter_vocab.items())
6464
inter_vocab = sorted(inter_vocab, key=lambda x : x[1])
6565
tokens = [e[0] for e in inter_vocab]
66-
66+
6767
tail = [fairseq_vocab[-4],
6868
fairseq_vocab[-3],
6969
fairseq_vocab[-2],
@@ -84,15 +84,15 @@ def convert_vocab(args, fairseq_model):
8484
gluon_vocab.save(vocab_save_path)
8585
os.remove(temp_vocab_file)
8686
os.remove(temp_merges_file)
87-
87+
8888
gluon_tokenizer = HuggingFaceByteBPETokenizer(
8989
merges_save_path,
9090
vocab_save_path
9191
)
92-
92+
9393
if args.test:
9494
test_vocab(fairseq_model, gluon_tokenizer)
95-
95+
9696
vocab_size = len(fairseq_vocab)
9797
print('| converted dictionary: {} types'.format(vocab_size))
9898
return vocab_size
@@ -103,14 +103,14 @@ def test_vocab(fairseq_model, gluon_tokenizer, check_all_tokens=False):
103103
gluon_vocab = gluon_tokenizer.vocab
104104
assert len(fairseq_vocab) == \
105105
len(gluon_vocab)
106-
106+
107107
# assert all_tokens
108108
# roberta with gpt2 bytebpe bpe does not provide all tokens directly
109109
if check_all_tokens:
110110
for i in range(len(fairseq_vocab)):
111111
assert fairseq_vocab[i] == gluon_vocab.all_tokens[i], \
112112
'{}, {}, {}'.format(i, fairseq_vocab[i], gluon_vocab.all_tokens[i])
113-
113+
114114
# assert special tokens
115115
for special_tokens in ['unk', 'pad', 'eos', 'bos']:
116116
assert getattr(fairseq_vocab, special_tokens + '_index') == \
@@ -121,7 +121,7 @@ def test_vocab(fairseq_model, gluon_tokenizer, check_all_tokens=False):
121121
assert fairseq_vocab[-1] == \
122122
gluon_vocab.all_tokens[-1] == \
123123
'<mask>'
124-
124+
125125
sentence = "Hello, y'all! How are you Ⅷ 😁 😁 😁 ?" + \
126126
'GluonNLP is great!!!!!!' + \
127127
"GluonNLP-Amazon-Haibin-Leonard-Sheng-Shuai-Xingjian...../:!@# 'abc'"
@@ -131,7 +131,7 @@ def test_vocab(fairseq_model, gluon_tokenizer, check_all_tokens=False):
131131
# Notice: we may append bos and eos
132132
# manuually after tokenizing sentences
133133
assert fs_tokens.numpy().tolist()[1:-1] == gl_tokens
134-
134+
135135
# assert decode
136136
fs_sentence = fairseq_model.decode(fs_tokens)
137137
gl_sentence = gluon_tokenizer.decode(gl_tokens)
@@ -170,7 +170,8 @@ def convert_params(fairseq_model,
170170
fairseq_prefix = 'model.decoder.'
171171
gluon_model = gluon_model_cls.from_cfg(
172172
gluon_cfg,
173-
return_all_hiddens=True,
173+
use_pooler=False,
174+
output_all_encodings=True,
174175
prefix=gluon_prefix
175176
)
176177
gluon_model.initialize(ctx=ctx)
@@ -196,7 +197,7 @@ def convert_params(fairseq_model,
196197
np.concatenate([fs_q_weight, fs_k_weight, fs_v_weight], axis=0))
197198
gl_qkv_bias.set_data(
198199
np.concatenate([fs_q_bias, fs_k_bias, fs_v_bias], axis=0))
199-
200+
200201
for k, v in [
201202
('self_attn.out_proj.weight', 'proj_weight'),
202203
('self_attn.out_proj.bias', 'proj_bias'),
@@ -230,20 +231,20 @@ def convert_params(fairseq_model,
230231
gl_name = gluon_prefix + v
231232
gluon_params[gl_name].set_data(
232233
fairseq_params[fs_name].cpu().numpy())
233-
234+
234235
# position embed weight
235236
padding_idx = fairseq_model.task.dictionary.pad_index
236237
fs_pos_embed_name = fairseq_prefix + 'sentence_encoder.embed_positions.weight'
237238
gl_pos_embed_name = gluon_prefix + 'pos_embed_embed_weight'
238239
gluon_params[gl_pos_embed_name].set_data(
239240
fairseq_params[fs_pos_embed_name].cpu().numpy()[padding_idx + 1:,:])
240-
241+
241242
# assert untie=False
242243
assert np.array_equal(
243244
fairseq_params[fairseq_prefix + 'sentence_encoder.embed_tokens.weight'].cpu().numpy(),
244245
fairseq_params[fairseq_prefix + 'lm_head.weight'].cpu().numpy()
245246
)
246-
247+
247248
return gluon_model
248249

249250
def test_model(fairseq_model, gluon_model, gpu):
@@ -272,16 +273,16 @@ def test_model(fairseq_model, gluon_model, gpu):
272273
fs_input_ids = torch.from_numpy(input_ids).cuda(gpu)
273274
if gpu is not None:
274275
fs_input_ids = fs_input_ids.cuda(gpu)
275-
276+
276277
fairseq_model.model.eval()
277-
278-
gl_x, gl_all_hiddens = \
278+
279+
gl_all_hiddens, gl_x = \
279280
gluon_model(gl_input_ids, gl_valid_length)
280281

281282
fs_x, fs_extra = \
282283
fairseq_model.model.cuda(gpu)(fs_input_ids, return_all_hiddens=True)
283284
fs_all_hiddens = fs_extra['inner_states']
284-
285+
285286
num_layers = fairseq_model.args.encoder_layers
286287
for i in range(num_layers + 1):
287288
gl_hidden = gl_all_hiddens[i].asnumpy()
@@ -317,7 +318,7 @@ def rename(save_dir):
317318
new_name = '{file_prefix}-{short_hash}.{file_sufix}'.format(
318319
file_prefix=file_prefix,
319320
short_hash=long_hash[:8],
320-
file_sufix=file_sufix)
321+
file_sufix=file_sufix)
321322
new_path = os.path.join(save_dir, new_name)
322323
shutil.move(old_path, new_path)
323324
file_size = os.path.getsize(new_path)

src/gluonnlp/models/roberta.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __init__(self,
130130
use_mlm=True,
131131
untie_weight=False,
132132
encoder_normalize_before=True,
133-
return_all_hiddens=False,
133+
output_all_encodings=False,
134134
prefix=None,
135135
params=None):
136136
"""
@@ -160,7 +160,7 @@ def __init__(self,
160160
untie_weight
161161
Whether to untie weights between embeddings and classifiers
162162
encoder_normalize_before
163-
return_all_hiddens
163+
output_all_encodings
164164
prefix
165165
params
166166
"""
@@ -182,7 +182,7 @@ def __init__(self,
182182
self.use_mlm = use_mlm
183183
self.untie_weight = untie_weight
184184
self.encoder_normalize_before = encoder_normalize_before
185-
self.return_all_hiddens = return_all_hiddens
185+
self.output_all_encodings = output_all_encodings
186186
with self.name_scope():
187187
self.tokens_embed = nn.Embedding(
188188
input_dim=self.vocab_size,
@@ -220,7 +220,7 @@ def __init__(self,
220220
bias_initializer=bias_initializer,
221221
activation=self.activation,
222222
dtype=self.dtype,
223-
return_all_hiddens=self.return_all_hiddens
223+
output_all_encodings=self.output_all_encodings
224224
)
225225
self.encoder.hybridize()
226226

@@ -237,25 +237,63 @@ def __init__(self,
237237
bias_initializer=bias_initializer
238238
)
239239
self.lm_head.hybridize()
240-
# TODO support use_pooler
241240

242241
def hybrid_forward(self, F, tokens, valid_length):
243-
x = self.tokens_embed(tokens)
242+
outputs = []
243+
embedding = self.get_initial_embedding(F, tokens)
244+
245+
inner_states = self.encoder(x, valid_length)
246+
if self.output_all_encodings:
247+
contextual_embeddings = inner_states[-1]
248+
else:
249+
contextual_embeddings = inner_states
250+
outputs.append(contextual_embeddings)
251+
252+
if self.use_pooler:
253+
pooled_out = self.apply_pooling(contextual_embeddings)
254+
outputs.append(pooled_out)
255+
256+
if self.use_mlm:
257+
mlm_output = self.lm_head(contextual_embeddings)
258+
outputs.append(mlm_output)
259+
return tuple(outputs) if len(outputs) > 1 else outputs[0]
260+
261+
def get_initial_embedding(self, F, inputs):
262+
"""Get the initial token embeddings that considers the token type and positional embeddings
263+
264+
Parameters
265+
----------
266+
F
267+
inputs
268+
Shape (batch_size, seq_length)
269+
270+
Returns
271+
-------
272+
embedding
273+
The initial embedding that will be fed into the encoder
274+
"""
275+
embedding = self.tokens_embed(inputs)
244276
if self.pos_embed_type:
245-
positional_embedding = self.pos_embed(F.npx.arange_like(x, axis=1))
277+
positional_embedding = self.pos_embed(F.npx.arange_like(inputs, axis=1))
246278
positional_embedding = F.np.expand_dims(positional_embedding, axis=0)
247-
x = x + positional_embedding
279+
embedding = embedding + positional_embedding
248280
if self.embed_ln:
249-
x = self.embed_ln(x)
250-
x = self.embed_dropout(x)
251-
inner_states = self.encoder(x, valid_length)
252-
x = inner_states[-1]
253-
if self.use_mlm:
254-
x = self.lm_head(x)
255-
if self.return_all_hiddens:
256-
return x, inner_states
257-
else:
258-
return x
281+
embedding = self.embed_ln(embedding)
282+
embedding = self.embed_dropout(embedding)
283+
284+
def apply_pooling(self, sequence):
285+
"""Generate the representation given the inputs.
286+
287+
This is used for pre-training or fine-tuning a mobile bert model.
288+
Get the first token of the whole sequence which is [CLS]
289+
290+
sequence:
291+
Shape (batch_size, sequence_length, units)
292+
return:
293+
Shape (batch_size, units)
294+
"""
295+
outputs = sequence[:, 0, :]
296+
return outputs
259297

260298
@staticmethod
261299
def get_cfg(key=None):
@@ -271,7 +309,7 @@ def from_cfg(cls,
271309
use_mlm=True,
272310
untie_weight=False,
273311
encoder_normalize_before=True,
274-
return_all_hiddens=False,
312+
output_all_encodings=False,
275313
prefix=None,
276314
params=None):
277315
cfg = RobertaModel.get_cfg().clone_merge(cfg)
@@ -298,7 +336,7 @@ def from_cfg(cls,
298336
use_mlm=use_mlm,
299337
untie_weight=untie_weight,
300338
encoder_normalize_before=encoder_normalize_before,
301-
return_all_hiddens=return_all_hiddens,
339+
output_all_encodings=output_all_encodings,
302340
prefix=prefix,
303341
params=params)
304342

@@ -316,7 +354,7 @@ def __init__(self,
316354
bias_initializer='zeros',
317355
activation='gelu',
318356
dtype='float32',
319-
return_all_hiddens=False,
357+
output_all_encodings=False,
320358
prefix='encoder_',
321359
params=None):
322360
super(RobertaEncoder, self).__init__(prefix=prefix, params=params)
@@ -329,7 +367,7 @@ def __init__(self,
329367
self.layer_norm_eps = layer_norm_eps
330368
self.activation = activation
331369
self.dtype = dtype
332-
self.return_all_hiddens = return_all_hiddens
370+
self.output_all_encodings = output_all_encodings
333371
with self.name_scope():
334372
self.all_layers = nn.HybridSequential(prefix='layers_')
335373
with self.all_layers.name_scope():
@@ -358,8 +396,8 @@ def hybrid_forward(self, F, x, valid_length):
358396
layer = self.all_layers[layer_idx]
359397
x, _ = layer(x, atten_mask)
360398
inner_states.append(x)
361-
if not self.return_all_hiddens:
362-
inner_states = [x]
399+
if not self.output_all_encodings:
400+
inner_states = x
363401
return inner_states
364402

365403
@use_np
@@ -419,7 +457,8 @@ def list_pretrained_roberta():
419457

420458
def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base',
421459
root: str = get_model_zoo_home_dir(),
422-
load_backbone: bool = True) \
460+
load_backbone: bool = True,
461+
load_mlm: bool = False) \
423462
-> Tuple[CN, HuggingFaceByteBPETokenizer, str]:
424463
"""Get the pretrained RoBERTa weights
425464

0 commit comments

Comments
 (0)