Skip to content

Commit a82fe7b

Browse files
committed
revert softmax inside the pooledr
Signed-off-by: Kevin-Yang <[email protected]>
1 parent 3acf28b commit a82fe7b

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

vllm/model_executor/layers/pooler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,15 @@ class Pooler(nn.Module):
2828
normalize: Whether to normalize the pooled data.
2929
"""
3030

31-
def __init__(self, pooling_type: PoolingType, normalize: bool):
31+
def __init__(self,
32+
pooling_type: PoolingType,
33+
normalize: bool,
34+
softmax: bool = False):
3235
super().__init__()
3336

3437
self.pooling_type = pooling_type
3538
self.normalize = normalize
39+
self.softmax = softmax
3640

3741
def forward(
3842
self,
@@ -64,6 +68,9 @@ def forward(
6468
if self.normalize:
6569
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
6670

71+
if self.softmax:
72+
pooled_data = nn.functional.softmax(pooled_data, dim=-1)
73+
6774
pooled_outputs = [
6875
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
6976
]

vllm/model_executor/models/qwen2_cls.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def __init__(
7777
self.score = RowParallelLinear(config.hidden_size,
7878
config.num_labels,
7979
quant_config=quant_config)
80-
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
80+
self._pooler = Pooler(pooling_type=PoolingType.LAST,
81+
normalize=False,
82+
softmax=True)
8183

8284
def forward(
8385
self,
@@ -97,8 +99,7 @@ def pooler(
9799
hidden_states: torch.Tensor,
98100
pooling_metadata: PoolingMetadata,
99101
) -> Optional[PoolerOutput]:
100-
pooled = self._pooler(hidden_states, pooling_metadata)
101-
return nn.functional.softmax(pooled, dim=-1)
102+
return self._pooler(hidden_states, pooling_metadata)
102103

103104
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
104105
loader = AutoWeightsLoader(self,

0 commit comments

Comments
 (0)