-
Notifications
You must be signed in to change notification settings - Fork 5.2k
Add intel_amx backend for Radix Attention for CPU #6408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 13 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
5fc516b
add a sgl_kernel.cpu wrapper for CPU OPs in sgl-kernel (#10)
chunyuan-w d787b52
Add intel_amx backend for Radix Attention, including extend attention…
yanbing-j e30f852
Update intel_amx_backend.py and cpu.py with fuse decode attention wit…
yanbing-j 3fb86ba
bug fixes
gau-nernst db7357f
k/v supports non-contiguous tensors, update extend_attention calling
yanbing-j b71983e
add check on whether AMX is supported
yanbing-j 6e61935
add support_triton
yanbing-j 474f05e
Add log and update comments
yanbing-j 5f6ada2
Remove cpu.py
yanbing-j aba8d0a
Merge branch 'main' into yanbing/split_intel_amx
zhyncs 7fcf8e5
Merge branch 'main' into yanbing/split_intel_amx
zhyncs 715d15f
Merge branch 'main' into yanbing/split_intel_amx
zhyncs 2ff26fc
Merge branch 'main' into yanbing/split_intel_amx
zhyncs 4eeeb7a
upd
zhyncs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
128 changes: 128 additions & 0 deletions
128
python/sglang/srt/layers/attention/intel_amx_backend.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.layers.attention.base_attn_backend import AttentionBackend | ||
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch | ||
|
|
||
| if TYPE_CHECKING: | ||
| from sglang.srt.layers.radix_attention import RadixAttention | ||
| from sglang.srt.model_executor.model_runner import ModelRunner | ||
|
|
||
|
|
||
| class IntelAMXAttnBackend(AttentionBackend): | ||
| def __init__(self, model_runner: ModelRunner): | ||
| import sgl_kernel | ||
|
|
||
| super().__init__() | ||
| self.forward_metadata = None | ||
| self.device = model_runner.device | ||
|
|
||
| self.num_head = ( | ||
| model_runner.model_config.num_attention_heads // model_runner.tp_size | ||
| ) | ||
|
|
||
| self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] | ||
|
|
||
| self.decode_attention_fwd = torch.ops.sgl_kernel.decode_attention_cpu | ||
| self.extend_attention_fwd = torch.ops.sgl_kernel.extend_attention_cpu | ||
|
|
||
| def init_forward_metadata(self, forward_batch: ForwardBatch): | ||
| """Init the metadata for a forward pass.""" | ||
|
|
||
| bs = forward_batch.batch_size | ||
| attn_logits = torch.zeros( | ||
| ( | ||
| bs, | ||
| self.num_head, | ||
| 8, # self.num_kv_splits, | ||
| self.v_head_dim + 1, | ||
| ), | ||
| dtype=torch.float32, | ||
| device=self.device, | ||
| ) | ||
| if forward_batch.forward_mode.is_decode_or_idle(): | ||
| max_extend_len = None | ||
| else: | ||
| max_extend_len = torch.max(forward_batch.extend_seq_lens).item() | ||
| self.forward_metadata = (attn_logits, max_extend_len) | ||
|
|
||
| def forward_extend( | ||
| self, | ||
| q, | ||
| k, | ||
| v, | ||
| layer: RadixAttention, | ||
| forward_batch: ForwardBatch, | ||
| save_kv_cache=True, | ||
| ): | ||
| if layer.qk_head_dim != layer.v_head_dim: | ||
| o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) | ||
| else: | ||
| o = torch.empty_like(q) | ||
|
|
||
| if save_kv_cache: | ||
| forward_batch.token_to_kv_pool.set_kv_buffer( | ||
| layer, forward_batch.out_cache_loc, k, v | ||
| ) | ||
|
|
||
| _, max_extend_len = self.forward_metadata | ||
|
|
||
| self.extend_attention_fwd( | ||
| q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), | ||
| k, | ||
| v, | ||
| o.view(-1, layer.tp_q_head_num, layer.v_head_dim), | ||
| forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), | ||
| forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), | ||
| forward_batch.req_to_token_pool.req_to_token, | ||
| forward_batch.req_pool_indices, | ||
| forward_batch.seq_lens, | ||
| forward_batch.extend_seq_lens, | ||
| forward_batch.extend_start_loc, | ||
| max_extend_len, | ||
| layer.scaling, | ||
| layer.logit_cap, | ||
| ) | ||
| return o | ||
|
|
||
| def forward_decode( | ||
| self, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| layer: RadixAttention, | ||
| forward_batch: ForwardBatch, | ||
| save_kv_cache=True, | ||
| ): | ||
| attn_logits, _ = self.forward_metadata | ||
|
|
||
| q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) | ||
|
|
||
| if layer.qk_head_dim != layer.v_head_dim: | ||
| o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) | ||
| else: | ||
| o = torch.empty_like(q) | ||
|
|
||
| self.decode_attention_fwd( | ||
| q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), | ||
| forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), | ||
| forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), | ||
| o.view(-1, layer.tp_q_head_num, layer.v_head_dim), | ||
| k, | ||
| v, | ||
| forward_batch.out_cache_loc, | ||
| attn_logits, | ||
| forward_batch.req_to_token_pool.req_to_token, | ||
| forward_batch.req_pool_indices, | ||
| forward_batch.seq_lens, | ||
| layer.scaling, | ||
| layer.logit_cap, | ||
| ) | ||
|
|
||
| return o | ||
|
|
||
| def support_triton(self): | ||
| return False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -265,3 +265,6 @@ def forward_decode( | |
| ) | ||
|
|
||
| return o | ||
|
|
||
| def support_triton(self): | ||
| return False | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.