-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[FIXBUG ] Allow disabling rocm_aiter_fa backend for ROCm GPUs not compatible with AITER #22795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
5eea6b2
d23a403
47f9141
48f239e
2b4967b
a2b3920
74b1bce
6eba4cd
5efbeeb
83d3e6a
33c5c34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import ast | ||
| import os | ||
| from dataclasses import replace | ||
| from typing import Optional | ||
|
|
||
|
|
@@ -20,8 +21,6 @@ | |
| from vllm.platforms import current_platform | ||
| from vllm.utils import is_pin_memory_available | ||
| from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata | ||
| from vllm.v1.attention.backends.rocm_aiter_fa import ( | ||
| AiterFlashAttentionMetadata) | ||
| from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, | ||
| TreeAttentionMetadataBuilder) | ||
| from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata | ||
|
|
@@ -233,14 +232,16 @@ | |
| # TODO: Currently, MTP module released by deepseek only has | ||
| # one layer. Adapt this code to support multiple layers once | ||
| # there's a multi-layer MTP module. | ||
|
|
||
|
Check failure on line 235 in vllm/v1/spec_decode/eagle.py
|
||
| # On ROCm, both AiterFlashAttention and TritonAttention | ||
| # support multi-token eagle spec decode. | ||
| if current_platform.is_rocm(): | ||
| assert isinstance( | ||
| attn_metadata, | ||
| (TritonAttentionMetadata, AiterFlashAttentionMetadata, | ||
| FlashAttentionMetadata)) | ||
| allowed_types = (TritonAttentionMetadata, FlashAttentionMetadata) | ||
| if os.environ.get("VLLM_ROCM_USE_AITER") == "1": | ||
| from vllm.v1.attention.backends.rocm_aiter_fa import ( | ||
| AiterFlashAttentionMetadata) | ||
| allowed_types += (AiterFlashAttentionMetadata, ) | ||
|
Check failure on line 243 in vllm/v1/spec_decode/eagle.py
|
||
|
||
| assert isinstance(attn_metadata, allowed_types) | ||
| else: | ||
| # Currently, only FlashAttention and TreeAttention support | ||
| # multi-token eagle spec decode. This is because the code below | ||
|
|
@@ -744,4 +745,4 @@ | |
| greedy_token_ids, | ||
| next_token_ids, | ||
| ) | ||
| return next_token_ids, probs | ||
| return next_token_ids, probs | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any way you can make this more dynamic if it's known what device types would support this vs not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@russellb
I think the architecture names can be used, but it will always have to be expanded. Do you know of another mechanism for this?
For example:
def _is_rocm_gpu_with_matrix_cores() -> bool:
if not torch.cuda.is_available() or not torch.version.hip:
returns False
proof:
device_properties = torch.cuda.get_device_properties(
torch.cuda.current_device())
gcn_arch_name = getattr(device_properties, "gcnArchName", "")
supported_archs = ("gfx908", "gfx90a", "gfx940", "gfx941", "gfx942")
returns any(gcn_arch_name.startswith(arch) for arch in support_archs)
except (RuntimeError, AttributeError):
returns False
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JartX
Let's cache the value of
os.environ.getas it's overhead is large, similar to#17067
And alternative approach is to check if
aiteris installed using fromimportlib.util import find_spec. However, this is also a very costly operation, it should be only called once when a class is initialized of a file is import.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjtanaa Many thanks for your answer the other way :) 47f9141
Using fallback