Skip to content

Commit 368670a

Browse files
authored
Merge pull request #866 from xanlsh/master
Rework how PreTrainedModel.from_pretrained handles its arguments
2 parents 6070b55 + 4fb56c7 commit 368670a

1 file changed

Lines changed: 45 additions & 12 deletions

File tree

pytorch_transformers/modeling_utils.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def save_pretrained(self, save_directory):
7878
self.to_json_file(output_config_file)
7979

8080
@classmethod
81-
def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
81+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
8282
r""" Instantiate a PretrainedConfig from a pre-trained model configuration.
8383
8484
Params:
@@ -91,20 +91,33 @@ def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
9191
**cache_dir**: (`optional`) string:
9292
Path to a directory in which a downloaded pre-trained model
9393
configuration should be cached if the standard cache should not be used.
94+
**return_unused_kwargs**: (`optional`) bool:
95+
- If False, then this function returns just the final configuration object.
96+
- If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs`
97+
is a dictionary consisting of the key/value pairs whose keys are not configuration attributes:
98+
ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
9499
**kwargs**: (`optional`) dict:
95-
Dictionnary of key, values to update the configuration object after loading.
96-
Can be used to override selected configuration parameters.
100+
Dictionary of key/value pairs with which to update the configuration object after loading.
101+
- The values in kwargs of any keys which are configuration attributes will be used
102+
to override the loaded values.
103+
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
104+
by the `return_unused_kwargs` keyword parameter.
97105
98106
Examples::
99107
100108
>>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
101109
>>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
102110
>>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
103-
>>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True)
111+
>>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
104112
>>> assert config.output_attention == True
113+
>>> config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
114+
>>> foo=False, return_unused_kwargs=True)
115+
>>> assert config.output_attention == True
116+
>>> assert unused_kwargs == {'foo': False}
105117
106118
"""
107119
cache_dir = kwargs.pop('cache_dir', None)
120+
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
108121

109122
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
110123
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
@@ -148,7 +161,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
148161
kwargs.pop(key, None)
149162

150163
logger.info("Model config %s", config)
151-
return config
164+
if return_unused_kwargs:
165+
return config, kwargs
166+
else:
167+
return config
152168

153169
@classmethod
154170
def from_dict(cls, json_object):
@@ -305,7 +321,7 @@ def save_pretrained(self, save_directory):
305321
torch.save(model_to_save.state_dict(), output_model_file)
306322

307323
@classmethod
308-
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
324+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
309325
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
310326
311327
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated)
@@ -322,6 +338,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
322338
provided as `config` argument. This loading option is slower than converting the TensorFlow
323339
checkpoint in a PyTorch model using the provided conversion scripts and loading
324340
the PyTorch model afterwards.
341+
**model_args**: (`optional`) Sequence:
342+
All remaning positional arguments will be passed to the underlying model's __init__ function
325343
**config**: an optional configuration for the model to use instead of an automatically loaded configuation.
326344
Configuration can be automatically loaded when:
327345
- the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or
@@ -337,8 +355,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
337355
**output_loading_info**: (`optional`) boolean:
338356
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
339357
**kwargs**: (`optional`) dict:
340-
Dictionnary of key, values to update the configuration object after loading.
341-
Can be used to override selected configuration parameters. E.g. ``output_attention=True``
358+
Dictionary of key, values to update the configuration object after loading.
359+
Can be used to override selected configuration parameters. E.g. ``output_attention=True``.
360+
361+
- If a configuration is provided with `config`, **kwargs will be directly passed
362+
to the underlying model's __init__ method.
363+
- If a configuration is not provided, **kwargs will be first passed to the pretrained
364+
model configuration class loading function (`PretrainedConfig.from_pretrained`).
365+
Each key of **kwargs that corresponds to a configuration attribute
366+
will be used to override said attribute with the supplied **kwargs value.
367+
Remaining keys that do not correspond to any configuration attribute will
368+
be passed to the underlying model's __init__ function.
342369
343370
Examples::
344371
@@ -359,7 +386,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
359386

360387
# Load config
361388
if config is None:
362-
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
389+
config, model_kwargs = cls.config_class.from_pretrained(
390+
pretrained_model_name_or_path, *model_args,
391+
cache_dir=cache_dir, return_unused_kwargs=True,
392+
**kwargs
393+
)
394+
else:
395+
model_kwargs = kwargs
363396

364397
# Load model
365398
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
@@ -400,7 +433,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
400433
archive_file, resolved_archive_file))
401434

402435
# Instantiate model.
403-
model = cls(config)
436+
model = cls(config, *model_args, **model_kwargs)
404437

405438
if state_dict is None and not from_tf:
406439
state_dict = torch.load(resolved_archive_file, map_location='cpu')
@@ -530,7 +563,7 @@ def forward(self, hidden_states, start_states=None, start_positions=None, p_mask
530563
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states
531564
hidden states of the first tokens for the labeled span.
532565
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
533-
position of the first token for the labeled span:
566+
position of the first token for the labeled span:
534567
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
535568
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
536569
1.0 means token should be masked.
@@ -717,7 +750,7 @@ class SequenceSummary(nn.Module):
717750
- 'attn' => Not implemented now, use multi-head attention
718751
summary_use_proj: Add a projection after the vector extraction
719752
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
720-
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
753+
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
721754
summary_first_dropout: Add a dropout before the projection and activation
722755
summary_last_dropout: Add a dropout after the projection and activation
723756
"""

0 commit comments

Comments
 (0)