Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
import collections
import gc
import inspect
import json
import os
import re
Expand Down Expand Up @@ -932,6 +933,15 @@ def floating_point_ops(
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


class BackboneBaseModel(nn.Module):
def forward_with_filtered_kwargs(self, *args, **kwargs):

signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}

return self(*args, **filtered_kwargs)


class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r"""
Base class for all models.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bit/modeling_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import BackboneBaseModel, PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
Expand Down Expand Up @@ -842,7 +842,7 @@ def forward(
""",
BIT_START_DOCSTRING,
)
class BitBackbone(BitPreTrainedModel):
class BitBackbone(BitPreTrainedModel, BackboneBaseModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really convenient that any backbone will need to inherit from 2 classes can we directly incorporate this into PreTrainedModel or not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, but just let @sgugger to confirm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see where the problem lies. That's the most Pythonic way of doing this (compared to having all models inherit from GenerationMixin for instance, whereas lots of them shouldn't). You give the type to objects that need it (and only those).

def __init__(self, config):
super().__init__(config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ...activations import ACT2FN
from ...file_utils import ModelOutput
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import BackboneBaseModel, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from .configuration_maskformer_swin import MaskFormerSwinConfig

Expand Down Expand Up @@ -837,7 +837,7 @@ def forward(
)


class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel):
class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneBaseModel):
"""
MaskFormerSwin backbone, designed especially for the MaskFormer framework.

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/resnet/modeling_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import BackboneBaseModel, PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
Expand Down Expand Up @@ -431,7 +431,7 @@ def forward(
""",
RESNET_START_DOCSTRING,
)
class ResNetBackbone(ResNetPreTrainedModel):
class ResNetBackbone(ResNetPreTrainedModel, BackboneBaseModel):
def __init__(self, config):
super().__init__(config)

Expand Down