Skip to content

Commit 930f2cc

Browse files
DarkLight1337afeldman-nm
authored andcommitted
[Model] Support is_causal HF config field for Qwen2 model (vllm-project#10621)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Andrew Feldman <[email protected]>
1 parent 5dea1d5 commit 930f2cc

File tree

5 files changed

+51
-13
lines changed

5 files changed

+51
-13
lines changed

docs/source/models/supported_models.rst

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ Text Embedding
342342
- ✅︎
343343
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
344344
- Qwen2-based
345-
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`, etc.
345+
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
346346
- ✅︎
347347
- ✅︎
348348
* - :code:`RobertaModel`, :code:`RobertaForMaskedLM`
@@ -363,6 +363,13 @@ Text Embedding
363363
.. tip::
364364
You can override the model's pooling method by passing :code:`--override-pooler-config`.
365365

366+
.. note::
367+
Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention.
368+
You can set `--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly.
369+
370+
On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention
371+
despite being described otherwise on its model card.
372+
366373
Reward Modeling
367374
---------------
368375

@@ -606,10 +613,10 @@ Text Generation
606613
| :sup:`+` Multiple items can be inputted per text prompt for this modality.
607614
608615
.. note::
609-
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
616+
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
610617

611618
.. note::
612-
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
619+
The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
613620
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
614621

615622
Multimodal Embedding

tests/models/embedding/language/test_embedding.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
marks=[pytest.mark.core_model]),
2222
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
2323
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
24+
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
2425
],
2526
)
2627
@pytest.mark.parametrize("dtype", ["half"])
@@ -31,6 +32,10 @@ def test_models(
3132
model,
3233
dtype: str,
3334
) -> None:
35+
vllm_extra_kwargs = {}
36+
if model == "Alibaba-NLP/gte-Qwen2-7B-instruct":
37+
vllm_extra_kwargs["hf_overrides"] = {"is_causal": False}
38+
3439
# The example_prompts has ending "\n", for example:
3540
# "Write a short story about a robot that dreams for the first time.\n"
3641
# sentence_transformers will strip the input texts, see:
@@ -43,8 +48,11 @@ def test_models(
4348
is_sentence_transformer=True) as hf_model:
4449
hf_outputs = hf_model.encode(example_prompts)
4550

46-
with vllm_runner(model, task="embedding", dtype=dtype,
47-
max_model_len=None) as vllm_model:
51+
with vllm_runner(model,
52+
task="embedding",
53+
dtype=dtype,
54+
max_model_len=None,
55+
**vllm_extra_kwargs) as vllm_model:
4856
vllm_outputs = vllm_model.encode(example_prompts)
4957
# This test is for verifying whether the model's extra_repr
5058
# can be printed correctly.

tests/models/embedding/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def check_embeddings_close(
2424
dim=0)
2525

2626
fail_msg = (f"Test{prompt_idx}:"
27-
f"\n{name_0}:\t{embeddings_0!r}"
28-
f"\n{name_1}:\t{embeddings_1!r}")
27+
f"\n{name_0}:\t{embeddings_0[:16]!r}"
28+
f"\n{name_1}:\t{embeddings_1[:16]!r}")
2929

3030
assert sim >= 1 - tol, fail_msg

vllm/config.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
get_hf_text_config, get_pooling_config,
2828
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
2929
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
30-
identity, print_warning_once, resolve_obj_by_qualname)
30+
print_warning_once, resolve_obj_by_qualname)
3131

3232
if TYPE_CHECKING:
3333
from ray.util.placement_group import PlacementGroup
@@ -183,7 +183,7 @@ def __init__(
183183
hf_overrides_fn = hf_overrides
184184
else:
185185
hf_overrides_kw = hf_overrides
186-
hf_overrides_fn = identity
186+
hf_overrides_fn = None
187187

188188
if rope_scaling is not None:
189189
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
@@ -212,8 +212,15 @@ def __init__(
212212
self.skip_tokenizer_init = skip_tokenizer_init
213213

214214
hf_config = get_config(self.model, trust_remote_code, revision,
215-
code_revision, config_format, **hf_overrides_kw)
216-
hf_config = hf_overrides_fn(hf_config)
215+
code_revision, config_format)
216+
217+
if hf_overrides_kw:
218+
logger.info("Overriding HF config with %s", hf_overrides_kw)
219+
hf_config.update(hf_overrides_kw)
220+
if hf_overrides_fn:
221+
logger.info("Overriding HF config with %s", hf_overrides_fn)
222+
hf_config = hf_overrides_fn(hf_config)
223+
217224
self.hf_config = hf_config
218225

219226
self.hf_text_config = get_hf_text_config(self.hf_config)

vllm/model_executor/models/qwen2.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from torch import nn
2828
from transformers import Qwen2Config
2929

30-
from vllm.attention import Attention, AttentionMetadata
30+
from vllm.attention import Attention, AttentionMetadata, AttentionType
3131
from vllm.compilation.decorators import support_torch_compile
3232
from vllm.config import CacheConfig, VllmConfig
3333
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -164,11 +164,17 @@ def forward(
164164
hidden_states: torch.Tensor,
165165
kv_cache: torch.Tensor,
166166
attn_metadata: AttentionMetadata,
167+
attn_type: str = AttentionType.DECODER,
167168
) -> torch.Tensor:
168169
qkv, _ = self.qkv_proj(hidden_states)
169170
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
170171
q, k = self.rotary_emb(positions, q, k)
171-
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
172+
attn_output = self.attn(q,
173+
k,
174+
v,
175+
kv_cache,
176+
attn_metadata,
177+
attn_type=attn_type)
172178
output, _ = self.o_proj(attn_output)
173179
return output
174180

@@ -210,6 +216,15 @@ def __init__(
210216
self.post_attention_layernorm = RMSNorm(config.hidden_size,
211217
eps=config.rms_norm_eps)
212218

219+
# By default, Qwen2 uses causal attention as it is a decoder-only model.
220+
# You can override the HF config with `is_causal=False` to enable
221+
# bidirectional attention, which is used in some embedding models
222+
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
223+
if getattr(config, "is_causal", True):
224+
self._attn_type = AttentionType.DECODER
225+
else:
226+
self._attn_type = AttentionType.ENCODER_ONLY
227+
213228
def forward(
214229
self,
215230
positions: torch.Tensor,
@@ -230,6 +245,7 @@ def forward(
230245
hidden_states=hidden_states,
231246
kv_cache=kv_cache,
232247
attn_metadata=attn_metadata,
248+
attn_type=self._attn_type,
233249
)
234250

235251
# Fully Connected

0 commit comments

Comments
 (0)