Skip to content

Commit e9c7f0b

Browse files
markmcdevpatelio
authored andcommitted
[KV Connector] Make KVCacheConfig an explicit constructor argument (vllm-project#27887)
Signed-off-by: Mark McLoughlin <[email protected]>
1 parent ab9bb3a commit e9c7f0b

File tree

14 files changed

+410
-43
lines changed

14 files changed

+410
-43
lines changed
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Unit tests for backwards compatibility with external KV connector implementations.
5+
6+
This test ensures that external connectors (loaded via kv_connector_module_path)
7+
implemented with the old signature continue to work:
8+
- Old signature: __init__(self, vllm_config, role)
9+
- New signature: __init__(self, vllm_config, role, kv_cache_config)
10+
"""
11+
12+
from typing import TYPE_CHECKING
13+
from unittest.mock import patch
14+
15+
import pytest
16+
17+
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
18+
from vllm.distributed.kv_transfer.kv_connector.v1 import (
19+
KVConnectorBase_V1,
20+
KVConnectorRole,
21+
)
22+
from vllm.v1.core.sched.output import SchedulerOutput
23+
24+
from .utils import create_scheduler, create_vllm_config
25+
26+
if TYPE_CHECKING:
27+
from vllm.attention.backends.abstract import AttentionMetadata
28+
from vllm.config import VllmConfig
29+
from vllm.forward_context import ForwardContext
30+
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
31+
from vllm.v1.kv_cache_interface import KVCacheConfig
32+
from vllm.v1.request import Request
33+
34+
35+
class OldStyleTestConnector(KVConnectorBase_V1):
36+
"""
37+
Test connector using the old signature with 2 required arguments.
38+
This simulates external connectors that haven't been updated yet.
39+
"""
40+
41+
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
42+
# Old-style call to super().__init__ with only 2 arguments
43+
super().__init__(vllm_config=vllm_config, role=role)
44+
45+
def get_num_new_matched_tokens(
46+
self, request: "Request", num_computed_tokens: int
47+
) -> tuple[int | None, bool]:
48+
return 0, False
49+
50+
def update_state_after_alloc(
51+
self,
52+
request: "Request",
53+
blocks: "KVCacheBlocks",
54+
num_external_tokens: int,
55+
):
56+
pass
57+
58+
def build_connector_meta(self, scheduler_output: SchedulerOutput):
59+
return None
60+
61+
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
62+
pass
63+
64+
def wait_for_layer_load(self, layer_name: str) -> None:
65+
pass
66+
67+
def save_kv_layer(
68+
self,
69+
layer_name: str,
70+
kv_layer,
71+
attn_metadata: "AttentionMetadata",
72+
**kwargs,
73+
) -> None:
74+
pass
75+
76+
def wait_for_save(self):
77+
pass
78+
79+
80+
class NewStyleTestConnector(KVConnectorBase_V1):
81+
"""
82+
Test connector using the new signature with 3 required arguments.
83+
"""
84+
85+
def __init__(
86+
self,
87+
vllm_config: "VllmConfig",
88+
role: KVConnectorRole,
89+
kv_cache_config: "KVCacheConfig",
90+
):
91+
# New-style call to super().__init__ with all 3 arguments
92+
super().__init__(
93+
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
94+
)
95+
96+
def get_num_new_matched_tokens(
97+
self, request: "Request", num_computed_tokens: int
98+
) -> tuple[int | None, bool]:
99+
return 0, False
100+
101+
def update_state_after_alloc(
102+
self,
103+
request: "Request",
104+
blocks: "KVCacheBlocks",
105+
num_external_tokens: int,
106+
):
107+
pass
108+
109+
def build_connector_meta(self, scheduler_output: SchedulerOutput):
110+
return None
111+
112+
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
113+
pass
114+
115+
def wait_for_layer_load(self, layer_name: str) -> None:
116+
pass
117+
118+
def save_kv_layer(
119+
self,
120+
layer_name: str,
121+
kv_layer,
122+
attn_metadata: "AttentionMetadata",
123+
**kwargs,
124+
) -> None:
125+
pass
126+
127+
def wait_for_save(self):
128+
pass
129+
130+
131+
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
132+
def test_external_old_signature_factory_instantiation(role):
133+
"""
134+
Test that external connectors with old signature (2 required args) loaded
135+
via kv_connector_module_path are correctly instantiated with backwards
136+
compatibility support.
137+
"""
138+
vllm_config = create_vllm_config()
139+
vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector"
140+
vllm_config.kv_transfer_config.kv_connector_module_path = (
141+
"tests.v1.kv_connector.unit.test_backwards_compatibility"
142+
)
143+
144+
scheduler = create_scheduler(vllm_config)
145+
kv_cache_config = scheduler.kv_cache_config
146+
147+
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
148+
149+
assert connector is not None
150+
assert isinstance(connector, OldStyleTestConnector)
151+
assert connector.role == role
152+
assert connector._kv_cache_config is None
153+
154+
155+
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
156+
def test_external_new_signature_factory_instantiation(role):
157+
"""
158+
Test that external connectors with new signature (3 required args) loaded
159+
via kv_connector_module_path are correctly instantiated.
160+
"""
161+
vllm_config = create_vllm_config()
162+
vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector"
163+
vllm_config.kv_transfer_config.kv_connector_module_path = (
164+
"tests.v1.kv_connector.unit.test_backwards_compatibility"
165+
)
166+
167+
scheduler = create_scheduler(vllm_config)
168+
kv_cache_config = scheduler.kv_cache_config
169+
170+
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
171+
172+
assert connector is not None
173+
assert isinstance(connector, NewStyleTestConnector)
174+
assert connector.role == role
175+
assert connector._kv_cache_config is not None
176+
assert connector._kv_cache_config == kv_cache_config
177+
178+
179+
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
180+
def test_old_signature_super_init(role):
181+
"""
182+
Test that old-style connectors can call super().__init__() without
183+
kv_cache_config parameter.
184+
"""
185+
vllm_config = create_vllm_config()
186+
187+
connector = OldStyleTestConnector(vllm_config, role)
188+
189+
assert connector is not None
190+
assert connector.role == role
191+
assert connector._kv_cache_config is None
192+
193+
194+
def test_old_signature_super_init_with_kwargs():
195+
"""
196+
Test that old-style connectors can call super().__init__() with keyword
197+
arguments in different orders.
198+
"""
199+
vllm_config = create_vllm_config()
200+
201+
# Test with vllm_config= and role= kwargs
202+
connector1 = OldStyleTestConnector(
203+
vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER
204+
)
205+
assert connector1 is not None
206+
assert connector1._kv_cache_config is None
207+
208+
# Test with role= and vllm_config= in reversed order
209+
connector2 = OldStyleTestConnector(
210+
role=KVConnectorRole.WORKER, vllm_config=vllm_config
211+
)
212+
assert connector2 is not None
213+
assert connector2._kv_cache_config is None
214+
215+
216+
def test_internal_connector_uses_new_signature():
217+
"""
218+
Test that internal connectors (registered in factory) always use the new
219+
signature and get kv_cache_config.
220+
"""
221+
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
222+
SharedStorageConnector,
223+
)
224+
225+
vllm_config = create_vllm_config()
226+
vllm_config.kv_transfer_config.kv_connector = "SharedStorageConnector"
227+
228+
scheduler = create_scheduler(vllm_config)
229+
kv_cache_config = scheduler.kv_cache_config
230+
231+
connector = KVConnectorFactory.create_connector(
232+
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
233+
)
234+
235+
assert connector is not None
236+
assert isinstance(connector, SharedStorageConnector)
237+
assert connector._kv_cache_config is not None
238+
assert connector._kv_cache_config == kv_cache_config
239+
240+
241+
def test_signature_detection_with_mocking():
242+
"""
243+
Test that the factory correctly applies compat_sig flag returned from
244+
_get_connector_class_with_compat.
245+
"""
246+
vllm_config = create_vllm_config()
247+
scheduler = create_scheduler(vllm_config)
248+
kv_cache_config = scheduler.kv_cache_config
249+
250+
# Mock _get_connector_class_with_compat to return old-style connector
251+
with patch.object(
252+
KVConnectorFactory,
253+
"_get_connector_class_with_compat",
254+
return_value=(OldStyleTestConnector, True),
255+
):
256+
old_connector = KVConnectorFactory.create_connector(
257+
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
258+
)
259+
assert old_connector is not None
260+
assert isinstance(old_connector, OldStyleTestConnector)
261+
assert old_connector._kv_cache_config is None
262+
263+
# Mock _get_connector_class_with_compat to return new-style connector
264+
with patch.object(
265+
KVConnectorFactory,
266+
"_get_connector_class_with_compat",
267+
return_value=(NewStyleTestConnector, False),
268+
):
269+
new_connector = KVConnectorFactory.create_connector(
270+
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
271+
)
272+
assert new_connector is not None
273+
assert isinstance(new_connector, NewStyleTestConnector)
274+
assert new_connector._kv_cache_config is not None
275+
assert new_connector._kv_cache_config == kv_cache_config

tests/v1/kv_connector/unit/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def create_model_runner_output(
254254

255255

256256
class TestSharedStorageConnector(SharedStorageConnector):
257-
def __init__(self, config: VllmConfig, role):
257+
def __init__(self, config: VllmConfig, role, kv_cache_config):
258258
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
259259
self._connector = SharedStorageConnector(config, role)
260260
self.call_record: dict[str, int] = defaultdict(int)

vllm/distributed/kv_transfer/kv_connector/factory.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33

44
import importlib
55
from collections.abc import Callable
6-
from typing import TYPE_CHECKING, cast
6+
from typing import TYPE_CHECKING, Optional, cast
77

88
import vllm.envs as envs
9-
from vllm.config import VllmConfig
109
from vllm.distributed.kv_transfer.kv_connector.base import (
1110
KVConnectorBase,
1211
KVConnectorBaseType,
@@ -16,9 +15,12 @@
1615
supports_hma,
1716
)
1817
from vllm.logger import init_logger
18+
from vllm.utils.func_utils import supports_kw
1919

2020
if TYPE_CHECKING:
21+
from vllm.config import VllmConfig
2122
from vllm.config.kv_transfer import KVTransferConfig
23+
from vllm.v1.kv_cache_interface import KVCacheConfig
2224

2325
logger = init_logger(__name__)
2426

@@ -41,8 +43,9 @@ def loader() -> type[KVConnectorBase]:
4143
@classmethod
4244
def create_connector(
4345
cls,
44-
config: VllmConfig,
46+
config: "VllmConfig",
4547
role: KVConnectorRole,
48+
kv_cache_config: Optional["KVCacheConfig"] = None,
4649
) -> KVConnectorBase:
4750
if not envs.VLLM_USE_V1:
4851
raise ValueError(
@@ -53,7 +56,9 @@ def create_connector(
5356
kv_transfer_config = config.kv_transfer_config
5457
if kv_transfer_config is None:
5558
raise ValueError("kv_transfer_config must be set to create a connector")
56-
connector_cls = cls.get_connector_class(kv_transfer_config)
59+
connector_cls, compat_sig = cls._get_connector_class_with_compat(
60+
kv_transfer_config
61+
)
5762

5863
# check if the connector supports HMA
5964
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
@@ -76,7 +81,12 @@ def create_connector(
7681
# - Co-locate with worker process
7782
# - Should only be used inside the forward context & attention layer
7883
# We build separately to enforce strict separation
79-
return connector_cls(config, role)
84+
if compat_sig:
85+
# Old signature: __init__(self, vllm_config, role)
86+
return connector_cls(config, role)
87+
else:
88+
# New signature: __init__(self, vllm_config, role, kv_cache_config)
89+
return connector_cls(config, role, kv_cache_config)
8090

8191
@classmethod
8292
def get_connector_class_by_name(
@@ -97,13 +107,13 @@ def get_connector_class_by_name(
97107
return cls._registry[connector_name]()
98108

99109
@classmethod
100-
def get_connector_class(
110+
def _get_connector_class_with_compat(
101111
cls, kv_transfer_config: "KVTransferConfig"
102-
) -> type[KVConnectorBaseType]:
103-
"""Get the connector class by name."""
112+
) -> tuple[type[KVConnectorBaseType], bool]:
104113
connector_name = kv_transfer_config.kv_connector
105114
if connector_name is None:
106115
raise ValueError("Connector name is not set in KVTransferConfig")
116+
compat_sig = False
107117
if connector_name in cls._registry:
108118
connector_cls = cls._registry[connector_name]()
109119
else:
@@ -118,6 +128,21 @@ def get_connector_class(
118128
f"Class {connector_name} not found in {connector_module_path}"
119129
) from e
120130
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
131+
if not supports_kw(connector_cls, "kv_cache_config"):
132+
compat_sig = True
133+
logger.warning(
134+
"Connector %s uses deprecated signature with 2 required arguments. "
135+
"Please update to include kv_cache_config as the second argument.",
136+
connector_cls.__name__,
137+
)
138+
return connector_cls, compat_sig
139+
140+
@classmethod
141+
def get_connector_class(
142+
cls, kv_transfer_config: "KVTransferConfig"
143+
) -> type[KVConnectorBaseType]:
144+
"""Get the connector class by name."""
145+
connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config)
121146
return connector_cls
122147

123148

0 commit comments

Comments
 (0)