Skip to content

Commit 9921806

Browse files
add support for V1 engine on v0.7.3
Co-authored-by: didongli182 <[email protected]> Signed-off-by: shen-shanshan <[email protected]>
1 parent de8cdea commit 9921806

File tree

6 files changed

+1345
-13
lines changed

6 files changed

+1345
-13
lines changed
File renamed without changes.
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from dataclasses import dataclass
19+
from typing import Any, Dict, List, Optional, Tuple, Type
20+
21+
import torch
22+
import torch_npu
23+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
24+
AttentionLayer, AttentionType)
25+
from vllm.attention.backends.utils import CommonAttentionState
26+
27+
28+
class AscendAttentionBackend(AttentionBackend):
29+
30+
@staticmethod
31+
def get_name() -> str:
32+
return "ASCEND"
33+
34+
@staticmethod
35+
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
36+
return AscendAttentionBackendImpl
37+
38+
@staticmethod
39+
def get_metadata_cls() -> Type["AscendMetadata"]:
40+
return AscendMetadata
41+
42+
@staticmethod
43+
def get_state_cls() -> Type["CommonAttentionState"]:
44+
return CommonAttentionState
45+
46+
@staticmethod
47+
def get_kv_cache_shape(
48+
num_blocks: int,
49+
block_size: int,
50+
num_kv_heads: int,
51+
head_size: int,
52+
) -> Tuple[int, ...]:
53+
return (2, num_blocks, block_size, num_kv_heads * head_size)
54+
55+
@staticmethod
56+
def swap_blocks(
57+
src_kv_cache: List[torch.Tensor],
58+
dst_kv_cache: List[torch.Tensor],
59+
src_to_dst: torch.Tensor,
60+
) -> None:
61+
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
62+
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
63+
src_indices = src_to_dst[:, 0]
64+
dst_indices = src_to_dst[:, 1]
65+
66+
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
67+
dst_key_cache.device)
68+
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
69+
dst_key_cache.device)
70+
71+
@staticmethod
72+
def copy_blocks(
73+
kv_caches: List[torch.Tensor],
74+
src_to_dists: torch.Tensor,
75+
) -> None:
76+
src_indices = src_to_dists[:, 0]
77+
dst_indices = src_to_dists[:, 1]
78+
79+
for kv_cache in kv_caches:
80+
key_caches = kv_cache[0]
81+
value_caches = kv_cache[1]
82+
key_caches[dst_indices] = key_caches[src_indices]
83+
value_caches[dst_indices] = value_caches[src_indices]
84+
85+
86+
@dataclass
87+
class AscendMetadata:
88+
# (batch_size, max_blocks_per_seq).
89+
# Block addresses per sequence. (Seq id -> list of physical block)
90+
block_tables: Optional[torch.Tensor]
91+
# (batch_size,). The sequence length per sequence. Sequence length means
92+
# the computed tokens + new tokens None if it is a decoding.
93+
seq_lens: Optional[List[int]] = None
94+
context_lens: Optional[List[int]] = None
95+
# Maximum query length in the batch. None for decoding.
96+
max_query_len: Optional[int] = None
97+
# (num_tokens,). The indices of the token slots that input tokens will be
98+
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
99+
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
100+
# in block 0, and 1st slot in block 1, respectively.
101+
slot_mapping: torch.Tensor = None
102+
# TODO: Indicates whether there are only prefill requests.
103+
# FlashAttention can be used when there are only prefill requests.
104+
# FlashAttention has better performance than PageAtttention,
105+
# but it does not support decode requests.
106+
is_only_prefill: bool = False
107+
108+
attn_mask: Optional[torch.Tensor] = None
109+
110+
111+
class AscendAttentionBackendImpl(AttentionImpl):
112+
113+
def __init__(
114+
self,
115+
num_heads: int,
116+
head_size: int,
117+
scale: float,
118+
num_kv_heads: int,
119+
alibi_slopes: Optional[List[float]],
120+
sliding_window: Optional[int],
121+
kv_cache_dtype: str,
122+
blocksparse_params: Optional[Dict[str, Any]] = None,
123+
logits_soft_cap: Optional[float] = None,
124+
attn_type: str = AttentionType.DECODER,
125+
) -> None:
126+
self.num_heads = num_heads
127+
self.head_size = head_size
128+
self.scale = float(scale)
129+
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
130+
self.hidden_size = self.num_heads * self.head_size
131+
self.kv_cache_dtype = kv_cache_dtype
132+
self.sliding_window = sliding_window
133+
if alibi_slopes is not None:
134+
alibi_slopes = torch.tensor(alibi_slopes,
135+
dtype=torch.float32,
136+
device="npu")
137+
self.alibi_slopes = alibi_slopes
138+
self.attn_type = attn_type
139+
140+
assert self.num_heads % self.num_kv_heads == 0
141+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
142+
self.seq_len_cpu_tensor = None
143+
144+
def forward(
145+
self,
146+
layer: AttentionLayer,
147+
query: torch.Tensor,
148+
key: torch.Tensor,
149+
value: torch.Tensor,
150+
kv_cache: torch.Tensor,
151+
attn_metadata: AscendMetadata,
152+
output: Optional[torch.Tensor] = None,
153+
) -> torch.Tensor:
154+
"""Forward pass with Ascend attention.
155+
Args:
156+
query: shape = [batch_size, seq_len, num_heads * head_size]
157+
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
158+
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
159+
kv_cache: shape = [2, num_blocks, block_size,
160+
num_kv_heads * head_size]
161+
key_cache = [num_blocks, block_size,
162+
num_kv_heads * head_size]
163+
value_cache = [num_blocks, block_size,
164+
num_kv_heads * head_size]
165+
attn_metadata: Metadata for attention.
166+
Returns:
167+
shape = [batch_size * seq_len, num_heads, head_size]
168+
"""
169+
num_tokens = query.shape[0]
170+
output = torch.empty(num_tokens,
171+
self.num_heads,
172+
self.head_size,
173+
dtype=query.dtype,
174+
device=query.device)
175+
176+
if attn_metadata is None:
177+
# Profiling run.
178+
return output.view(num_tokens, self.hidden_size)
179+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
180+
attn_type = self.attn_type
181+
if attn_type != AttentionType.DECODER:
182+
raise NotImplementedError("Encoder self-attention and "
183+
"encoder/decoder cross-attention "
184+
"are not implemented for "
185+
"PallasAttentionBackendImpl")
186+
# View q k v to BSH.
187+
query = query.view(-1, self.num_heads, self.head_size)
188+
key = key.view(-1, self.num_kv_heads, self.head_size)
189+
value = value.view(-1, self.num_kv_heads, self.head_size)
190+
# TODO: Remove this contiguous in the future.
191+
value = value.contiguous()
192+
193+
if hasattr(layer, 'quant_method'):
194+
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
195+
pass
196+
else:
197+
if kv_cache.numel() > 0:
198+
key_cache, value_cache = kv_cache[0], kv_cache[1]
199+
num_blocks, block_size, _ = key_cache.shape
200+
key_cache = key_cache.view(num_blocks, block_size,
201+
self.num_kv_heads, self.head_size)
202+
value_cache = value_cache.view(num_blocks, block_size,
203+
self.num_kv_heads,
204+
self.head_size)
205+
slots = attn_metadata.slot_mapping
206+
torch_npu._npu_reshape_and_cache(key=key,
207+
value=value,
208+
key_cache=key_cache,
209+
value_cache=value_cache,
210+
slot_indices=slots)
211+
212+
# use paged attention
213+
torch_npu._npu_paged_attention_splitfuse(
214+
query=query,
215+
key_cache=key_cache,
216+
value_cache=value_cache,
217+
mask=attn_metadata.attn_mask,
218+
block_table=attn_metadata.block_tables,
219+
seq_len=attn_metadata.seq_lens,
220+
context_lens=attn_metadata.context_lens,
221+
num_kv_heads=self.num_kv_heads,
222+
num_heads=self.num_heads,
223+
scale_value=self.scale,
224+
out=output)
225+
return output.view(num_tokens, self.hidden_size)

vllm_ascend/platform.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,10 @@
1919
from typing import TYPE_CHECKING, Optional, Tuple
2020

2121
import torch
22-
23-
try:
24-
import torch_npu # noqa: F401
25-
except ImportError:
26-
print("Failed to import torch_npu.")
27-
28-
from vllm.config import VllmConfig
22+
import torch_npu # noqa: F401
23+
import vllm.envs as envs
24+
from vllm.config import CompilationLevel, VllmConfig
25+
from vllm.logger import init_logger
2926
from vllm.platforms import Platform, PlatformEnum
3027

3128
if TYPE_CHECKING:
@@ -35,6 +32,8 @@
3532

3633
os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1"
3734

35+
logger = init_logger(__name__)
36+
3837

3938
def _device_id_to_physical_device_id(device_id: int) -> int:
4039
if "ASCEND_RT_VISIBLE_DEVICES" in os.environ:
@@ -54,7 +53,7 @@ class NPUPlatform(Platform):
5453
_enum = PlatformEnum.OOT
5554
device_name: str = "npu"
5655
device_type: str = "npu"
57-
simple_compile_backend: str = "npu"
56+
simple_compile_backend: str = "eager" # Disable torch.compile()
5857
ray_device_key: str = "NPU"
5958
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
6059
dispatch_key: str = "PrivateUse1"
@@ -106,9 +105,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
106105
# RayWorkerWrapper monkey patch when setup
107106
from vllm_ascend.patch import ray_patch # noqa: F401
108107

108+
compilation_config = vllm_config.compilation_config
109+
if compilation_config.level != CompilationLevel.NO_COMPILATION:
110+
logger.warning(
111+
"Compilation level %s is not supported on NPU now, forcing compilation level to NO_COMPILATION",
112+
compilation_config.level)
113+
compilation_config.level = CompilationLevel.NO_COMPILATION
114+
109115
parallel_config = vllm_config.parallel_config
110116
if parallel_config.worker_cls == "auto":
111-
if vllm_config.speculative_config:
117+
if envs.VLLM_USE_V1:
118+
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
119+
elif vllm_config.speculative_config:
112120
parallel_config.worker_cls = "vllm.spec_decode.spec_decode_worker.create_spec_worker"
113121
parallel_config.sd_worker_cls = "vllm_ascend.worker.worker.NPUWorker"
114122
elif vllm_config.scheduler_config.is_multi_step:
@@ -128,12 +136,20 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
128136
# Ascend attention quant uses int8 dtype.
129137
cache_config.cache_dtype = 'int8'
130138

139+
if envs.VLLM_USE_V1 and cache_config.enable_prefix_caching:
140+
logger.warning(
141+
"Prefix caching is not supported for V1 now, disable prefix caching"
142+
)
143+
cache_config.enable_prefix_caching = False
144+
131145
@classmethod
132146
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
133147
kv_cache_dtype, block_size, use_v1, use_mla):
134-
if use_mla:
135-
return "vllm_ascend.attention.AscendMLAAttentionBackend"
136-
return "vllm_ascend.attention.AscendAttentionBackend"
148+
if use_v1:
149+
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
150+
elif use_mla:
151+
return "vllm_ascend.attention.attention.AscendMLAAttentionBackend"
152+
return "vllm_ascend.attention.attention.AscendAttentionBackend"
137153

138154
@classmethod
139155
def get_current_memory_usage(cls,
@@ -145,3 +161,7 @@ def get_current_memory_usage(cls,
145161
@classmethod
146162
def get_device_communicator_cls(cls) -> str:
147163
return "vllm_ascend.communicator.NPUCommunicator"
164+
165+
@classmethod
166+
def is_pin_memory_available(cls):
167+
return True

vllm_ascend/worker/draft_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
ModelRunnerInputBase,
2828
ModelRunnerWrapperBase)
2929

30-
from vllm_ascend.attention import AscendMetadata as FlashAttentionMetadata
30+
from vllm_ascend.attention.attention import \
31+
AscendMetadata as FlashAttentionMetadata
3132

3233
logger = init_logger(__name__)
3334

0 commit comments

Comments
 (0)