Skip to content

[Spyre-Next] Pytorch Native Attention on Spyre: 4D Attention Kernel#2

Merged
joerunde merged 3 commits intotorch-spyre:mainfrom
jvlunteren:pytorch_native_attention_v2
May 1, 2026
Merged

[Spyre-Next] Pytorch Native Attention on Spyre: 4D Attention Kernel#2
joerunde merged 3 commits intotorch-spyre:mainfrom
jvlunteren:pytorch_native_attention_v2

Conversation

@jvlunteren
Copy link
Copy Markdown
Collaborator

Description

This PR replaces the 2D transposed attention kernel with a 4D broadcast matmul kernel, eliminating per‑sequence and per‑chunk loops, GQA head duplication, and block‑diagonal masking.

Related Issues

Relates to #647

Test Plan

Same approach as in PR #853.

Checklist

  • I have read the contributing guidelines
  • My code follows the project's code style (run bash format.sh)
  • I have added tests for my changes (if applicable)
  • I have updated the documentation (if applicable)
  • My commits include a Signed-off-by: line (DCO compliance)

@jvlunteren jvlunteren requested a review from tdoublep April 21, 2026 12:05
Copy link
Copy Markdown
Collaborator

@bringlein bringlein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we approved it in torch-spyre/sendnn-inference#914, but please fix the merge conflict

@jvlunteren jvlunteren force-pushed the pytorch_native_attention_v2 branch from 7c06eef to fe6d3e4 Compare April 30, 2026 13:28
@joerunde
Copy link
Copy Markdown
Collaborator

bot:next-test

Copy link
Copy Markdown
Collaborator

@joerunde joerunde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the tests need to be updated for this change as well. I see our local test_spyre_attn tests are failing for some cases:

=========================== short test summary info ============================
PASSED tests/test_vllm_spyre_next.py::test_basic_model_load
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[False-64-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[False-128-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[False-256-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[False-512-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[True-64-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[True-128-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[True-256-1]
PASSED tests/test_rms_norm.py::test_spyre_rmsnorm_matches_reference[True-512-1]
PASSED tests/test_rms_norm.py::test_rmsnorm_oot_dispatch[False]
PASSED tests/test_rms_norm.py::test_rmsnorm_oot_dispatch[True]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-128-num_heads0-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-128-num_heads0-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-128-num_heads0-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-128-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-128-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-128-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-256-num_heads0-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-256-num_heads0-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-256-num_heads0-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-256-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-256-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-2048-None-dtype0-None-16-256-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-128-num_heads0-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-128-num_heads0-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-128-num_heads0-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-128-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-128-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-128-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-256-num_heads0-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-256-num_heads0-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-256-num_heads0-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-256-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-256-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[True-32768-None-dtype0-None-16-256-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-128-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-128-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-128-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-256-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-256-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-256-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-128-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-128-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-128-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-256-num_heads1-seq_lens0]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-256-num_heads1-seq_lens1]
PASSED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-256-num_heads1-seq_lens2]
PASSED tests/test_spyre_attn.py::test_spyre_attn_single_sequence[dtype0-16-128-num_heads0]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-small_decode]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-small_prefill]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-mixed_small]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-medium_decode]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-medium_prefill]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-mixed_medium]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-large_decode]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-large_prefill]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-single_decode]
PASSED ::test_causal_backend_correctness[1-meta-llama/Meta-Llama-3-8B-single_prefill]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-128-num_heads0-seq_lens0]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-128-num_heads0-seq_lens1]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-128-num_heads0-seq_lens2]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-256-num_heads0-seq_lens0]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-256-num_heads0-seq_lens1]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-2048-None-dtype0-None-16-256-num_heads0-seq_lens2]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-128-num_heads0-seq_lens0]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-128-num_heads0-seq_lens1]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-128-num_heads0-seq_lens2]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-256-num_heads0-seq_lens0]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-256-num_heads0-seq_lens1]
FAILED tests/test_spyre_attn.py::test_spyre_attn[False-32768-None-dtype0-None-16-256-num_heads0-seq_lens2]
====== 12 failed, 58 passed, 42 skipped, 26 warnings in 301.97s (0:05:01) ======

@jvlunteren Do we still require test_spyre_attn now that we have the upstream test_causal_backend_correctness tests running?

@joerunde
Copy link
Copy Markdown
Collaborator

joerunde commented May 1, 2026

bot:next-test
(checking if the last image bump has fixed those failures)

@joerunde
Copy link
Copy Markdown
Collaborator

joerunde commented May 1, 2026

bot:next-test

Copy link
Copy Markdown
Collaborator

@joerunde joerunde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, looks like we're all good to go wit the torch-spyre + image bump 👍

@joerunde joerunde merged commit 242da3b into torch-spyre:main May 1, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants