-
Notifications
You must be signed in to change notification settings - Fork 53
[Spyre-Next] Pytorch Native Attention on Spyre: 4D Attention Kernel #914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jvlunteren
wants to merge
54
commits into
torch-spyre:main
Choose a base branch
from
jvlunteren:pytorch_native_attention_v2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+92
−209
Open
Changes from all commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
a0375d7
Integrated custom attention backend
bohnstingl 11255ac
Formatting issues
bohnstingl 89d5a75
Changed the name of the attention operation
bohnstingl bfbc64a
Changed filename
bohnstingl f8afb02
Implemented gather to avoid using full KV cache
bohnstingl df3ab2c
Removed .item() calls
bohnstingl 8e0bd74
Cleanup and adding of example
bohnstingl 90ce563
Lint
bohnstingl 8d314b9
Added testcase for attention backend
bohnstingl 0f34475
Added missing utils file
bohnstingl 3bc3ee6
Reformat
bohnstingl c2d264b
Functional update
bohnstingl 2e8e4aa
Lint issues
bohnstingl 14b6ef7
:art: linting, vllm compatibility, test integration
joerunde c98a9a2
refactored attention backend to support compilation and execution on …
jvlunteren 825a95c
formatting
jvlunteren a5c719f
add unit test
jvlunteren 6da9be4
formatting
jvlunteren 61e22f1
removed redundant code
jvlunteren 2d8bb12
added empty line back
jvlunteren 139ab4a
formatting
jvlunteren 9bf1283
removed custom num_heads handling
jvlunteren 0931224
removed compat_utils.py
jvlunteren 621df53
renamed spyre_paged_attn.py to spyre_attn.py
jvlunteren 2bf45c1
add dynamic=False argument to torch.compile
jvlunteren 9919ba2
adapted test_spyre_attn.py to previous name change
jvlunteren a8c26f6
limit supported data types to float16
jvlunteren 6338c32
limit supported kv cache data types to float16
jvlunteren e118284
removed redundant code
jvlunteren 82d2daf
indicated if steps are executed on CPU and/or Spyre
jvlunteren 781c095
renaming
jvlunteren c2556b3
further renaming
jvlunteren a2eadbd
use utils for transfers between cpu and spyre
jvlunteren 0c6100f
various updates to test
jvlunteren d25ec3f
formatting
jvlunteren 49b6109
WIP: reworked D2H movements
bohnstingl 88e12ea
fixed supports_head_size()
jvlunteren 17d2194
Merge branch 'pytorch_native_attention' of github.com:jvlunteren/vllm…
bohnstingl a13a657
Enforce dtype="float16"
bohnstingl 91a24d6
Moved assert
bohnstingl dc5a07c
Corrected stripped attention test
bohnstingl 80c7cc9
Updates to address review comments
bohnstingl 8adae0a
Merge branch 'main' of github.com:vllm-project/vllm-spyre into pytorc…
bohnstingl af9d8f9
Integrated minor review findings
bohnstingl c6fd7f9
Merge branch 'main' of github.com:vllm-project/vllm-spyre into pytorc…
bohnstingl 7882018
Integrated reviewer comments and suggestions
bohnstingl fef4e7f
Fixing formatting errors
bohnstingl 3d3a169
Switched KV cache format to (num_blocks, 2, ...)
bohnstingl 1749821
Removed outdated max_num_seqs==1 restriction
bohnstingl a3eecc5
Removed enforce_eager argument
bohnstingl 3b02c2c
Merge branch 'main' into pytorch_native_attention
jvlunteren 32e0b63
replace 2D transposed attention kernel with batched 4D broadcast matmul
jvlunteren da41c0e
Merge branch 'main' into pytorch_native_attention_v2
jvlunteren 3b7ac99
clarify shape comments
jvlunteren File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -173,26 +173,22 @@ class SpyreAttentionImpl(AttentionImpl[SpyreAttentionMetadata]): | |
| QUERY_CHUNK_SIZE = 32 | ||
|
|
||
| @staticmethod | ||
| def _attn_transposed(qt, k, vt, sm_scale, mask_values): | ||
| """Transposed attention for Spyre: handles all heads at once. | ||
| def _attn_4d(q, k, v, scale, mask): | ||
| """4D broadcast attention for Spyre: handles batched GQA. | ||
|
|
||
| Args: | ||
| qt: Query transposed [head_size, num_heads * query_len_padded] | ||
| k: Key [num_heads * kv_len, head_size] | ||
| vt: Value transposed [head_size, num_heads * kv_len] | ||
| sm_scale: Scale factor (1D tensor) [num_heads * query_len_padded] | ||
| mask_values: Mask values tensor [num_heads * kv_len, num_heads * query_len_padded] | ||
| Pre-computed on CPU: 0.0 for valid, -65504.0 for masked/padded | ||
| q: Query [num_seqs*num_kv_heads, num_queries_per_kv, query_len, head_size] | ||
| k: Key [num_seqs*num_kv_heads, 1, kv_len, head_size] | ||
| v: Value [num_seqs*num_kv_heads, 1, kv_len, head_size] | ||
| scale: Scale factor (float) | ||
| mask: Additive mask [num_seqs*num_kv_heads, 1, query_len, kv_len] | ||
| Pre-computed on CPU: 0.0 for valid, -65504.0 for masked/padded | ||
| """ | ||
| kq = k @ qt # [num_heads * kv_len, num_heads * query_len_padded] | ||
| kq = kq * sm_scale | ||
|
|
||
| # Add pre-computed mask values | ||
| # Valid positions have 0.0, masked/padded positions have -65504.0 | ||
| kq = kq + mask_values | ||
|
|
||
| p = kq.softmax(dim=0) | ||
| return vt @ p # [head_size, num_heads * query_len_padded] | ||
| scores = q @ k.transpose(-2, -1) | ||
| scores = scores * scale | ||
| scores = scores + mask | ||
| p = scores.softmax(dim=-1) | ||
| return p @ v | ||
|
|
||
| def __init__( | ||
| self, | ||
|
|
@@ -221,13 +217,13 @@ def __init__( | |
| self._target_dtype = torch.float16 | ||
|
|
||
| # When True, use torch.nn.functional.scaled_dot_product_attention. | ||
| # Otherwise, use the transposed matmul kernel (_attn_transposed). | ||
| # Otherwise, use the 4D matmul kernel (_attn_4d). | ||
| self.use_sdpa = use_sdpa | ||
|
|
||
| if self.use_sdpa: | ||
| self.attn_op = torch.nn.functional.scaled_dot_product_attention | ||
| else: | ||
| self.attn_op = self._attn_transposed | ||
| self.attn_op = self._attn_4d | ||
|
|
||
| # Compile the attention function once for reuse. | ||
| # dynamic=False forces static shapes, required by the Spyre compiler. | ||
|
|
@@ -303,7 +299,7 @@ def forward( | |
| query.device, | ||
| ) | ||
|
|
||
| # Step 5: Compute batched per-sequence attention (CPU, Spyre) | ||
| # Step 5: Compute batched attention (CPU, Spyre) | ||
| # attn_output: [num_seqs, max_query_len, num_heads, head_size] | ||
| attn_output = self._compute_attention( | ||
| query_per_seq, compact_k, compact_v, mask, query.device, query.dtype | ||
|
|
@@ -509,39 +505,16 @@ def _compute_attention( | |
| device: torch.device, # device for intermediate allocations | ||
| dtype: torch.dtype, # dtype for intermediate allocations | ||
| ) -> torch.Tensor: | ||
| """Dispatch attention: SDPA path or per-sequence chunked Spyre path. | ||
| """Dispatch attention: SDPA path or batched Spyre path. | ||
|
|
||
| Returns: | ||
| [num_seqs, max_query_len, num_heads, head_size] | ||
| """ | ||
| num_seqs = query.shape[0] | ||
|
|
||
| # As fallback, use SDPA implementation | ||
| if self.use_sdpa: | ||
| return self._compute_attention_sdpa(query, key, value, mask) | ||
|
|
||
| # Allocate output tensor for all sequences | ||
| output_all_seqs = torch.zeros_like(query) | ||
|
|
||
| # Process each sequence separately | ||
| for seq_idx in range(num_seqs): | ||
| # Extract single sequence | ||
| query_seq = query[seq_idx : seq_idx + 1] # [1, max_query_len, num_heads, head_size] | ||
| key_seq = key[seq_idx : seq_idx + 1] # [1, max_seq_len, num_kv_heads, head_size] | ||
| value_seq = value[seq_idx : seq_idx + 1] # [1, max_seq_len, num_kv_heads, head_size] | ||
| mask_seq = ( | ||
| mask[seq_idx : seq_idx + 1] if mask is not None else None | ||
| ) # [1, 1, max_query_len, max_seq_len] | ||
|
|
||
| # Compute attention for this sequence | ||
| output_seq = self._compute_attention_single_seq( | ||
| query_seq, key_seq, value_seq, mask_seq, device, dtype | ||
| ) | ||
|
|
||
| # Store result | ||
| output_all_seqs[seq_idx] = output_seq.squeeze(0) | ||
|
|
||
| return output_all_seqs | ||
| return self._compute_attention_impl(query, key, value, mask, device, dtype) | ||
|
|
||
| def _compute_attention_sdpa( | ||
| self, | ||
|
|
@@ -567,190 +540,100 @@ def _compute_attention_sdpa( | |
| ) | ||
| return out.transpose(1, 2) | ||
|
|
||
| def _compute_attention_single_seq( | ||
| def _compute_attention_impl( | ||
| self, | ||
| query: torch.Tensor, # [1, max_query_len, num_heads, head_size] | ||
| key: torch.Tensor, # [1, max_seq_len, num_kv_heads, head_size] | ||
| value: torch.Tensor, # [1, max_seq_len, num_kv_heads, head_size] | ||
| mask: torch.Tensor | None, # [1, 1, max_query_len, max_seq_len] | ||
| query: torch.Tensor, # [num_seqs, max_query_len, num_heads, head_size] | ||
| key: torch.Tensor, # [num_seqs, aligned_max_seq_len, num_kv_heads, head_size] | ||
| value: torch.Tensor, # [num_seqs, aligned_max_seq_len, num_kv_heads, head_size] | ||
| mask: torch.Tensor | None, # [num_seqs, 1, max_query_len, aligned_max_seq_len] | ||
| device: torch.device, | ||
| dtype: torch.dtype, | ||
| ) -> torch.Tensor: | ||
| """Compute attention for a single sequence using Spyre. | ||
| """Compute batched 4D attention on Spyre. | ||
|
|
||
| Processes queries in fixed-size chunks of QUERY_CHUNK_SIZE tokens. | ||
| """ | ||
| Pads query to QUERY_CHUNK_SIZE-aligned length, merges batch into | ||
| kv_heads, issues one compiled 4D kernel call, and trims output. | ||
|
|
||
| _, _, num_heads, head_size = query.shape | ||
| Returns: | ||
| [num_seqs, max_query_len, num_heads, head_size] | ||
| """ | ||
| num_seqs, max_query_len, num_heads, head_size = query.shape | ||
| _, kv_len, num_kv_heads, _ = key.shape | ||
| num_queries_per_kv = self.num_queries_per_kv | ||
|
|
||
| # Handle grouped-query attention by repeating KV heads | ||
| if self.num_queries_per_kv > 1: | ||
| key = key.repeat_interleave(self.num_queries_per_kv, dim=2) | ||
| value = value.repeat_interleave(self.num_queries_per_kv, dim=2) | ||
|
|
||
| # Squeeze batch dimension | ||
| query_squeezed = query.squeeze(0) # [query_len, num_heads, head_size] | ||
| key_squeezed = key.squeeze(0) # [kv_len, num_heads, head_size] | ||
| value_squeezed = value.squeeze(0) # [kv_len, num_heads, head_size] | ||
|
|
||
| # Calculate number of chunks needed | ||
| actual_query_len = query_squeezed.shape[0] | ||
| num_chunks = (actual_query_len + self.QUERY_CHUNK_SIZE - 1) // self.QUERY_CHUNK_SIZE | ||
|
|
||
| output_full = torch.empty( | ||
| actual_query_len, | ||
| num_heads, | ||
| head_size, | ||
| dtype=dtype, | ||
| device=device, | ||
| # Pad query length to QUERY_CHUNK_SIZE alignment | ||
| padded_query_len = ( | ||
| (max_query_len + self.QUERY_CHUNK_SIZE - 1) | ||
| // self.QUERY_CHUNK_SIZE | ||
| * self.QUERY_CHUNK_SIZE | ||
| ) | ||
|
|
||
| # Process each chunk | ||
| for chunk_idx in range(num_chunks): | ||
| chunk_start = chunk_idx * self.QUERY_CHUNK_SIZE | ||
| chunk_end = min(chunk_start + self.QUERY_CHUNK_SIZE, actual_query_len) | ||
| chunk_len = chunk_end - chunk_start | ||
|
|
||
| # Extract query chunk | ||
| query_chunk = query_squeezed[chunk_start:chunk_end] | ||
|
|
||
| # Pad chunk if needed | ||
| if chunk_len < self.QUERY_CHUNK_SIZE: | ||
| padding_size = self.QUERY_CHUNK_SIZE - chunk_len | ||
| query_chunk_padded = torch.nn.functional.pad( | ||
| query_chunk, (0, 0, 0, 0, 0, padding_size), mode="constant", value=0.0 | ||
| ) | ||
| else: | ||
| query_chunk_padded = query_chunk | ||
|
|
||
| # Extract corresponding mask for this chunk | ||
| if mask is not None: | ||
| mask_chunk = mask[:, :, chunk_start:chunk_end, :] # [1, 1, chunk_len, kv_len] | ||
| else: | ||
| mask_chunk = None | ||
|
|
||
| # Compute attention for this chunk | ||
| chunk_output = self._compute_attention_chunk( | ||
| query_chunk_padded, | ||
| key_squeezed, | ||
| value_squeezed, | ||
| mask_chunk, | ||
| chunk_len, | ||
| num_heads, | ||
| head_size, | ||
| kv_len, | ||
| device, | ||
| dtype, | ||
| if padded_query_len > max_query_len: | ||
| padding_size = padded_query_len - max_query_len | ||
| query = torch.nn.functional.pad( | ||
| query, (0, 0, 0, 0, 0, padding_size), mode="constant", value=0.0 | ||
| ) | ||
|
|
||
| # Store chunk output (only valid positions) | ||
| output_full[chunk_start:chunk_end] = chunk_output[:chunk_len] | ||
|
|
||
| return output_full.unsqueeze(0) # [1, query_len, num_heads, head_size] | ||
|
|
||
| def _compute_attention_chunk( | ||
| self, | ||
| query_chunk_padded: torch.Tensor, # [QUERY_CHUNK_SIZE, num_heads, head_size] | ||
| key_squeezed: torch.Tensor, # [kv_len, num_heads, head_size] | ||
| value_squeezed: torch.Tensor, # [kv_len, num_heads, head_size] | ||
| mask_chunk: torch.Tensor | None, # [1, 1, chunk_len, kv_len] | ||
| chunk_len: int, | ||
| num_heads: int, | ||
| head_size: int, | ||
| kv_len: int, | ||
| device: torch.device, | ||
| dtype: torch.dtype, | ||
| ) -> torch.Tensor: | ||
| """Compute attention for a single query chunk on Spyre. | ||
|
|
||
| Prepares tensors on CPU (reshape, stickify, build mask), transfers to | ||
| Spyre for the compiled matmul kernel, then transfers the result back. | ||
| # Q: [num_seqs, query_len_padded, num_heads, head_size] | ||
| # -> [num_seqs, num_heads, query_len_padded, head_size] | ||
| # -> [num_seqs*num_kv_heads, num_queries_per_kv, query_len_padded, head_size] | ||
| q = query.transpose(1, 2).contiguous() | ||
| q = q.reshape(num_seqs * num_kv_heads, num_queries_per_kv, padded_query_len, head_size) | ||
|
|
||
| # K/V: [num_seqs, kv_len, num_kv_heads, head_size] | ||
| # -> [num_seqs*num_kv_heads, 1, kv_len, head_size] | ||
| k = key.transpose(1, 2).contiguous() | ||
| k = k.reshape(num_seqs * num_kv_heads, 1, kv_len, head_size) | ||
| v = value.transpose(1, 2).contiguous() | ||
| v = v.reshape(num_seqs * num_kv_heads, 1, kv_len, head_size) | ||
|
|
||
| # --- Build additive mask [num_seqs*num_kv_heads, 1, query_len_padded, kv_len] --- | ||
| if mask is not None: | ||
| # mask: [num_seqs, 1, max_query_len, kv_len] (bool: True = masked) | ||
| mask_3d = mask[:, 0, :, :] # [num_seqs, max_query_len, kv_len] | ||
| if padded_query_len > max_query_len: | ||
| padding_size = padded_query_len - max_query_len | ||
| mask_padding = torch.ones( | ||
| (num_seqs, padding_size, kv_len), dtype=torch.bool, device=device | ||
| ) | ||
| mask_3d = torch.cat([mask_3d, mask_padding], dim=1) | ||
|
|
||
| Returns: | ||
| [QUERY_CHUNK_SIZE, num_heads, head_size] — attention output (padded) | ||
| """ | ||
| padded_query_len = self.QUERY_CHUNK_SIZE | ||
|
|
||
| # Reshape query to flatten heads into query dimension | ||
| query_reordered = query_chunk_padded.transpose( | ||
| 0, 1 | ||
| ).contiguous() # [num_heads, QUERY_CHUNK_SIZE, head_size] | ||
| query_flat = query_reordered.reshape(num_heads * padded_query_len, head_size) | ||
|
|
||
| # Key and value: also flatten across heads | ||
| key_reordered = key_squeezed.transpose(0, 1).contiguous() # [num_heads, kv_len, head_size] | ||
| value_reordered = value_squeezed.transpose( | ||
| 0, 1 | ||
| ).contiguous() # [num_heads, kv_len, head_size] | ||
|
|
||
| key_flat = key_reordered.reshape(num_heads * kv_len, head_size) | ||
| value_flat = value_reordered.reshape(num_heads * kv_len, head_size) | ||
|
|
||
| # Transpose for attention computation | ||
| qt = query_flat.T.contiguous() # [head_size, num_heads * QUERY_CHUNK_SIZE] | ||
| vt = value_flat.T.contiguous() # [head_size, num_heads * kv_len] | ||
| k = key_flat # [num_heads * kv_len, head_size] | ||
|
|
||
| # Stickification: force Spyre-friendly memory layout. | ||
| # Transposed tensors need double transpose-contiguous; standard tensors just contiguous. | ||
| qt_stickified = qt.transpose(0, 1).contiguous().transpose(0, 1).contiguous() | ||
| vt_stickified = vt.transpose(0, 1).contiguous().transpose(0, 1).contiguous() | ||
| k_stickified = k.contiguous() | ||
|
|
||
| # Scale factor: 1D tensor replicated per head × query position | ||
| sm_scale_1d = torch.tensor(self.scale, dtype=dtype, device=device).repeat( | ||
| num_heads * padded_query_len | ||
| ) # [num_heads * QUERY_CHUNK_SIZE] | ||
|
|
||
| # --- Build block-diagonal additive mask --- | ||
| # The transposed kernel flattens all heads into one matmul, so the mask | ||
| # must be block-diagonal: each head's causal/padding mask sits on the | ||
| # diagonal, off-diagonal blocks are masked (-65504). | ||
| if mask_chunk is not None: | ||
| mask_all_heads = mask_chunk[0, 0] # [chunk_len, kv_len] | ||
|
|
||
| # Pad query dimension to QUERY_CHUNK_SIZE if this is the last chunk | ||
| if chunk_len < self.QUERY_CHUNK_SIZE: | ||
| padding_size = self.QUERY_CHUNK_SIZE - chunk_len | ||
| mask_padding = torch.ones((padding_size, kv_len), dtype=torch.bool, device=device) | ||
| mask_all_heads = torch.cat([mask_all_heads, mask_padding], dim=0) | ||
|
|
||
| head_mask_t = mask_all_heads.T # [kv_len, QUERY_CHUNK_SIZE], True = masked | ||
| mask_bool = ~torch.block_diag(*([~head_mask_t] * num_heads)) | ||
| # Convert boolean mask to additive: True -> -65504.0, False -> 0.0 | ||
| mask_additive = torch.where( | ||
| mask_3d, | ||
| torch.tensor(-65504.0, dtype=dtype, device=device), | ||
| torch.tensor(0.0, dtype=dtype, device=device), | ||
| ) | ||
| # [num_seqs, query_len_padded, kv_len] | ||
| # -> expand [num_seqs, num_kv_heads, query_len_padded, kv_len] | ||
| # -> [num_seqs*num_kv_heads, 1, query_len_padded, kv_len] | ||
| mask_4d = ( | ||
| mask_additive.unsqueeze(1) | ||
| .expand(-1, num_kv_heads, -1, -1) | ||
| .reshape(num_seqs * num_kv_heads, 1, padded_query_len, kv_len) | ||
| .contiguous() | ||
| ) | ||
| else: | ||
| # No causal/padding mask: only cross-head positions are masked. | ||
| ones_block = torch.ones(kv_len, padded_query_len, dtype=torch.bool, device=device) | ||
| mask_bool = ~torch.block_diag(*([ones_block] * num_heads)) | ||
|
|
||
| # Convert boolean mask to additive: True → -65504.0, False → 0.0 | ||
| mask_values = torch.where( | ||
| mask_bool, | ||
| torch.tensor(-65504.0, dtype=dtype, device=device), | ||
| torch.tensor(0.0, dtype=dtype, device=device), | ||
| ).contiguous() | ||
| mask_4d = torch.zeros(1, 1, padded_query_len, kv_len, dtype=dtype, device=device) | ||
|
|
||
| # --- Transfer to Spyre, compute, transfer back --- | ||
| qt_spyre = convert(qt_stickified, self._target_device, self._target_dtype) | ||
| k_spyre = convert(k_stickified, self._target_device, self._target_dtype) | ||
| vt_spyre = convert(vt_stickified, self._target_device, self._target_dtype) | ||
| sm_scale_spyre = convert(sm_scale_1d, self._target_device, self._target_dtype) | ||
| mask_spyre = convert(mask_values, self._target_device, self._target_dtype) | ||
| q_spyre = convert(q, self._target_device, self._target_dtype) | ||
| k_spyre = convert(k, self._target_device, self._target_dtype) | ||
| v_spyre = convert(v, self._target_device, self._target_dtype) | ||
| mask_spyre = convert(mask_4d, self._target_device, self._target_dtype) | ||
|
|
||
| # Compiled attention on Spyre | ||
| output_spyre_t = self.attn_op(qt_spyre, k_spyre, vt_spyre, sm_scale_spyre, mask_spyre) | ||
| output_spyre = self.attn_op(q_spyre, k_spyre, v_spyre, self.scale, mask_spyre) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we actually start profiling the performance of the different versions?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
|
|
||
| # Transfer back to CPU | ||
| output_flat = convert( | ||
| output_spyre_t, device, dtype | ||
| ).contiguous() # [head_size, num_heads * QUERY_CHUNK_SIZE] | ||
|
|
||
| # Reshape: [head_size, N*Q] → [N, Q, head_size] → [Q, N, head_size] | ||
| output_transposed = output_flat.T # [num_heads * QUERY_CHUNK_SIZE, head_size] | ||
| output_reshaped = output_transposed.reshape(num_heads, padded_query_len, head_size) | ||
|
|
||
| # [QUERY_CHUNK_SIZE, num_heads, head_size] | ||
| return output_reshaped.transpose(0, 1).contiguous() | ||
| # [num_seqs*num_kv_heads, num_queries_per_kv, query_len_padded, head_size] | ||
| # -> [num_seqs, num_heads, query_len_padded, head_size] | ||
| # -> [num_seqs, query_len_padded, num_heads, head_size] | ||
| # -> trim to [num_seqs, max_query_len, num_heads, head_size] | ||
| output_4d = convert(output_spyre, device, dtype) | ||
| output_reshaped = output_4d.reshape(num_seqs, num_heads, padded_query_len, head_size) | ||
| output = output_reshaped.transpose(1, 2).contiguous() | ||
| return output[:, :max_query_len, :, :] | ||
|
|
||
| def _extract_relevant_output( | ||
| self, | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so we expect key and value to be padded, but not the query? What is the rational behind this interface? (if there is one, I'm fully aware this could also just be temporary)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and as @tdoublep pointed out, is there a way to support the flattened varlen format?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The query at the input is "flat"
[num_tokens, num_heads, head_size]. The query gets padded inside the code ( lines 573-577).