File tree Expand file tree Collapse file tree 2 files changed +12
-4
lines changed
Expand file tree Collapse file tree 2 files changed +12
-4
lines changed Original file line number Diff line number Diff line change @@ -28,11 +28,15 @@ class Pooler(nn.Module):
2828 normalize: Whether to normalize the pooled data.
2929 """
3030
31- def __init__ (self , pooling_type : PoolingType , normalize : bool ):
31+ def __init__ (self ,
32+ pooling_type : PoolingType ,
33+ normalize : bool ,
34+ softmax : bool = False ):
3235 super ().__init__ ()
3336
3437 self .pooling_type = pooling_type
3538 self .normalize = normalize
39+ self .softmax = softmax
3640
3741 def forward (
3842 self ,
@@ -64,6 +68,9 @@ def forward(
6468 if self .normalize :
6569 pooled_data = nn .functional .normalize (pooled_data , p = 2 , dim = 1 )
6670
71+ if self .softmax :
72+ pooled_data = nn .functional .softmax (pooled_data , dim = - 1 )
73+
6774 pooled_outputs = [
6875 EmbeddingSequenceGroupOutput (data .tolist ()) for data in pooled_data
6976 ]
Original file line number Diff line number Diff line change @@ -77,7 +77,9 @@ def __init__(
7777 self .score = RowParallelLinear (config .hidden_size ,
7878 config .num_labels ,
7979 quant_config = quant_config )
80- self ._pooler = Pooler (pooling_type = PoolingType .LAST , normalize = False )
80+ self ._pooler = Pooler (pooling_type = PoolingType .LAST ,
81+ normalize = False ,
82+ softmax = True )
8183
8284 def forward (
8385 self ,
@@ -97,8 +99,7 @@ def pooler(
9799 hidden_states : torch .Tensor ,
98100 pooling_metadata : PoolingMetadata ,
99101 ) -> Optional [PoolerOutput ]:
100- pooled = self ._pooler (hidden_states , pooling_metadata )
101- return nn .functional .softmax (pooled , dim = - 1 )
102+ return self ._pooler (hidden_states , pooling_metadata )
102103
103104 def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
104105 loader = AutoWeightsLoader (self ,
You can’t perform that action at this time.
0 commit comments