Skip to content

Conversation

@KuntaiDu
Copy link
Collaborator

@KuntaiDu KuntaiDu commented Sep 25, 2025

Purpose

Refactor of #25363 . This PR enables the combination of hybrid allocator + KV cache connector in a backward-compatible way.

Test Script



import os

# Set token chunk size to 256
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
# Enable CPU memory backend
os.environ["LMCACHE_LOCAL_CPU"] = "True"
# Set CPU memory limit to 5GB
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "20.0"
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
os.environ["LMCACHE_USE_LAYERWISE"] = "True"


from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig

# Configure KV cache transfer to use LMCache
ktc = KVTransferConfig(
    kv_connector="LMCacheConnectorV1",
    kv_role="kv_both",
)

# Initialize LLM with LMCache configuration
# Adjust gpu_memory_utilization based on your GPU memory
llm = LLM(model="google/gemma-3-4b-it",
          kv_transfer_config=ktc,
          max_model_len=75000,
          gpu_memory_utilization=0.18,
          enforce_eager=True)

# Define sampling parameters
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)

# Run inference
outputs = llm.generate("hi" * 70000 + "\nhow are you?", sampling_params)
generated_text = outputs[0].outputs[0].text
print(f"Generated text: {generated_text!r}")

# This requires loading KV cache and will success
outputs = llm.generate("hi" * 10000 + "\nTell me a story.", sampling_params)
generated_text = outputs[0].outputs[0].text
print(f"Generated text: {generated_text!r}")

# flush out prefix cache in GPU
outputs = llm.generate("1" + "hi" * 70000 + "\nhow are you?", sampling_params)
generated_text = outputs[0].outputs[0].text
print(f"Generated text: {generated_text!r}")

# This requires loading KV cache
# but this request cannot be executed as vLLM cannot allocate for long prefix 
# stored by LMCache
outputs = llm.generate("hi" * 70000 + "\nTell me a story.", sampling_params)
generated_text = outputs[0].outputs[0].text
print(f"Generated text: {generated_text!r}")

Test Result

Success.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully enables the use of the hybrid allocator with the KV cache connector by removing the explicit restriction and adding the necessary logic to handle multiple KV cache groups. The changes are well-structured, introducing a SupportsHMA interface to check for compatibility. My review focuses on improving code quality and performance. I've identified an opportunity to refactor duplicated code for better maintainability and two instances where an expensive deepcopy operation can be replaced with a more efficient shallow copy, which should improve initialization performance.

@KuntaiDu
Copy link
Collaborator Author

@NickLucche @njhill This is the refactored version of #25363 , PTAL

@mergify
Copy link

mergify bot commented Oct 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @KuntaiDu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 1, 2025
@KuntaiDu
Copy link
Collaborator Author

Just noticed this:

def _update_requests_with_invalid_blocks():
    ...
    # TODO (davidb): add support for hybrid memory allocator
    (req_block_ids,) = self.kv_cache_manager.get_block_ids(req_id)

from the PR (#19330) which adds logic to retry requests locally in D if KV fetching from P fails

Get it. Prefer to fix it in future PR though.

@mergify
Copy link

mergify bot commented Oct 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @KuntaiDu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 22, 2025
@mergify mergify bot removed the needs-rebase label Oct 23, 2025
@simon-mo simon-mo disabled auto-merge October 25, 2025 06:34
@simon-mo simon-mo merged commit b853540 into vllm-project:main Oct 25, 2025
50 of 52 checks passed
rohin-garg pushed a commit to rohin-garg/vllm that referenced this pull request Oct 25, 2025
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
… KV cache connector (vllm-project#25712)

Signed-off-by: KuntaiDu <[email protected]>
Signed-off-by: Kuntai Du <[email protected]>
Signed-off-by: 0xrushi <[email protected]>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
… KV cache connector (vllm-project#25712)

Signed-off-by: KuntaiDu <[email protected]>
Signed-off-by: Kuntai Du <[email protected]>
Signed-off-by: 0xrushi <[email protected]>
markmc added a commit to markmc/vllm that referenced this pull request Oct 31, 2025
Follow on from vllm-project#25712

`VllmConfig` is explicitly designed as a dataclass containing
user-provided configuration and model metadata. It is a global
configuration object that lives throughout the entire engine lifetime
and is meant to be immutable after `__post_init__()`.

`KVCacheConfig` is worker-specific, runtime-computed state. It has
limited lifetime, and its purpose is limited to initializing the KV
Cache in the model runner.

Even if we add KV cache hints to model config.json in future, this
would be parsed into `ModelConfig`, used as input to the
`get_kv_cache_configs()` computation, and the resulting
`KVCacheConfig` would still be runtime state.

We are currently creating per-worker copies of VllmConfig in order
to attach the runtime `KVCacheConfig` state. But instead we should
just explicitly pass `KVCacheConfig` to the connector.

Make sure to handle backwards compatibility for external connector
implementations (loaded via module path) that have the old style
constructor signature.

Signed-off-by: Mark McLoughlin <[email protected]>
markmc added a commit to markmc/vllm that referenced this pull request Oct 31, 2025
Follow on from vllm-project#25712

`VllmConfig` is explicitly designed as a dataclass containing
user-provided configuration and model metadata. It is a global
configuration object that lives throughout the entire engine lifetime
and is meant to be immutable after `__post_init__()`.

`KVCacheConfig` is worker-specific, runtime-computed state. It has
limited lifetime, and its purpose is limited to initializing the KV
Cache in the model runner.

Even if we add KV cache hints to model config.json in future, this
would be parsed into `ModelConfig`, used as input to the
`get_kv_cache_configs()` computation, and the resulting
`KVCacheConfig` would still be runtime state.

We are currently creating per-worker copies of VllmConfig in order
to attach the runtime `KVCacheConfig` state. But instead we should
just explicitly pass `KVCacheConfig` to the connector.

Make sure to handle backwards compatibility for external connector
implementations (loaded via module path) that have the old style
constructor signature.

Signed-off-by: Mark McLoughlin <[email protected]>
markmc added a commit to markmc/vllm that referenced this pull request Oct 31, 2025
Follow on from vllm-project#25712

`VllmConfig` is explicitly designed as a dataclass containing
user-provided configuration and model metadata. It is a global
configuration object that lives throughout the entire engine lifetime
and is meant to be immutable after `__post_init__()`.

`KVCacheConfig` is worker-specific, runtime-computed state. It has
limited lifetime, and its purpose is limited to initializing the KV
Cache in the model runner.

Even if we add KV cache hints to model config.json in future, this
would be parsed into `ModelConfig`, used as input to the
`get_kv_cache_configs()` computation, and the resulting
`KVCacheConfig` would still be runtime state.

We are currently creating per-worker copies of VllmConfig in order
to attach the runtime `KVCacheConfig` state. But instead we should
just explicitly pass `KVCacheConfig` to the connector.

Make sure to handle backwards compatibility for external connector
implementations (loaded via module path) that have the old style
constructor signature.

Signed-off-by: Mark McLoughlin <[email protected]>
ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants