Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,16 @@ Text Embedding
- :code:`BAAI/bge-multilingual-gemma2`, etc.
-
- ✅︎
* - :code:`MistralModel`
- Mistral-based
* - :code:`LlamaModel`, :code:`LlamaForCausalLM`, :code:`MistralModel`, etc.
- Llama-based
- :code:`intfloat/e5-mistral-7b-instruct`, etc.
- ✅︎
- ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`, etc.
- ✅︎
- ✅︎

.. important::
Some model architectures support both generation and embedding tasks.
Expand All @@ -355,7 +360,7 @@ Reward Modeling
* - :code:`Qwen2ForRewardModel`
- Qwen2-based
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
-
- ✅︎
- ✅︎

.. note::
Expand All @@ -376,7 +381,7 @@ Classification
* - :code:`Qwen2ForSequenceClassification`
- Qwen2-based
- :code:`jason9693/Qwen2.5-1.5B-apeach`, etc.
-
- ✅︎
- ✅︎

.. note::
Expand Down
2 changes: 2 additions & 0 deletions tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
"intfloat/e5-mistral-7b-instruct",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-multilingual-gemma2",
"ssmits/Qwen2-7B-Instruct-embed-base",
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
]

ENCODER_ONLY = [
Expand Down
116 changes: 104 additions & 12 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,17 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
Expand Down Expand Up @@ -248,6 +250,19 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()

# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = {} is less than "
"`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
))

self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
Expand Down Expand Up @@ -413,17 +428,7 @@ def __init__(
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = {} is less than "
"`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
))
pooler_config = vllm_config.model_config.pooler_config

self.config = config
self.lora_config = lora_config
Expand All @@ -445,6 +450,15 @@ def __init__(

self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()

# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
Comment on lines +448 to +451
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be able to be controlled by the pooling args we spoke about offline?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - these are the model's default values which can be overridden.


self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand Down Expand Up @@ -477,10 +491,88 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
loader.load_weights(weights)


class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config

self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.model = Qwen2Model(config, cache_config, quant_config)

self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."])
loader.load_weights(weights)
14 changes: 2 additions & 12 deletions vllm/model_executor/models/qwen2_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

from .interfaces import SupportsLoRA, SupportsPP
from .utils import AutoWeightsLoader


class Qwen2ForSequenceClassification(nn.Module):
class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -54,17 +55,6 @@ def __init__(
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = {} is less than "
"`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
))

self.config = config
self.lora_config = lora_config
Expand Down
15 changes: 2 additions & 13 deletions vllm/model_executor/models/qwen2_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

from .interfaces import SupportsPP
from .interfaces import SupportsLoRA, SupportsPP
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader

Expand All @@ -32,7 +32,7 @@ def forward(self, input):
return self.activation(input)


class Qwen2ForRewardModel(nn.Module, SupportsPP):
class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -66,17 +66,6 @@ def __init__(
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = {} is less than "
"`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
))

self.config = config
self.lora_config = lora_config
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
},
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
# [Multimodal]
Expand Down