Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
if allow_missing_keys:
missing_keys.append(name)
continue
elif tf_model.authorized_missing_keys is not None:
elif tf_model._keys_to_ignore_on_load is not None:
# authorized missing keys don't have to be loaded
if any(re.search(pat, name) is not None for pat in tf_model.authorized_missing_keys):
if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load):
continue

raise AttributeError("{} not found in PyTorch model".format(name))
Expand Down Expand Up @@ -209,8 +209,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a

unexpected_keys = list(all_pytorch_weights)

if tf_model.authorized_missing_keys is not None:
for pat in tf_model.authorized_missing_keys:
if tf_model._keys_to_ignore_on_load is not None:
for pat in tf_model._keys_to_ignore_on_load:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if tf_model.authorized_unexpected_keys is not None:
for pat in tf_model.authorized_unexpected_keys:
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
- **_keys_to_ignore_on_load** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
from the model when loading the model weights (and avoid unnecessary warnings).
- **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to
ignore from the weights when loading the model weights (and avoid unnecessary warnings).
"""
config_class = None
base_model_prefix = ""
authorized_missing_keys = None
_keys_to_ignore_on_load = None
authorized_unexpected_keys = None

@property
Expand Down Expand Up @@ -742,8 +742,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

model(model.dummy_inputs, training=False) # Make sure restore ops are run

if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
if cls._keys_to_ignore_on_load is not None:
for pat in cls._keys_to_ignore_on_load:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]

if cls.authorized_unexpected_keys is not None:
Expand Down
18 changes: 9 additions & 9 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,17 +404,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):

- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
- **_keys_to_ignore_on_load** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
when loading the model (and avoid unnecessary warnings).
- **keys_to_never_save** (:obj:`Optional[List[str]]`) -- A list of of tensor names to ignore when saving the
model (useful for keys that aren't trained, but which are deterministic)
- **_keys_to_ignore_on_save** (:obj:`Optional[List[str]]`) -- A list of of tensor names to ignore when saving
the model (useful for keys that aren't trained, but which are deterministic)

"""
config_class = None
base_model_prefix = ""
authorized_missing_keys = None
_keys_to_ignore_on_load = None
authorized_unexpected_keys = None
keys_to_never_save = None
_keys_to_ignore_on_save = None

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -719,8 +719,8 @@ def save_pretrained(self, save_directory):
state_dict = model_to_save.state_dict()

# Handle the case where some state_dict keys shouldn't be saved
if self.keys_to_never_save is not None:
state_dict = {k: v for k, v in state_dict.items() if k not in self.keys_to_never_save}
if self._keys_to_ignore_on_save is not None:
state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save}

# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
Expand Down Expand Up @@ -1034,8 +1034,8 @@ def load(module: nn.Module, prefix=""):

# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
if cls._keys_to_ignore_on_load is not None:
for pat in cls._keys_to_ignore_on_load:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]

if cls.authorized_unexpected_keys is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ class AlbertPreTrainedModel(PreTrainedModel):

config_class = AlbertConfig
base_model_prefix = "albert"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load = [r"position_ids"]

def _init_weights(self, module):
"""Initialize the weights."""
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ def call(self, pooled_output, training: bool):
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):

authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load = [r"pooler"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
Expand Down Expand Up @@ -1013,7 +1013,7 @@ def call(
)
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):

authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load = [r"pooler"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
Expand Down Expand Up @@ -1100,7 +1100,7 @@ def call(
)
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):

authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load = [r"pooler"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def get_output_embeddings(self):
)
class BartForConditionalGeneration(PretrainedBartModel):
base_model_prefix = "model"
authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
_keys_to_ignore_on_load = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]

def __init__(self, config: BartConfig):
super().__init__(config)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,7 @@ def get_output_embeddings(self):
)
class TFBartForConditionalGeneration(TFPretrainedBartModel):
base_model_prefix = "model"
authorized_missing_keys = [
_keys_to_ignore_on_load = [
r"final_logits_bias",
]
authorized_unexpected_keys = [
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ class BertPreTrainedModel(PreTrainedModel):
config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load = [r"position_ids"]

def _init_weights(self, module):
""" Initialize the weights """
Expand Down Expand Up @@ -970,7 +970,7 @@ def forward(
class BertLMHeadModel(BertPreTrainedModel):

authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load = [r"position_ids", r"predictions.decoder.bias"]

def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -1088,7 +1088,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_
class BertForMaskedLM(BertPreTrainedModel):

authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load = [r"position_ids", r"predictions.decoder.bias"]

def __init__(self, config):
super().__init__(config)
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/bert/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def call(
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):

authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load = [r"pooler"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
Expand Down Expand Up @@ -1024,7 +1024,7 @@ def call(
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):

authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load = [r"pooler"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
Expand Down Expand Up @@ -1417,7 +1417,7 @@ def call(
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):

authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load = [r"pooler"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
Expand Down Expand Up @@ -1503,7 +1503,7 @@ def call(
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):

authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load = [r"pooler"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel):

config_class = BertGenerationConfig
base_model_prefix = "bert"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load = [r"position_ids"]

def _init_weights(self, module):
""" Initialize the weights """
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ class DebertaPreTrainedModel(PreTrainedModel):

config_class = DebertaConfig
base_model_prefix = "deberta"
authorized_missing_keys = ["position_ids"]
_keys_to_ignore_on_load = ["position_ids"]

def _init_weights(self, module):
""" Initialize the weights """
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/dpr/modeling_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "ctx_encoder"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load = [r"position_ids"]

def init_weights(self):
self.ctx_encoder.init_weights()
Expand All @@ -294,7 +294,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "question_encoder"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load = [r"position_ids"]

def init_weights(self):
self.question_encoder.init_weights()
Expand All @@ -309,7 +309,7 @@ class DPRPretrainedReader(PreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "span_predictor"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load = [r"position_ids"]

def init_weights(self):
self.span_predictor.encoder.init_weights()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ class ElectraPreTrainedModel(PreTrainedModel):
config_class = ElectraConfig
load_tf_weights = load_tf_weights_in_electra
base_model_prefix = "electra"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load = [r"position_ids"]
authorized_unexpected_keys = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]

# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/fsmt/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,11 +1005,11 @@ def set_output_embeddings(self, value):
)
class FSMTForConditionalGeneration(PretrainedFSMTModel):
base_model_prefix = "model"
authorized_missing_keys = [
_keys_to_ignore_on_load = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
keys_to_never_save = [
_keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def custom_forward(*inputs):
GPT2_START_DOCSTRING,
)
class GPT2LMHeadModel(GPT2PreTrainedModel):
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
_keys_to_ignore_on_load = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]

def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -975,7 +975,7 @@ def forward(
GPT2_START_DOCSTRING,
)
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
_keys_to_ignore_on_load = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]

def __init__(self, config):
super().__init__(config)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/layoutlm/modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):

config_class = LayoutLMConfig
base_model_prefix = "layoutlm"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load = [r"position_ids"]

def _init_weights(self, module):
""" Initialize the weights """
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/longformer/modeling_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,7 @@ class LongformerPreTrainedModel(PreTrainedModel):

config_class = LongformerConfig
base_model_prefix = "longformer"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load = [r"position_ids"]

def _init_weights(self, module):
""" Initialize the weights """
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/longformer/modeling_tf_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,7 +1961,7 @@ def call(self, inputs, **kwargs):
)
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):

authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load = [r"pooler"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
Expand Down Expand Up @@ -2048,7 +2048,7 @@ def call(
)
class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):

authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load = [r"pooler"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
Expand Down Expand Up @@ -2199,7 +2199,7 @@ def call(self, hidden_states, training=False):
)
class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):

authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load = [r"pooler"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
Expand Down Expand Up @@ -2443,7 +2443,7 @@ def call(
)
class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):

authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load = [r"pooler"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ class MarianMTModel(BartForConditionalGeneration):

"""
config_class = MarianConfig
authorized_missing_keys = [
_keys_to_ignore_on_load = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
keys_to_never_save = [
_keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/marian/modeling_tf_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

@add_start_docstrings("Marian model for machine translation", START_DOCSTRING)
class TFMarianMTModel(TFBartForConditionalGeneration):
authorized_missing_keys = [
_keys_to_ignore_on_load = [
r"model.encoder.embed_positions.weight",
r"model.decoder.embed_positions.weight",
]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ class MBartForConditionalGeneration(BartForConditionalGeneration):
"""
model_type = "mbart"
config_class = MBartConfig
authorized_missing_keys = [
_keys_to_ignore_on_load = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
keys_to_never_save = [
_keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
2 changes: 1 addition & 1 deletion src/transformers/models/mobilebert/modeling_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ class MobileBertPreTrainedModel(PreTrainedModel):
pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
load_tf_weights = load_tf_weights_in_mobilebert
base_model_prefix = "mobilebert"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load = [r"position_ids"]

def _init_weights(self, module):
""" Initialize the weights """
Expand Down
Loading