Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/sglang/srt/layers/attention/base_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,7 @@ def forward_extend(
):
"""Run a forward for extend."""
raise NotImplementedError()

def support_triton(self):
"""Check if the current backend supports triton."""
return True
128 changes: 128 additions & 0 deletions python/sglang/srt/layers/attention/intel_amx_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch

if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner


class IntelAMXAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
import sgl_kernel

super().__init__()
self.forward_metadata = None
self.device = model_runner.device

self.num_head = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
)

self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]

self.decode_attention_fwd = torch.ops.sgl_kernel.decode_attention_cpu
self.extend_attention_fwd = torch.ops.sgl_kernel.extend_attention_cpu

def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""

bs = forward_batch.batch_size
attn_logits = torch.zeros(
(
bs,
self.num_head,
8, # self.num_kv_splits,
self.v_head_dim + 1,
),
dtype=torch.float32,
device=self.device,
)
if forward_batch.forward_mode.is_decode_or_idle():
max_extend_len = None
else:
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
self.forward_metadata = (attn_logits, max_extend_len)

def forward_extend(
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)

if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

_, max_extend_len = self.forward_metadata

self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k,
v,
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_seq_lens,
forward_batch.extend_start_loc,
max_extend_len,
layer.scaling,
layer.logit_cap,
)
return o

def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
attn_logits, _ = self.forward_metadata

q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)

if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)

self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
k,
v,
forward_batch.out_cache_loc,
attn_logits,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
layer.scaling,
layer.logit_cap,
)

return o

def support_triton(self):
return False
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/attention/torch_native_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,6 @@ def forward_decode(
)

return o

def support_triton(self):
return False
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton

if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
Expand Down Expand Up @@ -1257,7 +1257,7 @@ def prepare_for_extend(self):
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids

# Write to req_to_token_pool
if global_server_args_dict["attention_backend"] != "torch_native":
if support_triton(global_server_args_dict.get("attention_backend")):
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

write_req_to_token_pool_triton[(bs,)](
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import triton.language as tl

from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton

if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
Expand Down Expand Up @@ -351,7 +351,7 @@ def init_new(
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
if model_runner.server_args.attention_backend != "torch_native":
if support_triton(model_runner.server_args.attention_backend):
ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens,
Expand Down
23 changes: 22 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
MultiprocessingSerializer,
cpu_has_amx_support,
enable_show_time_cost,
get_available_gpu_memory,
get_bool_env_var,
Expand Down Expand Up @@ -317,6 +318,16 @@ def initialize(self, min_per_gpu_memory: float):
def model_specific_adjustment(self):
server_args = self.server_args

if (
server_args.attention_backend == "intel_amx"
and server_args.device == "cpu"
and not cpu_has_amx_support()
):
logger.info(
"The current platform does not support Intel AMX, will fallback to torch_native backend."
)
server_args.attention_backend = "torch_native"

if server_args.attention_backend is None:
"""
Auto select the fastest attention backend.
Expand Down Expand Up @@ -369,7 +380,10 @@ def model_specific_adjustment(self):
f"Invalid attention backend for MLA: {server_args.attention_backend}"
)
else:
raise ValueError("MLA optimization not supported on CPU.")
if server_args.attention_backend != "intel_amx":
raise ValueError(
"MLA optimization not supported on CPU except for intel_amx backend."
)

if (
server_args.attention_backend == "fa3"
Expand Down Expand Up @@ -1067,6 +1081,13 @@ def _get_attention_backend(self):
)

return CutlassMLABackend(self)
elif self.server_args.attention_backend == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend,
)

logger.info(f"Intel AMX attention backend is enabled.")
return IntelAMXAttnBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ def __post_init__(self):
self.sampling_backend = "pytorch"

# Set kernel backends
if self.device == "cpu":
if self.attention_backend is None:
self.attention_backend = "intel_amx"
self.sampling_backend = "pytorch"

if self.sampling_backend is None:
self.sampling_backend = (
"flashinfer" if is_flashinfer_available() else "pytorch"
Expand Down Expand Up @@ -993,6 +998,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"fa3",
"flashmla",
"cutlass_mla",
"intel_amx",
],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2225,3 +2225,17 @@ def bind_or_assign(target, source):
return target
else:
return source


def support_triton(backend: str) -> bool:
return backend not in ["torch_native", "intel_amx"]

try:
import sgl_kernel
is_intel_amx_backend_available = hasattr(torch.ops.sgl_kernel, "convert_weight_packed")
except:
is_intel_amx_backend_available = False


def cpu_has_amx_support():
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available