Skip to content

Commit 8513741

Browse files
authored
Merge pull request #17 from lukovnikov/master
activation function in BERTIntermediate
2 parents 5cd8d7a + 470076e commit 8513741

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

modeling.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torch
2626
import torch.nn as nn
2727
from torch.nn import CrossEntropyLoss
28+
from six import string_types
2829

2930
def gelu(x):
3031
"""Implementation of the gelu activation function.
@@ -34,6 +35,13 @@ def gelu(x):
3435
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
3536

3637

38+
def swish(x):
39+
return x * torch.sigmoid(x)
40+
41+
42+
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
43+
44+
3745
class BertConfig(object):
3846
"""Configuration class to store the configuration of a `BertModel`.
3947
"""
@@ -60,7 +68,7 @@ def __init__(self,
6068
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
6169
layer in the Transformer encoder.
6270
hidden_act: The non-linear activation function (function or string) in the
63-
encoder and pooler.
71+
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
6472
hidden_dropout_prob: The dropout probabilitiy for all fully connected
6573
layers in the embeddings, encoder, and pooler.
6674
attention_probs_dropout_prob: The dropout ratio for the attention
@@ -237,7 +245,8 @@ class BERTIntermediate(nn.Module):
237245
def __init__(self, config):
238246
super(BERTIntermediate, self).__init__()
239247
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
240-
self.intermediate_act_fn = gelu
248+
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
249+
if isinstance(config.hidden_act, string_types) else config.hidden_act
241250

242251
def forward(self, hidden_states):
243252
hidden_states = self.dense(hidden_states)

0 commit comments

Comments
 (0)