Skip to content

Commit 80a1dd4

Browse files
DarkLight1337afeldman-nm
authored andcommitted
[Model] Enable optional prefix when loading embedding models (vllm-project#10639)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
1 parent 5e36a52 commit 80a1dd4

5 files changed

Lines changed: 20 additions & 13 deletions

File tree

vllm/model_executor/models/bert.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,17 @@
1414
RowParallelLinear)
1515
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
1616
PoolingType)
17-
from vllm.model_executor.layers.quantization.base_config import (
18-
QuantizationConfig)
17+
from vllm.model_executor.layers.quantization import QuantizationConfig
1918
from vllm.model_executor.layers.vocab_parallel_embedding import (
2019
VocabParallelEmbedding)
2120
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22-
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
2321
from vllm.model_executor.pooling_metadata import PoolingMetadata
2422
from vllm.sequence import IntermediateTensors, PoolerOutput
2523
from vllm.transformers_utils.config import (
2624
get_cross_encoder_activation_function)
2725

28-
from .utils import maybe_prefix
26+
from .interfaces import SupportsCrossEncoding
27+
from .utils import WeightsMapper, maybe_prefix
2928

3029

3130
class BertEmbedding(nn.Module):
@@ -442,6 +441,8 @@ def pooler(
442441
return self._pooler(hidden_states, pooling_metadata)
443442

444443
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
444+
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
445+
weights = hf_to_vllm_mapper.apply(weights)
445446
self.model.load_weights(weights)
446447

447448
def _build_model(self,

vllm/model_executor/models/gemma2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from vllm.sequence import IntermediateTensors, PoolerOutput
4343

4444
from .interfaces import SupportsLoRA, SupportsPP
45-
from .utils import (AutoWeightsLoader, extract_layer_index,
45+
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
4646
is_pp_missing_parameter,
4747
make_empty_intermediate_tensors_factory, make_layers,
4848
maybe_prefix)
@@ -511,4 +511,6 @@ def pooler(
511511
return self._pooler(hidden_states, pooling_metadata)
512512

513513
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
514+
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
515+
weights = hf_to_vllm_mapper.apply(weights)
514516
self.model.load_weights(weights)

vllm/model_executor/models/llama.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@
5353
from vllm.sequence import IntermediateTensors, PoolerOutput
5454

5555
from .interfaces import SupportsLoRA, SupportsPP
56-
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
56+
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
57+
is_pp_missing_parameter,
5758
make_empty_intermediate_tensors_factory, make_layers,
5859
maybe_prefix)
5960

@@ -689,6 +690,8 @@ def pooler(
689690
return self._pooler(hidden_states, pooling_metadata)
690691

691692
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
693+
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
694+
weights = hf_to_vllm_mapper.apply(weights)
692695
self.model.load_weights(weights)
693696

694697
def load_kv_cache_scales(self, quantization_param_path: str) -> None:

vllm/model_executor/models/qwen2.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@
5050
from vllm.sequence import IntermediateTensors, PoolerOutput
5151

5252
from .interfaces import SupportsLoRA, SupportsPP
53-
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
53+
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
54+
is_pp_missing_parameter,
5455
make_empty_intermediate_tensors_factory, make_layers,
5556
maybe_prefix)
5657

@@ -585,8 +586,7 @@ def pooler(
585586
) -> Optional[PoolerOutput]:
586587
return self._pooler(hidden_states, pooling_metadata)
587588

588-
def load_weights(self, weights: Iterable[Tuple[str,
589-
torch.Tensor]]) -> Set[str]:
590-
loader = AutoWeightsLoader(self,
591-
ignore_unexpected_prefixes=["lm_head."])
592-
return loader.load_weights(weights)
589+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
590+
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
591+
weights = hf_to_vllm_mapper.apply(weights)
592+
self.model.load_weights(weights)

vllm/model_executor/models/roberta.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
VocabParallelEmbedding)
1212
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1313
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
14-
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
1514
from vllm.model_executor.models.utils import maybe_prefix
1615
from vllm.model_executor.pooling_metadata import PoolingMetadata
1716
from vllm.sequence import IntermediateTensors, PoolerOutput
1817
from vllm.transformers_utils.config import (
1918
get_cross_encoder_activation_function)
2019

20+
from .interfaces import SupportsCrossEncoding
21+
2122

2223
class RobertaEmbedding(nn.Module):
2324

0 commit comments

Comments
 (0)