diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 24121131a985..5e66fab85e62 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -35,6 +35,14 @@ def aiter_mla_decode_fwd( logit_cap: float = 0.0, num_kv_splits: int | None = None, num_kv_splits_indptr: torch.Tensor | None = None, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, ): torch.ops.vllm.rocm_aiter_mla_decode_fwd( q, @@ -49,6 +57,14 @@ def aiter_mla_decode_fwd( logit_cap=logit_cap, num_kv_splits=num_kv_splits, num_kv_splits_indptr=num_kv_splits_indptr, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=q_scale, + kv_scale=kv_scale, ) @@ -65,6 +81,14 @@ def mla_decode_fwd_impl( logit_cap: float = 0.0, num_kv_splits: int | None = None, num_kv_splits_indptr: torch.Tensor | None = None, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, ) -> None: from aiter.mla import mla_decode_fwd @@ -81,6 +105,14 @@ def mla_decode_fwd_impl( logit_cap=logit_cap, num_kv_splits=num_kv_splits, num_kv_splits_indptr=num_kv_splits_indptr, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=q_scale, + kv_scale=kv_scale, ) @@ -97,6 +129,14 @@ def mla_decode_fwd_fake( logit_cap: float = 0.0, num_kv_splits: int | None = None, num_kv_splits_indptr: torch.Tensor | None = None, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, + q_scale: torch.Tensor | None = None, + kv_scale: torch.Tensor | None = None, ) -> None: pass diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 9b60035f0db8..34664cf5234d 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from dataclasses import dataclass from typing import ClassVar @@ -17,6 +18,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder, + QueryLenSupport, ) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec @@ -60,6 +62,20 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): # Number of kv splits num_kv_splits: int | None = 16 + max_seqlen_qo: int = 1 + + work_metadata: torch.Tensor | None = None + + work_info_set: torch.Tensor | None = None + + work_indptr: torch.Tensor | None = None + + reduce_indptr: torch.Tensor | None = None + + reduce_final_map: torch.Tensor | None = None + + reduce_partial_map: torch.Tensor | None = None + class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): pass @@ -68,9 +84,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # TODO(luka, lucas): audit this as part of: # https://github.com/vllm-project/vllm/issues/22945 - cudagraph_support: ClassVar[AttentionCGSupport] = ( - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - ) + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN def __init__( self, @@ -83,6 +98,10 @@ def __init__( kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata ) + gpu = torch.cuda.current_device() + device_properties = torch.cuda.get_device_properties(gpu) + cu_num = device_properties.multi_processor_count + self.compilation_config = vllm_config.compilation_config self.num_kv_splits = 16 max_num_pages_per_req = cdiv( @@ -91,6 +110,36 @@ def __init__( max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req + # num_mtp = vllm_config.speculative_config.num_speculative_tokens + # num_mtp = 1 if num_mtp is None else num_mtp + max_seqlen_qo = ( + 1 + if vllm_config.speculative_config is None + else vllm_config.speculative_config.num_speculative_tokens + ) + + max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * self.num_heads / 128)) + self.work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda") + self.work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda") + self.work_info_set = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch * cu_num, 8], + dtype=torch.int32, + device="cuda", + ).fill_(-1) + self.reduce_indptr = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch + 1], + dtype=torch.int32, + device="cuda", + ) + self.reduce_final_map = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda" + ) + self.reduce_partial_map = torch.empty( + [max_num_reqs * max_qo_tiles_per_batch * cu_num], + dtype=torch.int32, + device="cuda", + ) + # Preparing persistent buffers # TODO: we can disambiguate between decode and mixed-prefill decode here # so we can only use the persistent buffer if a cudagraph is actually @@ -169,6 +218,32 @@ def _build_decode( seq_lens_device.cumsum(dim=0, dtype=torch.int32), ] ) + kv_indptr = torch.zeros( + [query_start_loc_cpu.size(0)], dtype=torch.int32, device="cuda" + ) + torch.cumsum(seq_lens_device, dim=0, out=kv_indptr[1:]) + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + max_seqlen_qo = torch.max(query_lens).item() + + import aiter + + aiter.get_mla_metadata_v1( + query_start_loc_device, + kv_indptr, + self.num_heads // self.kv_cache_spec.num_kv_heads, + self.kv_cache_spec.num_kv_heads, + True, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + kv_granularity=max(self.kv_cache_spec.block_size, 16), + max_seqlen_qo=max_seqlen_qo, + uni_seqlen_qo=max_seqlen_qo, + fast_mode=True, + ) if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): num_actual_pages = paged_kv_indices.size(0) @@ -223,6 +298,13 @@ def _build_decode( num_kv_splits_indptr=num_kv_splits_indptr, num_kv_splits=self.num_kv_splits, dcp_tot_seq_lens=dcp_tot_seq_lens_device, + max_seqlen_qo=max_seqlen_qo, + work_metadata=self.work_metadata, + work_info_set=self.work_info_set, + work_indptr=self.work_indptr, + reduce_indptr=self.reduce_indptr, + reduce_final_map=self.reduce_final_map, + reduce_partial_map=self.reduce_partial_map, ) return attn_metadata @@ -303,26 +385,33 @@ def _forward_decode( assert isinstance(q, torch.Tensor) B = q.shape[0] o = torch.zeros( - B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device + B, self.num_heads, self.kv_lora_rank, dtype=torch.bfloat16, device=q.device ) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) # max_seqlen_qo must be 1 except for MTP # TODO: Find the best value for MTP - max_seqlen_qo = 1 aiter_mla_decode_fwd( q, kv_buffer, o, self.scale, attn_metadata.decode.qo_indptr, - max_seqlen_qo, + attn_metadata.decode.max_seqlen_qo, attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len, num_kv_splits=attn_metadata.decode.num_kv_splits, num_kv_splits_indptr=attn_metadata.decode.num_kv_splits_indptr, + work_meta_data=attn_metadata.decode.work_metadata, + work_indptr=attn_metadata.decode.work_indptr, + work_info_set=attn_metadata.decode.work_info_set, + reduce_indptr=attn_metadata.decode.reduce_indptr, + reduce_final_map=attn_metadata.decode.reduce_final_map, + reduce_partial_map=attn_metadata.decode.reduce_partial_map, + q_scale=layer._q_scale, + kv_scale=layer._k_scale, ) return o, None