Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Find more about how to setup your environment step by step in [here](docs/source
## Getting Started

> [!NOTE]
> Currently, we are actively collaborating with the vLLM community to support the Ascend backend plugin, once supported you can use one line command `pip install vllm vllm-ascend` to compelete installation.
> Currently, we are actively collaborating with the vLLM community to support the Ascend backend plugin, once supported you can use one line command `pip install vllm vllm-ascend` to complete installation.

Installation from source code:
```bash
Expand Down
2 changes: 1 addition & 1 deletion docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ docker run --rm \
-it $IMAGE bash
```

:::{dropdown} Click here to see "Install CANN manally"
:::{dropdown} Click here to see "Install CANN manually"
:animate: fade-in-slide-down
You can also install CANN manually:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorials/multi_node.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ ray start --address='{head_node_ip}:{port_num}' --num-gpus=8 --node-ip-address={
```

:::{note}
If you're running DeepSeek V3/R1, please remove `quantization_config` section in `config.json` file since it's not supported by vllm-ascend currentlly.
If you're running DeepSeek V3/R1, please remove `quantization_config` section in `config.json` file since it's not supported by vllm-ascend currently.
:::

Start the vLLM server on head node:
Expand Down
10 changes: 5 additions & 5 deletions docs/source/user_guide/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
- Pin modelscope<1.23.0 on vLLM v0.7.3 to resolve: https://github.com/vllm-project/vllm/pull/13807

### Known issues
- In [some cases](https://github.com/vllm-project/vllm-ascend/issues/324), expecially when the input/output is very long, the accuracy of output may be incorrect. We are working on it. It'll be fixed in the next release.
- Improved and reduced the garbled code in model output. But if you still hit the issue, try to change the gerneration config value, such as `temperature`, and try again. There is also a knonwn issue shown below. Any [feedback](https://github.com/vllm-project/vllm-ascend/issues/267) is welcome. [#277](https://github.com/vllm-project/vllm-ascend/pull/277)
- In [some cases](https://github.com/vllm-project/vllm-ascend/issues/324), especially when the input/output is very long, the accuracy of output may be incorrect. We are working on it. It'll be fixed in the next release.
- Improved and reduced the garbled code in model output. But if you still hit the issue, try to change the generation config value, such as `temperature`, and try again. There is also a knonwn issue shown below. Any [feedback](https://github.com/vllm-project/vllm-ascend/issues/267) is welcome. [#277](https://github.com/vllm-project/vllm-ascend/pull/277)

## v0.7.1rc1

Expand All @@ -46,7 +46,7 @@ Please follow the [official doc](https://vllm-ascend.readthedocs.io/en/v0.7.1-de

### Core

- Added the Ascend quantization config option, the implementation will comming soon. [#7](https://github.com/vllm-project/vllm-ascend/pull/7) [#73](https://github.com/vllm-project/vllm-ascend/pull/73)
- Added the Ascend quantization config option, the implementation will coming soon. [#7](https://github.com/vllm-project/vllm-ascend/pull/7) [#73](https://github.com/vllm-project/vllm-ascend/pull/73)
- Add silu_and_mul and rope ops and add mix ops into attention layer. [#18](https://github.com/vllm-project/vllm-ascend/pull/18)

### Other
Expand All @@ -58,5 +58,5 @@ Please follow the [official doc](https://vllm-ascend.readthedocs.io/en/v0.7.1-de
### Known issues

- This release relies on an unreleased torch_npu version. It has been installed within official container image already. Please [install](https://vllm-ascend.readthedocs.io/en/v0.7.1rc1/installation.html) it manually if you are using non-container environment.
- There are logs like `No platform deteced, vLLM is running on UnspecifiedPlatform` or `Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")` shown when runing vllm-ascend. It actually doesn't affect any functionality and performance. You can just ignore it. And it has been fixed in this [PR](https://github.com/vllm-project/vllm/pull/12432) which will be included in v0.7.3 soon.
- There are logs like `# CPU blocks: 35064, # CPU blocks: 2730` shown when runing vllm-ascend which should be `# NPU blocks:` . It actually doesn't affect any functionality and performance. You can just ignore it. And it has been fixed in this [PR](https://github.com/vllm-project/vllm/pull/13378) which will be included in v0.7.3 soon.
- There are logs like `No platform detected, vLLM is running on UnspecifiedPlatform` or `Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")` shown when running vllm-ascend. It actually doesn't affect any functionality and performance. You can just ignore it. And it has been fixed in this [PR](https://github.com/vllm-project/vllm/pull/12432) which will be included in v0.7.3 soon.
- There are logs like `# CPU blocks: 35064, # CPU blocks: 2730` shown when running vllm-ascend which should be `# NPU blocks:` . It actually doesn't affect any functionality and performance. You can just ignore it. And it has been fixed in this [PR](https://github.com/vllm-project/vllm/pull/13378) which will be included in v0.7.3 soon.
2 changes: 1 addition & 1 deletion docs/source/user_guide/suppoted_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
| Speculative decoding | ✅ | | | Basic functions available | Need fully test |
| Pooling | ✅ | | | Basic functions available(Bert) | Need fully test and add more models support|
| Enc-dec | ❌ | | | NA | Plan in 2025.06.30|
| Multi Modality | ✅ | | ✅ | Basic functions available(LLaVA/Qwen2-vl/Qwen2-audio/internVL)| Improve perforamance, and add more models support |
| Multi Modality | ✅ | | ✅ | Basic functions available(LLaVA/Qwen2-vl/Qwen2-audio/internVL)| Improve performance, and add more models support |
| LogProbs | ✅ | | | Basic functions available | Need fully test |
| Prompt logProbs | ✅ | | | Basic functions available | Need fully test |
| Async output | ✅ | | | Basic functions available | Need fully test |
Expand Down
5 changes: 3 additions & 2 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,15 @@ CODESPELL_EXCLUDES=(
)

CODESPELL_IGNORE_WORDS=(
'-L' 'CANN,NNAL,ASCEND'
'-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend'
)

# check spelling of specified files
spell_check() {
codespell "$@" "${CODESPELL_IGNORE_WORDS[@]}"
}

spell_check_all(){
spell_check_all() {
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"
}

Expand All @@ -168,6 +168,7 @@ spell_check_changed() {
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
codespell "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"
codespell "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"
fi
}

Expand Down
31 changes: 31 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os

import torch
import torch_npu # noqa: F401

device_id = 0


def _device_id_to_physical_device_id(device_id: int) -> int:
if "ASCEND_RT_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")
if device_ids == [""]:
raise RuntimeError("ASCEND_RT_VISIBLE_DEVICES is set to empty"
"string, which means Ascend NPU support is"
"disabled.")
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id


physical_device_id = _device_id_to_physical_device_id(device_id)
print("physical_device_id: " + str(physical_device_id))

# return torch.npu.get_device_name(physical_device_id)
torch.npu.get_device_name(device_id)

for k, v in os.environ.items():
if k == "ASCEND_RT_VISIBLE_DEVICES":
print(k)
print(v)
2 changes: 1 addition & 1 deletion tools/actionlint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#

if command -v actionlint &> /dev/null; then
# NOTE: avoid check .github/workflows/vllm_ascend_test.yaml becase sel-hosted runner `npu-arm64` is unknown
# NOTE: avoid check .github/workflows/vllm_ascend_test.yaml because sel-hosted runner `npu-arm64` is unknown
actionlint .github/workflows/*.yml .github/workflows/mypy.yaml
exit 0
elif [ -x ./actionlint ]; then
Expand Down
File renamed without changes.
225 changes: 225 additions & 0 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState


class AscendAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "ASCEND"

@staticmethod
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
return AscendAttentionBackendImpl

@staticmethod
def get_metadata_cls() -> Type["AscendMetadata"]:
return AscendMetadata

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)

@staticmethod
def swap_blocks(
src_kv_cache: List[torch.Tensor],
dst_kv_cache: List[torch.Tensor],
src_to_dst: torch.Tensor,
) -> None:
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
src_indices = src_to_dst[:, 0]
dst_indices = src_to_dst[:, 1]

dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
dst_key_cache.device)
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
dst_key_cache.device)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
src_indices = src_to_dists[:, 0]
dst_indices = src_to_dists[:, 1]

for kv_cache in kv_caches:
key_caches = kv_cache[0]
value_caches = kv_cache[1]
key_caches[dst_indices] = key_caches[src_indices]
value_caches[dst_indices] = value_caches[src_indices]


@dataclass
class AscendMetadata:
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
block_tables: Optional[torch.Tensor]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
context_lens: Optional[List[int]] = None
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor = None
# TODO: Indicates whether there are only prefill requests.
# FlashAttention can be used when there are only prefill requests.
# FlashAttention has better performance than PageAtttention,
# but it does not support decode requests.
is_only_prefill: bool = False

attn_mask: Optional[torch.Tensor] = None


class AscendAttentionBackendImpl(AttentionImpl):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.hidden_size = self.num_heads * self.head_size
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes,
dtype=torch.float32,
device="npu")
self.alibi_slopes = alibi_slopes
self.attn_type = attn_type

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.seq_len_cpu_tensor = None

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache: shape = [2, num_blocks, block_size,
num_kv_heads * head_size]
key_cache = [num_blocks, block_size,
num_kv_heads * head_size]
value_cache = [num_blocks, block_size,
num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size * seq_len, num_heads, head_size]
"""
num_tokens = query.shape[0]
output = torch.empty(num_tokens,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)

if attn_metadata is None:
# Profiling run.
return output.view(num_tokens, self.hidden_size)
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
# View q k v to BSH.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# TODO: Remove this contiguous in the future.
value = value.contiguous()

if hasattr(layer, 'quant_method'):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
pass
else:
if kv_cache.numel() > 0:
key_cache, value_cache = kv_cache[0], kv_cache[1]
num_blocks, block_size, _ = key_cache.shape
key_cache = key_cache.view(num_blocks, block_size,
self.num_kv_heads, self.head_size)
value_cache = value_cache.view(num_blocks, block_size,
self.num_kv_heads,
self.head_size)
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_indices=slots)

# use paged attention
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=key_cache,
value_cache=value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.seq_lens,
context_lens=attn_metadata.context_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
return output.view(num_tokens, self.hidden_size)
Loading