[Spyre-Next] Pytorch Native Attention on Spyre: 4D Attention Kernel#2
Merged
joerunde merged 3 commits intotorch-spyre:mainfrom May 1, 2026
Merged
Conversation
bringlein
approved these changes
Apr 29, 2026
Collaborator
bringlein
left a comment
There was a problem hiding this comment.
we approved it in torch-spyre/sendnn-inference#914, but please fix the merge conflict
Signed-off-by: Jan van Lunteren <[email protected]>
7c06eef to
fe6d3e4
Compare
Collaborator
|
bot:next-test |
joerunde
requested changes
Apr 30, 2026
Collaborator
joerunde
left a comment
There was a problem hiding this comment.
It looks like the tests need to be updated for this change as well. I see our local test_spyre_attn tests are failing for some cases:
=========================== short test summary info ============================
PASSED tests/test_vllm_spyre_next.py::test_basic_model_load
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[False-64-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[False-128-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[False-256-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[False-512-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[True-64-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[True-128-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[True-256-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[True-512-1]
PASSED tests/test_rms_norm.py::test_rmsnorm_oot_dispatch[False]
PASSED tests/test_rms_norm.py::test_rmsnorm_oot_dispatch[True]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-128-num_heads0-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-128-num_heads0-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-128-num_heads0-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-128-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-128-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-128-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-256-num_heads0-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-256-num_heads0-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-256-num_heads0-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-256-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-256-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-256-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-128-num_heads0-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-128-num_heads0-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-128-num_heads0-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-128-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-128-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-128-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-256-num_heads0-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-256-num_heads0-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-256-num_heads0-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-256-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-256-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-256-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-128-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-128-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-128-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-256-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-256-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-256-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-128-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-128-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-128-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-256-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-256-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-256-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn_single_sequence[dtype0-16-128-num_heads0]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-small_decode]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-small_prefill]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-mixed_small]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-medium_decode]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-medium_prefill]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-mixed_medium]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-large_decode]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-large_prefill]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-single_decode]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-single_prefill]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-128-num_heads0-seq_lens0]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-128-num_heads0-seq_lens1]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-128-num_heads0-seq_lens2]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-256-num_heads0-seq_lens0]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-256-num_heads0-seq_lens1]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-256-num_heads0-seq_lens2]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-128-num_heads0-seq_lens0]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-128-num_heads0-seq_lens1]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-128-num_heads0-seq_lens2]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-256-num_heads0-seq_lens0]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-256-num_heads0-seq_lens1]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-256-num_heads0-seq_lens2]
====== 12 failed, 58 passed, 42 skipped, 26 warnings in 301.97s (0:05:01) ======
@jvlunteren Do we still require test_spyre_attn now that we have the upstream test_causal_backend_correctness tests running?
Collaborator
|
bot:next-test |
Collaborator
|
bot:next-test |
joerunde
approved these changes
May 1, 2026
Collaborator
joerunde
left a comment
There was a problem hiding this comment.
nice, looks like we're all good to go wit the torch-spyre + image bump 👍
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR replaces the 2D transposed attention kernel with a 4D broadcast matmul kernel, eliminating per‑sequence and per‑chunk loops, GQA head duplication, and block‑diagonal masking.
Related Issues
Relates to #647
Test Plan
Same approach as in PR #853.
Checklist
bash format.sh)Signed-off-by:line (DCO compliance)