Skip to content
15 changes: 8 additions & 7 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import os
from dataclasses import replace
from typing import Optional

Expand All @@ -20,8 +21,6 @@
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
Expand Down Expand Up @@ -233,14 +232,16 @@
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.

Check failure on line 235 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]

Check failure on line 235 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]

Check failure on line 235 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]

Check failure on line 235 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]
# On ROCm, both AiterFlashAttention and TritonAttention
# support multi-token eagle spec decode.
if current_platform.is_rocm():
assert isinstance(
attn_metadata,
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
FlashAttentionMetadata))
allowed_types = (TritonAttentionMetadata, FlashAttentionMetadata)
if os.environ.get("VLLM_ROCM_USE_AITER") == "1":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way you can make this more dynamic if it's known what device types would support this vs not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@russellb
I think the architecture names can be used, but it will always have to be expanded. Do you know of another mechanism for this?

For example:
def _is_rocm_gpu_with_matrix_cores() -> bool:
if not torch.cuda.is_available() or not torch.version.hip:
returns False
proof:
device_properties = torch.cuda.get_device_properties(
torch.cuda.current_device())
gcn_arch_name = getattr(device_properties, "gcnArchName", "")
supported_archs = ("gfx908", "gfx90a", "gfx940", "gfx941", "gfx942")
returns any(gcn_arch_name.startswith(arch) for arch in support_archs)
except (RuntimeError, AttributeError):
returns False

Copy link
Collaborator

@tjtanaa tjtanaa Aug 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JartX
Let's cache the value of os.environ.get as it's overhead is large, similar to
#17067

And alternative approach is to check if aiter is installed using from importlib.util import find_spec. However, this is also a very costly operation, it should be only called once when a class is initialized of a file is import.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa Many thanks for your answer the other way :) 47f9141

Using fallback

from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
allowed_types += (AiterFlashAttentionMetadata, )

Check failure on line 243 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]

Check failure on line 243 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]

Check failure on line 243 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]

Check failure on line 243 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]

Check failure on line 243 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]

Check failure on line 243 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]

Check failure on line 243 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]

Check failure on line 243 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]

Check failure on line 243 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]

Check failure on line 243 in vllm/v1/spec_decode/eagle.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata], type[AiterFlashAttentionMetadata]]", variable has type "tuple[type[TritonAttentionMetadata], type[FlashAttentionMetadata]]") [assignment]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling os.environ.get() inside the propose method can introduce performance overhead, as this method is on a hot path during inference. It's better to check the environment variable only once when the module is imported.

I recommend defining a module-level constant at the top of the file:

# At the top of vllm/v1/spec_decode/eagle.py
import os
_VLLM_ROCM_USE_AITER = os.environ.get("VLLM_ROCM_USE_AITER") == "1"

Then, you can use this constant here:

if _VLLM_ROCM_USE_AITER:
    from vllm.v1.attention.backends.rocm_aiter_fa import (
        AiterFlashAttentionMetadata)
    allowed_types += (AiterFlashAttentionMetadata, )

This change will improve performance by avoiding repeated environment variable lookups.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the pre-commit failures under this line

assert isinstance(attn_metadata, allowed_types)
else:
# Currently, only FlashAttention and TreeAttention support
# multi-token eagle spec decode. This is because the code below
Expand Down Expand Up @@ -744,4 +745,4 @@
greedy_token_ids,
next_token_ids,
)
return next_token_ids, probs
return next_token_ids, probs
Loading