Skip to content

Commit 067923d

Browse files
authored
Merge pull request #873 from huggingface/identity_replacement
Add nn.Identity replacement for old PyTorch
2 parents 368670a + 1383c7b commit 067923d

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

pytorch_transformers/modeling_utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@
3939
TF_WEIGHTS_NAME = 'model.ckpt'
4040

4141

42+
try:
43+
from torch.nn import Identity
44+
except ImportError:
45+
# Older PyTorch compatibility
46+
class Identity(nn.Module):
47+
r"""A placeholder identity operator that is argument-insensitive.
48+
"""
49+
def __init__(self, *args, **kwargs):
50+
super(Identity, self).__init__()
51+
52+
def forward(self, input):
53+
return input
54+
55+
4256
if not six.PY2:
4357
def add_start_docstrings(*docstr):
4458
def docstring_decorator(fn):
@@ -764,23 +778,23 @@ def __init__(self, config):
764778
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
765779
raise NotImplementedError
766780

767-
self.summary = nn.Identity()
781+
self.summary = Identity()
768782
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
769783
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
770784
num_classes = config.num_labels
771785
else:
772786
num_classes = config.hidden_size
773787
self.summary = nn.Linear(config.hidden_size, num_classes)
774788

775-
self.activation = nn.Identity()
789+
self.activation = Identity()
776790
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
777791
self.activation = nn.Tanh()
778792

779-
self.first_dropout = nn.Identity()
793+
self.first_dropout = Identity()
780794
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
781795
self.first_dropout = nn.Dropout(config.summary_first_dropout)
782796

783-
self.last_dropout = nn.Identity()
797+
self.last_dropout = Identity()
784798
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
785799
self.last_dropout = nn.Dropout(config.summary_last_dropout)
786800

0 commit comments

Comments
 (0)