Skip to content

Conversation

@xuechendi
Copy link
Contributor

@xuechendi xuechendi commented Oct 14, 2025

Purpose

To support scenarios when prefill and decode have their own preferred block_size. Ex: Prefill with CUDA(16 as block_size) and decode with Intel Gaudi(128 as block_size)

More details describe in #26744

Current status:

  • Prefill BlockSize < DecodeBlockSize
    • NHD => accuracy is good
    • HND => accuracy is good
  • Prefill BlockSize > DecodeBlockSize
    • NHD => accuracy is good
    • NHD => accuracy is good

CMD for test:

use case 1:

DECODE_BLOCK_SIZE=64 bash run_accuracy_test.sh
# expected 0.41
Task Metric Value
gsm8k exact_match,strict-match 0.45703125
gsm8k exact_match_stderr,strict-match 0.03119537
gsm8k exact_match,flexible-extract 0.44921875
gsm8k exact_match_stderr,flexible-extract 0.031149309

Accuracy is Ok.

DECODER_TP_SIZE=2 DECODE_BLOCK_SIZE=64 bash run_accuracy_test.sh
# expected 0.41
Task Metric Value
gsm8k exact_match,strict-match 0.44921875
gsm8k exact_match_stderr,strict-match 0.0311493
gsm8k exact_match,flexible-extract 0.4375
gsm8k exact_match_stderr,flexible-extract 0.031065

Accuracy is Ok.

use case 2:

PREFILL_BLOCK_SIZE=64 bash run_accuracy_test.sh 
# expected 0.41
Task Metric Value
gsm8k exact_match,strict-match 0.453125
gsm8k exact_match_stderr,strict-match 0.031173
gsm8k exact_match,flexible-extract 0.44921875
gsm8k exact_match_stderr,flexible-extract 0.0311493

when setting gpu_utilization=0.8, accuracy is good
When setting gpu_utilization=0.3, we might used up all block_ids, and we can't use tail block_ids for temp buffer to store prefill block before permute, accuracy might gets slight impact => detail please refer to below [Prefill Block Size < Decode Block Size ] - 3.get_finished()

DECODER_TP_SIZE=2 PREFILL_BLOCK_SIZE=64 bash run_accuracy_test.sh 
# expected 0.41
Task Metric Value
gsm8k exact_match,strict-match 0.4609375
gsm8k exact_match_stderr,strict-match 0.031215
gsm8k exact_match,flexible-extract 0.44921875
gsm8k exact_match_stderr,flexible-extract 0.031149

Design doc:

Case 1: nP < nD

PREFILL block size < Decode block size: (example block_size_ratio = 0.25)

  1. add_remote_agent
    1.1 we register remote address using remote layout
    1.2 we create a new local_xfer_handler using remote block_len, so it can do "one on one" remote to local copy
# remote:               | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|
# local origin:|          0|          1|          8|         12|
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
  1. read_blocks
    2.1 remap local_block_ids: block [1, 2, 3, 4] => block [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
    -> In that case, local blocks and remote blocks are size and count alligned

  2. get_finished()
    3.1 For HND, do permute on 4 blocks buffer

# remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
# local is  |h0-b0..................|h1-b0..................|...
          # permute is to:
          # 1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
          # 2. permute => (H, nblocks, remoteN, D)
          # 3. flatten => (H, localN, D)

3.2 For NHD, no permute needed

Case 2: nP > nD

For Prefill Block Size < Decode Block Size (example block_size_ratio = 4)

  1. add_remote_agent:
# remote: |         0|          1|          8|         12|
# local:          | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|
  1. read_blocks
    2.1 re-map remote_block_ids => block[0] => block [1, 2, 3, 4]
    -> In that case, local blocks and remote blocks are size and count alligned
    2.2 If len(local_block_ids) < len(after_mapping_remote_block_ids)
    -> allocate block_id from end of the blockAllocator to local_block_ids
    -> For ex: [1, 2, pad, pad ] => [1, 2, 12416, 12415] (Reason is at 3.1)

  2. get_finished()
    3.1 For HND, do permute on 4 blocks buffer
    ///// is not used token, it's hidden in copied buffer, we will permute to move them to the tail
    Ex: we don't need block3, but with larger block_size, it was part of each head. We need to firstly
    copy entire buffer, then permute to move block3 to tail and get rid of it.
    That is also why need a temp buffer in local to have enough to store remote buffer

# remote is  |h0-b0............./////|h1-b0............./////|h3-b0............./////|h4-b0............./////|
# local is   |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|h0-b2|h1-b2|h2-b2|/////|/////|/////|/////|/////|
          # 1. view => view remote as (-1, H, n_blocks, localN, D)
          # 2. permute => (-1, nblocks, H, localN, D)
          # 3. flatten => (-1, H, localN, D)

3.2 For NHD, no permute needed

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Chendi Xue <[email protected]>
@mergify mergify bot added the kv-connector label Oct 14, 2025
@mergify
Copy link

mergify bot commented Oct 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 16, 2025
@mergify mergify bot removed the needs-rebase label Oct 21, 2025
@mergify
Copy link

mergify bot commented Oct 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 21, 2025
@xuechendi xuechendi force-pushed the dev/nixl_heter_blocksize branch 3 times, most recently from 5431f65 to 9e5b623 Compare October 21, 2025 21:01
Current codes only works for NHD

Signed-off-by: Chendi Xue <[email protected]>
@xuechendi xuechendi force-pushed the dev/nixl_heter_blocksize branch 2 times, most recently from f141077 to 0fc409c Compare October 23, 2025 14:50
@xuechendi xuechendi force-pushed the dev/nixl_heter_blocksize branch from 0fc409c to e6e3d92 Compare October 24, 2025 15:43
@xuechendi xuechendi marked this pull request as ready for review October 24, 2025 22:46
@mergify mergify bot removed the needs-rebase label Oct 24, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@mergify mergify bot added the v1 label Oct 30, 2025
@mergify
Copy link

mergify bot commented Oct 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuechendi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 30, 2025
@mergify mergify bot added the needs-rebase label Nov 12, 2025
@mergify mergify bot removed the needs-rebase label Nov 12, 2025
@xuechendi xuechendi force-pushed the dev/nixl_heter_blocksize branch from d7c17ea to 984637d Compare November 13, 2025 00:01
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-16}
DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-16}
Copy link
Contributor Author

@xuechendi xuechendi Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched to 128 but switched back to 16 again.
I tested with 128, and noticed that even on origin/main, accuracy is not correct at this moment.
Will see if I can find out main reason in separate PR.

Copy link
Contributor Author

@xuechendi xuechendi Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified setting block_size=64 works, but somehow using block_size=128 for CUDA gets tensor Nan.
@NickLucche , do you want me to set as 64? Actually I refer to use current default for CUDA which is 16/

@xuechendi
Copy link
Contributor Author

@NickLucche , comments are mostly resolved, only one is for default block_size, I explained in another comments, I verified accuracy on this Branch and current main, both accuracy is not correct. So I switched back to 16

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for your patience @xuechendi! Left two minor comments but it would be nice if you could address those quickly before we turn auto-merge on.

Comment on lines 699 to 707
assert (
self.block_size % remote_block_size == 0
or remote_block_size % self.block_size == 0
), (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa."
)
ret = self.block_size / remote_block_size
return ret if ret < 0 else int(ret)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point, but I still find it could be confusing to readers.

Suggested change
assert (
self.block_size % remote_block_size == 0
or remote_block_size % self.block_size == 0
), (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa."
)
ret = self.block_size / remote_block_size
return ret if ret < 0 else int(ret)
assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size}."
)
return self.block_size // remote_block_size

Happy to add it back if/when we land the opposite case.

Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
@xuechendi xuechendi force-pushed the dev/nixl_heter_blocksize branch from 1c85cfb to a6641c2 Compare November 13, 2025 17:27
@xuechendi
Copy link
Contributor Author

Hi, @NickLucche , thanks for the review, I have fixed the last two comments, and tested with text_nixl_connector.py UT with small fixes to those as well.

Please help to review again

@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 14, 2025
@vllm-bot vllm-bot merged commit c9e6658 into vllm-project:main Nov 15, 2025
47 of 49 checks passed
@DarkLight1337
Copy link
Member

DarkLight1337 commented Nov 15, 2025

Test is failing on main, force merging

geodavic pushed a commit to geodavic/vllm that referenced this pull request Nov 16, 2025
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi.Xue <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
Signed-off-by: George D. Torres <[email protected]>
bwasti pushed a commit to bwasti/vllm that referenced this pull request Nov 17, 2025
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi.Xue <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
Signed-off-by: Bram Wasti <[email protected]>
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi.Xue <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi.Xue <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi.Xue <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
charlotte12l pushed a commit to charlotte12l/vllm that referenced this pull request Dec 5, 2025
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi.Xue <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
Signed-off-by: Xingyu Liu <[email protected]>
Zhathw pushed a commit to Zhathw/vllm that referenced this pull request Dec 6, 2025
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi.Xue <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants