Skip to content

Commit 1b5fa7b

Browse files
committed
fix comment1
1 parent 6533601 commit 1b5fa7b

3 files changed

Lines changed: 29 additions & 49 deletions

File tree

scripts/conversion_toolkits/convert_fairseq_bart.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def convert_config(fairseq_cfg, vocab_size, cfg):
4040
cfg.MODEL.shared_embed = fairseq_cfg.share_all_embeddings
4141
cfg.MODEL.scale_embed = not fairseq_cfg.no_scale_embedding
4242
cfg.MODEL.tie_weights = fairseq_cfg.share_decoder_input_output_embed
43-
cfg.MODEL.layernorm_embedding = fairseq_cfg.layernorm_embedding
43+
cfg.MODEL.data_norm = fairseq_cfg.layernorm_embedding
4444
cfg.MODEL.pooler_activation = fairseq_cfg.pooler_activation_fn
4545
cfg.MODEL.layer_norm_eps = 1E-5
4646
cfg.MODEL.dropout = fairseq_cfg.dropout
@@ -111,26 +111,6 @@ def convert_attention(num_layers,
111111
gl_qkv_bias.set_data(
112112
np.concatenate([fs_q_bias, fs_k_bias, fs_v_bias], axis=0))
113113

114-
def convert_embeddings(fairseq_prefix, gluon_prefix):
115-
for k, v in [
116-
('.embed_tokens.weight', '_embed_layer.weight'),
117-
('.layernorm_embedding.weight', '_embed_ln.gamma'),
118-
('.layernorm_embedding.bias', '_embed_ln.beta'),
119-
]:
120-
fs_name = fairseq_prefix + k
121-
gl_name = gluon_prefix + v
122-
all_keys.remove(gl_name)
123-
gluon_params[gl_name].set_data(
124-
fairseq_params[fs_name].cpu().numpy())
125-
126-
# position embed weight
127-
padding_idx = fairseq_model.task.dictionary.pad_index
128-
fs_pos_embed_name = fairseq_prefix + '.embed_positions.weight'
129-
gl_pos_embed_name = gluon_prefix + '_pos_embed_layer._embed.weight'
130-
all_keys.remove(gl_pos_embed_name)
131-
gluon_params[gl_pos_embed_name].set_data(
132-
fairseq_params[fs_pos_embed_name].cpu().numpy()[padding_idx + 1:, :])
133-
134114
def convert_ffn(num_layers, fairseq_prefix, gluon_prefix):
135115
# convert feed forward layer in encoder
136116
for layer_id in range(num_layers):
@@ -150,11 +130,33 @@ def convert_ffn(num_layers, fairseq_prefix, gluon_prefix):
150130
gluon_params[gl_name].set_data(
151131
fairseq_params[fs_name].cpu().numpy())
152132

133+
print('converting embedding params')
134+
padding_idx = fairseq_model.task.dictionary.pad_index
135+
for fs_name, gl_name in [
136+
('model.encoder.embed_tokens.weight', 'src_embed_layer.weight'),
137+
('model.encoder.embed_positions.weight', 'src_pos_embed_layer._embed.weight'),
138+
('model.encoder.layernorm_embedding.weight', 'encoder.ln_data.gamma'),
139+
('model.encoder.layernorm_embedding.bias', 'encoder.ln_data.beta'),
140+
('model.decoder.embed_tokens.weight', 'tgt_embed_layer.weight'),
141+
('model.decoder.embed_positions.weight', 'tgt_pos_embed_layer._embed.weight'),
142+
('model.decoder.layernorm_embedding.weight', 'decoder.ln_data.gamma'),
143+
('model.decoder.layernorm_embedding.bias', 'decoder.ln_data.beta'),
144+
# final projection in decoder
145+
('model.decoder.output_projection.weight', 'tgt_final_layer.weight'),
146+
]:
147+
all_keys.remove(gl_name)
148+
if 'embed_positions' in fs_name:
149+
# position embed weight
150+
gluon_params[gl_name].set_data(
151+
fairseq_params[fs_name].cpu().numpy()[padding_idx + 1:, :])
152+
else:
153+
gluon_params[gl_name].set_data(
154+
fairseq_params[fs_name].cpu().numpy())
155+
153156
print('converting encoder params')
154157
encoder_num_layers = gluon_cfg.MODEL.ENCODER.num_layers
155158
convert_attention(encoder_num_layers, 'model.encoder', 'encoder')
156159
convert_ffn(encoder_num_layers, 'model.encoder', 'encoder')
157-
convert_embeddings('model.encoder', 'src')
158160
for layer_id in range(encoder_num_layers):
159161
for k, v in [
160162
('self_attn.out_proj.weight', 'attention_proj.weight'),
@@ -170,6 +172,7 @@ def convert_ffn(num_layers, fairseq_prefix, gluon_prefix):
170172
gluon_params[gl_name].set_data(
171173
fairseq_params[fs_name].cpu().numpy())
172174

175+
print('converting decoder params')
173176
decoder_num_layers = gluon_cfg.MODEL.DECODER.num_layers
174177
convert_attention(decoder_num_layers, 'model.decoder', 'decoder',
175178
gluon_attn_prefix='attn_in_qkv')
@@ -201,14 +204,6 @@ def convert_ffn(num_layers, fairseq_prefix, gluon_prefix):
201204
gluon_params[gl_name].set_data(
202205
fairseq_params[fs_name].cpu().numpy())
203206

204-
convert_embeddings('model.decoder', 'tgt')
205-
# final projection in decoder
206-
for fs_name, gl_name in [
207-
('model.decoder.output_projection.weight', 'tgt_final_layer.weight'),
208-
]:
209-
all_keys.remove(gl_name)
210-
gluon_params[gl_name].set_data(
211-
fairseq_params[fs_name].cpu().numpy())
212207
assert len(all_keys) == 0, 'parameters missing from tensorflow checkpoint'
213208

214209
# check parameters sharing if share_decoder_input_output_embed is true

src/gluonnlp/models/bart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def bart_base():
6767
cfg.MODEL.dropout = 0.0
6868
cfg.MODEL.layer_norm_eps = 1E-5
6969
cfg.MODEL.pooler_activation = 'tanh'
70-
cfg.MODEL.layernorm_embedding = True
70+
cfg.MODEL.data_norm = True
7171
cfg.MODEL.layout = 'NT'
7272
cfg.MODEL.dtype = 'float32'
7373

@@ -285,13 +285,13 @@ def from_cfg(cls, cfg,
285285
pos_embed_type=cfg.MODEL.pos_embed_type,
286286
shared_embed=cfg.MODEL.shared_embed,
287287
tie_weights=cfg.MODEL.tie_weights,
288+
data_norm=cfg.MODEL.data_norm,
288289
use_pooler=use_pooler,
289290
attention_dropout=cfg.MODEL.attention_dropout,
290291
activation_dropout=cfg.MODEL.activation_dropout,
291292
dropout=cfg.MODEL.dropout,
292293
pooler_activation=cfg.MODEL.pooler_activation,
293294
layer_norm_eps=cfg.MODEL.layer_norm_eps,
294-
layernorm_embedding=cfg.MODEL.layernorm_embedding,
295295
enc_num_layers=cfg.MODEL.ENCODER.num_layers,
296296
enc_units=cfg.MODEL.ENCODER.units,
297297
enc_num_heads=cfg.MODEL.ENCODER.num_heads,

src/gluonnlp/models/transformer.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,7 @@ def __init__(self, units: int = 512,
440440
num_heads=num_heads,
441441
attention_dropout=self._attention_dropout,
442442
dtype=dtype,
443-
layout=attention_layout,
444-
layout='NTK')
443+
layout=attention_layout)
445444
self.proj_in = nn.Dense(units=units, in_units=units, flatten=False, use_bias=True,
446445
weight_initializer=weight_initializer,
447446
bias_initializer=bias_initializer,
@@ -914,7 +913,6 @@ def __init__(self, src_vocab_size: int,
914913
max_tgt_length: Optional[int] = None,
915914
scale_embed: bool = True,
916915
pos_embed_type="sinusoidal",
917-
layernorm_embedding: bool = False,
918916
shared_embed: bool = True,
919917
tie_weights: bool = True,
920918
activation_dropout: float = 0.0,
@@ -959,8 +957,6 @@ def __init__(self, src_vocab_size: int,
959957
Whether to multiply the src and dst embeddings by sqrt(units)
960958
pos_embed_type
961959
Type of the positional embedding
962-
layernorm_embedding
963-
Wether to layer normalize the embedding
964960
shared_embed
965961
Whether to share the embedding of the src and tgt language
966962
tie_weights
@@ -1027,7 +1023,6 @@ def __init__(self, src_vocab_size: int,
10271023
self._tgt_vocab_size = tgt_vocab_size
10281024
self.tie_weights = tie_weights
10291025
self.pos_embed_type = pos_embed_type
1030-
self.layernorm_embedding = layernorm_embedding
10311026
self.scaled_embed = scale_embed
10321027
self.enc_units = enc_units
10331028
self.dec_units = dec_units
@@ -1063,11 +1058,6 @@ def __init__(self, src_vocab_size: int,
10631058
max_length=max_tgt_length,
10641059
dtype=self._dtype,
10651060
method=pos_embed_type)
1066-
if layernorm_embedding:
1067-
self.src_embed_ln = nn.LayerNorm(epsilon=layer_norm_eps,
1068-
in_channels=enc_units)
1069-
self.tgt_embed_ln = nn.LayerNorm(epsilon=layer_norm_eps,
1070-
in_channels=dec_units)
10711061
self.encoder = TransformerEncoder(num_layers=enc_num_layers,
10721062
recurrent=enc_recurrent,
10731063
units=enc_units,
@@ -1164,8 +1154,6 @@ def encode(self, F, src_data, src_valid_length):
11641154
else:
11651155
src_data = src_data + F.np.expand_dims(self.src_pos_embed_layer(
11661156
F.npx.arange_like(src_data, axis=0)), axis=1)
1167-
if self.layernorm_embedding:
1168-
src_data = self.src_embed_ln(src_data)
11691157

11701158
enc_out = self.encoder(src_data, src_valid_length)
11711159
return enc_out
@@ -1209,8 +1197,7 @@ def decode_seq(self, F, tgt_data, tgt_valid_length, mem_data, mem_valid_length):
12091197
else:
12101198
tgt_data = tgt_data + F.np.expand_dims(self.tgt_pos_embed_layer(
12111199
F.npx.arange_like(tgt_data, axis=0)), axis=1)
1212-
if self.layernorm_embedding:
1213-
tgt_data = self.tgt_embed_ln(tgt_data)
1200+
12141201
dec_out = self.decoder(tgt_data, tgt_valid_length, mem_data, mem_valid_length)
12151202
return dec_out
12161203

@@ -1403,8 +1390,6 @@ def hybrid_forward(self, F, step_data, states):
14031390
step_data = step_data * np.sqrt(self.model.dec_units)
14041391
if self.model.pos_embed_type is not None:
14051392
step_data = step_data + self.model.tgt_pos_embed_layer(position)
1406-
if self.model.layernorm_embedding:
1407-
step_data = self.tgt_embed_ln(step_data)
14081393
out, new_states =\
14091394
self.model.decoder.incremental_decode(F, step_data, dec_states,
14101395
mem_data, mem_valid_length)

0 commit comments

Comments
 (0)