-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Hybrid] A simpler algorithm to find kernel_block_size #26476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
ca6ec4c
3eaf50a
3b2792b
802bba4
03c6a0c
6bd4c24
cb9ccdf
70b4d48
2501f4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3881,6 +3881,7 @@ def get_attn_backends_for_group( | |||||||||||||
|
|
||||||||||||||
| def create_attn_groups( | ||||||||||||||
| attn_backends_map: dict[AttentionGroupKey, list[str]], | ||||||||||||||
| kv_cache_group_id: int, | ||||||||||||||
| ) -> list[AttentionGroup]: | ||||||||||||||
| attn_groups: list[AttentionGroup] = [] | ||||||||||||||
| for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): | ||||||||||||||
|
|
@@ -3890,6 +3891,7 @@ def create_attn_groups( | |||||||||||||
| kv_cache_spec, | ||||||||||||||
| self.vllm_config, | ||||||||||||||
| self.device, | ||||||||||||||
| kv_cache_group_id, | ||||||||||||||
| num_metadata_builders=1 | ||||||||||||||
| if not self.parallel_config.enable_dbo | ||||||||||||||
| else 2, | ||||||||||||||
|
|
@@ -3898,9 +3900,9 @@ def create_attn_groups( | |||||||||||||
| attn_groups.append(attn_group) | ||||||||||||||
| return attn_groups | ||||||||||||||
|
|
||||||||||||||
| for kv_cache_group_spec in kv_cache_config.kv_cache_groups: | ||||||||||||||
| for i, kv_cache_group_spec in enumerate(kv_cache_config.kv_cache_groups): | ||||||||||||||
| attn_backends = get_attn_backends_for_group(kv_cache_group_spec) | ||||||||||||||
| self.attn_groups.append(create_attn_groups(attn_backends)) | ||||||||||||||
| self.attn_groups.append(create_attn_groups(attn_backends, i)) | ||||||||||||||
|
|
||||||||||||||
| # Calculate reorder batch threshold (if needed) | ||||||||||||||
| self.calculate_reorder_batch_threshold() | ||||||||||||||
|
|
@@ -4051,104 +4053,93 @@ def calculate_reorder_batch_threshold(self) -> None: | |||||||||||||
| else: | ||||||||||||||
| self.reorder_batch_threshold = reorder_batch_threshold_i | ||||||||||||||
|
|
||||||||||||||
| def _find_compatible_block_sizes( | ||||||||||||||
| self, | ||||||||||||||
| kv_manager_block_size: int, | ||||||||||||||
| backend_cls: type[AttentionBackend], | ||||||||||||||
| return_all: bool = False, | ||||||||||||||
| ) -> list[int]: | ||||||||||||||
| """ | ||||||||||||||
| Find compatible block sizes for a backend. | ||||||||||||||
|
|
||||||||||||||
| Args: | ||||||||||||||
| kv_manager_block_size: Physical block size of KV cache | ||||||||||||||
| backend_cls: Attention backend class | ||||||||||||||
| return_all: Return all compatible sizes if True, max size if False | ||||||||||||||
|
|
||||||||||||||
| Returns: | ||||||||||||||
| Compatible block size(s) based on return_all parameter | ||||||||||||||
|
|
||||||||||||||
| Raises: | ||||||||||||||
| ValueError: If no compatible block size found | ||||||||||||||
| """ | ||||||||||||||
| supported_block_size = backend_cls.get_supported_kernel_block_size() | ||||||||||||||
| compatible_sizes = [] | ||||||||||||||
|
|
||||||||||||||
| for block_size in supported_block_size: | ||||||||||||||
| if isinstance(block_size, int): | ||||||||||||||
| if kv_manager_block_size % block_size == 0: | ||||||||||||||
| compatible_sizes.append(block_size) | ||||||||||||||
| elif ( | ||||||||||||||
| isinstance(block_size, MultipleOf) | ||||||||||||||
| and kv_manager_block_size % block_size.base == 0 | ||||||||||||||
| ): | ||||||||||||||
| compatible_sizes.append(kv_manager_block_size) | ||||||||||||||
|
|
||||||||||||||
| if not compatible_sizes: | ||||||||||||||
| raise ValueError(f"No compatible block size for {kv_manager_block_size}") | ||||||||||||||
|
|
||||||||||||||
| return compatible_sizes if return_all else [max(compatible_sizes)] | ||||||||||||||
|
|
||||||||||||||
| def _select_common_block_size( | ||||||||||||||
| self, kv_manager_block_size: int, attn_groups: list[AttentionGroup] | ||||||||||||||
| ) -> int: | ||||||||||||||
| """ | ||||||||||||||
| Select common block size for all backends. | ||||||||||||||
| Select a block size that is supported by all backends and is divisible by | ||||||||||||||
| kv_manager_block_size. | ||||||||||||||
|
|
||||||||||||||
| If kv_manager_block_size is supported by all backends, return it directly. | ||||||||||||||
| Otherwise, return the max supported size. | ||||||||||||||
|
|
||||||||||||||
| Args: | ||||||||||||||
| kv_manager_block_size: Block size of KV cache | ||||||||||||||
| attn_groups: List of attention groups | ||||||||||||||
|
|
||||||||||||||
| Returns: | ||||||||||||||
| Block size supported by all backends, | ||||||||||||||
| prioritizing cache_config.block_size | ||||||||||||||
| The selected block size | ||||||||||||||
|
|
||||||||||||||
| Raises: | ||||||||||||||
| ValueError: If no common block size found | ||||||||||||||
| ValueError: If no valid block size found | ||||||||||||||
| """ | ||||||||||||||
| all_backend_supports = [] | ||||||||||||||
|
|
||||||||||||||
| for attn_group in attn_groups: | ||||||||||||||
| compatible_sizes = self._find_compatible_block_sizes( | ||||||||||||||
| kv_manager_block_size, attn_group.backend, return_all=True | ||||||||||||||
| ) | ||||||||||||||
| supported_sizes = sorted(list(set(compatible_sizes)), reverse=True) | ||||||||||||||
| all_backend_supports.append(set(supported_sizes)) | ||||||||||||||
|
|
||||||||||||||
| common_supported_sizes = set.intersection(*all_backend_supports) | ||||||||||||||
|
|
||||||||||||||
| if not common_supported_sizes: | ||||||||||||||
| error_msg = f"No common block size for {kv_manager_block_size}. " | ||||||||||||||
| for i, attn_group in enumerate(attn_groups): | ||||||||||||||
| supported = all_backend_supports[i] | ||||||||||||||
| error_msg += ( | ||||||||||||||
| f"Backend {attn_group.backend} supports: {sorted(supported)}. " | ||||||||||||||
| ) | ||||||||||||||
| raise ValueError(error_msg) | ||||||||||||||
|
|
||||||||||||||
| if self.cache_config.block_size in common_supported_sizes: | ||||||||||||||
| return self.cache_config.block_size | ||||||||||||||
| def block_size_is_supported( | ||||||||||||||
| backends: list[type[AttentionBackend]], block_size: int | ||||||||||||||
| ) -> bool: | ||||||||||||||
| """ | ||||||||||||||
| Check if the block size is supported by all backends. | ||||||||||||||
| """ | ||||||||||||||
| for backend in backends: | ||||||||||||||
| is_supported = False | ||||||||||||||
| for supported_size in backend.get_supported_kernel_block_size(): | ||||||||||||||
| if isinstance(supported_size, int): | ||||||||||||||
| if block_size == supported_size: | ||||||||||||||
| is_supported = True | ||||||||||||||
| elif isinstance(supported_size, MultipleOf): | ||||||||||||||
| if block_size % supported_size.base == 0: | ||||||||||||||
| is_supported = True | ||||||||||||||
| else: | ||||||||||||||
| raise ValueError(f"Unknown supported size: {supported_size}") | ||||||||||||||
| if not is_supported: | ||||||||||||||
| return False | ||||||||||||||
| return True | ||||||||||||||
|
|
||||||||||||||
| backends = [group.backend for group in attn_groups] | ||||||||||||||
|
|
||||||||||||||
| # Case 1: if the block_size of kv cache manager is supported by all backends, | ||||||||||||||
| # return it directly | ||||||||||||||
| if block_size_is_supported(backends, kv_manager_block_size): | ||||||||||||||
| return kv_manager_block_size | ||||||||||||||
|
|
||||||||||||||
| # Case 2: otherwise, the block_size must be a `int`-format supported size of | ||||||||||||||
| # at lease one backend. Iterate over all `int`-format supported sizes in | ||||||||||||||
| # descending order and return the first one that is supported by all backends. | ||||||||||||||
| # Simple proof: | ||||||||||||||
| # If the supported size b is in MultipleOf(x_i) format for all attention | ||||||||||||||
| # backends i, and b is divisible by kv_manager_block_size, then | ||||||||||||||
tdoublep marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||
| # kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will | ||||||||||||||
| # return kv_manager_block_size in case 1. | ||||||||||||||
| all_int_supported_sizes = set( | ||||||||||||||
| supported_size | ||||||||||||||
| for backend in backends | ||||||||||||||
| for supported_size in backend.get_supported_kernel_block_size() | ||||||||||||||
| if isinstance(supported_size, int) | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| return max(common_supported_sizes) | ||||||||||||||
| for supported_size in sorted(all_int_supported_sizes, reverse=True): | ||||||||||||||
| if block_size_is_supported(backends, supported_size): | ||||||||||||||
| return supported_size | ||||||||||||||
| raise ValueError(f"No common block size for {kv_manager_block_size}. ") | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new implementation raises a less informative error message compared to the previous version when no common block size is found. The old error message listed the supported block sizes for each backend, which is very helpful for debugging. It would be great to restore that level of detail in the error message.
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: | ||||||||||||||
| def may_reinitialize_input_batch( | ||||||||||||||
| self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] | ||||||||||||||
| ) -> None: | ||||||||||||||
| """ | ||||||||||||||
| Re-initialize the input batch if the block sizes are different from | ||||||||||||||
| `[self.cache_config.block_size]`. This usually happens when there | ||||||||||||||
| are multiple KV cache groups. | ||||||||||||||
|
|
||||||||||||||
| Args: | ||||||||||||||
| kv_cache_config: The KV cache configuration. | ||||||||||||||
| kernel_block_sizes: The kernel block sizes for each attention backend. | ||||||||||||||
tdoublep marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||
| """ | ||||||||||||||
| block_sizes = [ | ||||||||||||||
| kv_cache_group.kv_cache_spec.block_size | ||||||||||||||
| for kv_cache_group in kv_cache_config.kv_cache_groups | ||||||||||||||
| if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) | ||||||||||||||
| ] | ||||||||||||||
|
|
||||||||||||||
| # Generate kernel_block_sizes that matches each block_size | ||||||||||||||
| kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) | ||||||||||||||
|
|
||||||||||||||
| if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ | ||||||||||||||
| self.cache_config.block_size | ||||||||||||||
| ]: | ||||||||||||||
|
|
@@ -4261,6 +4252,7 @@ def _reshape_kv_cache_tensors( | |||||||||||||
| self, | ||||||||||||||
| kv_cache_config: KVCacheConfig, | ||||||||||||||
| kv_cache_raw_tensors: dict[str, torch.Tensor], | ||||||||||||||
| kernel_block_sizes: list[int], | ||||||||||||||
| ) -> dict[str, torch.Tensor]: | ||||||||||||||
| """ | ||||||||||||||
| Reshape the KV cache tensors to the desired shape and dtype. | ||||||||||||||
|
|
@@ -4269,6 +4261,7 @@ def _reshape_kv_cache_tensors( | |||||||||||||
| kv_cache_config: The KV cache config | ||||||||||||||
| kv_cache_raw_tensors: The KV cache buffer of each layer, with | ||||||||||||||
| correct size but uninitialized shape. | ||||||||||||||
| kernel_block_sizes: The kernel block sizes for attention backends. | ||||||||||||||
tdoublep marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||
| Returns: | ||||||||||||||
| Dict[str, torch.Tensor]: A map between layer names to their | ||||||||||||||
| corresponding memory buffer for KV cache. | ||||||||||||||
|
|
@@ -4278,6 +4271,7 @@ def _reshape_kv_cache_tensors( | |||||||||||||
| for group in self._kv_cache_spec_attn_group_iterator(): | ||||||||||||||
| kv_cache_spec = group.kv_cache_spec | ||||||||||||||
| attn_backend = group.backend | ||||||||||||||
| kernel_block_size = kernel_block_sizes[group.kv_cache_group_id] | ||||||||||||||
| for layer_name in group.layer_names: | ||||||||||||||
|
Comment on lines
4385
to
4392
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Encoder-only cache groups are appended to Useful? React with 👍 / 👎. |
||||||||||||||
| if layer_name in self.runner_only_attn_layers: | ||||||||||||||
| continue | ||||||||||||||
|
|
@@ -4286,24 +4280,21 @@ def _reshape_kv_cache_tensors( | |||||||||||||
| num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes | ||||||||||||||
| if isinstance(kv_cache_spec, AttentionSpec): | ||||||||||||||
| has_attn = True | ||||||||||||||
| kv_manager_block_size = kv_cache_spec.block_size | ||||||||||||||
| kernel_size_list = self._find_compatible_block_sizes( | ||||||||||||||
| kv_manager_block_size, attn_backend, return_all=False | ||||||||||||||
| num_blocks_per_kv_block = ( | ||||||||||||||
| kv_cache_spec.block_size // kernel_block_size | ||||||||||||||
| ) | ||||||||||||||
| kernel_size = kernel_size_list[0] | ||||||||||||||
| num_blocks_per_kv_block = kv_manager_block_size // kernel_size | ||||||||||||||
| kernel_num_blocks = num_blocks * num_blocks_per_kv_block | ||||||||||||||
|
|
||||||||||||||
| kv_cache_shape = attn_backend.get_kv_cache_shape( | ||||||||||||||
| kernel_num_blocks, | ||||||||||||||
| kernel_size, | ||||||||||||||
| kernel_block_size, | ||||||||||||||
| kv_cache_spec.num_kv_heads, | ||||||||||||||
| kv_cache_spec.head_size, | ||||||||||||||
| cache_dtype_str=self.cache_config.cache_dtype, | ||||||||||||||
| ) | ||||||||||||||
| dtype = kv_cache_spec.dtype | ||||||||||||||
| try: | ||||||||||||||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501 | ||||||||||||||
| kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() | ||||||||||||||
| assert len(kv_cache_stride_order) == len(kv_cache_shape) | ||||||||||||||
| except (AttributeError, NotImplementedError): | ||||||||||||||
| kv_cache_stride_order = tuple(range(len(kv_cache_shape))) | ||||||||||||||
|
|
@@ -4386,13 +4377,15 @@ def _update_hybrid_attention_mamba_layout( | |||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| def initialize_kv_cache_tensors( | ||||||||||||||
| self, kv_cache_config: KVCacheConfig | ||||||||||||||
| self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] | ||||||||||||||
| ) -> dict[str, torch.Tensor]: | ||||||||||||||
| """ | ||||||||||||||
| Initialize the memory buffer for KV cache. | ||||||||||||||
|
|
||||||||||||||
| Args: | ||||||||||||||
| kv_cache_config: The KV cache config | ||||||||||||||
| kernel_block_sizes: The kernel block sizes for each attention backend. | ||||||||||||||
tdoublep marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||
|
|
||||||||||||||
| Returns: | ||||||||||||||
| Dict[str, torch.Tensor]: A map between layer names to their | ||||||||||||||
| corresponding memory buffer for KV cache. | ||||||||||||||
|
|
@@ -4401,7 +4394,7 @@ def initialize_kv_cache_tensors( | |||||||||||||
| kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) | ||||||||||||||
| # Change the memory buffer to the desired shape | ||||||||||||||
| kv_caches = self._reshape_kv_cache_tensors( | ||||||||||||||
| kv_cache_config, kv_cache_raw_tensors | ||||||||||||||
| kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| # Set up cross-layer KV cache sharing | ||||||||||||||
|
|
@@ -4460,9 +4453,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | |||||||||||||
| self.may_add_encoder_only_layers_to_kv_cache_config() | ||||||||||||||
| self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) | ||||||||||||||
| self.initialize_attn_backend(kv_cache_config) | ||||||||||||||
| # The block size of attention backends. For example, if kv_cache_manager uses | ||||||||||||||
| # block_size 256, but attention backends only supports block_size 64, we will | ||||||||||||||
| # return kernel_block_size 64 and split the 256-token-block to 4 blocks with 64 | ||||||||||||||
| # tokens each. | ||||||||||||||
tdoublep marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||
| kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) | ||||||||||||||
| # Reinitialize need to after initialize_attn_backend | ||||||||||||||
| self.may_reinitialize_input_batch(kv_cache_config) | ||||||||||||||
| kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) | ||||||||||||||
| self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes) | ||||||||||||||
| kv_caches = self.initialize_kv_cache_tensors( | ||||||||||||||
| kv_cache_config, kernel_block_sizes | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| if self.speculative_config and self.speculative_config.use_eagle(): | ||||||||||||||
| assert isinstance(self.drafter, EagleProposer) | ||||||||||||||
|
|
||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.