Skip to content
Merged
281 changes: 281 additions & 0 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,287 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
return token_type_ids


class BertMLMHead(nn.Module):
def __init__(
self, hidden_size: int, vocab_size: int, layer_norm_eps: float = 1e-12
):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.GELU()
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.decoder = nn.Linear(hidden_size, vocab_size, bias=True)

def tie_weights_with_embeddings(self, embeddings_weight: torch.Tensor):
self.decoder.weight = embeddings_weight

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self.dense(hidden_states)
x = self.activation(x)
x = self.layer_norm(x)
logits = self.decoder(x)
return logits


class SPLADESparsePooler(Pooler):
"""
SPLADE sparse pooling:
logits = mlm_head(hidden_states)
-> log1p(relu(logits))
-> (max|sum over L)
-> [V]

Padding is masked with an attention mask,
[CLS]/[SEP] is removed (selected),
and then pooled.
"""

def __init__(
self,
mlm_head: nn.Module,
cls_token_id: Optional[int] = 101,
sep_token_id: Optional[int] = 102,
pooling: str = "max",
remove_cls_sep: bool = True,
):
super().__init__()
assert pooling in ("max", "sum")
self.mlm_head = mlm_head
self.cls_token_id = cls_token_id
self.sep_token_id = sep_token_id
self.pooling = pooling
self.remove_cls_sep = remove_cls_sep

def get_supported_tasks(self) -> Set[PoolingTask]:
return {"embed"}

def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)

@torch.no_grad()
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]:
if isinstance(hidden_states, torch.Tensor):
hs_list = [hidden_states]
else:
hs_list = list(hidden_states)

for i, hs in enumerate(hs_list):
if hs.dim() == 3 and hs.size(0) == 1:
hs_list[i] = hs.squeeze(0) # [L, H]
elif hs.dim() != 2:
raise ValueError(f"Expected [L,H] or [1,L,H], got {tuple(hs.shape)}")

B = len(hs_list)
H = hs_list[0].size(-1)

raw_lens = getattr(pooling_metadata, "prompt_lens", None)

def _fallback_lens_from_hs():
return [int(h.size(0)) for h in hs_list]

if raw_lens is None:
lens = _fallback_lens_from_hs()
elif isinstance(raw_lens, int):
lens = [int(raw_lens)] * B
else:
try:
tmp = list(raw_lens)
if len(tmp) == B:
lens = [int(x) for x in tmp]
elif len(tmp) == 1:
lens = [int(tmp[0])] * B
else:
lens = _fallback_lens_from_hs()
except TypeError:
lens = _fallback_lens_from_hs()

max_len = max(int(h.size(0)) for h in hs_list)
device = hs_list[0].device

# pad to [B, T, H]
padded = hs_list[0].new_zeros((B, max_len, H)) # zeros
attn_mask = torch.zeros((B, max_len), dtype=torch.bool, device=device)

for i, (hs, L) in enumerate(zip(hs_list, lens)):
L = int(L)
L = min(L, max_len)
padded[i, :L] = hs[:L]
attn_mask[i, :L] = True

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Pooler ignores batching layout and drops extra requests

The new SPLADESparsePooler.forward wraps the incoming hidden_states tensor into a single item whenever it is a 2‑D tensor (lines 638‑649) and never consults the pooling_metadata.pooling_cursor that encodes how multiple requests are concatenated. In the vLLM runner, embeddings are pooled from a single [total_tokens, hidden] tensor containing all prompts in a batch. With the current logic only the first prompt in the batch is padded and pooled while the remaining prompts are silently ignored, causing incorrect or missing embeddings whenever more than one request is processed together. The pooler should use pooling_cursor (as done in SimplePooler) to split the tensor per request before applying the MLM head.

Useful? React with 👍 / 👎.

# [CLS]/[SEP] remove(Optional)
token_ids = getattr(pooling_metadata, "prompt_token_ids", None) # [B,T] or None
if self.remove_cls_sep and token_ids is not None:
for i, L in enumerate(lens):
L = int(min(L, max_len))
if L <= 0:
continue
if (
self.cls_token_id is not None
and int(token_ids[i, 0].item()) == self.cls_token_id
):
attn_mask[i, 0] = False
if (
self.sep_token_id is not None
and int(token_ids[i, L - 1].item()) == self.sep_token_id
):
attn_mask[i, L - 1] = False

# MLM logits: [B, T, V]
B, T, _ = padded.shape
flat = padded.reshape(B * T, H) # [B*T, H]
logits = self.mlm_head(flat) # [B*T, V]
V = int(logits.size(-1))
logits = logits.view(B, T, V) # [B, T, V]

# SPLADE activation
scores = torch.log1p(torch.relu(logits)) # [B, T, V]

# pooling after masking
if self.pooling == "sum":
pooled = (scores * attn_mask.float().unsqueeze(-1)).sum(dim=1) # [B, V]
else:
neg_inf = torch.tensor(
float("-inf"), device=scores.device, dtype=scores.dtype
)
masked = scores.masked_fill(~attn_mask.unsqueeze(-1), neg_inf) # [B, T, V]
pooled = masked.max(dim=1).values # [B, V]
pooled = torch.where(
torch.isneginf(pooled), torch.zeros_like(pooled), pooled
)

outs: list[torch.Tensor] = []
for i in range(B):
vec = pooled[i].to(torch.float32).contiguous().view(-1) # [V]
outs.append(vec)

return outs


@default_pooling_type("CLS")
class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
"""
BertEmbeddingModel + SPLADE sparse embedding.
- Make logits by self.mlm_head
- pooler: SPLADESparsePooler(mlm_head...)
"""

def __init__(
self, *, vllm_config: VllmConfig, prefix: str = "", splade_pooling: str = "max"
):
super().__init__(vllm_config=vllm_config, prefix=prefix)
cfg = vllm_config.model_config.hf_config

# MLM head
self.mlm_head = BertMLMHead(
hidden_size=cfg.hidden_size,
vocab_size=cfg.vocab_size,
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
)

self._splade_pooling = splade_pooling
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = self._build_pooler(pooler_config)

def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
cfg = self.model.config

if not hasattr(self, "mlm_head"):
self.mlm_head = BertMLMHead(
hidden_size=cfg.hidden_size,
vocab_size=cfg.vocab_size,
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
)

pooling_mode = getattr(self, "_splade_pooling", "max")

cls_id = getattr(cfg, "cls_token_id", None)
sep_id = getattr(cfg, "sep_token_id", None)

return DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"embed": SPLADESparsePooler(
mlm_head=self.mlm_head,
cls_token_id=cls_id,
sep_token_id=sep_id,
pooling=pooling_mode, # "max" or "sum"
remove_cls_sep=True,
),
}
)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
if not hasattr(self, "mlm_head"):
cfg = self.model.config
self.mlm_head = BertMLMHead(
hidden_size=cfg.hidden_size,
vocab_size=cfg.vocab_size,
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
)

weights_list = list(weights)
loaded: set[str] = set()

model_side: list[tuple[str, torch.Tensor]] = []
for k, w in weights_list:
if k.startswith("cls.predictions."):
continue
name = k
if name.startswith("model."):
name = name[len("model.") :]
if name.startswith("bert."):
name = name[len("bert.") :]
model_side.append((name, w))

other, stacked = self.model._load_weights(model_side)
loaded.update({"model." + n for n in stacked})

other_prefixed = [("model." + n, w) for (n, w) in other]
loader_top = AutoWeightsLoader(
self, skip_prefixes=["pooler.", "mlm_head.", "lm_head."]
)
loaded_other = loader_top.load_weights(other_prefixed)
loaded.update(loaded_other)

name_map = {
"cls.predictions.transform.dense.weight": "mlm_head.dense.weight",
"cls.predictions.transform.dense.bias": "mlm_head.dense.bias",
"cls.predictions.transform.LayerNorm.weight": "mlm_head.layer_norm.weight",
"cls.predictions.transform.LayerNorm.bias": "mlm_head.layer_norm.bias",
"cls.predictions.decoder.weight": "mlm_head.decoder.weight",
"cls.predictions.decoder.bias": "mlm_head.decoder.bias",
}
extras: list[tuple[str, torch.Tensor]] = []
for k, w in weights_list:
name = k
if name.startswith("model."):
name = name[len("model.") :]
if name.startswith("bert."):
name = name[len("bert.") :]
tgt = name_map.get(name)
if tgt is not None:
extras.append((tgt, w))

if extras:
mlm_loader = AutoWeightsLoader(self)
loaded_mlm = mlm_loader.load_weights(extras)
loaded.update(loaded_mlm)

try:
emb_w = self.model.embeddings.word_embeddings.weight
dec_w = self.mlm_head.decoder.weight
if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr():
self.mlm_head.decoder.weight = emb_w
except Exception:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The try...except Exception: pass block is too broad and can hide important errors during weight loading. For instance, if self.model.embeddings or other attributes do not exist due to a model structure mismatch, an AttributeError would be silently ignored, making debugging difficult. This could lead to weights not being tied when they should be, resulting in incorrect model behavior. It's better to catch more specific exceptions, like AttributeError, or at least log a warning if an exception occurs.

Suggested change
try:
emb_w = self.model.embeddings.word_embeddings.weight
dec_w = self.mlm_head.decoder.weight
if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr():
self.mlm_head.decoder.weight = emb_w
except Exception:
pass
try:
emb_w = self.model.embeddings.word_embeddings.weight
dec_w = self.mlm_head.decoder.weight
if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr():
self.mlm_head.decoder.weight = emb_w
except AttributeError:
# It's possible that some BERT variants may not have this structure.
# If we can't find the weights to tie, it's not a critical
# error, as the model can still function with untied weights.
pass


return loaded


@default_pooling_type("CLS")
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
_EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
Expand Down