|
| 1 | +import itertools |
1 | 2 | from typing import Iterable, List, Optional, Tuple |
2 | 3 |
|
3 | 4 | import torch |
|
20 | 21 | from .interfaces import SupportsCrossEncoding |
21 | 22 |
|
22 | 23 |
|
| 24 | +def roberta_task_weights_filter( |
| 25 | + all_weights: Iterable[Tuple[str, torch.Tensor]] |
| 26 | +) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str, |
| 27 | + torch.Tensor]]]: |
| 28 | + """ |
| 29 | + Separate task-specific weights that are applied on top |
| 30 | + of the encoder-decoder bert base. |
| 31 | + To do so, return two generators over the original iterator. |
| 32 | + Also, remove the "roberta." prefix to make it loadable |
| 33 | + from vanilla BertModel. |
| 34 | + """ |
| 35 | + # Copy of a lazy iterator without in-memory overhead so both |
| 36 | + # iterators can be iterated upon independently. |
| 37 | + all_weights1, all_weights2 = itertools.tee(all_weights) |
| 38 | + |
| 39 | + def encoder_decoder_weights(): |
| 40 | + for name, weight in all_weights1: |
| 41 | + if name.startswith("roberta."): |
| 42 | + yield (name[len("roberta."):], weight) |
| 43 | + |
| 44 | + return encoder_decoder_weights(), ((n, w) for n, w in all_weights2 |
| 45 | + if not n.startswith("roberta.")) |
| 46 | + |
| 47 | + |
23 | 48 | class RobertaEmbedding(nn.Module): |
24 | 49 |
|
25 | 50 | def __init__(self, config: RobertaConfig): |
@@ -152,6 +177,18 @@ def _build_model(self, |
152 | 177 | prefix=prefix, |
153 | 178 | embedding_class=RobertaEmbedding) |
154 | 179 |
|
| 180 | + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
| 181 | + weights = self.hf_to_vllm_mapper.apply(weights) |
| 182 | + # Separate weights in "roberta"-prefixed and all else (not in memory). |
| 183 | + # For use with models like FacebookAI/roberta-base. |
| 184 | + bert_weights, task_weights = roberta_task_weights_filter(weights) |
| 185 | + loaded = self.model.load_weights(bert_weights) |
| 186 | + if not len(loaded): |
| 187 | + # Fix for models like `sentence-transformers/stsb-roberta-base-v2` |
| 188 | + # which use the same architecture, but have no "roberta" prefix. |
| 189 | + loaded = self.model.load_weights(task_weights) |
| 190 | + assert len(loaded), "Unable to load RobertaEmbeddingModel" |
| 191 | + |
155 | 192 |
|
156 | 193 | class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): |
157 | 194 | """A model that uses Roberta to provide embedding functionalities. |
@@ -181,20 +218,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
181 | 218 |
|
182 | 219 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
183 | 220 |
|
184 | | - self_weights = [] |
185 | | - |
186 | | - def weight_filter(): |
187 | | - for name, weight in weights: |
188 | | - if name.startswith("roberta."): |
189 | | - yield (name[len("roberta."):], weight) |
190 | | - else: |
191 | | - self_weights.append((name, weight)) |
192 | | - |
193 | | - self.roberta.load_weights(weight_filter()) |
| 221 | + bert_weights, task_weights = roberta_task_weights_filter(weights) |
| 222 | + self.roberta.load_weights(bert_weights) |
194 | 223 |
|
195 | 224 | params_dict = dict(self.named_parameters()) |
196 | 225 |
|
197 | | - for name, loaded_weight in self_weights: |
| 226 | + for name, loaded_weight in task_weights: |
198 | 227 | if name.startswith("classifier"): |
199 | 228 | param = params_dict[name] |
200 | 229 | weight_loader = getattr(param, "weight_loader", |
|
0 commit comments