Skip to content
Open
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
a0375d7
Integrated custom attention backend
bohnstingl Feb 27, 2026
11255ac
Formatting issues
bohnstingl Feb 27, 2026
89d5a75
Changed the name of the attention operation
bohnstingl Feb 27, 2026
bfbc64a
Changed filename
bohnstingl Feb 27, 2026
f8afb02
Implemented gather to avoid using full KV cache
bohnstingl Mar 3, 2026
df3ab2c
Removed .item() calls
bohnstingl Mar 3, 2026
8e0bd74
Cleanup and adding of example
bohnstingl Mar 3, 2026
90ce563
Lint
bohnstingl Mar 3, 2026
8d314b9
Added testcase for attention backend
bohnstingl Mar 3, 2026
0f34475
Added missing utils file
bohnstingl Mar 5, 2026
3bc3ee6
Reformat
bohnstingl Mar 6, 2026
c2d264b
Functional update
bohnstingl Mar 8, 2026
2e8e4aa
Lint issues
bohnstingl Mar 8, 2026
14b6ef7
:art: linting, vllm compatibility, test integration
joerunde Mar 9, 2026
c98a9a2
refactored attention backend to support compilation and execution on …
jvlunteren Mar 19, 2026
825a95c
formatting
jvlunteren Mar 20, 2026
a5c719f
add unit test
jvlunteren Mar 20, 2026
6da9be4
formatting
jvlunteren Mar 23, 2026
61e22f1
removed redundant code
jvlunteren Mar 24, 2026
2d8bb12
added empty line back
jvlunteren Mar 24, 2026
139ab4a
formatting
jvlunteren Mar 24, 2026
9bf1283
removed custom num_heads handling
jvlunteren Mar 24, 2026
0931224
removed compat_utils.py
jvlunteren Mar 25, 2026
621df53
renamed spyre_paged_attn.py to spyre_attn.py
jvlunteren Mar 25, 2026
2bf45c1
add dynamic=False argument to torch.compile
jvlunteren Mar 25, 2026
9919ba2
adapted test_spyre_attn.py to previous name change
jvlunteren Mar 25, 2026
a8c26f6
limit supported data types to float16
jvlunteren Mar 25, 2026
6338c32
limit supported kv cache data types to float16
jvlunteren Mar 25, 2026
e118284
removed redundant code
jvlunteren Mar 25, 2026
82d2daf
indicated if steps are executed on CPU and/or Spyre
jvlunteren Mar 25, 2026
781c095
renaming
jvlunteren Mar 25, 2026
c2556b3
further renaming
jvlunteren Mar 25, 2026
a2eadbd
use utils for transfers between cpu and spyre
jvlunteren Mar 25, 2026
0c6100f
various updates to test
jvlunteren Mar 25, 2026
d25ec3f
formatting
jvlunteren Mar 25, 2026
49b6109
WIP: reworked D2H movements
bohnstingl Mar 26, 2026
88e12ea
fixed supports_head_size()
jvlunteren Mar 26, 2026
17d2194
Merge branch 'pytorch_native_attention' of github.com:jvlunteren/vllm…
bohnstingl Mar 26, 2026
a13a657
Enforce dtype="float16"
bohnstingl Mar 26, 2026
91a24d6
Moved assert
bohnstingl Mar 26, 2026
dc5a07c
Corrected stripped attention test
bohnstingl Mar 26, 2026
80c7cc9
Updates to address review comments
bohnstingl Mar 26, 2026
8adae0a
Merge branch 'main' of github.com:vllm-project/vllm-spyre into pytorc…
bohnstingl Mar 30, 2026
af9d8f9
Integrated minor review findings
bohnstingl Mar 30, 2026
c6fd7f9
Merge branch 'main' of github.com:vllm-project/vllm-spyre into pytorc…
bohnstingl Apr 2, 2026
7882018
Integrated reviewer comments and suggestions
bohnstingl Apr 2, 2026
fef4e7f
Fixing formatting errors
bohnstingl Apr 3, 2026
3d3a169
Switched KV cache format to (num_blocks, 2, ...)
bohnstingl Apr 3, 2026
1749821
Removed outdated max_num_seqs==1 restriction
bohnstingl Apr 3, 2026
a3eecc5
Removed enforce_eager argument
bohnstingl Apr 9, 2026
3b02c2c
Merge branch 'main' into pytorch_native_attention
jvlunteren Apr 13, 2026
32e0b63
replace 2D transposed attention kernel with batched 4D broadcast matmul
jvlunteren Apr 14, 2026
da41c0e
Merge branch 'main' into pytorch_native_attention_v2
jvlunteren Apr 14, 2026
3b7ac99
clarify shape comments
jvlunteren Apr 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 92 additions & 209 deletions vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,26 +173,22 @@ class SpyreAttentionImpl(AttentionImpl[SpyreAttentionMetadata]):
QUERY_CHUNK_SIZE = 32

@staticmethod
def _attn_transposed(qt, k, vt, sm_scale, mask_values):
"""Transposed attention for Spyre: handles all heads at once.
def _attn_4d(q, k, v, scale, mask):
"""4D broadcast attention for Spyre: handles batched GQA.

Args:
qt: Query transposed [head_size, num_heads * query_len_padded]
k: Key [num_heads * kv_len, head_size]
vt: Value transposed [head_size, num_heads * kv_len]
sm_scale: Scale factor (1D tensor) [num_heads * query_len_padded]
mask_values: Mask values tensor [num_heads * kv_len, num_heads * query_len_padded]
Pre-computed on CPU: 0.0 for valid, -65504.0 for masked/padded
q: Query [num_seqs*num_kv_heads, num_queries_per_kv, query_len, head_size]
k: Key [num_seqs*num_kv_heads, 1, kv_len, head_size]
v: Value [num_seqs*num_kv_heads, 1, kv_len, head_size]
scale: Scale factor (float)
mask: Additive mask [num_seqs*num_kv_heads, 1, query_len, kv_len]
Pre-computed on CPU: 0.0 for valid, -65504.0 for masked/padded
"""
kq = k @ qt # [num_heads * kv_len, num_heads * query_len_padded]
kq = kq * sm_scale

# Add pre-computed mask values
# Valid positions have 0.0, masked/padded positions have -65504.0
kq = kq + mask_values

p = kq.softmax(dim=0)
return vt @ p # [head_size, num_heads * query_len_padded]
scores = q @ k.transpose(-2, -1)
scores = scores * scale
scores = scores + mask
p = scores.softmax(dim=-1)
return p @ v

def __init__(
self,
Expand Down Expand Up @@ -221,13 +217,13 @@ def __init__(
self._target_dtype = torch.float16

# When True, use torch.nn.functional.scaled_dot_product_attention.
# Otherwise, use the transposed matmul kernel (_attn_transposed).
# Otherwise, use the 4D matmul kernel (_attn_4d).
self.use_sdpa = use_sdpa

if self.use_sdpa:
self.attn_op = torch.nn.functional.scaled_dot_product_attention
else:
self.attn_op = self._attn_transposed
self.attn_op = self._attn_4d

# Compile the attention function once for reuse.
# dynamic=False forces static shapes, required by the Spyre compiler.
Expand Down Expand Up @@ -303,7 +299,7 @@ def forward(
query.device,
)

# Step 5: Compute batched per-sequence attention (CPU, Spyre)
# Step 5: Compute batched attention (CPU, Spyre)
# attn_output: [num_seqs, max_query_len, num_heads, head_size]
attn_output = self._compute_attention(
query_per_seq, compact_k, compact_v, mask, query.device, query.dtype
Expand Down Expand Up @@ -509,39 +505,16 @@ def _compute_attention(
device: torch.device, # device for intermediate allocations
dtype: torch.dtype, # dtype for intermediate allocations
) -> torch.Tensor:
"""Dispatch attention: SDPA path or per-sequence chunked Spyre path.
"""Dispatch attention: SDPA path or batched Spyre path.

Returns:
[num_seqs, max_query_len, num_heads, head_size]
"""
num_seqs = query.shape[0]

# As fallback, use SDPA implementation
if self.use_sdpa:
return self._compute_attention_sdpa(query, key, value, mask)

# Allocate output tensor for all sequences
output_all_seqs = torch.zeros_like(query)

# Process each sequence separately
for seq_idx in range(num_seqs):
# Extract single sequence
query_seq = query[seq_idx : seq_idx + 1] # [1, max_query_len, num_heads, head_size]
key_seq = key[seq_idx : seq_idx + 1] # [1, max_seq_len, num_kv_heads, head_size]
value_seq = value[seq_idx : seq_idx + 1] # [1, max_seq_len, num_kv_heads, head_size]
mask_seq = (
mask[seq_idx : seq_idx + 1] if mask is not None else None
) # [1, 1, max_query_len, max_seq_len]

# Compute attention for this sequence
output_seq = self._compute_attention_single_seq(
query_seq, key_seq, value_seq, mask_seq, device, dtype
)

# Store result
output_all_seqs[seq_idx] = output_seq.squeeze(0)

return output_all_seqs
return self._compute_attention_impl(query, key, value, mask, device, dtype)

def _compute_attention_sdpa(
self,
Expand All @@ -567,190 +540,100 @@ def _compute_attention_sdpa(
)
return out.transpose(1, 2)

def _compute_attention_single_seq(
def _compute_attention_impl(
self,
query: torch.Tensor, # [1, max_query_len, num_heads, head_size]
key: torch.Tensor, # [1, max_seq_len, num_kv_heads, head_size]
value: torch.Tensor, # [1, max_seq_len, num_kv_heads, head_size]
mask: torch.Tensor | None, # [1, 1, max_query_len, max_seq_len]
query: torch.Tensor, # [num_seqs, max_query_len, num_heads, head_size]
key: torch.Tensor, # [num_seqs, aligned_max_seq_len, num_kv_heads, head_size]
value: torch.Tensor, # [num_seqs, aligned_max_seq_len, num_kv_heads, head_size]
Comment on lines +545 to +547
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we expect key and value to be padded, but not the query? What is the rational behind this interface? (if there is one, I'm fully aware this could also just be temporary)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and as @tdoublep pointed out, is there a way to support the flattened varlen format?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The query at the input is "flat" [num_tokens, num_heads, head_size]. The query gets padded inside the code ( lines 573-577).

mask: torch.Tensor | None, # [num_seqs, 1, max_query_len, aligned_max_seq_len]
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Compute attention for a single sequence using Spyre.
"""Compute batched 4D attention on Spyre.

Processes queries in fixed-size chunks of QUERY_CHUNK_SIZE tokens.
"""
Pads query to QUERY_CHUNK_SIZE-aligned length, merges batch into
kv_heads, issues one compiled 4D kernel call, and trims output.

_, _, num_heads, head_size = query.shape
Returns:
[num_seqs, max_query_len, num_heads, head_size]
"""
num_seqs, max_query_len, num_heads, head_size = query.shape
_, kv_len, num_kv_heads, _ = key.shape
num_queries_per_kv = self.num_queries_per_kv

# Handle grouped-query attention by repeating KV heads
if self.num_queries_per_kv > 1:
key = key.repeat_interleave(self.num_queries_per_kv, dim=2)
value = value.repeat_interleave(self.num_queries_per_kv, dim=2)

# Squeeze batch dimension
query_squeezed = query.squeeze(0) # [query_len, num_heads, head_size]
key_squeezed = key.squeeze(0) # [kv_len, num_heads, head_size]
value_squeezed = value.squeeze(0) # [kv_len, num_heads, head_size]

# Calculate number of chunks needed
actual_query_len = query_squeezed.shape[0]
num_chunks = (actual_query_len + self.QUERY_CHUNK_SIZE - 1) // self.QUERY_CHUNK_SIZE

output_full = torch.empty(
actual_query_len,
num_heads,
head_size,
dtype=dtype,
device=device,
# Pad query length to QUERY_CHUNK_SIZE alignment
padded_query_len = (
(max_query_len + self.QUERY_CHUNK_SIZE - 1)
// self.QUERY_CHUNK_SIZE
* self.QUERY_CHUNK_SIZE
)

# Process each chunk
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * self.QUERY_CHUNK_SIZE
chunk_end = min(chunk_start + self.QUERY_CHUNK_SIZE, actual_query_len)
chunk_len = chunk_end - chunk_start

# Extract query chunk
query_chunk = query_squeezed[chunk_start:chunk_end]

# Pad chunk if needed
if chunk_len < self.QUERY_CHUNK_SIZE:
padding_size = self.QUERY_CHUNK_SIZE - chunk_len
query_chunk_padded = torch.nn.functional.pad(
query_chunk, (0, 0, 0, 0, 0, padding_size), mode="constant", value=0.0
)
else:
query_chunk_padded = query_chunk

# Extract corresponding mask for this chunk
if mask is not None:
mask_chunk = mask[:, :, chunk_start:chunk_end, :] # [1, 1, chunk_len, kv_len]
else:
mask_chunk = None

# Compute attention for this chunk
chunk_output = self._compute_attention_chunk(
query_chunk_padded,
key_squeezed,
value_squeezed,
mask_chunk,
chunk_len,
num_heads,
head_size,
kv_len,
device,
dtype,
if padded_query_len > max_query_len:
padding_size = padded_query_len - max_query_len
query = torch.nn.functional.pad(
query, (0, 0, 0, 0, 0, padding_size), mode="constant", value=0.0
)

# Store chunk output (only valid positions)
output_full[chunk_start:chunk_end] = chunk_output[:chunk_len]

return output_full.unsqueeze(0) # [1, query_len, num_heads, head_size]

def _compute_attention_chunk(
self,
query_chunk_padded: torch.Tensor, # [QUERY_CHUNK_SIZE, num_heads, head_size]
key_squeezed: torch.Tensor, # [kv_len, num_heads, head_size]
value_squeezed: torch.Tensor, # [kv_len, num_heads, head_size]
mask_chunk: torch.Tensor | None, # [1, 1, chunk_len, kv_len]
chunk_len: int,
num_heads: int,
head_size: int,
kv_len: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Compute attention for a single query chunk on Spyre.

Prepares tensors on CPU (reshape, stickify, build mask), transfers to
Spyre for the compiled matmul kernel, then transfers the result back.
# Q: [num_seqs, query_len_padded, num_heads, head_size]
# -> [num_seqs, num_heads, query_len_padded, head_size]
# -> [num_seqs*num_kv_heads, num_queries_per_kv, query_len_padded, head_size]
q = query.transpose(1, 2).contiguous()
q = q.reshape(num_seqs * num_kv_heads, num_queries_per_kv, padded_query_len, head_size)

# K/V: [num_seqs, kv_len, num_kv_heads, head_size]
# -> [num_seqs*num_kv_heads, 1, kv_len, head_size]
k = key.transpose(1, 2).contiguous()
k = k.reshape(num_seqs * num_kv_heads, 1, kv_len, head_size)
v = value.transpose(1, 2).contiguous()
v = v.reshape(num_seqs * num_kv_heads, 1, kv_len, head_size)

# --- Build additive mask [num_seqs*num_kv_heads, 1, query_len_padded, kv_len] ---
if mask is not None:
# mask: [num_seqs, 1, max_query_len, kv_len] (bool: True = masked)
mask_3d = mask[:, 0, :, :] # [num_seqs, max_query_len, kv_len]
if padded_query_len > max_query_len:
padding_size = padded_query_len - max_query_len
mask_padding = torch.ones(
(num_seqs, padding_size, kv_len), dtype=torch.bool, device=device
)
mask_3d = torch.cat([mask_3d, mask_padding], dim=1)

Returns:
[QUERY_CHUNK_SIZE, num_heads, head_size] — attention output (padded)
"""
padded_query_len = self.QUERY_CHUNK_SIZE

# Reshape query to flatten heads into query dimension
query_reordered = query_chunk_padded.transpose(
0, 1
).contiguous() # [num_heads, QUERY_CHUNK_SIZE, head_size]
query_flat = query_reordered.reshape(num_heads * padded_query_len, head_size)

# Key and value: also flatten across heads
key_reordered = key_squeezed.transpose(0, 1).contiguous() # [num_heads, kv_len, head_size]
value_reordered = value_squeezed.transpose(
0, 1
).contiguous() # [num_heads, kv_len, head_size]

key_flat = key_reordered.reshape(num_heads * kv_len, head_size)
value_flat = value_reordered.reshape(num_heads * kv_len, head_size)

# Transpose for attention computation
qt = query_flat.T.contiguous() # [head_size, num_heads * QUERY_CHUNK_SIZE]
vt = value_flat.T.contiguous() # [head_size, num_heads * kv_len]
k = key_flat # [num_heads * kv_len, head_size]

# Stickification: force Spyre-friendly memory layout.
# Transposed tensors need double transpose-contiguous; standard tensors just contiguous.
qt_stickified = qt.transpose(0, 1).contiguous().transpose(0, 1).contiguous()
vt_stickified = vt.transpose(0, 1).contiguous().transpose(0, 1).contiguous()
k_stickified = k.contiguous()

# Scale factor: 1D tensor replicated per head × query position
sm_scale_1d = torch.tensor(self.scale, dtype=dtype, device=device).repeat(
num_heads * padded_query_len
) # [num_heads * QUERY_CHUNK_SIZE]

# --- Build block-diagonal additive mask ---
# The transposed kernel flattens all heads into one matmul, so the mask
# must be block-diagonal: each head's causal/padding mask sits on the
# diagonal, off-diagonal blocks are masked (-65504).
if mask_chunk is not None:
mask_all_heads = mask_chunk[0, 0] # [chunk_len, kv_len]

# Pad query dimension to QUERY_CHUNK_SIZE if this is the last chunk
if chunk_len < self.QUERY_CHUNK_SIZE:
padding_size = self.QUERY_CHUNK_SIZE - chunk_len
mask_padding = torch.ones((padding_size, kv_len), dtype=torch.bool, device=device)
mask_all_heads = torch.cat([mask_all_heads, mask_padding], dim=0)

head_mask_t = mask_all_heads.T # [kv_len, QUERY_CHUNK_SIZE], True = masked
mask_bool = ~torch.block_diag(*([~head_mask_t] * num_heads))
# Convert boolean mask to additive: True -> -65504.0, False -> 0.0
mask_additive = torch.where(
mask_3d,
torch.tensor(-65504.0, dtype=dtype, device=device),
torch.tensor(0.0, dtype=dtype, device=device),
)
# [num_seqs, query_len_padded, kv_len]
# -> expand [num_seqs, num_kv_heads, query_len_padded, kv_len]
# -> [num_seqs*num_kv_heads, 1, query_len_padded, kv_len]
mask_4d = (
mask_additive.unsqueeze(1)
.expand(-1, num_kv_heads, -1, -1)
.reshape(num_seqs * num_kv_heads, 1, padded_query_len, kv_len)
.contiguous()
)
else:
# No causal/padding mask: only cross-head positions are masked.
ones_block = torch.ones(kv_len, padded_query_len, dtype=torch.bool, device=device)
mask_bool = ~torch.block_diag(*([ones_block] * num_heads))

# Convert boolean mask to additive: True → -65504.0, False → 0.0
mask_values = torch.where(
mask_bool,
torch.tensor(-65504.0, dtype=dtype, device=device),
torch.tensor(0.0, dtype=dtype, device=device),
).contiguous()
mask_4d = torch.zeros(1, 1, padded_query_len, kv_len, dtype=dtype, device=device)

# --- Transfer to Spyre, compute, transfer back ---
qt_spyre = convert(qt_stickified, self._target_device, self._target_dtype)
k_spyre = convert(k_stickified, self._target_device, self._target_dtype)
vt_spyre = convert(vt_stickified, self._target_device, self._target_dtype)
sm_scale_spyre = convert(sm_scale_1d, self._target_device, self._target_dtype)
mask_spyre = convert(mask_values, self._target_device, self._target_dtype)
q_spyre = convert(q, self._target_device, self._target_dtype)
k_spyre = convert(k, self._target_device, self._target_dtype)
v_spyre = convert(v, self._target_device, self._target_dtype)
mask_spyre = convert(mask_4d, self._target_device, self._target_dtype)

# Compiled attention on Spyre
output_spyre_t = self.attn_op(qt_spyre, k_spyre, vt_spyre, sm_scale_spyre, mask_spyre)
output_spyre = self.attn_op(q_spyre, k_spyre, v_spyre, self.scale, mask_spyre)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we actually start profiling the performance of the different versions?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes


# Transfer back to CPU
output_flat = convert(
output_spyre_t, device, dtype
).contiguous() # [head_size, num_heads * QUERY_CHUNK_SIZE]

# Reshape: [head_size, N*Q] → [N, Q, head_size] → [Q, N, head_size]
output_transposed = output_flat.T # [num_heads * QUERY_CHUNK_SIZE, head_size]
output_reshaped = output_transposed.reshape(num_heads, padded_query_len, head_size)

# [QUERY_CHUNK_SIZE, num_heads, head_size]
return output_reshaped.transpose(0, 1).contiguous()
# [num_seqs*num_kv_heads, num_queries_per_kv, query_len_padded, head_size]
# -> [num_seqs, num_heads, query_len_padded, head_size]
# -> [num_seqs, query_len_padded, num_heads, head_size]
# -> trim to [num_seqs, max_query_len, num_heads, head_size]
output_4d = convert(output_spyre, device, dtype)
output_reshaped = output_4d.reshape(num_seqs, num_heads, padded_query_len, head_size)
output = output_reshaped.transpose(1, 2).contiguous()
return output[:, :max_query_len, :, :]

def _extract_relevant_output(
self,
Expand Down
Loading