@@ -130,7 +130,7 @@ def __init__(self,
130130 use_mlm = True ,
131131 untie_weight = False ,
132132 encoder_normalize_before = True ,
133- return_all_hiddens = False ,
133+ output_all_encodings = False ,
134134 prefix = None ,
135135 params = None ):
136136 """
@@ -160,7 +160,7 @@ def __init__(self,
160160 untie_weight
161161 Whether to untie weights between embeddings and classifiers
162162 encoder_normalize_before
163- return_all_hiddens
163+ output_all_encodings
164164 prefix
165165 params
166166 """
@@ -182,7 +182,7 @@ def __init__(self,
182182 self .use_mlm = use_mlm
183183 self .untie_weight = untie_weight
184184 self .encoder_normalize_before = encoder_normalize_before
185- self .return_all_hiddens = return_all_hiddens
185+ self .output_all_encodings = output_all_encodings
186186 with self .name_scope ():
187187 self .tokens_embed = nn .Embedding (
188188 input_dim = self .vocab_size ,
@@ -220,7 +220,7 @@ def __init__(self,
220220 bias_initializer = bias_initializer ,
221221 activation = self .activation ,
222222 dtype = self .dtype ,
223- return_all_hiddens = self .return_all_hiddens
223+ output_all_encodings = self .output_all_encodings
224224 )
225225 self .encoder .hybridize ()
226226
@@ -237,25 +237,63 @@ def __init__(self,
237237 bias_initializer = bias_initializer
238238 )
239239 self .lm_head .hybridize ()
240- # TODO support use_pooler
241240
242241 def hybrid_forward (self , F , tokens , valid_length ):
243- x = self .tokens_embed (tokens )
242+ outputs = []
243+ embedding = self .get_initial_embedding (F , tokens )
244+
245+ inner_states = self .encoder (x , valid_length )
246+ if self .output_all_encodings :
247+ contextual_embeddings = inner_states [- 1 ]
248+ else :
249+ contextual_embeddings = inner_states
250+ outputs .append (contextual_embeddings )
251+
252+ if self .use_pooler :
253+ pooled_out = self .apply_pooling (contextual_embeddings )
254+ outputs .append (pooled_out )
255+
256+ if self .use_mlm :
257+ mlm_output = self .lm_head (contextual_embeddings )
258+ outputs .append (mlm_output )
259+ return tuple (outputs ) if len (outputs ) > 1 else outputs [0 ]
260+
261+ def get_initial_embedding (self , F , inputs ):
262+ """Get the initial token embeddings that considers the token type and positional embeddings
263+
264+ Parameters
265+ ----------
266+ F
267+ inputs
268+ Shape (batch_size, seq_length)
269+
270+ Returns
271+ -------
272+ embedding
273+ The initial embedding that will be fed into the encoder
274+ """
275+ embedding = self .tokens_embed (inputs )
244276 if self .pos_embed_type :
245- positional_embedding = self .pos_embed (F .npx .arange_like (x , axis = 1 ))
277+ positional_embedding = self .pos_embed (F .npx .arange_like (inputs , axis = 1 ))
246278 positional_embedding = F .np .expand_dims (positional_embedding , axis = 0 )
247- x = x + positional_embedding
279+ embedding = embedding + positional_embedding
248280 if self .embed_ln :
249- x = self .embed_ln (x )
250- x = self .embed_dropout (x )
251- inner_states = self .encoder (x , valid_length )
252- x = inner_states [- 1 ]
253- if self .use_mlm :
254- x = self .lm_head (x )
255- if self .return_all_hiddens :
256- return x , inner_states
257- else :
258- return x
281+ embedding = self .embed_ln (embedding )
282+ embedding = self .embed_dropout (embedding )
283+
284+ def apply_pooling (self , sequence ):
285+ """Generate the representation given the inputs.
286+
287+ This is used for pre-training or fine-tuning a mobile bert model.
288+ Get the first token of the whole sequence which is [CLS]
289+
290+ sequence:
291+ Shape (batch_size, sequence_length, units)
292+ return:
293+ Shape (batch_size, units)
294+ """
295+ outputs = sequence [:, 0 , :]
296+ return outputs
259297
260298 @staticmethod
261299 def get_cfg (key = None ):
@@ -271,7 +309,7 @@ def from_cfg(cls,
271309 use_mlm = True ,
272310 untie_weight = False ,
273311 encoder_normalize_before = True ,
274- return_all_hiddens = False ,
312+ output_all_encodings = False ,
275313 prefix = None ,
276314 params = None ):
277315 cfg = RobertaModel .get_cfg ().clone_merge (cfg )
@@ -298,7 +336,7 @@ def from_cfg(cls,
298336 use_mlm = use_mlm ,
299337 untie_weight = untie_weight ,
300338 encoder_normalize_before = encoder_normalize_before ,
301- return_all_hiddens = return_all_hiddens ,
339+ output_all_encodings = output_all_encodings ,
302340 prefix = prefix ,
303341 params = params )
304342
@@ -316,7 +354,7 @@ def __init__(self,
316354 bias_initializer = 'zeros' ,
317355 activation = 'gelu' ,
318356 dtype = 'float32' ,
319- return_all_hiddens = False ,
357+ output_all_encodings = False ,
320358 prefix = 'encoder_' ,
321359 params = None ):
322360 super (RobertaEncoder , self ).__init__ (prefix = prefix , params = params )
@@ -329,7 +367,7 @@ def __init__(self,
329367 self .layer_norm_eps = layer_norm_eps
330368 self .activation = activation
331369 self .dtype = dtype
332- self .return_all_hiddens = return_all_hiddens
370+ self .output_all_encodings = output_all_encodings
333371 with self .name_scope ():
334372 self .all_layers = nn .HybridSequential (prefix = 'layers_' )
335373 with self .all_layers .name_scope ():
@@ -358,8 +396,8 @@ def hybrid_forward(self, F, x, valid_length):
358396 layer = self .all_layers [layer_idx ]
359397 x , _ = layer (x , atten_mask )
360398 inner_states .append (x )
361- if not self .return_all_hiddens :
362- inner_states = [ x ]
399+ if not self .output_all_encodings :
400+ inner_states = x
363401 return inner_states
364402
365403@use_np
@@ -419,7 +457,8 @@ def list_pretrained_roberta():
419457
420458def get_pretrained_roberta (model_name : str = 'fairseq_roberta_base' ,
421459 root : str = get_model_zoo_home_dir (),
422- load_backbone : bool = True ) \
460+ load_backbone : bool = True ,
461+ load_mlm : bool = False ) \
423462 -> Tuple [CN , HuggingFaceByteBPETokenizer , str ]:
424463 """Get the pretrained RoBERTa weights
425464
0 commit comments