-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Bugfix][Nixl] Fix kernel physical<>logical block_size issue #28677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,6 +49,7 @@ | |
| from vllm.utils.network_utils import make_zmq_path, make_zmq_socket | ||
| from vllm.v1.attention.backends.utils import get_kv_cache_layout | ||
| from vllm.v1.core.sched.output import SchedulerOutput | ||
| from vllm.v1.worker.block_table import BlockTable | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.attention.backends.abstract import AttentionMetadata | ||
|
|
@@ -112,6 +113,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata): | |
| @dataclass | ||
| class ReqMeta: | ||
| local_block_ids: list[int] | ||
| # To be used when logical block size does not match the kernel block size | ||
| local_physical_block_ids: list[int] | ||
| remote_block_ids: list[int] | ||
| remote_host: str | ||
| remote_port: int | ||
|
|
@@ -139,6 +142,7 @@ def add_new_req( | |
| assert load_remote_cache ^ save_to_host | ||
| _req = ReqMeta( | ||
| local_block_ids=local_block_ids, | ||
| local_physical_block_ids=local_block_ids, | ||
| remote_block_ids=kv_transfer_params["remote_block_ids"], | ||
| remote_engine_id=kv_transfer_params["remote_engine_id"], | ||
| remote_host=kv_transfer_params["remote_host"], | ||
|
|
@@ -935,6 +939,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): | |
| attn_backend=backend, | ||
| ) | ||
| self._use_pallas = self.kv_topo._use_pallas | ||
| self._physical_blocks_per_logical_kv_block = 1 | ||
|
|
||
| def _nixl_handshake( | ||
| self, | ||
|
|
@@ -1133,6 +1138,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
| if base_addr in seen_base_addresses: | ||
| continue | ||
|
|
||
| # TODO (NickLucche): Get kernel_block_size in a cleaner way | ||
| # NHD default "view" for non-MLA cache | ||
| kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3] | ||
|
|
||
| if self.block_size != kernel_block_size: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a silly question, when will this scenario happen? what is the max kernel block size for CUDA? where is it set?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's backend-dependent, it happens every time the supplied block_size is not one of https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/flash_attn.py#L63 (so kernel one is used for physical tensors and block_size becomes only logical)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I see |
||
| logger.info_once( | ||
| "User-specified logical block size (%s) does not match" | ||
| " physical kernel block size (%s). Using the latter. ", | ||
| self.block_size, | ||
| kernel_block_size, | ||
| ) | ||
| self._physical_blocks_per_logical_kv_block = ( | ||
| self.block_size // kernel_block_size | ||
| ) | ||
| self.block_size = kernel_block_size | ||
|
|
||
| seen_base_addresses.append(base_addr) | ||
| curr_tensor_size_bytes = cache.numel() * cache.element_size() | ||
|
|
||
|
|
@@ -1479,7 +1500,7 @@ def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): | |
| assert self.use_host_buffer | ||
| assert self.copy_blocks is not None | ||
|
|
||
| local_block_ids = meta.local_block_ids | ||
| local_block_ids = meta.local_physical_block_ids | ||
| self.copy_blocks( | ||
| self.host_xfer_buffers, | ||
| self.device_kv_caches, | ||
|
|
@@ -1492,7 +1513,7 @@ def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): | |
| "synced recved kv of request[%s] to device kv buffer," | ||
| "local_block_ids: %s. ", | ||
| req_id, | ||
| ",".join(map(str, meta.local_block_ids)), | ||
| ",".join(map(str, local_block_ids)), | ||
| ) | ||
|
|
||
| def save_kv_to_host(self, metadata: NixlConnectorMetadata): | ||
|
|
@@ -1501,19 +1522,22 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): | |
| assert self.copy_blocks is not None | ||
|
|
||
| for req_id, meta in metadata.reqs_to_save.items(): | ||
| meta.local_physical_block_ids = self._logical_to_kernel_block_ids( | ||
| meta.local_block_ids | ||
| ) | ||
| if logger.isEnabledFor(logging.DEBUG): | ||
| logger.debug( | ||
| "save_load_kv for request[%s] to host xfer buffer." | ||
| "local_block_ids: %s. ", | ||
| req_id, | ||
| ",".join(map(str, meta.local_block_ids)), | ||
| ",".join(map(str, meta.local_physical_block_ids)), | ||
| ) | ||
| # blocking | ||
| self.copy_blocks( | ||
| self.device_kv_caches, | ||
| self.host_xfer_buffers, | ||
| meta.local_block_ids, | ||
| meta.local_block_ids, | ||
| meta.local_physical_block_ids, | ||
| meta.local_physical_block_ids, | ||
| "d2h", | ||
| ) | ||
|
|
||
|
|
@@ -1582,7 +1606,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: | |
| if self.use_host_buffer: | ||
| self.sync_recved_kv_to_device(req_id, meta) | ||
| if self.enable_permute_local_kv: | ||
| block_ids_to_permute += meta.local_block_ids | ||
| block_ids_to_permute += meta.local_physical_block_ids | ||
| if len(block_ids_to_permute) > 0: | ||
| self.permute_device_kv(block_ids_to_permute) | ||
|
|
||
|
|
@@ -1669,7 +1693,7 @@ def _pop_done_transfers( | |
| req_id, | ||
| xfer_state, | ||
| ) | ||
| # mark all blocks for this request as invalid | ||
| # mark all (logical)blocks for this request as invalid | ||
| if meta := self._recving_metadata.pop(req_id, None): | ||
| self._invalid_block_ids.update(meta.local_block_ids) | ||
| self._recving_metadata.pop(req_id, None) | ||
|
|
@@ -1686,13 +1710,19 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): | |
| We check for these trnxs to complete in each step(). | ||
| """ | ||
| for req_id, meta in metadata.reqs_to_recv.items(): | ||
| meta.local_physical_block_ids = self._logical_to_kernel_block_ids( | ||
| meta.local_block_ids | ||
| ) | ||
| meta.remote_block_ids = self._logical_to_kernel_block_ids( | ||
| meta.remote_block_ids | ||
| ) | ||
| remote_engine_id = meta.remote_engine_id | ||
| logger.debug( | ||
| "start_load_kv for request %s from remote engine %s. " | ||
| "Num local_block_ids: %s. Num remote_block_ids: %s. ", | ||
| req_id, | ||
| remote_engine_id, | ||
| len(meta.local_block_ids), | ||
| len(meta.local_physical_block_ids), | ||
| len(meta.remote_block_ids), | ||
| ) | ||
| # always store metadata for failure recovery | ||
|
|
@@ -1740,7 +1770,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): | |
| self._read_blocks( | ||
| request_id=req_id, | ||
| dst_engine_id=meta.remote_engine_id, | ||
| local_block_ids=meta.local_block_ids, | ||
| local_block_ids=meta.local_physical_block_ids, | ||
| remote_block_ids=meta.remote_block_ids, | ||
| ) | ||
|
|
||
|
|
@@ -1867,7 +1897,7 @@ def _read_blocks( | |
| "Marking blocks as invalid.", | ||
| request_id, | ||
| ) | ||
| # mark all blocks for this request as invalid | ||
| # mark all (logical) blocks for this request as invalid | ||
| if meta := self._recving_metadata.get(request_id): | ||
| self._invalid_block_ids.update(meta.local_block_ids) | ||
| self.xfer_stats.record_failed_transfer() | ||
|
|
@@ -1906,6 +1936,23 @@ def _get_block_descs_ids( | |
| descs_ids = region_ids * num_blocks + block_ids | ||
| return descs_ids.flatten() | ||
|
|
||
| def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]: | ||
| """ | ||
| Convert logical block ids to kernel physical block ids. | ||
| This is required when the logical block size (the one set by the user) | ||
| does not match the one required by the attn backend. | ||
| """ | ||
| if self._physical_blocks_per_logical_kv_block == 1: | ||
| # Noop when physical and logical block sizes are the same | ||
| return block_ids | ||
| block_ids_np = np.array(block_ids) | ||
| block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape( | ||
| 1, -1 | ||
| ) | ||
| return BlockTable.map_to_kernel_blocks( | ||
| block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange | ||
| ).tolist() | ||
|
|
||
| def get_backend_aware_kv_block_len(self, layer_idx: int): | ||
| """ | ||
| Get the block length for one K/V element (K and V have the same size). | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.