Skip to content

Commit 7157b98

Browse files
tlrmchlsmthnjhill
authored andcommitted
[Bugfix] Fix 2 precommit issues - (mamba_block_size, kv_cache_config) (vllm-project#27811)
Signed-off-by: Tyler Michael Smith <[email protected]> Signed-off-by: Tyler Michael Smith <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 797bb1f commit 7157b98

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

vllm/model_executor/models/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
406406
# easily by changing the way we layout chunks in the
407407
# mamba2 kernels.
408408

409-
base_chunk_size = model_config.get_mamba_chunk_size()
409+
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
410410
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
411411
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
412412
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)

vllm/v1/core/sched/scheduler.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.distributed.kv_transfer.kv_connector.v1 import (
1414
KVConnectorBase_V1,
1515
KVConnectorRole,
16-
supports_hma,
16+
SupportsHMA,
1717
)
1818
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
1919
from vllm.logger import init_logger
@@ -93,7 +93,11 @@ def __init__(
9393
)
9494

9595
connector_vllm_config = copy.copy(self.vllm_config)
96-
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config)
96+
97+
# We're dynamically inserting a kv_cache_config variable into the
98+
# connector_vllm_config. This is distinct from the cache_config
99+
# that is already in there.
100+
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) # type: ignore[attr-defined]
97101
self.connector = KVConnectorFactory.create_connector(
98102
config=connector_vllm_config, role=KVConnectorRole.SCHEDULER
99103
)
@@ -1327,15 +1331,15 @@ def _connector_finished(
13271331

13281332
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
13291333

1330-
if not supports_hma(self.connector):
1334+
if not isinstance(self.connector, SupportsHMA):
13311335
# NOTE(Kuntai): We should deprecate this code path after we enforce
13321336
# all connectors to support HMA.
13331337
# Hybrid memory allocator should be already turned off for this
13341338
# code path, but let's double-check here.
13351339
assert len(self.kv_cache_config.kv_cache_groups) == 1
13361340
return self.connector.request_finished(request, block_ids[0])
1337-
else:
1338-
return self.connector.request_finished(request, block_ids)
1341+
1342+
return self.connector.request_finished_all_groups(request, block_ids)
13391343

13401344
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
13411345
"""

0 commit comments

Comments
 (0)