Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
a62c2df
Rename `rope_scaling` -> `rope_parameters` in `get_rope`
hmellor Nov 12, 2025
f42b03d
Patch rope parameters to new name, `rope_parameters`
hmellor Nov 12, 2025
a2a9437
Update models where it's a simple rename
hmellor Nov 12, 2025
fba5bf5
Fix model config overrides
hmellor Nov 12, 2025
ee5cf66
Update examples
hmellor Nov 12, 2025
080530d
Update benchmarks
hmellor Nov 12, 2025
889b900
More renaming in transformers utils
hmellor Nov 12, 2025
50b1a87
Fix `patch_rope_parameters` for when `rope_scaling` was explicitly `N…
hmellor Nov 12, 2025
bd182e0
Update Gemma3 and Gemma3n
hmellor Nov 12, 2025
4c61e2e
Merge branch 'main' into update-rope-config
hmellor Nov 13, 2025
65c8658
Get `rope_theta` from the new location too
hmellor Nov 13, 2025
5d65739
Fix condition for non gemma3 models
hmellor Nov 13, 2025
b4e1967
Make Transformers backend torch compile check work with new rope params
hmellor Nov 13, 2025
ee77bd7
Re-enable a load of Transformers nightly tests which are now fixed
hmellor Nov 13, 2025
df4c007
Update the custom configs
hmellor Nov 13, 2025
325ff8d
Make sure scaling factor always exists
hmellor Nov 13, 2025
11c23a7
A couple more models that now init on v5
hmellor Nov 13, 2025
4ea113c
Update Commandr
hmellor Nov 13, 2025
59b0f27
Update Qwen3Next
hmellor Nov 13, 2025
064441b
Update Olmo2
hmellor Nov 13, 2025
bdd0e6c
rope_parameters always present because of rope_theta
hmellor Nov 13, 2025
f224ef4
Update LFM2MoE
hmellor Nov 13, 2025
19dcc18
Update LFM2
hmellor Nov 13, 2025
2eecd31
Update the rest
hmellor Nov 13, 2025
e95ccd4
update tests
hmellor Nov 13, 2025
f2bac15
Update configs
hmellor Nov 13, 2025
36e8a1f
Missed 2
hmellor Nov 13, 2025
dfa75cf
Improve comment about what `rope_parameters` is
hmellor Nov 13, 2025
708ea0c
Move scaling factor out of loop
hmellor Nov 13, 2025
4a28512
Early exit `patch_rope_parameters` if no rope params present
hmellor Nov 13, 2025
dfb476f
Be more explicit about v4 vs v5 behaviour
hmellor Nov 13, 2025
97bb339
Update a few models to not pass `base` outside of `rope_parameters`
hmellor Nov 13, 2025
97766f5
Update some more models
hmellor Nov 13, 2025
783962b
Update some more models
hmellor Nov 13, 2025
797fbea
Add back `type` -> `rope_type` for legacy custom models
hmellor Nov 13, 2025
b780892
More models
hmellor Nov 13, 2025
ad9dff2
Fix docs build
hmellor Nov 13, 2025
461ff94
Update some more models
hmellor Nov 13, 2025
fa2cced
Update some more models
hmellor Nov 13, 2025
4127d54
Remove last references to `base` arg of `get_rope`
hmellor Nov 13, 2025
1ebd0e4
Update mrope test
hmellor Nov 13, 2025
ec30fef
Check everything
hmellor Nov 13, 2025
6368078
fix
hmellor Nov 13, 2025
482f378
Merge branch 'main' into update-rope-config
hmellor Nov 14, 2025
d4b2fbb
Don't delete the legacy attributes when still using v4
hmellor Nov 14, 2025
1e68d27
Fix typo in commandr
hmellor Nov 14, 2025
db6a880
Fix typo in deepseek v2
hmellor Nov 14, 2025
26a51d4
Handle multimodal models where vision model uses RoPE
hmellor Nov 14, 2025
dd69244
Use new default value of rope_parameters in kernels test
hmellor Nov 14, 2025
132dc4b
Use `rope_parameters` instead of `base` in compile test
hmellor Nov 14, 2025
d7a6ded
Don't overwrite main config for v4 style Gemma 3
hmellor Nov 14, 2025
8ceffd6
Only raise for `disable_sliding_window` if the model actually has `sl…
hmellor Nov 14, 2025
08126a9
Fix arctic config docstring for docs
hmellor Nov 14, 2025
f1c3c33
Fix typo in gpt-oss
hmellor Nov 14, 2025
a2601ce
Remove disable_sliding_window errors
hmellor Nov 14, 2025
03d50e0
Fix olmo2
hmellor Nov 14, 2025
93827b6
Fix custom code mm models
hmellor Nov 14, 2025
3b3c233
Fix models with no rope info at all in their `config.json`
hmellor Nov 14, 2025
3f9ce07
Fix unaccounted for style of config
hmellor Nov 14, 2025
f1714ac
Hopefully final fix for multimodal rope overrides
hmellor Nov 15, 2025
981aac4
Fix condition for raising error
hmellor Nov 15, 2025
5c2f394
Only override `rope_type` to `deepseek_yarn` if it was not `default`
hmellor Nov 15, 2025
6c64ba5
Make 10000 the default base for `get_rope` if `rope_parameters == None`
hmellor Nov 15, 2025
6beee2b
Set all model defaults which are not 10000
hmellor Nov 15, 2025
002fb90
Update models which can default to 10000
hmellor Nov 15, 2025
99c5d47
Fix nemotron config
hmellor Nov 18, 2025
c38e8bb
Fix ernie 4.5 vl
hmellor Nov 18, 2025
eebe73c
Fix benchmarks/tests where `get_rope` is called with positional argum…
hmellor Nov 18, 2025
540a46b
Merge branch 'main' into update-rope-config
hmellor Nov 18, 2025
a60b5ec
Fix get_rope kwargs in vision transformers
hmellor Nov 18, 2025
00f2853
Update new model
hmellor Nov 18, 2025
717a704
Missed positional args
hmellor Nov 18, 2025
a9fa3b0
Fix nemotron config validation
hmellor Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -874,12 +874,12 @@ steps:
optional: true
commands:
- pip install --upgrade git+https://github.com/huggingface/transformers
- pytest -v -s tests/models/test_initialization.py -k 'not (Gemma3 or ModernBert or Qwen2_5_VL or Qwen2_5vl or Qwen2VL or TransformersMultiModalEmbeddingModel or TransformersMultiModalForSequenceClassification or Ultravox or Phi4Multimodal or LlavaNextVideo or MiniCPMO or Lfm2Moe or PaliGemma or RobertaForSequenceClassification or Ovis2_5 or Fuyu or DeepseekOCR or KimiVL)'
- pytest -v -s tests/models/test_initialization.py -k 'not (Ultravox or Phi4Multimodal or MiniCPMO or Lfm2Moe or RobertaForSequenceClassification or Ovis2_5 or DeepseekOCR or KimiVL)'
- pytest -v -s tests/models/test_transformers.py
# - pytest -v -s tests/models/multimodal/processing/
- pytest -v -s tests/models/multimodal/test_mapping.py -k 'not (Gemma3 or Qwen2VL or Qwen2_5_VL)'
- pytest -v -s tests/models/multimodal/test_mapping.py
- python3 examples/offline_inference/basic/chat.py
# - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
- python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
# Whisper needs spawn method to avoid deadlock
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper

Expand Down
18 changes: 9 additions & 9 deletions benchmarks/kernels/benchmark_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
#
# The CSV file (named with current date/time) contains these columns:
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99,
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
# speedup
# rope_theta, is_neox_style, rope_parameters, dtype, torch_mean, torch_median,
# torch_p99, torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min,
# triton_max, speedup
#
# == Usage Examples ==
#
Expand Down Expand Up @@ -88,7 +88,7 @@ def benchmark_mrope(
max_position: int = 8192,
rope_theta: float = 10000,
is_neox_style: bool = True,
rope_scaling: dict[str, Any] = None,
rope_parameters: dict[str, Any] = None,
dtype: torch.dtype = torch.bfloat16,
seed: int = 0,
warmup_iter: int = 10,
Expand All @@ -104,7 +104,7 @@ def benchmark_mrope(
max_position=max_position,
base=rope_theta,
is_neox_style=is_neox_style,
rope_scaling=rope_scaling,
rope_parameters=rope_parameters,
dtype=dtype,
).to(device=device)

Expand Down Expand Up @@ -205,7 +205,7 @@ def benchmark_mrope(
max_position,
rope_theta,
is_neox_style,
str(rope_scaling),
str(rope_parameters),
str(dtype).split(".")[-1],
torch_stats["mean"],
torch_stats["median"],
Expand Down Expand Up @@ -257,7 +257,7 @@ def benchmark_mrope(
"max_position",
"rope_theta",
"is_neox_style",
"rope_scaling",
"rope_parameters",
"dtype",
"torch_mean",
"torch_median",
Expand Down Expand Up @@ -303,7 +303,7 @@ def benchmark_mrope(
q_size = num_heads * head_dim
kv_size = num_kv_heads * head_dim
is_neox_style = True
rope_theta = config.rope_theta
rope_theta = config.rope_parameters["rope_theta"]
max_position = config.max_position_embeddings

for num_tokens in num_tokens_list:
Expand All @@ -317,7 +317,7 @@ def benchmark_mrope(
max_position=max_position,
rope_theta=rope_theta,
is_neox_style=is_neox_style,
rope_scaling=config.rope_scaling,
rope_parameters=config.rope_parameters,
dtype=getattr(torch, args.dtype),
seed=args.seed,
warmup_iter=args.warmup_iter,
Expand Down
4 changes: 2 additions & 2 deletions examples/offline_inference/context_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This script demonstrates how to extend the context length
of a Qwen model using the YARN method (rope_scaling)
of a Qwen model using the YARN method (rope_parameters)
and run a simple chat example.

Usage:
Expand All @@ -20,7 +20,7 @@ def create_llm():
# Use yarn to extend context
hf_overrides = {
"rope_theta": rope_theta,
"rope_scaling": {
"rope_parameters": {
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings": original_max_position_embeddings,
Expand Down
16 changes: 7 additions & 9 deletions tests/kernels/core/test_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import pytest
import torch
from packaging.version import Version
from transformers import AutoConfig
from transformers import __version__ as TRANSFORMERS_VERSION

from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -98,8 +98,7 @@ def test_mrope(
atol = model_info.atol
rtol = model_info.rtol

config = AutoConfig.from_pretrained(model_name)
config = config.get_text_config()
config = get_config(model_name, False).get_text_config()

# get the model config
total_num_kv_heads = config.num_key_value_heads
Expand All @@ -113,7 +112,7 @@ def test_mrope(
)
is_neox_style = True

rope_theta = config.rope_theta
rope_theta = config.rope_parameters["rope_theta"]
max_position = config.max_position_embeddings
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
rotary_dim = int(head_dim * partial_rotary_factor)
Expand All @@ -124,7 +123,7 @@ def test_mrope(
max_position=max_position,
base=rope_theta,
is_neox_style=is_neox_style,
rope_scaling=config.rope_scaling,
rope_parameters=config.rope_parameters,
dtype=dtype,
).to(device=device)

Expand Down Expand Up @@ -173,8 +172,7 @@ def test_mrope_torch_compile_tracing(
atol = model_info.atol
rtol = model_info.rtol

config = AutoConfig.from_pretrained(model_name)
config = config.get_text_config()
config = get_config(model_name, False).get_text_config()

# get the model config
total_num_kv_heads = config.num_key_value_heads
Expand All @@ -187,7 +185,7 @@ def test_mrope_torch_compile_tracing(
else config.hidden_size // total_num_heads
)
is_neox_style = True
rope_theta = config.rope_theta
rope_theta = config.rope_parameters["rope_theta"]
max_position = config.max_position_embeddings
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
rotary_dim = int(head_dim * partial_rotary_factor)
Expand All @@ -198,7 +196,7 @@ def test_mrope_torch_compile_tracing(
max_position=max_position,
base=rope_theta,
is_neox_style=is_neox_style,
rope_scaling=config.rope_scaling,
rope_parameters=config.rope_parameters,
dtype=dtype,
).to(device=device)

Expand Down
20 changes: 10 additions & 10 deletions tests/kernels/core/test_pos_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_rotary_embedding(
def test_rope_module_cache():
MAX_POSITIONS = [123, 1234]
BASES = [10000, 1000000]
ROPE_SCALINGS = (
ROPE_PARAMETERS = (
None,
{"rope_type": "linear", "factor": (1,)},
{"rope_type": "dynamic", "factor": 1},
Expand All @@ -132,7 +132,7 @@ def test_rope_module_cache():
MAX_POSITIONS,
BASES,
IS_NEOX_STYLE,
ROPE_SCALINGS,
ROPE_PARAMETERS,
DTYPES,
)
rope_setting_id_map: dict[str, int] = {}
Expand All @@ -142,8 +142,8 @@ def test_rope_module_cache():
rotary_dim,
max_position,
base,
is_neox_stype,
rope_scaling,
is_neox_style,
rope_parameters,
dtype,
) = setting
if rotary_dim is None:
Expand All @@ -153,8 +153,8 @@ def test_rope_module_cache():
rotary_dim,
max_position,
base,
is_neox_stype,
rope_scaling,
is_neox_style,
rope_parameters,
dtype,
)
# different settings cannot share the same rope module
Expand All @@ -169,8 +169,8 @@ def test_rope_module_cache():
rotary_dim,
max_position,
base,
is_neox_stype,
rope_scaling,
is_neox_style,
rope_parameters,
dtype,
) = setting
if rotary_dim is None:
Expand All @@ -180,8 +180,8 @@ def test_rope_module_cache():
rotary_dim,
max_position,
base,
is_neox_stype,
rope_scaling,
is_neox_style,
rope_parameters,
dtype,
)
# check if cache take effect
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/moe/test_gpt_oss_triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class ModelConfig:
sliding_window: int = 128
initial_context_length: int = 4096
rope_theta: float = 150000.0
rope_scaling_factor: float = 32.0
rope_parameters_factor: float = 32.0
rope_ntk_alpha: float = 1.0
rope_ntk_beta: float = 32.0

Expand Down
6 changes: 3 additions & 3 deletions tests/models/language/pooling/test_nomic_max_model_len.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_set_max_model_len_illegal(model_info, vllm_runner):
def test_use_rope_scaling_legal(model_info, vllm_runner):
hf_overrides = {
"rope_theta": rope_theta,
"rope_scaling": {
"rope_parameters": {
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings": original_max_position_embeddings,
Expand All @@ -98,7 +98,7 @@ def test_use_rope_scaling_legal(model_info, vllm_runner):
def test_use_rope_scaling_illegal(model_info, vllm_runner):
hf_overrides = {
"rope_theta": rope_theta,
"rope_scaling": {
"rope_parameters": {
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings": original_max_position_embeddings,
Expand All @@ -116,7 +116,7 @@ def test_use_rope_scaling_illegal(model_info, vllm_runner):

hf_overrides = {
"rope_theta": rope_theta,
"rope_scaling": {
"rope_parameters": {
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings": original_max_position_embeddings,
Expand Down
13 changes: 7 additions & 6 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,39 +254,40 @@ def test_rope_customization():
LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}

llama_model_config = ModelConfig("meta-llama/Meta-Llama-3-8B-Instruct")
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
assert getattr(llama_model_config.hf_config, "rope_parameters", None) is None
assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
assert llama_model_config.max_model_len == 8192

llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
hf_overrides={
"rope_scaling": TEST_ROPE_SCALING,
"rope_parameters": TEST_ROPE_SCALING,
"rope_theta": TEST_ROPE_THETA,
},
)
assert (
getattr(llama_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING
getattr(llama_model_config.hf_config, "rope_parameters", None)
== TEST_ROPE_SCALING
)
assert getattr(llama_model_config.hf_config, "rope_theta", None) == TEST_ROPE_THETA
assert llama_model_config.max_model_len == 16384

longchat_model_config = ModelConfig("lmsys/longchat-13b-16k")
# Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
assert all(
longchat_model_config.hf_config.rope_scaling.get(key) == value
longchat_model_config.hf_config.rope_parameters.get(key) == value
for key, value in LONGCHAT_ROPE_SCALING.items()
)
assert longchat_model_config.max_model_len == 16384

longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
hf_overrides={
"rope_scaling": TEST_ROPE_SCALING,
"rope_parameters": TEST_ROPE_SCALING,
},
)
assert (
getattr(longchat_model_config.hf_config, "rope_scaling", None)
getattr(longchat_model_config.hf_config, "rope_parameters", None)
== TEST_ROPE_SCALING
)
assert longchat_model_config.max_model_len == 4096
Expand Down
49 changes: 28 additions & 21 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pydantic import ConfigDict, SkipValidation, field_validator, model_validator
from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers.configuration_utils import ALLOWED_LAYER_TYPES

import vllm.envs as envs
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
Expand Down Expand Up @@ -2068,30 +2069,34 @@ def _get_and_verify_max_len(
)
derived_max_model_len = default_max_len

rope_scaling = getattr(hf_config, "rope_scaling", None)
rope_parameters = getattr(hf_config, "rope_parameters", None)
# NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE
# scaling, so we skip applying the scaling factor again.
if rope_scaling is not None and "gemma3" not in hf_config.model_type:
# No need to consider "type" key because of patch_rope_scaling when
# loading HF config
rope_type = rope_scaling["rope_type"]

if rope_type not in ("su", "longrope", "llama3"):
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can "
"investigate."
)
if rope_parameters is not None and "gemma3" not in hf_config.model_type:
# In Transformers v5 this could be RopeParameters or dict[str, RopeParameters]
# To simplify, we convert any RopeParameters to dict[str, RopeParameters]
if not set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
rope_parameters = {"": rope_parameters}
for rp in rope_parameters.values():
rope_type = rp["rope_type"]
scaling_factor = 1.0

if rope_type not in ("su", "longrope", "llama3"):
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_parameters
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models with "
"rope_parameters. Please raise an issue so we can investigate."
)

# NOTE: rope_type == "default" does not define factor
# https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
scaling_factor = rope_scaling.get("factor", 1.0)
# NOTE: rope_type == "default" does not define factor
# https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py
scaling_factor = rp.get("factor", 1.0)

if rope_type == "yarn":
derived_max_model_len = rope_scaling["original_max_position_embeddings"]
if rope_type == "yarn":
derived_max_model_len = rp["original_max_position_embeddings"]
# Do this outside loop since all layers should have the same scaling
derived_max_model_len *= scaling_factor

if encoder_config and "max_seq_length" in encoder_config:
Expand All @@ -2102,7 +2107,9 @@ def _get_and_verify_max_len(
if max_model_len is None:
# For LongRoPE, default to original_max_position_embeddings to avoid
# performance degradation for shorter sequences
if rope_scaling is not None and rope_scaling["rope_type"] == "longrope":
if rope_parameters is not None and any(
rp["rope_type"] == "longrope" for rp in rope_parameters.values()
):
max_model_len = int(
getattr(
hf_config, "original_max_position_embeddings", derived_max_model_len
Expand Down
Loading