-
-
Notifications
You must be signed in to change notification settings - Fork 13k
[Feature] Add support for naver/splade-v3 (BERT-based sparse embedding model) #26339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
3827c27
657860b
693c658
415137d
0c22312
3276ca4
b220766
706a735
06392ab
893b570
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -573,6 +573,287 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: | |||||||||||||||||||||||||||||||||||
| return token_type_ids | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| class BertMLMHead(nn.Module): | ||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||
| self, hidden_size: int, vocab_size: int, layer_norm_eps: float = 1e-12 | ||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||
| self.dense = nn.Linear(hidden_size, hidden_size) | ||||||||||||||||||||||||||||||||||||
| self.activation = nn.GELU() | ||||||||||||||||||||||||||||||||||||
| self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) | ||||||||||||||||||||||||||||||||||||
| self.decoder = nn.Linear(hidden_size, vocab_size, bias=True) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def tie_weights_with_embeddings(self, embeddings_weight: torch.Tensor): | ||||||||||||||||||||||||||||||||||||
| self.decoder.weight = embeddings_weight | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||
| x = self.dense(hidden_states) | ||||||||||||||||||||||||||||||||||||
| x = self.activation(x) | ||||||||||||||||||||||||||||||||||||
| x = self.layer_norm(x) | ||||||||||||||||||||||||||||||||||||
| logits = self.decoder(x) | ||||||||||||||||||||||||||||||||||||
| return logits | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| class SPLADESparsePooler(Pooler): | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| SPLADE sparse pooling: | ||||||||||||||||||||||||||||||||||||
| logits = mlm_head(hidden_states) | ||||||||||||||||||||||||||||||||||||
| -> log1p(relu(logits)) | ||||||||||||||||||||||||||||||||||||
| -> (max|sum over L) | ||||||||||||||||||||||||||||||||||||
| -> [V] | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Padding is masked with an attention mask, | ||||||||||||||||||||||||||||||||||||
| [CLS]/[SEP] is removed (selected), | ||||||||||||||||||||||||||||||||||||
| and then pooled. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||
| mlm_head: nn.Module, | ||||||||||||||||||||||||||||||||||||
| cls_token_id: Optional[int] = 101, | ||||||||||||||||||||||||||||||||||||
| sep_token_id: Optional[int] = 102, | ||||||||||||||||||||||||||||||||||||
| pooling: str = "max", | ||||||||||||||||||||||||||||||||||||
| remove_cls_sep: bool = True, | ||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||
| assert pooling in ("max", "sum") | ||||||||||||||||||||||||||||||||||||
| self.mlm_head = mlm_head | ||||||||||||||||||||||||||||||||||||
| self.cls_token_id = cls_token_id | ||||||||||||||||||||||||||||||||||||
| self.sep_token_id = sep_token_id | ||||||||||||||||||||||||||||||||||||
| self.pooling = pooling | ||||||||||||||||||||||||||||||||||||
| self.remove_cls_sep = remove_cls_sep | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def get_supported_tasks(self) -> Set[PoolingTask]: | ||||||||||||||||||||||||||||||||||||
| return {"embed"} | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: | ||||||||||||||||||||||||||||||||||||
| return PoolingParamsUpdate(requires_token_ids=True) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| @torch.no_grad() | ||||||||||||||||||||||||||||||||||||
| def forward( | ||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||
| hidden_states: Union[torch.Tensor, list[torch.Tensor]], | ||||||||||||||||||||||||||||||||||||
| pooling_metadata: PoolingMetadata, | ||||||||||||||||||||||||||||||||||||
| ) -> Union[torch.Tensor, list[torch.Tensor]]: | ||||||||||||||||||||||||||||||||||||
| if isinstance(hidden_states, torch.Tensor): | ||||||||||||||||||||||||||||||||||||
| hs_list = [hidden_states] | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| hs_list = list(hidden_states) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| for i, hs in enumerate(hs_list): | ||||||||||||||||||||||||||||||||||||
| if hs.dim() == 3 and hs.size(0) == 1: | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| hs_list[i] = hs.squeeze(0) # [L, H] | ||||||||||||||||||||||||||||||||||||
| elif hs.dim() != 2: | ||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Expected [L,H] or [1,L,H], got {tuple(hs.shape)}") | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| B = len(hs_list) | ||||||||||||||||||||||||||||||||||||
| H = hs_list[0].size(-1) | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| raw_lens = getattr(pooling_metadata, "prompt_lens", None) | ||||||||||||||||||||||||||||||||||||
DarkLight1337 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def _fallback_lens_from_hs(): | ||||||||||||||||||||||||||||||||||||
| return [int(h.size(0)) for h in hs_list] | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if raw_lens is None: | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| lens = _fallback_lens_from_hs() | ||||||||||||||||||||||||||||||||||||
| elif isinstance(raw_lens, int): | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| lens = [int(raw_lens)] * B | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| tmp = list(raw_lens) | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| if len(tmp) == B: | ||||||||||||||||||||||||||||||||||||
| lens = [int(x) for x in tmp] | ||||||||||||||||||||||||||||||||||||
| elif len(tmp) == 1: | ||||||||||||||||||||||||||||||||||||
| lens = [int(tmp[0])] * B | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| lens = _fallback_lens_from_hs() | ||||||||||||||||||||||||||||||||||||
| except TypeError: | ||||||||||||||||||||||||||||||||||||
| lens = _fallback_lens_from_hs() | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| max_len = max(int(h.size(0)) for h in hs_list) | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| device = hs_list[0].device | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # pad to [B, T, H] | ||||||||||||||||||||||||||||||||||||
| padded = hs_list[0].new_zeros((B, max_len, H)) # zeros | ||||||||||||||||||||||||||||||||||||
| attn_mask = torch.zeros((B, max_len), dtype=torch.bool, device=device) | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| for i, (hs, L) in enumerate(zip(hs_list, lens)): | ||||||||||||||||||||||||||||||||||||
| L = int(L) | ||||||||||||||||||||||||||||||||||||
| L = min(L, max_len) | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| padded[i, :L] = hs[:L] | ||||||||||||||||||||||||||||||||||||
| attn_mask[i, :L] = True | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
| # [CLS]/[SEP] remove(Optional) | ||||||||||||||||||||||||||||||||||||
| token_ids = getattr(pooling_metadata, "prompt_token_ids", None) # [B,T] or None | ||||||||||||||||||||||||||||||||||||
| if self.remove_cls_sep and token_ids is not None: | ||||||||||||||||||||||||||||||||||||
| for i, L in enumerate(lens): | ||||||||||||||||||||||||||||||||||||
| L = int(min(L, max_len)) | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| if L <= 0: | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||
| self.cls_token_id is not None | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| and int(token_ids[i, 0].item()) == self.cls_token_id | ||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||
| attn_mask[i, 0] = False | ||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||
| self.sep_token_id is not None | ||||||||||||||||||||||||||||||||||||
| and int(token_ids[i, L - 1].item()) == self.sep_token_id | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||
| attn_mask[i, L - 1] = False | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # MLM logits: [B, T, V] | ||||||||||||||||||||||||||||||||||||
| B, T, _ = padded.shape | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| flat = padded.reshape(B * T, H) # [B*T, H] | ||||||||||||||||||||||||||||||||||||
| logits = self.mlm_head(flat) # [B*T, V] | ||||||||||||||||||||||||||||||||||||
| V = int(logits.size(-1)) | ||||||||||||||||||||||||||||||||||||
| logits = logits.view(B, T, V) # [B, T, V] | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # SPLADE activation | ||||||||||||||||||||||||||||||||||||
| scores = torch.log1p(torch.relu(logits)) # [B, T, V] | ||||||||||||||||||||||||||||||||||||
gjgjos marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # pooling after masking | ||||||||||||||||||||||||||||||||||||
| if self.pooling == "sum": | ||||||||||||||||||||||||||||||||||||
| pooled = (scores * attn_mask.float().unsqueeze(-1)).sum(dim=1) # [B, V] | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| neg_inf = torch.tensor( | ||||||||||||||||||||||||||||||||||||
| float("-inf"), device=scores.device, dtype=scores.dtype | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| masked = scores.masked_fill(~attn_mask.unsqueeze(-1), neg_inf) # [B, T, V] | ||||||||||||||||||||||||||||||||||||
| pooled = masked.max(dim=1).values # [B, V] | ||||||||||||||||||||||||||||||||||||
| pooled = torch.where( | ||||||||||||||||||||||||||||||||||||
| torch.isneginf(pooled), torch.zeros_like(pooled), pooled | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| outs: list[torch.Tensor] = [] | ||||||||||||||||||||||||||||||||||||
| for i in range(B): | ||||||||||||||||||||||||||||||||||||
| vec = pooled[i].to(torch.float32).contiguous().view(-1) # [V] | ||||||||||||||||||||||||||||||||||||
| outs.append(vec) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| return outs | ||||||||||||||||||||||||||||||||||||
DarkLight1337 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| @default_pooling_type("CLS") | ||||||||||||||||||||||||||||||||||||
| class BertSpladeSparseEmbeddingModel(BertEmbeddingModel): | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| BertEmbeddingModel + SPLADE sparse embedding. | ||||||||||||||||||||||||||||||||||||
| - Make logits by self.mlm_head | ||||||||||||||||||||||||||||||||||||
| - pooler: SPLADESparsePooler(mlm_head...) | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||
| self, *, vllm_config: VllmConfig, prefix: str = "", splade_pooling: str = "max" | ||||||||||||||||||||||||||||||||||||
DarkLight1337 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||
| super().__init__(vllm_config=vllm_config, prefix=prefix) | ||||||||||||||||||||||||||||||||||||
| cfg = vllm_config.model_config.hf_config | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # MLM head | ||||||||||||||||||||||||||||||||||||
| self.mlm_head = BertMLMHead( | ||||||||||||||||||||||||||||||||||||
| hidden_size=cfg.hidden_size, | ||||||||||||||||||||||||||||||||||||
| vocab_size=cfg.vocab_size, | ||||||||||||||||||||||||||||||||||||
| layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| self._splade_pooling = splade_pooling | ||||||||||||||||||||||||||||||||||||
| pooler_config = vllm_config.model_config.pooler_config | ||||||||||||||||||||||||||||||||||||
| assert pooler_config is not None | ||||||||||||||||||||||||||||||||||||
| self.pooler = self._build_pooler(pooler_config) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: | ||||||||||||||||||||||||||||||||||||
| cfg = self.model.config | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if not hasattr(self, "mlm_head"): | ||||||||||||||||||||||||||||||||||||
| self.mlm_head = BertMLMHead( | ||||||||||||||||||||||||||||||||||||
| hidden_size=cfg.hidden_size, | ||||||||||||||||||||||||||||||||||||
| vocab_size=cfg.vocab_size, | ||||||||||||||||||||||||||||||||||||
| layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| pooling_mode = getattr(self, "_splade_pooling", "max") | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| cls_id = getattr(cfg, "cls_token_id", None) | ||||||||||||||||||||||||||||||||||||
| sep_id = getattr(cfg, "sep_token_id", None) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| return DispatchPooler( | ||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||
| "encode": Pooler.for_encode(pooler_config), | ||||||||||||||||||||||||||||||||||||
| "embed": SPLADESparsePooler( | ||||||||||||||||||||||||||||||||||||
| mlm_head=self.mlm_head, | ||||||||||||||||||||||||||||||||||||
| cls_token_id=cls_id, | ||||||||||||||||||||||||||||||||||||
| sep_token_id=sep_id, | ||||||||||||||||||||||||||||||||||||
| pooling=pooling_mode, # "max" or "sum" | ||||||||||||||||||||||||||||||||||||
| remove_cls_sep=True, | ||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | ||||||||||||||||||||||||||||||||||||
| if not hasattr(self, "mlm_head"): | ||||||||||||||||||||||||||||||||||||
| cfg = self.model.config | ||||||||||||||||||||||||||||||||||||
| self.mlm_head = BertMLMHead( | ||||||||||||||||||||||||||||||||||||
| hidden_size=cfg.hidden_size, | ||||||||||||||||||||||||||||||||||||
| vocab_size=cfg.vocab_size, | ||||||||||||||||||||||||||||||||||||
| layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| weights_list = list(weights) | ||||||||||||||||||||||||||||||||||||
| loaded: set[str] = set() | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| model_side: list[tuple[str, torch.Tensor]] = [] | ||||||||||||||||||||||||||||||||||||
| for k, w in weights_list: | ||||||||||||||||||||||||||||||||||||
| if k.startswith("cls.predictions."): | ||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||
| name = k | ||||||||||||||||||||||||||||||||||||
| if name.startswith("model."): | ||||||||||||||||||||||||||||||||||||
| name = name[len("model.") :] | ||||||||||||||||||||||||||||||||||||
| if name.startswith("bert."): | ||||||||||||||||||||||||||||||||||||
| name = name[len("bert.") :] | ||||||||||||||||||||||||||||||||||||
| model_side.append((name, w)) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| other, stacked = self.model._load_weights(model_side) | ||||||||||||||||||||||||||||||||||||
| loaded.update({"model." + n for n in stacked}) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| other_prefixed = [("model." + n, w) for (n, w) in other] | ||||||||||||||||||||||||||||||||||||
| loader_top = AutoWeightsLoader( | ||||||||||||||||||||||||||||||||||||
| self, skip_prefixes=["pooler.", "mlm_head.", "lm_head."] | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| loaded_other = loader_top.load_weights(other_prefixed) | ||||||||||||||||||||||||||||||||||||
| loaded.update(loaded_other) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| name_map = { | ||||||||||||||||||||||||||||||||||||
| "cls.predictions.transform.dense.weight": "mlm_head.dense.weight", | ||||||||||||||||||||||||||||||||||||
| "cls.predictions.transform.dense.bias": "mlm_head.dense.bias", | ||||||||||||||||||||||||||||||||||||
| "cls.predictions.transform.LayerNorm.weight": "mlm_head.layer_norm.weight", | ||||||||||||||||||||||||||||||||||||
| "cls.predictions.transform.LayerNorm.bias": "mlm_head.layer_norm.bias", | ||||||||||||||||||||||||||||||||||||
| "cls.predictions.decoder.weight": "mlm_head.decoder.weight", | ||||||||||||||||||||||||||||||||||||
| "cls.predictions.decoder.bias": "mlm_head.decoder.bias", | ||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||
| extras: list[tuple[str, torch.Tensor]] = [] | ||||||||||||||||||||||||||||||||||||
| for k, w in weights_list: | ||||||||||||||||||||||||||||||||||||
| name = k | ||||||||||||||||||||||||||||||||||||
| if name.startswith("model."): | ||||||||||||||||||||||||||||||||||||
| name = name[len("model.") :] | ||||||||||||||||||||||||||||||||||||
| if name.startswith("bert."): | ||||||||||||||||||||||||||||||||||||
| name = name[len("bert.") :] | ||||||||||||||||||||||||||||||||||||
| tgt = name_map.get(name) | ||||||||||||||||||||||||||||||||||||
| if tgt is not None: | ||||||||||||||||||||||||||||||||||||
| extras.append((tgt, w)) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if extras: | ||||||||||||||||||||||||||||||||||||
| mlm_loader = AutoWeightsLoader(self) | ||||||||||||||||||||||||||||||||||||
| loaded_mlm = mlm_loader.load_weights(extras) | ||||||||||||||||||||||||||||||||||||
| loaded.update(loaded_mlm) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| emb_w = self.model.embeddings.word_embeddings.weight | ||||||||||||||||||||||||||||||||||||
| dec_w = self.mlm_head.decoder.weight | ||||||||||||||||||||||||||||||||||||
| if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr(): | ||||||||||||||||||||||||||||||||||||
| self.mlm_head.decoder.weight = emb_w | ||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
| try: | |
| emb_w = self.model.embeddings.word_embeddings.weight | |
| dec_w = self.mlm_head.decoder.weight | |
| if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr(): | |
| self.mlm_head.decoder.weight = emb_w | |
| except Exception: | |
| pass | |
| try: | |
| emb_w = self.model.embeddings.word_embeddings.weight | |
| dec_w = self.mlm_head.decoder.weight | |
| if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr(): | |
| self.mlm_head.decoder.weight = emb_w | |
| except AttributeError: | |
| # It's possible that some BERT variants may not have this structure. | |
| # If we can't find the weights to tie, it's not a critical | |
| # error, as the model can still function with untied weights. | |
| pass |
Uh oh!
There was an error while loading. Please reload this page.