Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def kernel_unified_attention_2d(
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
Comment on lines 56 to +63
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make all of the strides constexpr to be safe?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would think its best to not to avoid over recompilation, but for the last stride it makes sense since this is almost always 1 (and when it is 1 we want the compiler to optimize around this, i.e. use wider loads)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoiding recompilation is one reason, but I also think for really long sequences there is a risk that the strides can overflow unless they are explicitly marked as tl.int64. This can't happen for the stride_k_cache_3 and stride_v_cache_3 though, so I think we are safe to do this.

query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import torch

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
FlashAttentionImpl, FlashAttentionMetadata, FlashAttentionMetadataBuilder)

logger = init_logger(__name__)

Expand Down Expand Up @@ -56,7 +56,7 @@ def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder


class TritonAttentionImpl(AttentionImpl):
class TritonAttentionImpl(FlashAttentionImpl):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do:

class TritonAttentionMetadata(FlashAttentionMetadata):
    def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
                 block_table: BlockTable):
        super().__init__(runner, kv_cache_spec, block_table)
        self.aot_schedule = False

instead of this so we also avoid calls to the FA AOT scheduler unless you see an issue with this approach

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ on this as a temp solution

Copy link
Member Author

@tdoublep tdoublep May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, trying this now.

I will implement a clean version of TritonAttentionMetadata in a follow-on PR, I think there might be some benefits to the kernel if we do this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson nice catch. Is it good to go now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I similar workaround is done in the PR #16606


def __init__(
self,
Expand Down