2525import torch
2626import torch .nn as nn
2727from torch .nn import CrossEntropyLoss
28+ from six import string_types
2829
2930def gelu (x ):
3031 """Implementation of the gelu activation function.
@@ -34,6 +35,13 @@ def gelu(x):
3435 return x * 0.5 * (1.0 + torch .erf (x / math .sqrt (2.0 )))
3536
3637
38+ def swish (x ):
39+ return x * torch .sigmoid (x )
40+
41+
42+ ACT2FN = {"gelu" : gelu , "relu" : torch .nn .functional .relu , "swish" : swish }
43+
44+
3745class BertConfig (object ):
3846 """Configuration class to store the configuration of a `BertModel`.
3947 """
@@ -60,7 +68,7 @@ def __init__(self,
6068 intermediate_size: The size of the "intermediate" (i.e., feed-forward)
6169 layer in the Transformer encoder.
6270 hidden_act: The non-linear activation function (function or string) in the
63- encoder and pooler.
71+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
6472 hidden_dropout_prob: The dropout probabilitiy for all fully connected
6573 layers in the embeddings, encoder, and pooler.
6674 attention_probs_dropout_prob: The dropout ratio for the attention
@@ -237,7 +245,8 @@ class BERTIntermediate(nn.Module):
237245 def __init__ (self , config ):
238246 super (BERTIntermediate , self ).__init__ ()
239247 self .dense = nn .Linear (config .hidden_size , config .intermediate_size )
240- self .intermediate_act_fn = gelu
248+ self .intermediate_act_fn = ACT2FN [config .hidden_act ] \
249+ if isinstance (config .hidden_act , string_types ) else config .hidden_act
241250
242251 def forward (self , hidden_states ):
243252 hidden_states = self .dense (hidden_states )
0 commit comments