|
| 1 | +# |
| 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 3 | +# This file is a part of the vllm-ascend project. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +from dataclasses import dataclass |
| 19 | +from typing import Any, Dict, List, Optional, Tuple, Type |
| 20 | + |
| 21 | +import torch |
| 22 | +import torch_npu |
| 23 | +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, |
| 24 | + AttentionLayer, AttentionType) |
| 25 | +from vllm.attention.backends.utils import CommonAttentionState |
| 26 | + |
| 27 | + |
| 28 | +class AscendAttentionBackend(AttentionBackend): |
| 29 | + |
| 30 | + @staticmethod |
| 31 | + def get_name() -> str: |
| 32 | + return "ASCEND" |
| 33 | + |
| 34 | + @staticmethod |
| 35 | + def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: |
| 36 | + return AscendAttentionBackendImpl |
| 37 | + |
| 38 | + @staticmethod |
| 39 | + def get_metadata_cls() -> Type["AscendMetadata"]: |
| 40 | + return AscendMetadata |
| 41 | + |
| 42 | + @staticmethod |
| 43 | + def get_state_cls() -> Type["CommonAttentionState"]: |
| 44 | + return CommonAttentionState |
| 45 | + |
| 46 | + @staticmethod |
| 47 | + def get_kv_cache_shape( |
| 48 | + num_blocks: int, |
| 49 | + block_size: int, |
| 50 | + num_kv_heads: int, |
| 51 | + head_size: int, |
| 52 | + ) -> Tuple[int, ...]: |
| 53 | + return (2, num_blocks, block_size, num_kv_heads * head_size) |
| 54 | + |
| 55 | + @staticmethod |
| 56 | + def swap_blocks( |
| 57 | + src_kv_cache: List[torch.Tensor], |
| 58 | + dst_kv_cache: List[torch.Tensor], |
| 59 | + src_to_dst: torch.Tensor, |
| 60 | + ) -> None: |
| 61 | + src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1] |
| 62 | + dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1] |
| 63 | + src_indices = src_to_dst[:, 0] |
| 64 | + dst_indices = src_to_dst[:, 1] |
| 65 | + |
| 66 | + dst_key_cache[dst_indices] = src_key_cache[src_indices].to( |
| 67 | + dst_key_cache.device) |
| 68 | + dst_value_cache[dst_indices] = src_value_cache[src_indices].to( |
| 69 | + dst_key_cache.device) |
| 70 | + |
| 71 | + @staticmethod |
| 72 | + def copy_blocks( |
| 73 | + kv_caches: List[torch.Tensor], |
| 74 | + src_to_dists: torch.Tensor, |
| 75 | + ) -> None: |
| 76 | + src_indices = src_to_dists[:, 0] |
| 77 | + dst_indices = src_to_dists[:, 1] |
| 78 | + |
| 79 | + for kv_cache in kv_caches: |
| 80 | + key_caches = kv_cache[0] |
| 81 | + value_caches = kv_cache[1] |
| 82 | + key_caches[dst_indices] = key_caches[src_indices] |
| 83 | + value_caches[dst_indices] = value_caches[src_indices] |
| 84 | + |
| 85 | + |
| 86 | +@dataclass |
| 87 | +class AscendMetadata: |
| 88 | + # (batch_size, max_blocks_per_seq). |
| 89 | + # Block addresses per sequence. (Seq id -> list of physical block) |
| 90 | + block_tables: Optional[torch.Tensor] |
| 91 | + # (batch_size,). The sequence length per sequence. Sequence length means |
| 92 | + # the computed tokens + new tokens None if it is a decoding. |
| 93 | + seq_lens: Optional[List[int]] = None |
| 94 | + context_lens: Optional[List[int]] = None |
| 95 | + # Maximum query length in the batch. None for decoding. |
| 96 | + max_query_len: Optional[int] = None |
| 97 | + # (num_tokens,). The indices of the token slots that input tokens will be |
| 98 | + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size |
| 99 | + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot |
| 100 | + # in block 0, and 1st slot in block 1, respectively. |
| 101 | + slot_mapping: torch.Tensor = None |
| 102 | + # TODO: Indicates whether there are only prefill requests. |
| 103 | + # FlashAttention can be used when there are only prefill requests. |
| 104 | + # FlashAttention has better performance than PageAtttention, |
| 105 | + # but it does not support decode requests. |
| 106 | + is_only_prefill: bool = False |
| 107 | + |
| 108 | + attn_mask: Optional[torch.Tensor] = None |
| 109 | + |
| 110 | + |
| 111 | +class AscendAttentionBackendImpl(AttentionImpl): |
| 112 | + |
| 113 | + def __init__( |
| 114 | + self, |
| 115 | + num_heads: int, |
| 116 | + head_size: int, |
| 117 | + scale: float, |
| 118 | + num_kv_heads: int, |
| 119 | + alibi_slopes: Optional[List[float]], |
| 120 | + sliding_window: Optional[int], |
| 121 | + kv_cache_dtype: str, |
| 122 | + blocksparse_params: Optional[Dict[str, Any]] = None, |
| 123 | + logits_soft_cap: Optional[float] = None, |
| 124 | + attn_type: str = AttentionType.DECODER, |
| 125 | + ) -> None: |
| 126 | + self.num_heads = num_heads |
| 127 | + self.head_size = head_size |
| 128 | + self.scale = float(scale) |
| 129 | + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads |
| 130 | + self.hidden_size = self.num_heads * self.head_size |
| 131 | + self.kv_cache_dtype = kv_cache_dtype |
| 132 | + self.sliding_window = sliding_window |
| 133 | + if alibi_slopes is not None: |
| 134 | + alibi_slopes = torch.tensor(alibi_slopes, |
| 135 | + dtype=torch.float32, |
| 136 | + device="npu") |
| 137 | + self.alibi_slopes = alibi_slopes |
| 138 | + self.attn_type = attn_type |
| 139 | + |
| 140 | + assert self.num_heads % self.num_kv_heads == 0 |
| 141 | + self.num_queries_per_kv = self.num_heads // self.num_kv_heads |
| 142 | + self.seq_len_cpu_tensor = None |
| 143 | + |
| 144 | + def forward( |
| 145 | + self, |
| 146 | + layer: AttentionLayer, |
| 147 | + query: torch.Tensor, |
| 148 | + key: torch.Tensor, |
| 149 | + value: torch.Tensor, |
| 150 | + kv_cache: torch.Tensor, |
| 151 | + attn_metadata: AscendMetadata, |
| 152 | + output: Optional[torch.Tensor] = None, |
| 153 | + ) -> torch.Tensor: |
| 154 | + """Forward pass with Ascend attention. |
| 155 | + Args: |
| 156 | + query: shape = [batch_size, seq_len, num_heads * head_size] |
| 157 | + key: shape = [batch_size, seq_len, num_kv_heads * head_size] |
| 158 | + value: shape = [batch_size, seq_len, num_kv_heads * head_size] |
| 159 | + kv_cache: shape = [2, num_blocks, block_size, |
| 160 | + num_kv_heads * head_size] |
| 161 | + key_cache = [num_blocks, block_size, |
| 162 | + num_kv_heads * head_size] |
| 163 | + value_cache = [num_blocks, block_size, |
| 164 | + num_kv_heads * head_size] |
| 165 | + attn_metadata: Metadata for attention. |
| 166 | + Returns: |
| 167 | + shape = [batch_size * seq_len, num_heads, head_size] |
| 168 | + """ |
| 169 | + num_tokens = query.shape[0] |
| 170 | + output = torch.empty(num_tokens, |
| 171 | + self.num_heads, |
| 172 | + self.head_size, |
| 173 | + dtype=query.dtype, |
| 174 | + device=query.device) |
| 175 | + |
| 176 | + if attn_metadata is None: |
| 177 | + # Profiling run. |
| 178 | + return output.view(num_tokens, self.hidden_size) |
| 179 | + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 |
| 180 | + attn_type = self.attn_type |
| 181 | + if attn_type != AttentionType.DECODER: |
| 182 | + raise NotImplementedError("Encoder self-attention and " |
| 183 | + "encoder/decoder cross-attention " |
| 184 | + "are not implemented for " |
| 185 | + "PallasAttentionBackendImpl") |
| 186 | + # View q k v to BSH. |
| 187 | + query = query.view(-1, self.num_heads, self.head_size) |
| 188 | + key = key.view(-1, self.num_kv_heads, self.head_size) |
| 189 | + value = value.view(-1, self.num_kv_heads, self.head_size) |
| 190 | + # TODO: Remove this contiguous in the future. |
| 191 | + value = value.contiguous() |
| 192 | + |
| 193 | + if hasattr(layer, 'quant_method'): |
| 194 | + # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata |
| 195 | + pass |
| 196 | + else: |
| 197 | + if kv_cache.numel() > 0: |
| 198 | + key_cache, value_cache = kv_cache[0], kv_cache[1] |
| 199 | + num_blocks, block_size, _ = key_cache.shape |
| 200 | + key_cache = key_cache.view(num_blocks, block_size, |
| 201 | + self.num_kv_heads, self.head_size) |
| 202 | + value_cache = value_cache.view(num_blocks, block_size, |
| 203 | + self.num_kv_heads, |
| 204 | + self.head_size) |
| 205 | + slots = attn_metadata.slot_mapping |
| 206 | + torch_npu._npu_reshape_and_cache(key=key, |
| 207 | + value=value, |
| 208 | + key_cache=key_cache, |
| 209 | + value_cache=value_cache, |
| 210 | + slot_indices=slots) |
| 211 | + |
| 212 | + # use paged attention |
| 213 | + torch_npu._npu_paged_attention_splitfuse( |
| 214 | + query=query, |
| 215 | + key_cache=key_cache, |
| 216 | + value_cache=value_cache, |
| 217 | + mask=attn_metadata.attn_mask, |
| 218 | + block_table=attn_metadata.block_tables, |
| 219 | + seq_len=attn_metadata.seq_lens, |
| 220 | + context_lens=attn_metadata.context_lens, |
| 221 | + num_kv_heads=self.num_kv_heads, |
| 222 | + num_heads=self.num_heads, |
| 223 | + scale_value=self.scale, |
| 224 | + out=output) |
| 225 | + return output.view(num_tokens, self.hidden_size) |
0 commit comments