diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index e72bd43ff56e..5a598dcab718 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -241,6 +241,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) reqs: list[CachedRequestState] = [] req_id_reqs = {} @@ -335,6 +336,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) ref_input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, @@ -344,6 +346,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) reqs: list[CachedRequestState] = [] diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index ef2956bd3ec2..208b889e8788 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -68,6 +68,9 @@ def initialize_kv_cache(runner: GPUModelRunner): pin_memory=runner.pin_memory, vocab_size=runner.model_config.get_vocab_size(), block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size], + kernel_block_sizes=[ + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ], ) runner.initialize_attn_backend(kv_cache_config) @@ -816,42 +819,231 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape # assert we are using FlashInfer - assert attn_shape[0] == num_blocks + assert attn_shape[0] % num_blocks == 0 + block_split_ratio = attn_shape[0] // num_blocks + + # use small blocks for testing to avoid memory issues + test_block_size = min(2, len(blocks0), len(blocks1)) + + # use non-overlapping blocks to avoid data contamination + # Split kernel blocks: first half for attention, second half for mamba + mid_point = num_blocks // 2 + + # attention uses kernel blocks from first half (mapped to logical blocks) + kv_blocks_for_attention = np.array([0, 1])[:test_block_size] + + # mamba uses kernel blocks from second half + kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size] + + # create small constant tensors for testing with corrected shapes + # attention: [block_size, ...] starting from dimension 2 + attn_constant_shape = attn_shape[2:] + conv_constant_shape = conv_shape[1:] + ssm_constant_shape = ssm_shape[1:] attn_blocks_constant = torch.full( - (len(blocks0), *attn_shape[1:]), device=DEVICE, fill_value=3.33 + (test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33 ) conv_blocks_constant = torch.full( - (len(blocks1), *conv_shape[1:]), device=DEVICE, fill_value=6.66 + (test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66 ) ssm_blocks_constant = torch.full( - (len(blocks1), *ssm_shape[1:]), device=DEVICE, fill_value=9.99 + (test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99 ) - # fill all attention blocks with constant + # Fill attention blocks with constants using kv block indices + kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio + for layer in [layer_0, layer_1]: - vllm_ctx[layer].kv_cache[0][blocks0, :] = ( - attn_blocks_constant.detach().clone() - ) + # attention: kv_cache[0][kernel_block_idx, kv_idx, ...] + for i, kernel_block in enumerate(kernel_blocks_for_attention): + vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i] - # fill all mamba blocks with constant + # fill mamba blocks with constants using kernel block indices for layer in [layer_2, layer_3, layer_4, layer_5]: - vllm_ctx[layer].kv_cache[0][0][blocks1, :] = ( - conv_blocks_constant.detach().clone() - ) - vllm_ctx[layer].kv_cache[0][1][blocks1, :] = ( - ssm_blocks_constant.detach().clone() - ) + # mamba: kv_cache[0][component][kernel_block_idx, ...] + for i, kv_block in enumerate(kv_blocks_for_mamba): + vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i] + vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i] # verify attention and mamba contents are correct for layer in [layer_0, layer_1]: - assert torch.equal( - vllm_ctx[layer].kv_cache[0][blocks0, :], attn_blocks_constant - ) + for i, kernel_block in enumerate(kernel_blocks_for_attention): + actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :] + expected = attn_blocks_constant[i] + + # Check K and V separately + assert torch.equal(actual_kv[0], expected) + assert torch.equal(actual_kv[1], expected) + for layer in [layer_2, layer_3, layer_4, layer_5]: - assert torch.equal( - vllm_ctx[layer].kv_cache[0][0][blocks1, :], conv_blocks_constant - ) - assert torch.equal( - vllm_ctx[layer].kv_cache[0][1][blocks1, :], ssm_blocks_constant - ) + for i, kv_block in enumerate(kv_blocks_for_mamba): + actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] + actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] + expected_conv = conv_blocks_constant[i] + expected_ssm = ssm_blocks_constant[i] + + assert torch.equal(actual_conv, expected_conv) + assert torch.equal(actual_ssm, expected_ssm) + + for layer in [layer_2, layer_3, layer_4, layer_5]: + for i, kv_block in enumerate(kv_blocks_for_mamba): + actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] + actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] + expected_conv = conv_blocks_constant[i] + expected_ssm = ssm_blocks_constant[i] + assert torch.equal(actual_conv, expected_conv) + assert torch.equal(actual_ssm, expected_ssm) + + +def test_hybrid_block_table_initialization(): + """Test hybrid block table with different kernel and kvcache_manager block + sizes.""" + from vllm.v1.worker.block_table import BlockTable + + # Test configuration: kvcache_manager block size = 32, + # kernel block size = 16 + block_size = 32 + kernel_block_sizes = [16] + max_num_reqs = 10 + max_num_blocks_per_req = 20 + max_num_batched_tokens = 512 + + block_table = BlockTable( + block_size=block_size, + max_num_reqs=max_num_reqs, + max_num_blocks_per_req=max_num_blocks_per_req, + max_num_batched_tokens=max_num_batched_tokens, + pin_memory=False, + device=torch.device(DEVICE), + kernel_block_size=kernel_block_sizes[0], + ) + + # Verify hybrid block configuration + assert block_table.use_hybrid_blocks is True + assert block_table.block_size == kernel_block_sizes[0] + assert block_table.blocks_per_kv_block == ( + block_size // kernel_block_sizes[0] + ) # Changed to use first element + + # Test block table conversion logic + # One kvcache_manager block should map to multiple kernel blocks + kvcache_manager_blocks = [0, 1, 2] + + # Verify that kvcache_manager blocks can be converted to kernel blocks + # and that block table operations work correctly. + req_index = 0 + block_table.append_row(kvcache_manager_blocks, req_index) + # Get expected kernel blocks from the implementation for verification. + expected_kernel_blocks = block_table._map_to_kernel_blocks( + np.array(kvcache_manager_blocks) + ) + # Verify block table state + assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks) + assert np.array_equal( + block_table.block_table.np[req_index, : len(expected_kernel_blocks)], + expected_kernel_blocks, + ) + + +def test_input_batch_with_kernel_block_sizes(): + """Test InputBatch initialization with kernel_block_sizes parameter.""" + max_num_reqs = 10 + max_model_len = 512 + max_num_batched_tokens = 512 + device = torch.device(DEVICE) + pin_memory = False + vocab_size = 50272 + + # Test with different kernel block sizes + block_sizes = [32, 64] + kernel_block_sizes = [16, 32] + + input_batch = InputBatch( + max_num_reqs=max_num_reqs, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + device=device, + pin_memory=pin_memory, + vocab_size=vocab_size, + block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, + ) + + # Verify that block tables were created with kernel block sizes + assert len(input_batch.block_table.block_tables) == len(block_sizes) + + for i, (kv_size, kernel_size) in enumerate(zip(block_sizes, kernel_block_sizes)): + block_table = input_batch.block_table.block_tables[i] + if kv_size != kernel_size: + assert block_table.use_hybrid_blocks is True + assert block_table.block_size == kernel_size + else: + assert block_table.use_hybrid_blocks is False + assert block_table.block_size == kernel_size + + +def test_hybrid_cache_integration(model_runner, dist_init): + """Test hybrid cache architecture integration with GPUModelRunner.""" + # Create a new model runner with hybrid cache configuration + vllm_config = get_vllm_config() + + # Configure hybrid cache with different kvcache_manager block size + vllm_config.cache_config.block_size = 32 + + model_config = vllm_config.model_config + num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config) + head_size = model_config.get_head_size() + vllm_config.compilation_config.static_forward_context["layer.0"] = Attention( + num_heads, head_size, 0.1 + ) + + runner = GPUModelRunner(vllm_config, DEVICE) + + # Initialize KV cache with configuration + attn_spec = FullAttentionSpec( + block_size=16, # Use kernel block size directly + num_kv_heads=runner.model_config.get_num_kv_heads(runner.parallel_config), + head_size=runner.model_config.get_head_size(), + dtype=runner.kv_cache_dtype, + ) + tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS + kv_cache_config = KVCacheConfig( + num_blocks=NUM_BLOCKS, + kv_cache_tensors=[ + KVCacheTensor(size=tensor_size, shared_by=["layer.0"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec) + ], + ) + runner.kv_cache_config = kv_cache_config + + # Initialize input batch with kernel block sizes + runner.input_batch = InputBatch( + max_num_reqs=runner.max_num_reqs, + max_model_len=runner.max_model_len, + max_num_batched_tokens=runner.max_num_tokens, + device=runner.device, + pin_memory=runner.pin_memory, + vocab_size=runner.model_config.get_vocab_size(), + block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size], + kernel_block_sizes=[16], + ) # Use kernel block size + + runner.initialize_attn_backend(kv_cache_config) + + # Verify hybrid block table configuration + block_table = runner.input_batch.block_table.block_tables[0] + assert block_table.block_size == ( + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ) + + # Test request processing with hybrid blocks + req_id = "hybrid_req_0" + scheduler_output = _schedule_new_request(req_id) + + # Update states should work with hybrid blocks + runner._update_states(scheduler_output) + assert _is_req_scheduled(runner, req_id) + assert _is_req_state_block_table_match(runner, req_id) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 697d134f2018..3f23d4ef7d2c 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Generic, Optional, Protocol, TypeVar +from typing import Generic, Optional, Protocol, TypeVar, Union import torch @@ -26,6 +26,13 @@ class AttentionType: """Attention between dec. Q and enc. K/V for encoder-decoder.""" +class MultipleOf: + base: int + + def __init__(self, base: int): + self.base = base + + class AttentionBackend(ABC): """Abstract class for attention backends.""" @@ -57,6 +64,10 @@ def get_impl_cls() -> type["AttentionImpl"]: def get_metadata_cls() -> type["AttentionMetadata"]: raise NotImplementedError + @classmethod + def get_supported_kernel_block_size(cls) -> list[Union[int, MultipleOf]]: + return cls.get_impl_cls().get_supported_kernel_block_size() + @classmethod def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) @@ -157,6 +168,11 @@ def __init__( ) -> None: raise NotImplementedError + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + # TODO: implement this function for all backends. + return [MultipleOf(1)] + @abstractmethod def forward( self, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index caf481f5aec6..ee6a3ba773bb 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -365,6 +365,23 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: block_size=model_config.max_model_len, ).page_size_bytes + # Model may be marked as is_hybrid + # but mamba is skipped via config, + # return directly + if mamba_page_size == 0: + return + + # Attention backend constraints: + # - FlashAttention (FA) requires block size to be multiple of 16 + # - MLA (Multi-head Latent Attention) requires larger alignment: + # * CUTLASS_MLA backend: 128-byte alignment + # * Other MLA backends: 64-byte alignment + if model_config.use_mla: + use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" + kernel_block_alignment_size = 128 if use_cutlass_mla else 64 + else: + kernel_block_alignment_size = 16 + if cache_config.enable_prefix_caching: # With prefix caching, select attention block size to # optimize for mamba kernel performance @@ -381,19 +398,28 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # TODO(tdoublep): this constraint can be relaxed fairly # easily by changing the way we layout chunks in the # mamba2 kernels. - chunk_size = model_config.get_mamba_chunk_size() + + from math import gcd + + def lcm(a, b): + return a * b // gcd(a, b) + + base_chunk_size = model_config.get_mamba_chunk_size() attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) + + chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) cache_config.mamba_block_size = attn_block_size else: # Without prefix caching, select minimum valid attention block size # to minimize mamba state padding - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token) + # Calculate minimum attention block size that satisfies both: + # 1. Backend alignment requirements (kernel_block_alignment_size) + # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size) + attn_block_size = kernel_block_alignment_size * cdiv( + mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token + ) # override attention block size if either (a) the # user has not set it or (b) the user has set it diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8a4565b4d1a0..e0f832b43114 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -118,7 +118,15 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing - if model_config is not None and model_config.use_mla: + # Note: block_size is initialized in + # HybridAttentionMambaModelConfig.verify_and_update_config + # for models with both attention and mamba, + # and doesn't need to be reinitialized here + if ( + model_config is not None + and model_config.use_mla + and cache_config.block_size is not None + ): use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, @@ -151,18 +159,22 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: if ( use_flashmla and is_flashmla_dense_supported()[0] - and cache_config.block_size != 64 + and cache_config.block_size % 64 != 0 ): cache_config.block_size = 64 logger.info("Forcing kv cache block size to 64 for FlashMLA backend.") - if use_cutlass_mla and cache_config.block_size != 128: + if use_cutlass_mla and cache_config.block_size % 128 != 0: cache_config.block_size = 128 logger.info( "Forcing kv cache block size to 128 for CUTLASS_MLA backend." ) - if use_flashinfer_mla and cache_config.block_size not in [32, 64]: + if ( + use_flashinfer_mla + and cache_config.block_size != 32 + and cache_config.block_size % 64 != 0 + ): cache_config.block_size = 64 logger.info( "Forcing kv cache block size to 64 for FlashInferMLA backend." @@ -269,12 +281,12 @@ def get_attn_backend_cls( use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( selected_backend is None and cls.is_device_capability(100) - and block_size == 128 + and block_size % 128 == 0 ) use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( selected_backend is None and cls.is_device_capability(100) - and block_size in [32, 64] + and (block_size == 32 or block_size % 64 == 0) ) use_flashmla = selected_backend == _Backend.FLASHMLA or ( selected_backend is None and is_flashmla_dense_supported()[0] @@ -298,7 +310,7 @@ def get_attn_backend_cls( "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" ) if use_flashmla: - if block_size != 64: + if block_size % 64 != 0: logger.warning( "FlashMLA backend is not supported for block size %d" " (currently only supports block size 64).", diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1f6b7e41b37e..a71e51471905 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -3,7 +3,7 @@ """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import numpy as np import torch @@ -14,6 +14,7 @@ AttentionImpl, AttentionMetadata, AttentionType, + MultipleOf, is_quantized_kv_cache, ) from vllm.attention.layer import Attention @@ -57,6 +58,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 55186e2938c3..ff91ecd2aaef 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -23,6 +23,7 @@ AttentionBackend, AttentionImpl, AttentionType, + MultipleOf, ) from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger @@ -165,6 +166,13 @@ def get_supported_head_sizes(cls) -> list[int]: # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 return [64, 128, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + # Note: Not sure for all platforms, + # but on Blackwell, only support a page size of + # 16, 32, 64 + return [16, 32, 64] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index a3c677ca2108..11e06cc6daac 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import ( AttentionLayer, AttentionType, + MultipleOf, is_quantized_kv_cache, ) from vllm.logger import init_logger @@ -44,6 +45,10 @@ def get_impl_cls() -> type["CutlassMLAImpl"]: def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: return CutlassMLAMetadataBuilder + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [128] + class SM100Workspace: def __init__(self, initial_workspace_size): diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 6ba2c682760c..f4f82f1cce91 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -6,7 +6,7 @@ import torch -from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf from vllm.attention.ops.flashmla import ( flash_mla_with_kvcache, get_mla_metadata, @@ -44,6 +44,10 @@ def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [64] + @dataclass class FlashMLADecodeMetadata(MLACommonDecodeMetadata): diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 348eca55eefb..82505f6281c0 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -3,7 +3,7 @@ """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch @@ -12,6 +12,7 @@ AttentionImpl, AttentionMetadata, AttentionType, + MultipleOf, ) from vllm.config import VllmConfig from vllm.logger import init_logger @@ -359,6 +360,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [64, 128, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index a209bb79580c..669dbe31810b 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -4,7 +4,7 @@ import ast from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch @@ -14,6 +14,7 @@ AttentionImpl, AttentionMetadata, AttentionType, + MultipleOf, ) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig @@ -39,6 +40,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 9997ed16bed1..878634c7f521 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -3,7 +3,7 @@ """High-Performance Triton-only Attention layer.""" from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch @@ -12,6 +12,7 @@ AttentionImpl, AttentionMetadata, AttentionType, + MultipleOf, ) from vllm.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash, @@ -157,6 +158,10 @@ class TritonAttentionBackend(AttentionBackend): def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16, torch.float32] + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: # Triton Attention supports any head size above 32 diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index b21562fac741..eb1fcc2c024d 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch @@ -12,6 +12,7 @@ AttentionImpl, AttentionMetadata, AttentionType, + MultipleOf, ) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig @@ -80,6 +81,10 @@ def get_supported_head_sizes(cls) -> list[int]: 256, ] + @staticmethod + def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 4d3688453cb9..0c44834b5505 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -22,22 +22,64 @@ def __init__( max_num_batched_tokens: int, pin_memory: bool, device: torch.device, + kernel_block_size: int, ): - self.block_size = block_size + """ + Args: + block_size: Block size used for KV cache memory allocation + max_num_reqs: Maximum number of concurrent requests supported. + max_num_blocks_per_req: Maximum number of blocks per request. + max_num_batched_tokens: Maximum number of tokens in a batch. + pin_memory: Whether to pin memory for faster GPU transfers. + device: Target device for the block table. + kernel_block_size: The block_size of underlying attention kernel. + Will be the same as `block_size` if `block_size` is supported + by the attention kernel. + """ self.max_num_reqs = max_num_reqs - self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device + if kernel_block_size == block_size: + # Standard case: allocation and computation use same block size + # No block splitting needed, direct mapping + self.block_size = block_size + self.blocks_per_kv_block = 1 + self.use_hybrid_blocks = False + else: + # Hybrid case: allocation block size differs from kernel block size + # Memory blocks are subdivided to match kernel requirements + # Example: 32-token memory blocks with 16-token kernel blocks + # → Each memory block corresponds to 2 kernel blocks + if block_size % kernel_block_size != 0: + raise ValueError( + f"kernel_block_size {kernel_block_size} must divide " + f"kv_manager_block_size size {block_size} evenly" + ) + + self.block_size = kernel_block_size + self.blocks_per_kv_block = block_size // kernel_block_size + self.use_hybrid_blocks = True + + self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block + self.block_table = self._make_buffer( - max_num_reqs, max_num_blocks_per_req, dtype=torch.int32 + self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32 ) self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) self.slot_mapping = self._make_buffer( self.max_num_batched_tokens, dtype=torch.int64 ) + + if self.use_hybrid_blocks: + self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape( + 1, -1 + ) + else: + self._kernel_block_arange = None + try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group @@ -53,6 +95,10 @@ def append_row( ) -> None: if not block_ids: return + + if self.use_hybrid_blocks: + block_ids = self._map_to_kernel_blocks(np.array(block_ids)) + num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] self.num_blocks_per_row[row_idx] += num_blocks @@ -94,6 +140,7 @@ def compute_slot_mapping( req_indices * self.max_num_blocks_per_req + positions // virtual_block_size ) + block_numbers = self.block_table.np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local # tokens. @@ -111,6 +158,7 @@ def compute_slot_mapping( block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // self.block_size ) + block_numbers = self.block_table.np.ravel()[block_table_indices] block_offsets = positions % self.block_size np.add( @@ -129,6 +177,31 @@ def clear(self) -> None: self.block_table.gpu.fill_(0) self.block_table.cpu.fill_(0) + def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray: + """Convert kv_manager_block_id IDs to kernel block IDs. + + Example: + # kv_manager_block_ids: 32 tokens, + # Kernel block size: 16 tokens + # blocks_per_kv_block = 2 + >>> kv_manager_block_ids = np.array([0, 1, 2]) + >>> Result: [0, 1, 2, 3, 4, 5] + + # Each kv_manager_block_id maps to 2 kernel block id: + # kv_manager_block_id 0 → kernel block id [0, 1] + # kv_manager_block_id 1 → kernel block id [2, 3] + # kv_manager_block_id 2 → kernel block id [4, 5] + """ + if not self.use_hybrid_blocks: + return kv_manager_block_ids + + kernel_block_ids = ( + kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block + + self._kernel_block_arange + ) + + return kernel_block_ids.reshape(-1) + def get_device_tensor(self, num_reqs: int) -> torch.Tensor: """Returns the device tensor of the block table.""" return self.block_table.gpu[:num_reqs] @@ -160,6 +233,7 @@ def __init__( pin_memory: bool, device: torch.device, block_sizes: list[int], + kernel_block_sizes: list[int], num_speculative_tokens: int = 0, ) -> None: # Note(hc): each dcp rank only store @@ -172,6 +246,12 @@ def __init__( # DCP might not be initialized in testing dcp_world_size = 1 + if len(kernel_block_sizes) != len(block_sizes): + raise ValueError( + f"kernel_block_sizes length ({len(kernel_block_sizes)}) " + f"must match block_sizes length ({len(block_sizes)})" + ) + self.block_tables = [ BlockTable( block_size, @@ -183,8 +263,9 @@ def __init__( max_num_batched_tokens, pin_memory, device, + kernel_block_size, ) - for block_size in block_sizes + for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) ] def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 22f5c6f7e683..6d7473d8f44b 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -78,6 +78,7 @@ def __init__( pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group + kernel_block_sizes: list[int], logitsprocs: Optional[LogitsProcessors] = None, is_spec_decode: bool = False, is_pooling_model: bool = False, @@ -132,6 +133,7 @@ def __init__( pin_memory=pin_memory, device=device, block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, num_speculative_tokens=num_speculative_tokens, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cbac67d9e24e..ea3b18b447f3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -19,7 +19,7 @@ import vllm.envs as envs from vllm.attention import Attention, AttentionType -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import AttentionBackend, MultipleOf from vllm.attention.layer import MLAAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter @@ -359,6 +359,7 @@ def __init__( pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], + kernel_block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( self.vllm_config, @@ -4050,6 +4051,86 @@ def calculate_reorder_batch_threshold(self) -> None: else: self.reorder_batch_threshold = reorder_batch_threshold_i + def _find_compatible_block_sizes( + self, + kv_manager_block_size: int, + backend_cls: type[AttentionBackend], + return_all: bool = False, + ) -> list[int]: + """ + Find compatible block sizes for a backend. + + Args: + kv_manager_block_size: Physical block size of KV cache + backend_cls: Attention backend class + return_all: Return all compatible sizes if True, max size if False + + Returns: + Compatible block size(s) based on return_all parameter + + Raises: + ValueError: If no compatible block size found + """ + supported_block_size = backend_cls.get_supported_kernel_block_size() + compatible_sizes = [] + + for block_size in supported_block_size: + if isinstance(block_size, int): + if kv_manager_block_size % block_size == 0: + compatible_sizes.append(block_size) + elif ( + isinstance(block_size, MultipleOf) + and kv_manager_block_size % block_size.base == 0 + ): + compatible_sizes.append(kv_manager_block_size) + + if not compatible_sizes: + raise ValueError(f"No compatible block size for {kv_manager_block_size}") + + return compatible_sizes if return_all else [max(compatible_sizes)] + + def _select_common_block_size( + self, kv_manager_block_size: int, attn_groups: list[AttentionGroup] + ) -> int: + """ + Select common block size for all backends. + + Args: + kv_manager_block_size: Block size of KV cache + attn_groups: List of attention groups + + Returns: + Block size supported by all backends, + prioritizing cache_config.block_size + + Raises: + ValueError: If no common block size found + """ + all_backend_supports = [] + + for attn_group in attn_groups: + compatible_sizes = self._find_compatible_block_sizes( + kv_manager_block_size, attn_group.backend, return_all=True + ) + supported_sizes = sorted(list(set(compatible_sizes)), reverse=True) + all_backend_supports.append(set(supported_sizes)) + + common_supported_sizes = set.intersection(*all_backend_supports) + + if not common_supported_sizes: + error_msg = f"No common block size for {kv_manager_block_size}. " + for i, attn_group in enumerate(attn_groups): + supported = all_backend_supports[i] + error_msg += ( + f"Backend {attn_group.backend} supports: {sorted(supported)}. " + ) + raise ValueError(error_msg) + + if self.cache_config.block_size in common_supported_sizes: + return self.cache_config.block_size + + return max(common_supported_sizes) + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ Re-initialize the input batch if the block sizes are different from @@ -4062,8 +4143,15 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: block_sizes = [ kv_cache_group.kv_cache_spec.block_size for kv_cache_group in kv_cache_config.kv_cache_groups + if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) ] - if block_sizes != [self.cache_config.block_size]: + + # Generate kernel_block_sizes that matches each block_size + kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) + + if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ + self.cache_config.block_size + ]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 @@ -4077,6 +4165,7 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, is_pooling_model=self.is_pooling_model, @@ -4128,6 +4217,46 @@ def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]: for attn_groups in self.attn_groups: yield from attn_groups + def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[int]: + """ + Generate kernel_block_sizes that matches each block_size. + + For attention backends that support virtual block splitting, + use the supported block sizes from the backend. + For other backends (like Mamba), use the same block size (no splitting). + + Args: + kv_cache_config: The KV cache configuration. + + Returns: + list[int]: List of kernel block sizes for each cache group. + """ + kernel_block_sizes = [] + for kv_cache_group_id, kv_cache_group in enumerate( + kv_cache_config.kv_cache_groups + ): + if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): + continue + elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): + # This is an attention backend that supports virtual + # block splitting. Get the supported block sizes from + # all backends in the group. + attn_groups = self.attn_groups[kv_cache_group_id] + kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size + selected_kernel_size = self._select_common_block_size( + kv_manager_block_size, attn_groups + ) + kernel_block_sizes.append(selected_kernel_size) + elif isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + # This is likely Mamba or other non-attention cache, + # no splitting. + kernel_block_sizes.append(kv_cache_group.kv_cache_spec.block_size) + else: + raise NotImplementedError( + f"unknown kv cache spec {kv_cache_group.kv_cache_spec}" + ) + return kernel_block_sizes + def _reshape_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, @@ -4157,16 +4286,24 @@ def _reshape_kv_cache_tensors( num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): has_attn = True + kv_manager_block_size = kv_cache_spec.block_size + kernel_size_list = self._find_compatible_block_sizes( + kv_manager_block_size, attn_backend, return_all=False + ) + kernel_size = kernel_size_list[0] + num_blocks_per_kv_block = kv_manager_block_size // kernel_size + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, - kv_cache_spec.block_size, + kernel_num_blocks, + kernel_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=self.cache_config.cache_dtype, ) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501 assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(kv_cache_shape))) @@ -4320,10 +4457,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config - self.may_reinitialize_input_batch(kv_cache_config) self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) + # Reinitialize need to after initialize_attn_backend + self.may_reinitialize_input_batch(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) if self.speculative_config and self.speculative_config.use_eagle(): diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 34fed8f96467..ef115ade09ab 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -27,6 +27,7 @@ def __init__( pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group + kernel_block_sizes: list[int], ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -68,6 +69,7 @@ def __init__( pin_memory=pin_memory, device=device, block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, ) # Sampling-related. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 7877f288c2ec..f8c1ec850b1b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -259,6 +259,7 @@ def __init__( pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.block_size], + kernel_block_sizes=[self.cache_config.block_size], ) # Cached torch/numpy tensor @@ -1788,6 +1789,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: block_sizes=[ kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size ], + kernel_block_sizes=[ + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ], ) # Verify dtype compatibility between block_table_cpu and input_batch assert (