Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ Text Embedding
- ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`, etc.
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
- ✅︎
- ✅︎
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
Expand All @@ -363,6 +363,13 @@ Text Embedding
.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.

.. note::
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
You can set `--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.

On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention
despite being described otherwise on its model card.

Reward Modeling
---------------

Expand Down Expand Up @@ -606,10 +613,10 @@ Text Generation
| :sup:`+` Multiple items can be inputted per text prompt for this modality.

.. note::
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
vLLM currently only supports adding LoRA to the language backbone of multimodal models.

.. note::
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630

Multimodal Embedding
Expand Down
12 changes: 10 additions & 2 deletions tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
marks=[pytest.mark.core_model]),
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
],
)
@pytest.mark.parametrize("dtype", ["half"])
Expand All @@ -31,6 +32,10 @@ def test_models(
model,
dtype: str,
) -> None:
vllm_extra_kwargs = {}
if model == "Alibaba-NLP/gte-Qwen2-7B-instruct":
vllm_extra_kwargs["hf_overrides"] = {"is_causal": False}

# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
Expand All @@ -43,8 +48,11 @@ def test_models(
is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)

with vllm_runner(model, task="embedding", dtype=dtype,
max_model_len=None) as vllm_model:
with vllm_runner(model,
task="embedding",
dtype=dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
Expand Down
4 changes: 2 additions & 2 deletions tests/models/embedding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def check_embeddings_close(
dim=0)

fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{embeddings_0!r}"
f"\n{name_1}:\t{embeddings_1!r}")
f"\n{name_0}:\t{embeddings_0[:16]!r}"
f"\n{name_1}:\t{embeddings_1[:16]!r}")

assert sim >= 1 - tol, fail_msg
15 changes: 11 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
identity, print_warning_once, resolve_obj_by_qualname)
print_warning_once, resolve_obj_by_qualname)

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
Expand Down Expand Up @@ -182,7 +182,7 @@ def __init__(
hf_overrides_fn = hf_overrides
else:
hf_overrides_kw = hf_overrides
hf_overrides_fn = identity
hf_overrides_fn = None

if rope_scaling is not None:
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
Expand Down Expand Up @@ -211,8 +211,15 @@ def __init__(
self.skip_tokenizer_init = skip_tokenizer_init

hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, config_format, **hf_overrides_kw)
hf_config = hf_overrides_fn(hf_config)
code_revision, config_format)

if hf_overrides_kw:
logger.info("Overriding HF config with %s", hf_overrides_kw)
hf_config.update(hf_overrides_kw)
if hf_overrides_fn:
logger.info("Overriding HF config with %s", hf_overrides_fn)
hf_config = hf_overrides_fn(hf_config)

self.hf_config = hf_config

self.hf_text_config = get_hf_text_config(self.hf_config)
Expand Down
20 changes: 18 additions & 2 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch import nn
from transformers import Qwen2Config

from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -164,11 +164,17 @@ def forward(
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=attn_type)
output, _ = self.o_proj(attn_output)
return output

Expand Down Expand Up @@ -210,6 +216,15 @@ def __init__(
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
self._attn_type = AttentionType.DECODER
else:
self._attn_type = AttentionType.ENCODER_ONLY

def forward(
self,
positions: torch.Tensor,
Expand All @@ -230,6 +245,7 @@ def forward(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
attn_type=self._attn_type,
)

# Fully Connected
Expand Down