-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Bugfix][Nixl] Fix kernel physical<>logical block_size issue #28677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,6 +49,7 @@ | |
| from vllm.utils.network_utils import make_zmq_path, make_zmq_socket | ||
| from vllm.v1.attention.backends.utils import get_kv_cache_layout | ||
| from vllm.v1.core.sched.output import SchedulerOutput | ||
| from vllm.v1.worker.block_table import BlockTable | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.attention.backends.abstract import AttentionMetadata | ||
|
|
@@ -935,6 +936,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): | |
| attn_backend=backend, | ||
| ) | ||
| self._use_pallas = self.kv_topo._use_pallas | ||
| self._physical_blocks_per_logical_kv_block = 1 | ||
|
|
||
| def _nixl_handshake( | ||
| self, | ||
|
|
@@ -1133,6 +1135,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
| if base_addr in seen_base_addresses: | ||
| continue | ||
|
|
||
| # TODO (NickLucche): Get kernel_block_size in a cleaner way | ||
| # NHD default "view" for non-MLA cache | ||
| kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3] | ||
|
|
||
| if self.block_size != kernel_block_size: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a silly question, when will this scenario happen? what is the max kernel block size for CUDA? where is it set?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's backend-dependent, it happens every time the supplied block_size is not one of https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/flash_attn.py#L63 (so kernel one is used for physical tensors and block_size becomes only logical)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I see |
||
| logger.info_once( | ||
| "User-specified logical block size (%s) does not match" | ||
| " physical kernel block size (%s). Using the latter. ", | ||
| self.block_size, | ||
| kernel_block_size, | ||
| ) | ||
| self._physical_blocks_per_logical_kv_block = ( | ||
| self.block_size // kernel_block_size | ||
| ) | ||
| self.block_size = kernel_block_size | ||
|
|
||
| seen_base_addresses.append(base_addr) | ||
| curr_tensor_size_bytes = cache.numel() * cache.element_size() | ||
|
|
||
|
|
@@ -1686,6 +1704,12 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): | |
| We check for these trnxs to complete in each step(). | ||
| """ | ||
| for req_id, meta in metadata.reqs_to_recv.items(): | ||
| meta.local_block_ids = self._logical_to_kernel_block_ids( | ||
| meta.local_block_ids | ||
| ) | ||
| meta.remote_block_ids = self._logical_to_kernel_block_ids( | ||
| meta.remote_block_ids | ||
| ) | ||
| remote_engine_id = meta.remote_engine_id | ||
| logger.debug( | ||
| "start_load_kv for request %s from remote engine %s. " | ||
|
|
@@ -1906,6 +1930,23 @@ def _get_block_descs_ids( | |
| descs_ids = region_ids * num_blocks + block_ids | ||
| return descs_ids.flatten() | ||
|
|
||
| def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]: | ||
| """ | ||
| Convert logical block ids to kernel physical block ids. | ||
| This is required when the logical block size (the one set by the user) | ||
| does not match the one required by the attn backend. | ||
| """ | ||
| if self._physical_blocks_per_logical_kv_block == 1: | ||
| # Noop when physical and logical block sizes are the same | ||
| return block_ids | ||
| block_ids_np = np.array(block_ids) | ||
| block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape( | ||
| 1, -1 | ||
| ) | ||
| return BlockTable.map_to_kernel_blocks( | ||
| block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange | ||
| ).tolist() | ||
|
|
||
| def get_backend_aware_kv_block_len(self, layer_idx: int): | ||
| """ | ||
| Get the block length for one K/V element (K and V have the same size). | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.