-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[Kernel] [V1] Fix performance regression for triton unified attention #18161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,13 +5,13 @@ | |
| import torch | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||
| from vllm.attention.backends.abstract import (AttentionBackend, | ||
robertgshaw2-redhat marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| AttentionMetadata, AttentionType) | ||
| from vllm.attention.ops.triton_unified_attention import unified_attention | ||
| from vllm.logger import init_logger | ||
| from vllm.platforms import current_platform | ||
| from vllm.v1.attention.backends.flash_attn import ( | ||
| FlashAttentionMetadata, FlashAttentionMetadataBuilder) | ||
| FlashAttentionImpl, FlashAttentionMetadata, FlashAttentionMetadataBuilder) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
@@ -56,7 +56,7 @@ def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: | |
| return FlashAttentionMetadataBuilder | ||
|
|
||
|
|
||
| class TritonAttentionImpl(AttentionImpl): | ||
| class TritonAttentionImpl(FlashAttentionImpl): | ||
|
||
|
|
||
| def __init__( | ||
| self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not make all of the strides constexpr to be safe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would think its best to not to avoid over recompilation, but for the last stride it makes sense since this is almost always 1 (and when it is 1 we want the compiler to optimize around this, i.e. use wider loads)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoiding recompilation is one reason, but I also think for really long sequences there is a risk that the strides can overflow unless they are explicitly marked as
tl.int64. This can't happen for thestride_k_cache_3andstride_v_cache_3though, so I think we are safe to do this.