|
39 | 39 | TF_WEIGHTS_NAME = 'model.ckpt' |
40 | 40 |
|
41 | 41 |
|
| 42 | +try: |
| 43 | + from torch.nn import Identity |
| 44 | +except ImportError: |
| 45 | + # Older PyTorch compatibility |
| 46 | + class Identity(nn.Module): |
| 47 | + r"""A placeholder identity operator that is argument-insensitive. |
| 48 | + """ |
| 49 | + def __init__(self, *args, **kwargs): |
| 50 | + super(Identity, self).__init__() |
| 51 | + |
| 52 | + def forward(self, input): |
| 53 | + return input |
| 54 | + |
| 55 | + |
42 | 56 | if not six.PY2: |
43 | 57 | def add_start_docstrings(*docstr): |
44 | 58 | def docstring_decorator(fn): |
@@ -764,23 +778,23 @@ def __init__(self, config): |
764 | 778 | # We can probably just use the multi-head attention module of PyTorch >=1.1.0 |
765 | 779 | raise NotImplementedError |
766 | 780 |
|
767 | | - self.summary = nn.Identity() |
| 781 | + self.summary = Identity() |
768 | 782 | if hasattr(config, 'summary_use_proj') and config.summary_use_proj: |
769 | 783 | if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: |
770 | 784 | num_classes = config.num_labels |
771 | 785 | else: |
772 | 786 | num_classes = config.hidden_size |
773 | 787 | self.summary = nn.Linear(config.hidden_size, num_classes) |
774 | 788 |
|
775 | | - self.activation = nn.Identity() |
| 789 | + self.activation = Identity() |
776 | 790 | if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': |
777 | 791 | self.activation = nn.Tanh() |
778 | 792 |
|
779 | | - self.first_dropout = nn.Identity() |
| 793 | + self.first_dropout = Identity() |
780 | 794 | if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: |
781 | 795 | self.first_dropout = nn.Dropout(config.summary_first_dropout) |
782 | 796 |
|
783 | | - self.last_dropout = nn.Identity() |
| 797 | + self.last_dropout = Identity() |
784 | 798 | if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: |
785 | 799 | self.last_dropout = nn.Dropout(config.summary_last_dropout) |
786 | 800 |
|
|
0 commit comments