Skip to content

Conversation

@ivanium
Copy link
Contributor

@ivanium ivanium commented Dec 6, 2025

Purpose

This is a contiuation work along PR #23624 to support hybrid KV cache manager + KV cache connector.

Design doc with details drafted by @KuntaiDu: link

In short, the current hybrid KV cache manager will try to allocate all tokens for sliding window layers similar to full attention layers, and then in the next scheduling step, the manager will free unuseful tokens (those outside the sliding window) and turn them into prefix cache in GRAM. This PR, instead, aims to allocate KV cache only for tokens in the sliding window for sliding window layers. This addresses two issues:

  1. When using with an external KV cache layer (e.g., LMCache), over-allocating all prefix tokens for sliding window layers will incur a high memory pressure and can fail when remaining GPU memory is insufficient;
  2. When using with P/D disaggregation connectors, this allocate-then-free pattern will cause data contention, where the connector might copy some KV cache blocks for one request in the background but the manager frees and reuses them for another request.

This PR currently supports only LMCache connector. The support for the other connectors will be added in follow-up PRs.

cc @KuntaiDu @heheda12345

Test Plan

The test script is a modification from the one in PR #25712.

The script should be run with LMCache-side support: LMCache/LMCache#1436.

Caution

Please apply the following patch to LMCache if getting import errors for cdiv:

Patch

diff --git a/lmcache/integration/vllm/vllm_v1_adapter.py b/lmcache/integration/vllm/vllm_v1_adapter.py
index a849097..4db64df 100644
--- a/lmcache/integration/vllm/vllm_v1_adapter.py
+++ b/lmcache/integration/vllm/vllm_v1_adapter.py
@@ -18,7 +18,10 @@ from vllm.distributed.parallel_state import (
     get_tp_group,
 )
 from vllm.sampling_params import SamplingParams
-from vllm.utils import cdiv
+try:
+    from vllm.utils import cdiv
+except ImportError:
+    from vllm.utils.math_utils import cdiv

To run this script on H100, please save the following code into test_connector_w_hybrid_kv_allocator.py, and python test_connector_w_hybrid_kv_allocator.py.

`test_connector_w_hybrid_kv_allocator.py`

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os

# Set token chunk size to 256
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
# Enable CPU memory backend
os.environ["LMCACHE_LOCAL_CPU"] = "True"
# Set CPU memory limit to 5GB
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "20.0"
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
os.environ["LMCACHE_USE_LAYERWISE"] = "True"


from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig

# Configure KV cache transfer to use LMCache
ktc = KVTransferConfig(
    kv_connector="LMCacheConnectorV1",
    kv_role="kv_both",
)

# Initialize LLM with LMCache configuration
# Adjust gpu_memory_utilization based on your GPU memory
# Parameters below are for 80GB GPUs
llm = LLM(
    model="google/gemma-3-4b-it",
    kv_transfer_config=ktc,
    max_model_len=75000,
    gpu_memory_utilization=0.28,
    # gpu_memory_utilization=0.4,
    # gpu_memory_utilization=0.8,
    max_num_seqs=16,
    enforce_eager=True,
)

# Define sampling parameters
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)

# Run inference
print("Generate request 1. This will store long prefix in LMCache.")
outputs = llm.generate("hi" * 70000 + "\nhow are you?", sampling_params)
generated_text = outputs[0].outputs[0].text
print(f"Generated text: {generated_text!r}")

# This requires loading KV cache and will succeed
print("Generate request 2. This will load prefix from LMCache and succeed.")
outputs = llm.generate("hi" * 10000 + "\nTell me a story.", sampling_params)
generated_text = outputs[0].outputs[0].text
print(f"Generated text: {generated_text!r}")

# flush out prefix cache in GPU
print("Generate request 3. This will evict prefix cache in GPU.")
outputs = llm.generate("1" + "hi" * 70000 + "\nhow are you?", sampling_params)
generated_text = outputs[0].outputs[0].text
print(f"Generated text: {generated_text!r}")

# This requires loading KV cache
# but this request cannot be executed as vLLM cannot allocate for long prefix
# stored by LMCache
print("Generate request 4. This will attempt to load long prefix from LMCache.")
outputs = llm.generate("hi" * 70000 + "\nTell me a story.", sampling_params)
generated_text = outputs[0].outputs[0].text
print(f"Generated text: {generated_text!r}")

print("All requests finished.")

Test Result

Previous, we cannot allocate KV cache for the 3rd request which tries to allocate long prefixes and load external KV cache even for sliding window layers. With this PR, the 3rd request can allocate only KV caches needed for sliding window layers and is able to be scheduled and finish with correct results.

Detailed output

Generate request 1. This will store long prefix in LMCache.
Adding requests: 100%|██████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 16.21it/s]
Processed prompts:   0%|                                    | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][2025-12-05 16:50:01,689] LMCache INFO: Reqid: 0, Total tokens 70006, LMCache hit tokens: 0, need to load: 0 (vllm_v1_adapter.py:1262:lmcache.integration.vllm.vllm_v1_adapter)
[2025-12-05 16:50:01,713] LMCache INFO: Post-initializing LMCacheEngine (cache_engine.py:176:lmcache.v1.cache_engine)
[2025-12-05 16:50:01,748] LMCache INFO: Storing KV cache for 16384 out of 16384 tokens (skip_leading_tokens=0) for request 0 (vllm_v1_adapter.py:1059:lmcache.integration.vllm.vllm_v1_adapter)
[2025-12-05 16:50:01,754] LMCache INFO: Lazily initializing GPU buffer. (gpu_connector.py:1075:lmcache.v1.gpu_connector)
[2025-12-05 16:50:01,754] LMCache INFO: Lazily initializing GPU buffer (max tokens=355120). (gpu_connector.py:1098:lmcache.v1.gpu_connector)
[2025-12-05 16:50:03,694] LMCache INFO: Storing KV cache for 16384 out of 32768 tokens (skip_leading_tokens=16384) for request 0 (vllm_v1_adapter.py:1059:lmcache.integration.vllm.vllm_v1_adapter)
[2025-12-05 16:50:07,326] LMCache INFO: Storing KV cache for 16384 out of 49152 tokens (skip_leading_tokens=32768) for request 0 (vllm_v1_adapter.py:1059:lmcache.integration.vllm.vllm_v1_adapter)
[2025-12-05 16:50:12,649] LMCache INFO: Storing KV cache for 16384 out of 65536 tokens (skip_leading_tokens=49152) for request 0 (vllm_v1_adapter.py:1059:lmcache.integration.vllm.vllm_v1_adapter)
[2025-12-05 16:50:19,642] LMCache INFO: Storing KV cache for 4470 out of 70006 tokens (skip_leading_tokens=65536) for request 0 (vllm_v1_adapter.py:1059:lmcache.integration.vllm.vllm_v1_adapter)
[rank0]:W1205 16:50:21.852000 3690043 .venv/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8)
[rank0]:W1205 16:50:21.852000 3690043 .venv/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py:1358] [0/8]    function: 'forward_static' (/data/yifanqiao/code/vllm/vllm/model_executor/layers/layernorm.py:274)
[rank0]:W1205 16:50:21.852000 3690043 .venv/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py:1358] [0/8]    last reason: 0/7: expected type of 'residual' to be a tensor type, ' but found <class 'NoneType'>
[rank0]:W1205 16:50:21.852000 3690043 .venv/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[rank0]:W1205 16:50:21.852000 3690043 .venv/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html
Processed prompts: 100%|█████████████████████████| 1/1 [00:20<00:00, 20.59s/it, est. speed input: 3400.30 toks/s, output: 0.49 toks/s]
Generated text: '\nI am doing well, thank you for asking'
Generate request 2. This will load prefix from LMCache and succeed.
Adding requests: 100%|█████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 104.00it/s]
Processed prompts:   0%|                                    | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][2025-12-05 16:50:22,290] LMCache INFO: Reqid: 1, Total tokens 10007, LMCache hit tokens: 9984, need to load: 9984 (vllm_v1_adapter.py:1262:lmcache.integration.vllm.vllm_v1_adapter)
[2025-12-05 16:50:22,361] LMCache INFO: Retrieved 9984 out of 9984 out of total 9984 tokens (cache_engine.py:645:lmcache.v1.cache_engine)
[2025-12-05 16:50:22,361] LMCache INFO: Retrieved 9984 tokens (vllm_v1_adapter.py:978:lmcache.integration.vllm.vllm_v1_adapter)
Processed prompts: 100%|███████████████████████| 1/1 [00:00<00:00,  3.21it/s, est. speed input: 32128.69 toks/s, output: 32.11 toks/s]
Generated text: "\nOkay, here's a story for you"
Generate request 3. This will evict prefix cache in GPU.
Adding requests: 100%|██████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 15.21it/s]
Processed prompts:   0%|                                    | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][2025-12-05 16:50:22,665] LMCache INFO: Reqid: 2, Total tokens 70007, LMCache hit tokens: 0, need to load: 0 (vllm_v1_adapter.py:1262:lmcache.integration.vllm.vllm_v1_adapter)
[2025-12-05 16:50:22,707] LMCache INFO: Storing KV cache for 16384 out of 16384 tokens (skip_leading_tokens=0) for request 2 (vllm_v1_adapter.py:1059:lmcache.integration.vllm.vllm_v1_adapter)
[2025-12-05 16:50:24,647] LMCache INFO: Storing KV cache for 16384 out of 32768 tokens (skip_leading_tokens=16384) for request 2 (vllm_v1_adapter.py:1059:lmcache.integration.vllm.vllm_v1_adapter)
[2025-12-05 16:50:28,280] LMCache INFO: Storing KV cache for 16384 out of 49152 tokens (skip_leading_tokens=32768) for request 2 (vllm_v1_adapter.py:1059:lmcache.integration.vllm.vllm_v1_adapter)
[2025-12-05 16:50:33,595] LMCache INFO: Storing KV cache for 16384 out of 65536 tokens (skip_leading_tokens=49152) for request 2 (vllm_v1_adapter.py:1059:lmcache.integration.vllm.vllm_v1_adapter)
[2025-12-05 16:50:40,588] LMCache INFO: Storing KV cache for 4471 out of 70007 tokens (skip_leading_tokens=65536) for request 2 (vllm_v1_adapter.py:1059:lmcache.integration.vllm.vllm_v1_adapter)
Processed prompts: 100%|█████████████████████████| 1/1 [00:20<00:00, 20.54s/it, est. speed input: 3408.18 toks/s, output: 0.49 toks/s]
Generated text: '\n\nI am doing well, thank you for asking'
Generate request 4. This will attempt to load long prefix from LMCache.
Adding requests: 100%|██████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.73it/s]
Processed prompts:   0%|                                    | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][2025-12-05 16:50:43,298] LMCache INFO: Reqid: 3, Total tokens 70007, LMCache hit tokens: 69888, need to load: 69888 (vllm_v1_adapter.py:1262:lmcache.integration.vllm.vllm_v1_adapter)
[2025-12-05 16:50:43,530] LMCache INFO: Retrieved 69888 out of 69888 out of total 69888 tokens (cache_engine.py:645:lmcache.v1.cache_engine)
[2025-12-05 16:50:43,530] LMCache INFO: Retrieved 69888 tokens (vllm_v1_adapter.py:978:lmcache.integration.vllm.vllm_v1_adapter)
Processed prompts: 100%|██████████████████████| 1/1 [00:00<00:00,  1.47it/s, est. speed input: 102891.47 toks/s, output: 14.70 toks/s]
Generated text: '\nOkay, here’s a story for you'
All requests finished.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added v1 tpu Related to Google TPUs kv-connector labels Dec 6, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for using the hybrid KV cache allocator with a KV cache connector, which is a significant enhancement for models with sliding window attention. The goal is to reduce memory pressure and prevent data contention by allocating KV cache blocks more precisely. The changes are extensive, modifying the core allocation logic in SingleTypeKVCacheManager and propagating these changes up to the KVCacheCoordinator and Scheduler. While the overall approach is sound, the implementation contains several temporary workarounds and comments marked as "REMOVE BEFORE MERGE", which are critical to address. I've identified issues in the KV connector factory, the LMCache connector implementation, and potential bugs or data correctness concerns in single_type_kv_cache_manager.py and block_pool.py. These must be resolved to ensure the stability and correctness of the new functionality.

Comment on lines +59 to 65
## REMOVE BEFORE MERGE (YIFAN): Revert this warning back to raising
# an ValueError.
logger.warning(
"Connector %s does not support HMA but HMA is enabled. Please set "
"--disable-hybrid-kv-cache-manager to disable HMA.",
connector_cls.__name__,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change from raising a ValueError to a logger.warning is marked with a "REMOVE BEFORE MERGE" comment. Using a connector that does not support Hybrid Memory Allocation (HMA) when HMA is enabled can lead to incorrect behavior or hard-to-debug runtime errors. It is much safer to fail fast with an exception. This change should be reverted to raise ValueError before merging to prevent potential issues in production.

            raise ValueError(
                f"Connector {connector_cls.__name__} does not support HMA but "
                f"HMA is enabled. Please set `--disable-hybrid-kv-cache-manager`.
            )

Comment on lines 37 to 39
## REMOVE BEFORE MERGE (YIFAN): this is temporary workaround to work with
# LMCache. Remove this once having LMCache-side support for new interfaces.
vllm_config.kv_cache_config = kv_cache_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block contains a "REMOVE BEFORE MERGE" comment, indicating a temporary workaround. Directly modifying vllm_config by assigning to kv_cache_config is a side effect that can lead to unexpected behavior elsewhere in the system. This workaround should be removed, and a proper solution that avoids mutating the config object should be implemented as noted in the comment.

Comment on lines +224 to 234
## REMOVE BEFORE MERGE (YIFAN): this is temporary workaround to work with
# LMCache. Remove this once having LMCache-side support for new interfaces.
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
# NOTE: LMCache overloads request_finished so `block_ids` here can be
# either list[int] or tuple[list[int], ...]. This could be changed in
# the future to separate these two methods.
return self._lmcache_engine.request_finished(request, block_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The request_finished_all_groups method is marked as a temporary workaround with a "REMOVE BEFORE MERGE" comment. It appears to be a shim for a new interface required by the hybrid allocator. This temporary implementation should be replaced with a proper solution, and the dependency on this fix in LMCache should be resolved before this pull request is merged.

Comment on lines +166 to +168
# REMOVE BEFORE MERGE (YIFAN): why len(new_computed_blocks)
# rather than len(req_blocks)?
self.num_cached_block[request_id] = len(new_computed_blocks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line is flagged with a "REMOVE BEFORE MERGE" comment that questions the logic. Setting self.num_cached_block[request_id] to len(new_computed_blocks) seems incorrect, as it doesn't account for previously existing blocks for the request. This could lead to an incorrect count of cached blocks, potentially causing issues in subsequent caching logic. It should likely be set to len(req_blocks) to reflect the total number of blocks for the request. Please clarify or fix this.

Suggested change
# REMOVE BEFORE MERGE (YIFAN): why len(new_computed_blocks)
# rather than len(req_blocks)?
self.num_cached_block[request_id] = len(new_computed_blocks)
self.num_cached_block[request_id] = len(req_blocks)

Comment on lines +294 to +295
## TODO(Yifan): here token_ids may be over-estimated for
## sliding window layers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TODO comment indicates that token_ids might be over-estimated for sliding window layers. This could lead to incorrect data in BlockStored events, which could be problematic for external systems consuming these events. If external systems rely on exact token IDs for correctness, this over-estimation could be a significant issue. This should be addressed to ensure data integrity for event consumers.

@mergify
Copy link

mergify bot commented Dec 6, 2025

Hi @ivanium, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

ivanium and others added 9 commits December 5, 2025 22:03
Signed-off-by: Yifan Qiao <[email protected]>
Co-authored-by: KuntaiDu <[email protected]>
…indow, and leading padding with null blocks

Signed-off-by: Yifan Qiao <[email protected]>

fixes

Signed-off-by: Yifan Qiao <[email protected]>

fix get_num_blocks_to_allocate

Signed-off-by: Yifan Qiao <[email protected]>
Signed-off-by: Yifan Qiao <[email protected]>
…ll blocks inside the single_type_block_manager

Signed-off-by: Yifan Qiao <[email protected]>
Signed-off-by: Yifan Qiao <[email protected]>
Signed-off-by: Yifan Qiao <[email protected]>
Signed-off-by: Yifan Qiao <[email protected]>
Signed-off-by: Yifan Qiao <[email protected]>
@ivanium ivanium force-pushed the feat/partial_ext_token_hit branch from 223fb4d to fa53140 Compare December 6, 2025 06:03
@KuntaiDu
Copy link
Collaborator

KuntaiDu commented Dec 7, 2025

Good work! In terms of landing this PR, @heheda12345 previously suggested me to separate into small PRs and I would prefer the same for this PR.

Example:
Pr 1: don't change the allocation logic at all, simply introduce num_connector_tokens into the allocation API suite, and change the function correspondingly.
Pr 2: build abstractions (example like the get_num_skipped_tokens)
Pr 3: make the estimation of # of blocks accurate
Pr 4: change the allocation logic

Comment on lines +257 to +267
# Some blocks may be null blocks when enabling sparse attention or sliding
# window attention. For now, we only have sliding window attention, and
# null blocks must be at the beginning.
first_non_null_blk_idx = 0
for i, blk in enumerate(new_full_blocks):
if not blk.is_null:
first_non_null_blk_idx = i
break

for i, blk in enumerate(new_full_blocks[first_non_null_blk_idx:]):
assert not blk.is_null
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Some blocks may be null blocks when enabling sparse attention or sliding
# window attention. For now, we only have sliding window attention, and
# null blocks must be at the beginning.
first_non_null_blk_idx = 0
for i, blk in enumerate(new_full_blocks):
if not blk.is_null:
first_non_null_blk_idx = i
break
for i, blk in enumerate(new_full_blocks[first_non_null_blk_idx:]):
assert not blk.is_null
for i, blk in enumerate(new_full_blocks[first_non_null_blk_idx:]):
if blk.is_null:
continue

what about this?

BlockStored(
block_hashes=new_hashes,
parent_block_hash=parent_block_hash,
## TODO(Yifan): here token_ids may be over-estimated for
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kv cache event + hybrid allocator is not supported now.

num_evictable_blocks_to_allocate.append(
num_evictable_blocks_to_allocate_single_group
)
return num_new_blocks_to_allocate, num_evictable_blocks_to_allocate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need num_evictable_blocks_to_allocate here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here.
I guess num_evictable_blocks_to_allocate refers to the blocks that has ref_cnt==0 and num_new_blocks_to_allocate refers to unallocated blocks, but I'm not sure if the logic of processing them should be different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need num_new_blocks_to_allocate for the actual allocation, but we also need num_new_blocks_to_allocate + num_evictable_blocks_to_allocate to check if remaining available memory is sufficient for allocating this request. So num_evictable_blocks_to_allocate is returned to calculate the total amount of additional memory usage to allocate this request

# it as needed to be allocated.
num_evictable_computed_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks
num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to call this function in every step?

num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)

# Fast-path: nothing is skipped.
if num_skipped_tokens <= 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like:

num_evictable_computed_blocks = xxxx
max(num_evictable_computed_blocks + (num_computed_token - len(new_computed_blocks) / block_size, (num_computed_tokens-num_skipped_tokens)/block_size)

---------------------------------------------------------------------
| < to be allocated > |
---------------------------------------------------------------------
| < to be cached > |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new is not to be cached. It includes unverified spec decode tokens

---------------------------------------------------------------------
| < to be cached > |
---------------------------------------------------------------------
| Prefix-cached tokens from both vLLM |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ext are not cached yet. They will be cached after all tokens are loaded.

Comment on lines +259 to +265
---------------------------------------------------------------------
| ref_cnt |
| increased|
---------------------------------------------------------------------
| ref_cnt not |
| increased yet|
---------------------------------------------------------------------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge to one line?

Comment on lines +252 to +259
---------------------------------------------------------------------
| not cached by |
| vLLM, but |
| cached by |
| connector |
---------------------------------------------------------------------
| < cached by vLLM > |
---------------------------------------------------------------------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge to one line?

num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
(
num_new_blocks_to_allocate,
num_evictable_blocks_to_allocate,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't expand the concept of num_evictable_blocks_to_allocate to more places

Comment on lines +147 to +148
num_blocks_to_allocate_per_group: list[int],
num_tokens: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be difficult for ppl to understand why these two are diffferent

if num_null_blocks_to_pad > 0:
req_blocks.extend([self._null_block] * num_null_blocks_to_pad)
# Add the remaining computed blocks.
req_blocks.extend(new_computed_blocks[num_null_blocks_to_pad:])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that blocks in new_computed_blocks are already touched.

Comment on lines +216 to +221
# Only allocate real new blocks; cached hits should already be present
# in req_blocks via save_new_computed_blocks.
num_blocks_to_padding = num_new_blocks - num_blocks_to_allocate
assert num_blocks_to_padding >= 0, (
f"Invalid padding: need {num_new_blocks}, allocate {num_blocks_to_allocate}"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need padding in this function. Yon can add padding to either save_new_computed_blocks (if allocating for local computed and external computed in one shot) or the allocate_for_connector (if in two shots)

else:
# A running request. Should not have new computed blocks.
assert len(new_computed_blocks) == 0
num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about changing this function to handle both local and external computed:

total_computed_tokens = num_local_computed_tokens + num_external_computed_tokens
if request_id not in self.num_cached_block:
    # A running request. Should not have new computed blocks.
    assert len(new_computed_blocks) == 0
    return
req_blocks = self.req_to_blocks[request_id]
assert len(req_blocks) == 0
num_skipped_blocks = self.get_num_skipped_tokens(total_computed_tokens) // block_size
if num_skipped_blocks > 0:
    # sparse like sliding window
    req_blocks.extend([null_block] * num_skipped_blocks)
    new_computed_blocks = new_computed_blocks[num_skipped_blocks:]
self.block_pool.touch(new_computed_blocks)
req_blocks.extend(new_computed_blocks)
self.num_cached_block[request_id] = len(req_blocks)
if num_external_computed_tokens > 0
    # happens when external connector
    req_blocks.extend(block_pool.get_new_blocks(cdiv(total_computed_tokens, block_size) - len(req_blocks))

Comment on lines +249 to +250
| Prefix-cached tokens from both vLLM |
| and connector. Can be safely removed if |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
| Prefix-cached tokens from both vLLM |
| and connector. Can be safely removed if |
| Prefix-cached tokens from either vLLM |
| or connector. Can be safely removed if |

Comment on lines +295 to +296
and num_lookahead_tokens == 0
and num_external_computed_tokens == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still assert num_new_tokens == 0

blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks
num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)

# Fast-path: nothing is skipped.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fast path for decode

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants