Skip to content

Conversation

@NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Nov 13, 2025

Fix #28661.
PD deployments with block_size=128 on FA (as well as FI) are currently broken on main.

This is pretty bad because 128 is the suggested block_size as it maximizes the xfer window.
This is due to the changes introduced in #24486, which adjust block_sizeat runtime based on Backend constraints.
This update value is not picked by NixlConnector and it ends up using the one defined by the user (to be precise, the updated value it's not stored anywhere).

Example

kv_cache.shape=torch.Size([3636, ->64, 8, 128])
nixl_worker.block_size = ->128

Furthermore, kv_cache_manager.get_block_ids() returns manager (logical) block IDs, but the connector (particularly nixl_connector) needs kernel-space block IDs.
Currently there's not a clean accessible way to map from logical->physical block_ids from anywhere in code. Also it is not clear in terms of Connector interface contract, whether blocks supplied by the Scheduler should be logical or already physical.
For the sake of getting this bug fixed, I have implemented the mapping logic within the connector worker.

Bug was spotted after #27753 landed, as it changed the default FA backend to not support 128 anymore.

Test with

Same procedure as described on issue #28661:

bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
...
(EngineCore_DP0 pid=74640) INFO 11-13 17:19:11 [distributed/.../v1/nixl_connector.py:1142] User-specified logical block size (128) does not match physical kernel block size (64). Using the latter.
...
(APIServer pid=497251) INFO:     Application shutdown complete.
+ echo 'All tests completed!'
All tests completed!
++ jobs -pr
+ kill 498809

cc @heheda12345 @njhill @robertgshaw2-redhat

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug where the NixlConnector fails to use the correct physical kernel block_size, particularly affecting deployments with block_size=128. The changes correctly detect the kernel block size at runtime and update the connector's state. It also introduces a necessary mapping from logical to physical block IDs to ensure correct block access.

My review identifies a critical issue in the logic used to determine the kernel_block_size. The current implementation is fragile and may fail for several attention backends, potentially leading to memory corruption. I've provided a more robust code suggestion to handle different KV cache layouts correctly. The rest of the changes for logical-to-physical block mapping appear to be well-implemented.

@NickLucche NickLucche changed the title [Bugfix][Nixl] Fix physical kernel block_size issue [Bugfix][Nixl] Fix kernel physical<>logical block_size issue Nov 13, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

# NHD default "view" for non-MLA cache
kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3]

if self.block_size != kernel_block_size:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a silly question, when will this scenario happen? what is the max kernel block size for CUDA? where is it set?
if self.block_size != kernel_block_size
@jikunshang , may you check ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's backend-dependent, it happens every time the supplied block_size is not one of https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/flash_attn.py#L63 (so kernel one is used for physical tensors and block_size becomes only logical)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see

Copy link
Contributor

@xuechendi xuechendi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, looks good to me

@mergify mergify bot added the v1 label Nov 14, 2025
@NickLucche
Copy link
Collaborator Author

@codex review

@NickLucche
Copy link
Collaborator Author

NickLucche commented Nov 14, 2025

@xuechendi I moved up the logical<>physical conversion so that permutation works as intended. Tested with DECODER_KV_LAYOUT=NHD bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh at block_size=128

@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 14, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Signed-off-by: NickLucche <[email protected]>
@NickLucche
Copy link
Collaborator Author

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

def save_kv_to_host(self, metadata: NixlConnectorMetadata):
"""copy kv from device to host buffer."""
assert self.use_host_buffer
assert self.copy_blocks is not None
for req_id, meta in metadata.reqs_to_save.items():
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"save_load_kv for request[%s] to host xfer buffer."
"local_block_ids: %s. ",
req_id,
",".join(map(str, meta.local_block_ids)),
)
# blocking
self.copy_blocks(
self.device_kv_caches,
self.host_xfer_buffers,
meta.local_block_ids,
meta.local_block_ids,
"d2h",

P1 Badge Convert logical block IDs before saving to host buffer

The new _logical_to_kernel_block_ids conversion is applied when pulling remote blocks, but save_kv_to_host still uses the unconverted meta.local_block_ids when copying device KV into the host staging buffer. When the runtime shrinks the kernel block size (e.g. logical 128 → physical 64), those IDs are still logical, so the host buffer ends up containing different physical blocks than the ones requested later by _read_blocks, which now uses converted indices. Any deployment that relies on use_host_buffer will therefore transfer mismatched blocks after this change. The host-buffer path should run the same logical→physical mapping before calling copy_blocks.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@NickLucche
Copy link
Collaborator Author

Further separated logical<>physical blocks in ReqMeta so that failures can still be reported using logical blocks ("user/engine-space" rather than "kernel-space"). cc @wseaton

Also, tested hd2dh use-case with:
KV_BUFFER_DEVICE=cpu bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh

@NickLucche
Copy link
Collaborator Author

@codex review

@chatgpt-codex-connector
Copy link

Codex Review: Didn't find any major issues. Swish!

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@DarkLight1337 DarkLight1337 merged commit 96b23b8 into vllm-project:main Nov 14, 2025
49 checks passed
@njhill njhill added this to the v0.11.1 milestone Nov 14, 2025
geodavic pushed a commit to geodavic/vllm that referenced this pull request Nov 16, 2025
khluu pushed a commit that referenced this pull request Nov 16, 2025
@nvpohanh
Copy link
Contributor

cc @nvjullin

bwasti pushed a commit to bwasti/vllm that referenced this pull request Nov 17, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: NIXL run_accuracy_test.sh is broken for block_size=128

6 participants