@@ -78,7 +78,7 @@ def save_pretrained(self, save_directory):
7878 self .to_json_file (output_config_file )
7979
8080 @classmethod
81- def from_pretrained (cls , pretrained_model_name_or_path , * input , * *kwargs ):
81+ def from_pretrained (cls , pretrained_model_name_or_path , ** kwargs ):
8282 r""" Instantiate a PretrainedConfig from a pre-trained model configuration.
8383
8484 Params:
@@ -91,20 +91,33 @@ def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
9191 **cache_dir**: (`optional`) string:
9292 Path to a directory in which a downloaded pre-trained model
9393 configuration should be cached if the standard cache should not be used.
94+ **return_unused_kwargs**: (`optional`) bool:
95+ - If False, then this function returns just the final configuration object.
96+ - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs`
97+ is a dictionary consisting of the key/value pairs whose keys are not configuration attributes:
98+ ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
9499 **kwargs**: (`optional`) dict:
95- Dictionnary of key, values to update the configuration object after loading.
96- Can be used to override selected configuration parameters.
100+ Dictionary of key/value pairs with which to update the configuration object after loading.
101+ - The values in kwargs of any keys which are configuration attributes will be used
102+ to override the loaded values.
103+ - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
104+ by the `return_unused_kwargs` keyword parameter.
97105
98106 Examples::
99107
100108 >>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
101109 >>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
102110 >>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
103- >>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True)
111+ >>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False )
104112 >>> assert config.output_attention == True
113+ >>> config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
114+ >>> foo=False, return_unused_kwargs=True)
115+ >>> assert config.output_attention == True
116+ >>> assert unused_kwargs == {'foo': False}
105117
106118 """
107119 cache_dir = kwargs .pop ('cache_dir' , None )
120+ return_unused_kwargs = kwargs .pop ('return_unused_kwargs' , False )
108121
109122 if pretrained_model_name_or_path in cls .pretrained_config_archive_map :
110123 config_file = cls .pretrained_config_archive_map [pretrained_model_name_or_path ]
@@ -148,7 +161,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
148161 kwargs .pop (key , None )
149162
150163 logger .info ("Model config %s" , config )
151- return config
164+ if return_unused_kwargs :
165+ return config , kwargs
166+ else :
167+ return config
152168
153169 @classmethod
154170 def from_dict (cls , json_object ):
@@ -305,7 +321,7 @@ def save_pretrained(self, save_directory):
305321 torch .save (model_to_save .state_dict (), output_model_file )
306322
307323 @classmethod
308- def from_pretrained (cls , pretrained_model_name_or_path , * inputs , ** kwargs ):
324+ def from_pretrained (cls , pretrained_model_name_or_path , * model_args , ** kwargs ):
309325 r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
310326
311327 The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated)
@@ -322,6 +338,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
322338 provided as `config` argument. This loading option is slower than converting the TensorFlow
323339 checkpoint in a PyTorch model using the provided conversion scripts and loading
324340 the PyTorch model afterwards.
341+ **model_args**: (`optional`) Sequence:
342+ All remaning positional arguments will be passed to the underlying model's __init__ function
325343 **config**: an optional configuration for the model to use instead of an automatically loaded configuation.
326344 Configuration can be automatically loaded when:
327345 - the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
@@ -337,8 +355,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
337355 **output_loading_info**: (`optional`) boolean:
338356 Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
339357 **kwargs**: (`optional`) dict:
340- Dictionnary of key, values to update the configuration object after loading.
341- Can be used to override selected configuration parameters. E.g. ``output_attention=True``
358+ Dictionary of key, values to update the configuration object after loading.
359+ Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
360+
361+ - If a configuration is provided with `config`, **kwargs will be directly passed
362+ to the underlying model's __init__ method.
363+ - If a configuration is not provided, **kwargs will be first passed to the pretrained
364+ model configuration class loading function (`PretrainedConfig.from_pretrained`).
365+ Each key of **kwargs that corresponds to a configuration attribute
366+ will be used to override said attribute with the supplied **kwargs value.
367+ Remaining keys that do not correspond to any configuration attribute will
368+ be passed to the underlying model's __init__ function.
342369
343370 Examples::
344371
@@ -359,7 +386,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
359386
360387 # Load config
361388 if config is None :
362- config = cls .config_class .from_pretrained (pretrained_model_name_or_path , * inputs , ** kwargs )
389+ config , model_kwargs = cls .config_class .from_pretrained (
390+ pretrained_model_name_or_path , * model_args ,
391+ cache_dir = cache_dir , return_unused_kwargs = True ,
392+ ** kwargs
393+ )
394+ else :
395+ model_kwargs = kwargs
363396
364397 # Load model
365398 if pretrained_model_name_or_path in cls .pretrained_model_archive_map :
@@ -400,7 +433,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
400433 archive_file , resolved_archive_file ))
401434
402435 # Instantiate model.
403- model = cls (config )
436+ model = cls (config , * model_args , ** model_kwargs )
404437
405438 if state_dict is None and not from_tf :
406439 state_dict = torch .load (resolved_archive_file , map_location = 'cpu' )
@@ -530,7 +563,7 @@ def forward(self, hidden_states, start_states=None, start_positions=None, p_mask
530563 **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
531564 hidden states of the first tokens for the labeled span.
532565 **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
533- position of the first token for the labeled span:
566+ position of the first token for the labeled span:
534567 **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
535568 Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
536569 1.0 means token should be masked.
@@ -717,7 +750,7 @@ class SequenceSummary(nn.Module):
717750 - 'attn' => Not implemented now, use multi-head attention
718751 summary_use_proj: Add a projection after the vector extraction
719752 summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
720- summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
753+ summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
721754 summary_first_dropout: Add a dropout before the projection and activation
722755 summary_last_dropout: Add a dropout after the projection and activation
723756 """
0 commit comments