Skip to content

Commit 71e5cd0

Browse files
committed
Update LoRA handling for Qwen2-VL to conform to Transformers
1 parent 07c11cf commit 71e5cd0

File tree

7 files changed

+49
-167
lines changed

7 files changed

+49
-167
lines changed

requirements-common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ numpy < 2.0.0
44
requests >= 2.26.0
55
tqdm
66
py-cpuinfo
7-
transformers >= 4.45.0 # Required for Llama 3.2.
7+
transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
88
tokenizers >= 0.19.1 # Required for Llama 3.
99
protobuf # Required by LlamaTokenizer.
1010
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'

vllm/config.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,16 +1723,25 @@ def _get_and_verify_max_len(
17231723

17241724
rope_scaling = getattr(hf_config, "rope_scaling", None)
17251725
if rope_scaling is not None:
1726+
# Backwards compatibility. Although HF prefers "rope_type", we still
1727+
# have code that accesses "type"
17261728
if "type" in rope_scaling:
1727-
rope_type = rope_scaling["type"]
1729+
rope_type = rope_scaling["rope_type"] = rope_scaling["type"]
17281730
elif "rope_type" in rope_scaling:
1729-
rope_type = rope_scaling["rope_type"]
1731+
rope_type = rope_scaling["type"] = rope_scaling["rope_type"]
17301732
else:
17311733
raise ValueError(
17321734
"rope_scaling must have a 'type' or 'rope_type' key.")
17331735

1734-
# The correct one should be "longrope", kept "su" here
1735-
# to be backward compatible
1736+
# Backwards compatibility
1737+
if rope_type == "su":
1738+
rope_scaling["rope_type"] = rope_type = "longrope"
1739+
logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
1740+
elif rope_type == "mrope":
1741+
assert "mrope_section" in rope_scaling
1742+
rope_scaling["rope_type"] = rope_type = "default"
1743+
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
1744+
17361745
if rope_type not in ("su", "longrope", "llama3"):
17371746
if disable_sliding_window:
17381747
# TODO(robertgshaw): Find a model that supports rope_scaling
@@ -1742,11 +1751,10 @@ def _get_and_verify_max_len(
17421751
"with rope_scaling. Please raise an issue so we can "
17431752
"investigate.")
17441753

1745-
if rope_type == "mrope":
1746-
scaling_factor = 1
1747-
else:
1748-
assert "factor" in rope_scaling
1749-
scaling_factor = rope_scaling["factor"]
1754+
# NOTE: rope_type == "default" does not define factor
1755+
# https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
1756+
scaling_factor = rope_scaling.get("factor", 1.0)
1757+
17501758
if rope_type == "yarn":
17511759
derived_max_model_len = rope_scaling[
17521760
"original_max_position_embeddings"]

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -922,11 +922,9 @@ def get_rope(
922922
else:
923923
scaling_type = rope_scaling[
924924
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
925-
# The correct one should be "longrope" but keep "su" here
926-
# for backward compatible
927-
if scaling_type not in {"su", "longrope"}:
928-
scaling_factor = rope_scaling.get("factor", 1.0)
925+
929926
if scaling_type == "llama3":
927+
scaling_factor = rope_scaling["factor"]
930928
low_freq_factor = rope_scaling["low_freq_factor"]
931929
high_freq_factor = rope_scaling["high_freq_factor"]
932930
original_max_position = rope_scaling[
@@ -937,16 +935,39 @@ def get_rope(
937935
scaling_factor, low_freq_factor,
938936
high_freq_factor,
939937
original_max_position)
938+
elif scaling_type == "default":
939+
if "mrope_section" in rope_scaling:
940+
rotary_emb = MRotaryEmbedding(
941+
head_size,
942+
rotary_dim,
943+
max_position,
944+
base,
945+
is_neox_style,
946+
dtype,
947+
mrope_section=rope_scaling["mrope_section"],
948+
)
949+
else:
950+
rotary_emb = RotaryEmbedding(
951+
head_size,
952+
rotary_dim,
953+
max_position,
954+
base,
955+
is_neox_style,
956+
dtype,
957+
)
940958
elif scaling_type == "linear":
959+
scaling_factor = rope_scaling["factor"]
941960
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
942961
max_position, base,
943962
is_neox_style,
944963
scaling_factor, dtype)
945964
elif scaling_type == "dynamic":
965+
scaling_factor = rope_scaling["factor"]
946966
rotary_emb = DynamicNTKScalingRotaryEmbedding(
947967
head_size, rotary_dim, max_position, base, is_neox_style,
948968
scaling_factor, dtype)
949969
elif scaling_type == "yarn":
970+
scaling_factor = rope_scaling["factor"]
950971
original_max_position = rope_scaling[
951972
"original_max_position_embeddings"]
952973
extra_kwargs = {
@@ -961,6 +982,7 @@ def get_rope(
961982
scaling_factor, dtype,
962983
**extra_kwargs)
963984
elif scaling_type == "deepseek_yarn":
985+
scaling_factor = rope_scaling["factor"]
964986
original_max_position = rope_scaling[
965987
"original_max_position_embeddings"]
966988
# assert max_position == original_max_position * scaling_factor
@@ -973,9 +995,7 @@ def get_rope(
973995
rotary_emb = DeepseekScalingRotaryEmbedding(
974996
head_size, rotary_dim, original_max_position, base,
975997
is_neox_style, scaling_factor, dtype, **extra_kwargs)
976-
# The correct one should be "longrope" but keep "su" here
977-
# for backward compatible
978-
elif scaling_type == "su" or scaling_type == "longrope":
998+
elif scaling_type == "longrope":
979999
short_factor = rope_scaling["short_factor"]
9801000
long_factor = rope_scaling["long_factor"]
9811001
original_max_position = rope_scaling[
@@ -989,16 +1009,6 @@ def get_rope(
9891009
head_size, rotary_dim, max_position, original_max_position,
9901010
base, is_neox_style, dtype, short_factor, long_factor,
9911011
**extra_kwargs)
992-
elif scaling_type == "mrope":
993-
rotary_emb = MRotaryEmbedding(
994-
head_size,
995-
rotary_dim,
996-
max_position,
997-
base,
998-
is_neox_style,
999-
dtype,
1000-
mrope_section=rope_scaling["mrope_section"],
1001-
)
10021012
else:
10031013
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
10041014
_ROPE_DICT[key] = rotary_emb

vllm/model_executor/models/qwen2_vl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from transformers.image_utils import (get_image_size,
3535
infer_channel_dimension_format,
3636
to_numpy_array)
37+
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
38+
Qwen2VLConfig, Qwen2VLVisionConfig)
3739
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
3840
make_batched_images, make_batched_videos, smart_resize)
3941

@@ -62,8 +64,6 @@
6264
from vllm.multimodal.image import cached_get_image_processor
6365
from vllm.platforms import current_platform
6466
from vllm.sequence import IntermediateTensors, SequenceData
65-
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
66-
Qwen2VLVisionConfig)
6767
from vllm.transformers_utils.processor import get_processor
6868
from vllm.utils import is_cpu
6969

vllm/transformers_utils/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
MedusaConfig, MllamaConfig,
2424
MLPSpeculatorConfig, MPTConfig,
2525
NemotronConfig, NVLM_D_Config,
26-
Qwen2VLConfig, RWConfig,
27-
SolarConfig, UltravoxConfig)
26+
RWConfig, SolarConfig,
27+
UltravoxConfig)
2828
# yapf: enable
2929
from vllm.transformers_utils.utils import check_gguf_file
3030

@@ -57,7 +57,6 @@
5757
"NVLM_D": NVLM_D_Config,
5858
"solar": SolarConfig,
5959
"ultravox": UltravoxConfig,
60-
"qwen2_vl": Qwen2VLConfig,
6160
**_CONFIG_REGISTRY_OVERRIDE_HF
6261
}
6362

vllm/transformers_utils/configs/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
from vllm.transformers_utils.configs.mpt import MPTConfig
1515
from vllm.transformers_utils.configs.nemotron import NemotronConfig
1616
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
17-
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
18-
Qwen2VLVisionConfig)
1917
from vllm.transformers_utils.configs.solar import SolarConfig
2018
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
2119

@@ -35,6 +33,4 @@
3533
"NVLM_D_Config",
3634
"SolarConfig",
3735
"UltravoxConfig",
38-
"Qwen2VLConfig",
39-
"Qwen2VLVisionConfig",
4036
]

vllm/transformers_utils/configs/qwen2vl.py

Lines changed: 0 additions & 131 deletions
This file was deleted.

0 commit comments

Comments
 (0)