From f32d15795260522313126582e7d8046d2a562188 Mon Sep 17 00:00:00 2001
From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Date: Thu, 10 Jul 2025 20:42:50 +0000
Subject: [PATCH 01/27] trtllm gen mla initial commit
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
---
python/pyproject.toml | 1 +
.../layers/attention/trtllm_mla_backend.py | 200 ++++++++++++++++
.../sglang/srt/model_executor/model_runner.py | 20 +-
.../test/attention/test_trtllm_mla_backend.py | 224 ++++++++++++++++++
4 files changed, 444 insertions(+), 1 deletion(-)
create mode 100644 python/sglang/srt/layers/attention/trtllm_mla_backend.py
create mode 100644 python/sglang/test/attention/test_trtllm_mla_backend.py
diff --git a/python/pyproject.toml b/python/pyproject.toml
index d916fcb57e6c..c79681ae1c3d 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -75,6 +75,7 @@ blackwell = [
"tiktoken",
]
+
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
srt_hip = [
diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
new file mode 100644
index 000000000000..4534249d083e
--- /dev/null
+++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
@@ -0,0 +1,200 @@
+from __future__ import annotations
+
+"""
+Support attention backend for TRTLLM MLA kernels from flashinfer.
+"""
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Optional, Union
+
+import torch
+import triton
+
+from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
+from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
+from sglang.srt.layers.dp_attention import get_attention_tp_size
+from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
+from sglang.srt.utils import is_flashinfer_available
+
+if TYPE_CHECKING:
+ from sglang.srt.layers.radix_attention import RadixAttention
+ from sglang.srt.model_executor.model_runner import ModelRunner
+ from sglang.srt.speculative.spec_info import SpecInfo
+
+if is_flashinfer_available():
+ import flashinfer
+
+
+# TRTLLM MLA supports variable page sizes
+
+
+@dataclass
+class TRTLLMMLADecodeMetadata:
+ workspace: Optional[torch.Tensor] = None
+ block_kv_indices: Optional[torch.Tensor] = None
+
+
+class TRTLLMMLABackend(FlashInferMLAAttnBackend):
+ """TRTLLM MLA attention kernels from flashinfer."""
+
+ def __init__(
+ self,
+ model_runner: ModelRunner,
+ skip_prefill: bool = False,
+ kv_indptr_buf: Optional[torch.Tensor] = None,
+ kv_last_page_len_buf: Optional[torch.Tensor] = None,
+ ):
+ super().__init__(
+ model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
+ )
+
+ # Model parameters
+ self.num_q_heads = (
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
+ )
+ self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
+ get_attention_tp_size()
+ )
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
+ self.num_local_heads = (
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
+ )
+ self.forward_metadata: Union[TRTLLMMLADecodeMetadata] = None
+ self.kv_lora_rank = model_runner.model_config.kv_lora_rank
+ self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
+ self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
+ self.v_head_dim = model_runner.model_config.v_head_dim
+ self.scaling = model_runner.model_config.scaling
+ self.data_type = model_runner.kv_cache_dtype
+ self.q_data_type = model_runner.dtype
+ self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
+ self.page_size = model_runner.page_size # Use page size from model runner
+
+ # Validate dimensions for TRTLLM MLA (based on test requirements)
+ if self.qk_nope_head_dim != 128:
+ raise ValueError(f"TRTLLM MLA requires qk_nope_head_dim=128, got {self.qk_nope_head_dim}")
+ if self.kv_lora_rank != 512:
+ raise ValueError(f"TRTLLM MLA requires kv_lora_rank=512, got {self.kv_lora_rank}")
+ if self.qk_rope_head_dim != 64:
+ raise ValueError(f"TRTLLM MLA requires qk_rope_head_dim=64, got {self.qk_rope_head_dim}")
+
+ # Allocate larger workspace for TRTLLM (128MB as in the test)
+ self.workspace_size = 128 * 1024 * 1024
+ self.workspace_buffer = torch.empty(
+ self.workspace_size, dtype=torch.int8, device=self.device
+ )
+
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
+ """Init the metadata for a forward pass."""
+ bs = forward_batch.batch_size
+ spec_info = forward_batch.spec_info
+
+ if forward_batch.forward_mode.is_decode_or_idle():
+ if spec_info is None:
+ # Calculate max sequence length padded to page boundary
+ max_seqlen_pad = triton.cdiv(
+ forward_batch.seq_lens_cpu.max().item(), self.page_size
+ )
+
+ # Create block indices
+ block_kv_indices = torch.full(
+ (bs, max_seqlen_pad),
+ -1,
+ dtype=torch.int32,
+ device=forward_batch.seq_lens.device,
+ )
+
+ # Fill block indices using the existing triton kernel
+ create_flashmla_kv_indices_triton[(bs,)](
+ self.req_to_token,
+ forward_batch.req_pool_indices,
+ forward_batch.seq_lens,
+ None,
+ block_kv_indices,
+ self.req_to_token.stride(0),
+ max_seqlen_pad,
+ self.page_size,
+ )
+
+ forward_batch.decode_trtllm_mla_metadata = TRTLLMMLADecodeMetadata(
+ self.workspace_buffer,
+ block_kv_indices,
+ )
+ self.forward_metadata = forward_batch.decode_trtllm_mla_metadata
+ else:
+ # Speculative decoding: use parent class implementation
+ super().init_forward_metadata(forward_batch)
+ else:
+ # Prefill: use parent class implementation
+ super().init_forward_metadata(forward_batch)
+
+ def forward_decode(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ layer: RadixAttention,
+ forward_batch: ForwardBatch,
+ save_kv_cache: bool = True,
+ # For multi-head latent attention
+ q_rope: Optional[torch.Tensor] = None,
+ k_rope: Optional[torch.Tensor] = None,
+ ):
+ """Run forward for decode using TRTLLM kernel."""
+ cache_loc = forward_batch.out_cache_loc
+
+ if k is not None:
+ assert v is not None
+ if save_kv_cache:
+ if k_rope is not None:
+ # MLA style KV cache storage
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
+ layer,
+ cache_loc,
+ k,
+ k_rope,
+ )
+ else:
+ # Standard KV cache storage
+ forward_batch.token_to_kv_pool.set_kv_buffer(
+ layer,
+ cache_loc,
+ k,
+ v,
+ )
+
+ # Prepare query tensor - concatenate q_nope and q_rope
+ if q_rope is not None:
+ # q and q_rope are separate
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
+ q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.qk_rope_head_dim)
+ query = torch.cat([q_nope, q_rope], dim=-1)
+ else:
+ # q already contains both nope and rope parts
+ query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
+
+ # Get KV cache
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
+
+ # Reshape KV cache for TRTLLM format
+ kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim)
+
+ # Call TRTLLM MLA decode kernel
+ output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
+ query=query,
+ kv_cache=kv_cache,
+ workspace_buffer=forward_batch.decode_trtllm_mla_metadata.workspace,
+ qk_nope_head_dim=self.qk_nope_head_dim,
+ kv_lora_rank=self.kv_lora_rank,
+ qk_rope_head_dim=self.qk_rope_head_dim,
+ block_tables=forward_batch.decode_trtllm_mla_metadata.block_kv_indices,
+ seq_lens=forward_batch.seq_lens.to(torch.int32),
+ block_size=self.page_size,
+ max_seq_len=forward_batch.seq_lens.max().item(),
+ scale=layer.scaling,
+ out=None,
+ bmm1_scale=1.0, # Only needed for FP8
+ bmm2_scale=1.0, # Only needed for FP8
+ )
+
+ return output.view(-1, layer.tp_q_head_num * layer.v_head_dim)
\ No newline at end of file
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 13555adeb186..0f53e52285e5 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -396,7 +396,18 @@ def model_specific_adjustment(self):
):
server_args.attention_backend = "fa3"
elif is_sm100_supported():
- server_args.attention_backend = "flashinfer"
+ # On Blackwell, prefer TRTLLM MLA if available, otherwise flashinfer
+ if is_flashinfer_available():
+ try:
+ import flashinfer
+ if hasattr(flashinfer.decode, 'trtllm_batch_decode_with_kv_cache_mla'):
+ server_args.attention_backend = "trtllm_mla"
+ else:
+ server_args.attention_backend = "flashinfer"
+ except:
+ server_args.attention_backend = "flashinfer"
+ else:
+ server_args.attention_backend = "flashinfer"
elif _is_hip:
head_num = self.model_config.get_num_kv_heads(self.tp_size)
# TODO current aiter only support head number 16 or 128 head number
@@ -422,6 +433,7 @@ def model_specific_adjustment(self):
"triton",
"flashmla",
"cutlass_mla",
+ "trtllm_mla",
"ascend",
]:
logger.info(
@@ -1430,6 +1442,12 @@ def _get_attention_backend_from_str(self, backend_str: str):
)
return CutlassMLABackend(self)
+ elif self.server_args.attention_backend == "trtllm_mla":
+ from sglang.srt.layers.attention.trtllm_mla_backend import (
+ TRTLLMMLABackend,
+ )
+
+ return TRTLLMMLABackend(self)
elif self.server_args.attention_backend == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend,
diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py
new file mode 100644
index 000000000000..b2e05874d663
--- /dev/null
+++ b/python/sglang/test/attention/test_trtllm_mla_backend.py
@@ -0,0 +1,224 @@
+import unittest
+
+import pytest
+import torch
+
+from sglang.srt.configs.model_config import AttentionArch
+from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
+from sglang.srt.layers.radix_attention import RadixAttention
+from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
+from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
+from sglang.srt.utils import is_flashinfer_available
+from sglang.test.test_utils import CustomTestCase
+
+
+class MockModelRunner:
+ def __init__(
+ self,
+ kv_lora_rank,
+ qk_rope_head_dim,
+ page_size=16,
+ ):
+ attention_arch = AttentionArch.MLA
+ self.device = "cuda"
+ self.dtype = torch.bfloat16
+ self.kv_cache_dtype = torch.bfloat16
+ context_len = 2048
+ self.model_config = type(
+ "ModelConfig",
+ (),
+ {
+ "context_len": context_len,
+ "attention_arch": attention_arch,
+ "num_attention_heads": 128,
+ "kv_lora_rank": kv_lora_rank,
+ "qk_nope_head_dim": 128,
+ "qk_rope_head_dim": qk_rope_head_dim,
+ "v_head_dim": kv_lora_rank,
+ "scaling": 1.0 / ((128 + 64) ** 0.5),
+ },
+ )
+ self.sliding_window_size = None
+ self.page_size = page_size
+
+ batch_size = 256
+ # Create a proper req_to_token_pool with the req_to_token attribute
+ self.req_to_token_pool = type(
+ "TokenPool",
+ (),
+ {
+ "size": batch_size,
+ "req_to_token": torch.zeros(
+ batch_size, context_len, dtype=torch.int32, device=self.device
+ ),
+ },
+ )
+
+ max_total_num_tokens = batch_size * context_len
+ self.token_to_kv_pool = MLATokenToKVPool(
+ size=max_total_num_tokens,
+ page_size=self.page_size,
+ dtype=self.dtype,
+ kv_lora_rank=kv_lora_rank,
+ qk_rope_head_dim=qk_rope_head_dim,
+ layer_num=1, # only consider layer=1 for unit test
+ device=self.device,
+ enable_memory_saver=False,
+ )
+
+ def get_num_kv_heads(self, tp_size):
+ """MLA uses single KV head."""
+ return 1
+
+
+@pytest.mark.skipif(
+ not torch.cuda.is_available() or not is_flashinfer_available(),
+ reason="Test requires CUDA and flashinfer"
+)
+@pytest.mark.parametrize("batch_size", [16, 32, 64])
+@pytest.mark.parametrize("page_size", [16, 32, 64])
+@pytest.mark.parametrize("seq_len", [256, 512, 1024])
+class TestTRTLLMMLABackend(CustomTestCase):
+ def test_trtllm_decode_mla(self, batch_size, page_size, seq_len):
+ """Test TRTLLM MLA decode operation with various configurations."""
+ # Check if PyTorch supports current GPU
+ if torch.cuda.is_available():
+ capability = torch.cuda.get_device_capability()
+ archs = torch.cuda.get_arch_list()
+ current_arch = f"sm_{capability[0]}{capability[1]}"
+ supported = any(current_arch in arch for arch in archs)
+ if not supported:
+ pytest.skip(f"PyTorch doesn't support {current_arch} - need nightly build")
+
+ # DeepSeek MLA configuration
+ num_q_heads = 128
+ kv_lora_rank = 512
+ qk_nope_head_dim = 128
+ qk_rope_head_dim = 64
+ device = "cuda"
+ dtype = torch.bfloat16
+
+ # Initialize model runner and backend
+ model_runner = MockModelRunner(
+ kv_lora_rank=kv_lora_rank,
+ qk_rope_head_dim=qk_rope_head_dim,
+ page_size=page_size,
+ )
+
+ # Check if flashinfer has TRTLLM MLA support
+ try:
+ import flashinfer
+ if not hasattr(flashinfer.decode, 'trtllm_batch_decode_with_kv_cache_mla'):
+ pytest.skip("flashinfer version does not have TRTLLM MLA support")
+ except ImportError:
+ pytest.skip("flashinfer not available")
+
+ backend = TRTLLMMLABackend(model_runner)
+
+ # Create attention layer
+ layer = RadixAttention(
+ num_heads=num_q_heads,
+ head_dim=kv_lora_rank + qk_rope_head_dim,
+ scaling=model_runner.model_config.scaling,
+ num_kv_heads=1,
+ layer_id=0,
+ v_head_dim=kv_lora_rank,
+ prefix="attn_mqa",
+ )
+
+ # Generate sequence lengths
+ seq_lens = torch.randint(1, seq_len, (batch_size,), device=device)
+ seq_lens[-1] = seq_len # Ensure at least one max length
+ max_seq_len = seq_lens.max().item()
+
+ # Calculate blocks needed
+ blocks_per_seq = (seq_lens + page_size - 1) // page_size
+ total_blocks_needed = blocks_per_seq.sum().item()
+
+ # Create req_to_token mapping
+ req_to_token = torch.zeros(batch_size, seq_len, dtype=torch.int32, device=device)
+ token_offset = 0
+ for i in range(batch_size):
+ seq_len_i = seq_lens[i].item()
+ req_to_token[i, :seq_len_i] = torch.arange(
+ token_offset, token_offset + seq_len_i, device=device
+ )
+ token_offset += seq_len_i
+
+ model_runner.req_to_token_pool.req_to_token = req_to_token
+
+ # Create forward batch for decode
+ forward_batch = ForwardBatch(
+ batch_size=batch_size,
+ input_ids=torch.randint(0, 100, (batch_size, 1), device=device),
+ out_cache_loc=torch.arange(batch_size, device=device),
+ seq_lens_sum=seq_lens.sum().item(),
+ forward_mode=ForwardMode.DECODE,
+ req_pool_indices=torch.arange(batch_size, device=device),
+ seq_lens=seq_lens,
+ seq_lens_cpu=seq_lens.cpu(),
+ attn_backend=backend,
+ )
+
+ # Add pools to forward batch
+ forward_batch.req_to_token_pool = model_runner.req_to_token_pool
+ forward_batch.token_to_kv_pool = model_runner.token_to_kv_pool
+
+ # Fill KV cache with some data
+ cache_data = torch.randn(
+ seq_lens.sum().item(),
+ 1, # num_kv_heads
+ kv_lora_rank + qk_rope_head_dim,
+ dtype=dtype,
+ device=device,
+ )
+ cache_indices = torch.arange(seq_lens.sum().item(), device=device)
+ forward_batch.token_to_kv_pool.set_kv_buffer(
+ layer, cache_indices, cache_data, None
+ )
+
+ # Initialize metadata
+ backend.init_forward_metadata(forward_batch)
+
+ # Create input tensors
+ q_shape = (batch_size, num_q_heads, kv_lora_rank + qk_rope_head_dim)
+ q = torch.randn(q_shape, dtype=dtype, device=device)
+
+ # For MLA, k contains compressed KV, v is not used
+ k = torch.randn(batch_size, 1, kv_lora_rank + qk_rope_head_dim, dtype=dtype, device=device)
+ v = None
+
+ # Run forward decode
+ output = backend.forward_decode(q, k, v, layer, forward_batch)
+
+ # Verify output
+ expected_shape = (batch_size, num_q_heads * kv_lora_rank)
+ assert output.shape == expected_shape, f"Expected shape {expected_shape}, got {output.shape}"
+ assert output.dtype == dtype
+ assert output.device.type == "cuda"
+ assert not torch.isnan(output).any(), "Output contains NaN values"
+ assert not torch.isinf(output).any(), "Output contains Inf values"
+
+
+# Simplified test for quick verification
+@pytest.mark.skipif(
+ not torch.cuda.is_available() or not is_flashinfer_available(),
+ reason="Test requires CUDA and flashinfer"
+)
+def test_trtllm_mla_basic():
+ """Basic test to verify TRTLLM MLA backend works."""
+ # Check if flashinfer has TRTLLM MLA support
+ try:
+ import flashinfer
+ if not hasattr(flashinfer.decode, 'trtllm_batch_decode_with_kv_cache_mla'):
+ pytest.skip("flashinfer version does not have TRTLLM MLA support")
+ except ImportError:
+ pytest.skip("flashinfer not available")
+
+ test = TestTRTLLMMLABackend()
+ test.test_trtllm_decode_mla(batch_size=32, page_size=32, seq_len=512)
+ print("TRTLLM MLA basic test passed!")
+
+
+if __name__ == "__main__":
+ test_trtllm_mla_basic()
\ No newline at end of file
From 3bca375ceaabc939942d5f72788606a7a7a735fa Mon Sep 17 00:00:00 2001
From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Date: Fri, 11 Jul 2025 09:51:15 -0700
Subject: [PATCH 02/27] Unittest passing
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
---
.../layers/attention/trtllm_mla_backend.py | 37 +--
python/sglang/srt/server_args.py | 1 +
.../test/attention/test_trtllm_mla_backend.py | 280 +++++++-----------
3 files changed, 121 insertions(+), 197 deletions(-)
diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
index 4534249d083e..9fae41785c64 100644
--- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py
+++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
@@ -143,19 +143,18 @@ def forward_decode(
"""Run forward for decode using TRTLLM kernel."""
cache_loc = forward_batch.out_cache_loc
- if k is not None:
- assert v is not None
- if save_kv_cache:
- if k_rope is not None:
- # MLA style KV cache storage
- forward_batch.token_to_kv_pool.set_mla_kv_buffer(
- layer,
- cache_loc,
- k,
- k_rope,
- )
- else:
- # Standard KV cache storage
+ if k is not None and save_kv_cache:
+ if k_rope is not None:
+ # MLA style KV cache storage
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
+ layer,
+ cache_loc,
+ k,
+ k_rope,
+ )
+ else:
+ # Standard KV cache storage path. Skip if value tensor is absent (e.g., MLA decode tests).
+ if v is not None:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
@@ -176,11 +175,11 @@ def forward_decode(
# Get KV cache
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
- # Reshape KV cache for TRTLLM format
- kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim)
+ # Reshape KV cache to 4-D (num_kv_heads, num_blocks, page_size, kv_dim)
+ kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(0) # 1 KV head
# Call TRTLLM MLA decode kernel
- output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
+ raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache,
workspace_buffer=forward_batch.decode_trtllm_mla_metadata.workspace,
@@ -196,5 +195,7 @@ def forward_decode(
bmm1_scale=1.0, # Only needed for FP8
bmm2_scale=1.0, # Only needed for FP8
)
-
- return output.view(-1, layer.tp_q_head_num * layer.v_head_dim)
\ No newline at end of file
+ output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
+ if output.shape[0] > forward_batch.batch_size:
+ output = output[: forward_batch.batch_size]
+ return output
\ No newline at end of file
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index dc0c6cd1acab..bd5a8336aab2 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -1226,6 +1226,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"torch_native",
"ascend",
"triton",
+ "trtllm_mla",
],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py
index b2e05874d663..7f4f91d0e588 100644
--- a/python/sglang/test/attention/test_trtllm_mla_backend.py
+++ b/python/sglang/test/attention/test_trtllm_mla_backend.py
@@ -1,224 +1,146 @@
import unittest
-
-import pytest
import torch
+from sglang.srt.layers import dp_attention as _dp_attn
+
+# Patch DP-attention globals **before** importing the backend so that all
+# downstream `from … import get_attention_tp_size` statements receive the
+# patched version.
+_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test
+
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
-from sglang.srt.utils import is_flashinfer_available
from sglang.test.test_utils import CustomTestCase
+from sglang.srt.utils import is_flashinfer_available
class MockModelRunner:
- def __init__(
- self,
- kv_lora_rank,
- qk_rope_head_dim,
- page_size=16,
- ):
- attention_arch = AttentionArch.MLA
+ """Minimal fake `ModelRunner` for MLA backend unit tests."""
+
+ def __init__(self, page_size: int):
self.device = "cuda"
self.dtype = torch.bfloat16
self.kv_cache_dtype = torch.bfloat16
- context_len = 2048
+ self.page_size = page_size
+
+ # Model-config stub – only the attributes accessed by the backend.
self.model_config = type(
"ModelConfig",
(),
{
- "context_len": context_len,
- "attention_arch": attention_arch,
+ "context_len": 2048,
+ "attention_arch": AttentionArch.MLA,
"num_attention_heads": 128,
- "kv_lora_rank": kv_lora_rank,
+ "kv_lora_rank": 512,
"qk_nope_head_dim": 128,
- "qk_rope_head_dim": qk_rope_head_dim,
- "v_head_dim": kv_lora_rank,
+ "qk_rope_head_dim": 64,
+ "v_head_dim": 512,
"scaling": 1.0 / ((128 + 64) ** 0.5),
+ "get_num_kv_heads": staticmethod(lambda _: 1),
},
)
- self.sliding_window_size = None
- self.page_size = page_size
- batch_size = 256
- # Create a proper req_to_token_pool with the req_to_token attribute
+ # Req-to-token pool (dummy)
+ max_bs = 64
+ max_ctx = self.model_config.context_len
self.req_to_token_pool = type(
"TokenPool",
(),
{
- "size": batch_size,
- "req_to_token": torch.zeros(
- batch_size, context_len, dtype=torch.int32, device=self.device
- ),
+ "size": max_bs,
+ "req_to_token": torch.zeros(max_bs, max_ctx, dtype=torch.int32, device=self.device),
},
)
-
- max_total_num_tokens = batch_size * context_len
+
+ # KV-token pool
self.token_to_kv_pool = MLATokenToKVPool(
- size=max_total_num_tokens,
- page_size=self.page_size,
- dtype=self.dtype,
- kv_lora_rank=kv_lora_rank,
- qk_rope_head_dim=qk_rope_head_dim,
- layer_num=1, # only consider layer=1 for unit test
+ size=max_bs * max_ctx,
+ page_size=page_size,
+ dtype=self.kv_cache_dtype,
+ kv_lora_rank=512,
+ qk_rope_head_dim=64,
+ layer_num=1,
device=self.device,
enable_memory_saver=False,
)
- def get_num_kv_heads(self, tp_size):
- """MLA uses single KV head."""
- return 1
-
-@pytest.mark.skipif(
- not torch.cuda.is_available() or not is_flashinfer_available(),
- reason="Test requires CUDA and flashinfer"
-)
-@pytest.mark.parametrize("batch_size", [16, 32, 64])
-@pytest.mark.parametrize("page_size", [16, 32, 64])
-@pytest.mark.parametrize("seq_len", [256, 512, 1024])
+@unittest.skipIf(not torch.cuda.is_available() or not is_flashinfer_available(), "CUDA + flashinfer required")
class TestTRTLLMMLABackend(CustomTestCase):
- def test_trtllm_decode_mla(self, batch_size, page_size, seq_len):
- """Test TRTLLM MLA decode operation with various configurations."""
- # Check if PyTorch supports current GPU
- if torch.cuda.is_available():
- capability = torch.cuda.get_device_capability()
- archs = torch.cuda.get_arch_list()
- current_arch = f"sm_{capability[0]}{capability[1]}"
- supported = any(current_arch in arch for arch in archs)
- if not supported:
- pytest.skip(f"PyTorch doesn't support {current_arch} - need nightly build")
-
- # DeepSeek MLA configuration
- num_q_heads = 128
- kv_lora_rank = 512
- qk_nope_head_dim = 128
- qk_rope_head_dim = 64
- device = "cuda"
- dtype = torch.bfloat16
-
- # Initialize model runner and backend
- model_runner = MockModelRunner(
- kv_lora_rank=kv_lora_rank,
- qk_rope_head_dim=qk_rope_head_dim,
- page_size=page_size,
- )
-
- # Check if flashinfer has TRTLLM MLA support
- try:
- import flashinfer
- if not hasattr(flashinfer.decode, 'trtllm_batch_decode_with_kv_cache_mla'):
- pytest.skip("flashinfer version does not have TRTLLM MLA support")
- except ImportError:
- pytest.skip("flashinfer not available")
-
- backend = TRTLLMMLABackend(model_runner)
-
- # Create attention layer
- layer = RadixAttention(
- num_heads=num_q_heads,
- head_dim=kv_lora_rank + qk_rope_head_dim,
- scaling=model_runner.model_config.scaling,
- num_kv_heads=1,
- layer_id=0,
- v_head_dim=kv_lora_rank,
- prefix="attn_mqa",
- )
-
- # Generate sequence lengths
- seq_lens = torch.randint(1, seq_len, (batch_size,), device=device)
- seq_lens[-1] = seq_len # Ensure at least one max length
- max_seq_len = seq_lens.max().item()
-
- # Calculate blocks needed
- blocks_per_seq = (seq_lens + page_size - 1) // page_size
- total_blocks_needed = blocks_per_seq.sum().item()
-
- # Create req_to_token mapping
- req_to_token = torch.zeros(batch_size, seq_len, dtype=torch.int32, device=device)
- token_offset = 0
- for i in range(batch_size):
- seq_len_i = seq_lens[i].item()
- req_to_token[i, :seq_len_i] = torch.arange(
- token_offset, token_offset + seq_len_i, device=device
- )
- token_offset += seq_len_i
-
- model_runner.req_to_token_pool.req_to_token = req_to_token
-
- # Create forward batch for decode
- forward_batch = ForwardBatch(
- batch_size=batch_size,
- input_ids=torch.randint(0, 100, (batch_size, 1), device=device),
- out_cache_loc=torch.arange(batch_size, device=device),
- seq_lens_sum=seq_lens.sum().item(),
+ """Structure mirrors `test_flashattn_backend.py` but focuses on MLA decode."""
+
+ def setUp(self):
+ self.batch_size = 16
+ self.seq_len = 512
+ self.page_sizes = [16, 32, 64]
+ self.device = "cuda"
+ self.dtype = torch.bfloat16
+
+ # ‑- helpers ---------------------------------------------------------
+ def _init(self, page_size: int):
+ self.model_runner = MockModelRunner(page_size)
+ self.backend = TRTLLMMLABackend(self.model_runner)
+ # Attach num_heads required by RadixAttention convenience
+ self.model_runner.model_config.num_attention_heads = 128
+
+ def _alloc_qkv(self):
+ head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim
+ q_shape = (self.batch_size, 128, head_dim)
+ q = torch.randn(q_shape, dtype=self.dtype, device=self.device)
+ k = torch.randn(self.batch_size, 1, head_dim, dtype=self.dtype, device=self.device)
+ v = None # TRTLLM MLA decode kernel ignores v
+ return q, k, v
+
+ def _create_forward_batch(self, seq_lens: torch.Tensor):
+ fb = ForwardBatch(
+ batch_size=self.batch_size,
+ input_ids=torch.randint(0, 100, (self.batch_size, 1), device=self.device),
+ out_cache_loc=torch.arange(self.batch_size, device=self.device),
+ seq_lens_sum=int(seq_lens.sum().item()),
forward_mode=ForwardMode.DECODE,
- req_pool_indices=torch.arange(batch_size, device=device),
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(),
- attn_backend=backend,
- )
-
- # Add pools to forward batch
- forward_batch.req_to_token_pool = model_runner.req_to_token_pool
- forward_batch.token_to_kv_pool = model_runner.token_to_kv_pool
-
- # Fill KV cache with some data
- cache_data = torch.randn(
- seq_lens.sum().item(),
- 1, # num_kv_heads
- kv_lora_rank + qk_rope_head_dim,
- dtype=dtype,
- device=device,
+ attn_backend=self.backend,
)
- cache_indices = torch.arange(seq_lens.sum().item(), device=device)
- forward_batch.token_to_kv_pool.set_kv_buffer(
- layer, cache_indices, cache_data, None
- )
-
- # Initialize metadata
- backend.init_forward_metadata(forward_batch)
-
- # Create input tensors
- q_shape = (batch_size, num_q_heads, kv_lora_rank + qk_rope_head_dim)
- q = torch.randn(q_shape, dtype=dtype, device=device)
-
- # For MLA, k contains compressed KV, v is not used
- k = torch.randn(batch_size, 1, kv_lora_rank + qk_rope_head_dim, dtype=dtype, device=device)
- v = None
-
- # Run forward decode
- output = backend.forward_decode(q, k, v, layer, forward_batch)
-
- # Verify output
- expected_shape = (batch_size, num_q_heads * kv_lora_rank)
- assert output.shape == expected_shape, f"Expected shape {expected_shape}, got {output.shape}"
- assert output.dtype == dtype
- assert output.device.type == "cuda"
- assert not torch.isnan(output).any(), "Output contains NaN values"
- assert not torch.isinf(output).any(), "Output contains Inf values"
-
-
-# Simplified test for quick verification
-@pytest.mark.skipif(
- not torch.cuda.is_available() or not is_flashinfer_available(),
- reason="Test requires CUDA and flashinfer"
-)
-def test_trtllm_mla_basic():
- """Basic test to verify TRTLLM MLA backend works."""
- # Check if flashinfer has TRTLLM MLA support
- try:
- import flashinfer
- if not hasattr(flashinfer.decode, 'trtllm_batch_decode_with_kv_cache_mla'):
- pytest.skip("flashinfer version does not have TRTLLM MLA support")
- except ImportError:
- pytest.skip("flashinfer not available")
-
- test = TestTRTLLMMLABackend()
- test.test_trtllm_decode_mla(batch_size=32, page_size=32, seq_len=512)
- print("TRTLLM MLA basic test passed!")
+ fb.req_to_token_pool = self.model_runner.req_to_token_pool
+ fb.token_to_kv_pool = self.model_runner.token_to_kv_pool
+ return fb
+
+ # ‑- actual tests ----------------------------------------------------
+ def test_forward_decode(self):
+ """Smoke test decode across several page sizes."""
+ for ps in self.page_sizes:
+ self._init(ps)
+
+ # Random seq lens (ensure one matches max)
+ seq_lens = torch.randint(1, self.seq_len, (self.batch_size,), device=self.device)
+ seq_lens[0] = self.seq_len
+
+ forward_batch = self._create_forward_batch(seq_lens)
+ self.backend.init_forward_metadata(forward_batch)
+
+ q, k, v = self._alloc_qkv()
+ layer = RadixAttention(
+ num_heads=128,
+ head_dim=512 + 64,
+ scaling=self.model_runner.model_config.scaling,
+ num_kv_heads=1,
+ layer_id=0,
+ v_head_dim=512,
+ prefix="attn_mqa",
+ )
+ out = self.backend.forward_decode(q, k, v, layer, forward_batch)
+
+ self.assertEqual(out.shape, (self.batch_size, 128 * 512))
+ self.assertEqual(out.dtype, self.dtype)
+ self.assertEqual(out.device.type, "cuda")
+ self.assertFalse(torch.isnan(out).any())
+ self.assertFalse(torch.isinf(out).any())
if __name__ == "__main__":
- test_trtllm_mla_basic()
\ No newline at end of file
+ unittest.main()
\ No newline at end of file
From 15f3b8079d5d75ac5e7d07c2ceb51c0035453169 Mon Sep 17 00:00:00 2001
From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Date: Sun, 13 Jul 2025 18:42:11 -0700
Subject: [PATCH 03/27] diff output vs flashinfer mla
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
---
.../layers/attention/trtllm_mla_backend.py | 265 ++++++++++++++----
.../test/attention/test_trtllm_mla_backend.py | 10 +-
.../test_trtllm_vs_flashinfer_mla.py | 218 ++++++++++++++
3 files changed, 431 insertions(+), 62 deletions(-)
create mode 100644 python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py
diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
index 9fae41785c64..14d6bad3e271 100644
--- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py
+++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
@@ -9,6 +9,8 @@
import torch
import triton
+import math # Needed for scale correction
+import os
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
@@ -16,17 +18,14 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available
+if is_flashinfer_available():
+ import flashinfer
+
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
-if is_flashinfer_available():
- import flashinfer
-
-
-# TRTLLM MLA supports variable page sizes
-
@dataclass
class TRTLLMMLADecodeMetadata:
@@ -70,41 +69,166 @@ def __init__(
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
self.page_size = model_runner.page_size # Use page size from model runner
- # Validate dimensions for TRTLLM MLA (based on test requirements)
- if self.qk_nope_head_dim != 128:
- raise ValueError(f"TRTLLM MLA requires qk_nope_head_dim=128, got {self.qk_nope_head_dim}")
- if self.kv_lora_rank != 512:
- raise ValueError(f"TRTLLM MLA requires kv_lora_rank=512, got {self.kv_lora_rank}")
- if self.qk_rope_head_dim != 64:
- raise ValueError(f"TRTLLM MLA requires qk_rope_head_dim=64, got {self.qk_rope_head_dim}")
-
# Allocate larger workspace for TRTLLM (128MB as in the test)
self.workspace_size = 128 * 1024 * 1024
self.workspace_buffer = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
)
+
+ # CUDA graph metadata storage
+ self.decode_cuda_graph_metadata = {}
+ self.cuda_graph_kv_indices = None
+
+ def _calc_padded_blocks(self, max_seq_len: int) -> int:
+ """Return number of blocks padded so that it satisfies TRTLLM constraint."""
+ blocks = triton.cdiv(max_seq_len, self.page_size)
+ min_blocks = 128 // self.page_size # kernel requirement
+ if blocks % min_blocks != 0:
+ blocks = triton.cdiv(blocks, min_blocks) * min_blocks
+ return blocks
+
+ def init_cuda_graph_state(
+ self,
+ max_bs: int,
+ max_num_tokens: int,
+ kv_indices_buf: Optional[torch.Tensor] = None,
+ ):
+ """Initialize CUDA graph state for TRTLLM MLA."""
+ # Calculate padded block size that satisfies TRTLLM constraint
+ max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
+
+ self.cuda_graph_kv_indices = torch.full(
+ (max_bs, max_blocks_per_seq),
+ -1,
+ dtype=torch.int32,
+ device=self.device,
+ )
+ self.cuda_graph_workspace = torch.empty(
+ self.workspace_size, dtype=torch.int8, device=self.device
+ )
+
+ def init_forward_metadata_capture_cuda_graph(
+ self,
+ bs: int,
+ num_tokens: int,
+ req_pool_indices: torch.Tensor,
+ seq_lens: torch.Tensor,
+ encoder_lens: Optional[torch.Tensor],
+ forward_mode: ForwardMode,
+ spec_info: Optional[SpecInfo],
+ ):
+ """Initialize metadata for CUDA graph capture."""
+ if forward_mode.is_decode_or_idle():
+ if spec_info is None:
+ max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
+
+ block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
+ create_flashmla_kv_indices_triton[(bs,)](
+ self.req_to_token,
+ req_pool_indices,
+ seq_lens,
+ None,
+ block_kv_indices,
+ self.req_to_token.stride(0),
+ max_seqlen_pad,
+ self.page_size,
+ )
+ metadata = TRTLLMMLADecodeMetadata(
+ self.cuda_graph_workspace,
+ block_kv_indices,
+ )
+ self.decode_cuda_graph_metadata[bs] = metadata
+ self.forward_metadata = metadata
+ else:
+ super().init_forward_metadata_capture_cuda_graph(
+ bs,
+ num_tokens,
+ req_pool_indices,
+ seq_lens,
+ encoder_lens,
+ forward_mode,
+ spec_info,
+ )
+ else:
+ super().init_forward_metadata_capture_cuda_graph(
+ bs,
+ num_tokens,
+ req_pool_indices,
+ seq_lens,
+ encoder_lens,
+ forward_mode,
+ spec_info,
+ )
+
+ def init_forward_metadata_replay_cuda_graph(
+ self,
+ bs: int,
+ req_pool_indices: torch.Tensor,
+ seq_lens: torch.Tensor,
+ seq_lens_sum: int,
+ encoder_lens: Optional[torch.Tensor],
+ forward_mode: ForwardMode,
+ spec_info: Optional[SpecInfo],
+ seq_lens_cpu: Optional[torch.Tensor],
+ ):
+ """Replay CUDA graph with new inputs."""
+ if forward_mode.is_decode_or_idle():
+ if spec_info is None:
+ # Reuse cached metadata
+ metadata = self.decode_cuda_graph_metadata[bs]
+
+ # Update block indices for new sequences
+ create_flashmla_kv_indices_triton[(bs,)](
+ self.req_to_token,
+ req_pool_indices[:bs],
+ seq_lens[:bs],
+ None,
+ metadata.block_kv_indices,
+ self.req_to_token.stride(0),
+ metadata.block_kv_indices.shape[1],
+ self.page_size,
+ )
+
+ self.forward_metadata = metadata
+ else:
+ # Speculative decoding: use parent class implementation
+ super().init_forward_metadata_replay_cuda_graph(
+ bs, req_pool_indices, seq_lens, seq_lens_sum,
+ encoder_lens, forward_mode, spec_info, seq_lens_cpu
+ )
+ else:
+ # Prefill: use parent class implementation
+ super().init_forward_metadata_replay_cuda_graph(
+ bs, req_pool_indices, seq_lens, seq_lens_sum,
+ encoder_lens, forward_mode, spec_info, seq_lens_cpu
+ )
+
+ def get_cuda_graph_seq_len_fill_value(self):
+ """Get the fill value for sequence lengths in CUDA graph."""
+ return 1
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
bs = forward_batch.batch_size
spec_info = forward_batch.spec_info
-
+
if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None:
- # Calculate max sequence length padded to page boundary
- max_seqlen_pad = triton.cdiv(
- forward_batch.seq_lens_cpu.max().item(), self.page_size
- )
-
- # Create block indices
+ # seq_lens_cpu may be None when cuda-graphs are disabled
+ if getattr(forward_batch, "seq_lens_cpu", None) is not None:
+ max_seq = forward_batch.seq_lens_cpu.max().item()
+ else:
+ max_seq = forward_batch.seq_lens.max().item()
+
+ max_seqlen_pad = self._calc_padded_blocks(max_seq)
+
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=forward_batch.seq_lens.device,
)
-
- # Fill block indices using the existing triton kernel
+
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
@@ -115,17 +239,16 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
max_seqlen_pad,
self.page_size,
)
-
- forward_batch.decode_trtllm_mla_metadata = TRTLLMMLADecodeMetadata(
+
+ self.forward_metadata = TRTLLMMLADecodeMetadata(
self.workspace_buffer,
block_kv_indices,
)
- self.forward_metadata = forward_batch.decode_trtllm_mla_metadata
+ # Expose to the ForwardBatch so that other components can access it
+ forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
else:
- # Speculative decoding: use parent class implementation
super().init_forward_metadata(forward_batch)
else:
- # Prefill: use parent class implementation
super().init_forward_metadata(forward_batch)
def forward_decode(
@@ -136,66 +259,94 @@ def forward_decode(
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
- # For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
- """Run forward for decode using TRTLLM kernel."""
+ """Run forward for decode using TRTLLM MLA kernel."""
cache_loc = forward_batch.out_cache_loc
+ # Save KV cache if requested
if k is not None and save_kv_cache:
if k_rope is not None:
- # MLA style KV cache storage
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
- layer,
- cache_loc,
- k,
- k_rope,
+ layer, cache_loc, k, k_rope
)
else:
- # Standard KV cache storage path. Skip if value tensor is absent (e.g., MLA decode tests).
if v is not None:
forward_batch.token_to_kv_pool.set_kv_buffer(
- layer,
- cache_loc,
- k,
- v,
+ layer, cache_loc, k, v
)
- # Prepare query tensor - concatenate q_nope and q_rope
+ # Build query tensor
if q_rope is not None:
- # q and q_rope are separate
- q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
- q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.qk_rope_head_dim)
+ q_nope = q.view(-1, layer.tp_q_head_num, self.qk_nope_head_dim)
+ q_rope = q_rope.view(-1, layer.tp_q_head_num, self.qk_rope_head_dim)
query = torch.cat([q_nope, q_rope], dim=-1)
else:
- # q already contains both nope and rope parts
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
- # Get KV cache
+ # Scale factor for TRT-LLM MLA kernel.
+ # The kernel computes softmax with scale: 1 / (sqrt(head_dim_qk) * scale)
+ # where head_dim_qk = 576 (kv_lora_rank + qk_rope_head_dim).
+ # To get the same result as FlashInfer (which uses layer.scaling = 1/sqrt(192)),
+ # we need: 1 / (sqrt(576) * scale) = 1 / sqrt(192)
+ # Therefore: scale = sqrt(576) / sqrt(192) = sqrt(3)
+ scale = math.sqrt(self.kv_lora_rank + self.qk_rope_head_dim) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
+
+ # KV cache tensor: reshape to (num_pages, page_size, dim)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
+ # Build KV cache slices expected by TRT-LLM: slice 0 → CKV+K, slice 1 → KPE.
+ pages = k_cache.view(-1, self.page_size, self.kv_cache_dim) # (P, blk, 576)
+
+ # According to the FlashInfer test, both slices should contain the full 576-dim tensor.
+ # Use torch.stack so each slice is an independent contiguous view.
+ # NOTE: this duplicates the storage but matches the reference behaviour.
+ kv_cache = torch.stack([pages, pages], dim=1) # (P, 2, blk, 576)
- # Reshape KV cache to 4-D (num_kv_heads, num_blocks, page_size, kv_dim)
- kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(0) # 1 KV head
+ # Metadata (prefer attribute on forward_batch for compatibility)
+ metadata = getattr(forward_batch, "decode_trtllm_mla_metadata", None)
+ if metadata is None:
+ metadata = self.forward_metadata
+
+ # ---------- Debug output (enable with env var) ----------
+ if os.getenv("SGLANG_DEBUG_TRTLLM_MLA", "0") == "1":
+ print(
+ f"[TRTLLM-MLA] Debug shapes before kernel call:\n"
+ f" query: {query.shape} dtype={query.dtype}\n"
+ f" kv_cache: {kv_cache.shape} dtype={kv_cache.dtype}\n"
+ f" block_tables: {metadata.block_kv_indices.shape} dtype={metadata.block_kv_indices.dtype}\n"
+ f" seq_lens: {forward_batch.seq_lens.shape} dtype={forward_batch.seq_lens.dtype}\n"
+ f" page_size: {self.page_size}\n"
+ f" max_seq_len: {metadata.block_kv_indices.shape[1] * self.page_size}\n"
+ f" scale: {scale}\n"
+ f" qk_nope_head_dim: {self.qk_nope_head_dim}\n"
+ f" kv_lora_rank: {self.kv_lora_rank}\n"
+ f" qk_rope_head_dim: {self.qk_rope_head_dim}"
+ )
+ # --------------------------------------------------------
- # Call TRTLLM MLA decode kernel
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache,
- workspace_buffer=forward_batch.decode_trtllm_mla_metadata.workspace,
+ workspace_buffer=metadata.workspace,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
- block_tables=forward_batch.decode_trtllm_mla_metadata.block_kv_indices,
+ block_tables=metadata.block_kv_indices,
seq_lens=forward_batch.seq_lens.to(torch.int32),
block_size=self.page_size,
- max_seq_len=forward_batch.seq_lens.max().item(),
- scale=layer.scaling,
- out=None,
- bmm1_scale=1.0, # Only needed for FP8
- bmm2_scale=1.0, # Only needed for FP8
+ # Avoid .item() (host sync) during CUDA graph capture.
+ # max_seq_len equals padded_blocks * page_size.
+ max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
+ scale=scale,
+ bmm1_scale=1.0,
+ bmm2_scale=1.0,
)
- output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
+ # TRTLLM kernel may return both V and ROPE dims (kv_lora_rank + qk_rope_head_dim).
+ # We only need the value projection part (v_head_dim).
+ raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
+
+ output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
if output.shape[0] > forward_batch.batch_size:
output = output[: forward_batch.batch_size]
return output
\ No newline at end of file
diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py
index 7f4f91d0e588..8d1b2e39c500 100644
--- a/python/sglang/test/attention/test_trtllm_mla_backend.py
+++ b/python/sglang/test/attention/test_trtllm_mla_backend.py
@@ -54,7 +54,7 @@ def __init__(self, page_size: int):
"req_to_token": torch.zeros(max_bs, max_ctx, dtype=torch.int32, device=self.device),
},
)
-
+
# KV-token pool
self.token_to_kv_pool = MLATokenToKVPool(
size=max_bs * max_ctx,
@@ -75,10 +75,10 @@ class TestTRTLLMMLABackend(CustomTestCase):
def setUp(self):
self.batch_size = 16
self.seq_len = 512
- self.page_sizes = [16, 32, 64]
+ self.page_sizes = [32, 64]
self.device = "cuda"
self.dtype = torch.bfloat16
-
+
# ‑- helpers ---------------------------------------------------------
def _init(self, page_size: int):
self.model_runner = MockModelRunner(page_size)
@@ -122,7 +122,7 @@ def test_forward_decode(self):
forward_batch = self._create_forward_batch(seq_lens)
self.backend.init_forward_metadata(forward_batch)
-
+
q, k, v = self._alloc_qkv()
layer = RadixAttention(
num_heads=128,
@@ -134,7 +134,7 @@ def test_forward_decode(self):
prefix="attn_mqa",
)
out = self.backend.forward_decode(q, k, v, layer, forward_batch)
-
+
self.assertEqual(out.shape, (self.batch_size, 128 * 512))
self.assertEqual(out.dtype, self.dtype)
self.assertEqual(out.device.type, "cuda")
diff --git a/python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py b/python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py
new file mode 100644
index 000000000000..db0affb7c048
--- /dev/null
+++ b/python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py
@@ -0,0 +1,218 @@
+import unittest
+import torch
+import math
+
+from sglang.srt.layers import dp_attention as _dp_attn
+
+# Patch DP-attention globals before importing backends
+_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test
+
+from sglang.srt.configs.model_config import AttentionArch
+from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
+from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
+from sglang.srt.layers.radix_attention import RadixAttention
+from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
+from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
+from sglang.test.test_utils import CustomTestCase
+from sglang.srt.utils import is_flashinfer_available
+
+
+class MockModelRunner:
+ """Minimal fake ModelRunner for comparing both MLA backends."""
+
+ def __init__(self, page_size: int):
+ self.device = "cuda"
+ self.dtype = torch.bfloat16
+ self.kv_cache_dtype = torch.bfloat16
+ self.page_size = page_size
+
+ # Model-config stub with MLA attributes
+ self.model_config = type(
+ "ModelConfig",
+ (),
+ {
+ "context_len": 2048,
+ "attention_arch": AttentionArch.MLA,
+ "num_attention_heads": 128,
+ "kv_lora_rank": 512,
+ "qk_nope_head_dim": 128,
+ "qk_rope_head_dim": 64,
+ "v_head_dim": 512,
+ "scaling": 1.0 / ((128 + 64) ** 0.5),
+ "get_num_kv_heads": staticmethod(lambda _: 1),
+ },
+ )
+
+ # Req-to-token pool
+ max_bs = 64
+ max_ctx = self.model_config.context_len
+ self.req_to_token_pool = type(
+ "TokenPool",
+ (),
+ {
+ "size": max_bs,
+ "req_to_token": torch.zeros(max_bs, max_ctx, dtype=torch.int32, device=self.device),
+ },
+ )
+
+ # KV-token pool (MLA)
+ self.token_to_kv_pool = MLATokenToKVPool(
+ size=max_bs * max_ctx,
+ page_size=page_size,
+ dtype=self.kv_cache_dtype,
+ kv_lora_rank=512,
+ qk_rope_head_dim=64,
+ layer_num=1,
+ device=self.device,
+ enable_memory_saver=False,
+ )
+
+
+@unittest.skipIf(not torch.cuda.is_available() or not is_flashinfer_available(), "CUDA + flashinfer required")
+class TestTRTLLMvsFlashInferMLA(CustomTestCase):
+ """Test numerical equivalence between TRTLLM and FlashInfer MLA backends."""
+
+ def setUp(self):
+ self.batch_size = 8
+ self.seq_len = 256
+ self.page_size = 32
+ self.device = "cuda"
+ self.dtype = torch.bfloat16
+
+ # Create model runner
+ self.model_runner = MockModelRunner(self.page_size)
+
+ # Initialize both backends
+ self.trtllm_backend = TRTLLMMLABackend(self.model_runner)
+ self.flashinfer_backend = FlashInferMLAAttnBackend(self.model_runner)
+
+ # Create RadixAttention layer for testing
+ self.layer = RadixAttention(
+ num_heads=128,
+ head_dim=512 + 64, # kv_lora_rank + qk_rope_head_dim
+ scaling=self.model_runner.model_config.scaling,
+ num_kv_heads=1,
+ layer_id=0,
+ v_head_dim=512,
+ prefix="attn_mqa",
+ )
+
+ def _create_qkv_tensors(self):
+ """Create Q, K, V tensors for testing."""
+ head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim
+ q = torch.randn((self.batch_size, 128, head_dim), dtype=self.dtype, device=self.device)
+ k = torch.randn((self.batch_size, 1, head_dim), dtype=self.dtype, device=self.device)
+ # For FlashInfer MLA, if k is provided v must not be None.
+ v = torch.randn((self.batch_size, 1, 512), dtype=self.dtype, device=self.device)
+ return q, k, v
+
+ def _create_forward_batch(self, backend):
+ """Create a forward batch for the given backend."""
+ # Random sequence lengths
+ seq_lens = torch.randint(1, self.seq_len, (self.batch_size,), device=self.device)
+ seq_lens[0] = self.seq_len # Ensure at least one max length
+
+ fb = ForwardBatch(
+ batch_size=self.batch_size,
+ input_ids=torch.randint(0, 100, (self.batch_size, 1), device=self.device),
+ out_cache_loc=torch.arange(self.batch_size, device=self.device),
+ seq_lens_sum=int(seq_lens.sum().item()),
+ forward_mode=ForwardMode.DECODE,
+ req_pool_indices=torch.arange(self.batch_size, device=self.device),
+ seq_lens=seq_lens,
+ seq_lens_cpu=seq_lens.cpu(),
+ attn_backend=backend,
+ )
+ fb.req_to_token_pool = self.model_runner.req_to_token_pool
+ fb.token_to_kv_pool = self.model_runner.token_to_kv_pool
+ return fb
+
+ def test_decode_output_match(self):
+ """Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
+ # Create identical forward batches for both backends
+ fb_trtllm = self._create_forward_batch(self.trtllm_backend)
+ fb_flashinfer = self._create_forward_batch(self.flashinfer_backend)
+
+ # Initialize metadata for both backends
+ self.trtllm_backend.init_forward_metadata(fb_trtllm)
+ self.flashinfer_backend.init_forward_metadata(fb_flashinfer)
+
+ # Create Q, K, V tensors
+ q, k, v = self._create_qkv_tensors()
+
+ # Run forward decode on both backends
+ out_trtllm = self.trtllm_backend.forward_decode(q.clone(), k.clone(), v, self.layer, fb_trtllm)
+ out_flashinfer = self.flashinfer_backend.forward_decode(q.clone(), k.clone(), v.clone(), self.layer, fb_flashinfer)
+
+ # Debug: print scale info
+ print(f"\n[DEBUG] Scale analysis:")
+ print(f" layer.scaling = {self.layer.scaling}")
+ print(f" qk_nope_head_dim = {self.model_runner.model_config.qk_nope_head_dim}")
+ print(f" qk_rope_head_dim = {self.model_runner.model_config.qk_rope_head_dim}")
+ print(f" kv_lora_rank = {self.model_runner.model_config.kv_lora_rank}")
+ print(f" Expected TRT scale factor = {math.sqrt(128 + 64) / math.sqrt(512 + 64)} = {math.sqrt(192) / math.sqrt(576)}")
+ print(f" Output shapes: TRTLLM {out_trtllm.shape}, FlashInfer {out_flashinfer.shape}")
+ print(f" Output means: TRTLLM {out_trtllm.mean().item():.6f}, FlashInfer {out_flashinfer.mean().item():.6f}")
+ print(f" Output stds: TRTLLM {out_trtllm.std().item():.6f}, FlashInfer {out_flashinfer.std().item():.6f}")
+ print(f" Max diff = {(out_trtllm - out_flashinfer).abs().max().item()}")
+ print(f" Ratio of means = {out_trtllm.mean().item() / out_flashinfer.mean().item() if out_flashinfer.mean().item() != 0 else 'inf'}")
+
+ # Additional debug
+ print(f"\n[DEBUG] Scale computation:")
+ print(f" layer.scaling = 1/sqrt(192) = {1/math.sqrt(192)}")
+ print(f" TRT scale passed = layer.scaling * sqrt(192)/sqrt(576) = {self.layer.scaling * math.sqrt(192) / math.sqrt(576)}")
+ print(f" TRT kernel will compute: 1 / (sqrt(576) * scale) = {1 / (math.sqrt(576) * self.layer.scaling * math.sqrt(192) / math.sqrt(576))}")
+ print(f" Which equals: 1 / (layer.scaling * sqrt(192)) = {1 / (self.layer.scaling * math.sqrt(192))}")
+ print(f" But FlashInfer uses: layer.scaling = {self.layer.scaling}")
+ print(f" Ratio: {(1 / (self.layer.scaling * math.sqrt(192))) / self.layer.scaling} = sqrt(192) = {math.sqrt(192)}")
+
+ # Check output shapes match
+ self.assertEqual(out_trtllm.shape, out_flashinfer.shape,
+ f"Output shapes differ: TRTLLM {out_trtllm.shape} vs FlashInfer {out_flashinfer.shape}")
+
+ # Check output dtypes match
+ self.assertEqual(out_trtllm.dtype, out_flashinfer.dtype)
+
+ # Check numerical equivalence with tolerance
+ # Note: Using higher tolerance due to potential numerical differences in implementations
+ self.assertTrue(
+ torch.allclose(out_trtllm, out_flashinfer, atol=1e-2, rtol=1e-2),
+ f"TRTLLM and FlashInfer outputs differ beyond tolerance. "
+ f"Max diff: {(out_trtllm - out_flashinfer).abs().max().item()}"
+ )
+
+ # Additional checks
+ self.assertFalse(torch.isnan(out_trtllm).any(), "TRTLLM output contains NaN")
+ self.assertFalse(torch.isnan(out_flashinfer).any(), "FlashInfer output contains NaN")
+ self.assertFalse(torch.isinf(out_trtllm).any(), "TRTLLM output contains Inf")
+ self.assertFalse(torch.isinf(out_flashinfer).any(), "FlashInfer output contains Inf")
+
+ def test_decode_with_different_page_sizes(self):
+ """Test output consistency across different page sizes."""
+ page_sizes = [32, 64]
+ outputs = []
+
+ for ps in page_sizes:
+ # Reinitialize with new page size
+ self.model_runner = MockModelRunner(ps)
+ self.trtllm_backend = TRTLLMMLABackend(self.model_runner)
+
+ # Create batch and run decode
+ fb = self._create_forward_batch(self.trtllm_backend)
+ self.trtllm_backend.init_forward_metadata(fb)
+
+ q, k, v = self._create_qkv_tensors()
+ out = self.trtllm_backend.forward_decode(q, k, v, self.layer, fb)
+ outputs.append(out)
+
+ # Check that outputs are consistent across page sizes
+ # Note: Different page sizes might lead to slightly different numerical results
+ for i in range(1, len(outputs)):
+ self.assertTrue(
+ torch.allclose(outputs[0], outputs[i], atol=5e-2, rtol=5e-2),
+ f"Output with page_size={page_sizes[0]} differs from page_size={page_sizes[i]}"
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
\ No newline at end of file
From cdd315e6a37ab7cc6a2e5823775168209a1b5e52 Mon Sep 17 00:00:00 2001
From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Date: Mon, 14 Jul 2025 18:06:43 -0700
Subject: [PATCH 04/27] trtllm mla kernel working
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
---
...a_backend.py => trtllm_gen_mla_backend.py} | 204 ++++++---
.../sglang/srt/model_executor/model_runner.py | 6 +-
.../attention/test_trtllm_gen_mla_backend.py | 386 ++++++++++++++++++
.../test/attention/test_trtllm_mla_backend.py | 146 -------
.../test_trtllm_vs_flashinfer_mla.py | 218 ----------
5 files changed, 547 insertions(+), 413 deletions(-)
rename python/sglang/srt/layers/attention/{trtllm_mla_backend.py => trtllm_gen_mla_backend.py} (56%)
create mode 100644 python/sglang/test/attention/test_trtllm_gen_mla_backend.py
delete mode 100644 python/sglang/test/attention/test_trtllm_mla_backend.py
delete mode 100644 python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py
diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
similarity index 56%
rename from python/sglang/srt/layers/attention/trtllm_mla_backend.py
rename to python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
index 14d6bad3e271..cdb817984894 100644
--- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py
+++ b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
@@ -1,7 +1,7 @@
from __future__ import annotations
"""
-Support attention backend for TRTLLM MLA kernels from flashinfer.
+Support attention backend for TRTLLM-Gen MLA kernels from flashinfer.
"""
from dataclasses import dataclass
@@ -28,12 +28,12 @@
@dataclass
-class TRTLLMMLADecodeMetadata:
+class TRTLLMGENMLADecodeMetadata:
workspace: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None
-class TRTLLMMLABackend(FlashInferMLAAttnBackend):
+class TRTLLMGENMLABackend(FlashInferMLAAttnBackend):
"""TRTLLM MLA attention kernels from flashinfer."""
def __init__(
@@ -58,7 +58,7 @@ def __init__(
self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
- self.forward_metadata: Union[TRTLLMMLADecodeMetadata] = None
+ self.forward_metadata: Union[TRTLLMGENMLADecodeMetadata] = None
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
@@ -82,7 +82,13 @@ def __init__(
def _calc_padded_blocks(self, max_seq_len: int) -> int:
"""Return number of blocks padded so that it satisfies TRTLLM constraint."""
blocks = triton.cdiv(max_seq_len, self.page_size)
- min_blocks = 128 // self.page_size # kernel requirement
+
+ # The Triton helper that builds `block_kv_indices` emits `NUM_PAGE_PER_BLOCK`
+ # (= 64) indices at a time, independent of `page_size`. To avoid it writing
+ # past the end of the row we **must** make every row at least 64 long.
+ # (Side-effect: max_seq_len is effectively rounded up to 64 × page_size = 2 K
+ # tokens for page_size 32, which is still below the 2048-ctx used here.)
+ min_blocks = 64
if blocks % min_blocks != 0:
blocks = triton.cdiv(blocks, min_blocks) * min_blocks
return blocks
@@ -133,7 +139,7 @@ def init_forward_metadata_capture_cuda_graph(
max_seqlen_pad,
self.page_size,
)
- metadata = TRTLLMMLADecodeMetadata(
+ metadata = TRTLLMGENMLADecodeMetadata(
self.cuda_graph_workspace,
block_kv_indices,
)
@@ -172,6 +178,9 @@ def init_forward_metadata_replay_cuda_graph(
seq_lens_cpu: Optional[torch.Tensor],
):
"""Replay CUDA graph with new inputs."""
+ if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
+ print(f"[TRTLLM-MLA] init_forward_metadata_replay_cuda_graph: bs={bs}, forward_mode={forward_mode}, spec_info={spec_info is not None}")
+
if forward_mode.is_decode_or_idle():
if spec_info is None:
# Reuse cached metadata
@@ -188,8 +197,11 @@ def init_forward_metadata_replay_cuda_graph(
metadata.block_kv_indices.shape[1],
self.page_size,
)
-
+ # Pad invalid blocks to keep TRTLLM happy
+ # self._pad_invalid_blocks(metadata.block_kv_indices)
+
self.forward_metadata = metadata
+
else:
# Speculative decoding: use parent class implementation
super().init_forward_metadata_replay_cuda_graph(
@@ -212,6 +224,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size
spec_info = forward_batch.spec_info
+ if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
+ print(f"[TRTLLM-MLA] init_forward_metadata: bs={bs}, forward_mode={forward_batch.forward_mode}, spec_info={spec_info is not None}")
if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None:
# seq_lens_cpu may be None when cuda-graphs are disabled
@@ -240,7 +254,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
self.page_size,
)
- self.forward_metadata = TRTLLMMLADecodeMetadata(
+ # Ensure padded blocks are valid to avoid TRTLLM skipping shorter seqs
+ # self._pad_invalid_blocks(block_kv_indices)
+
+ self.forward_metadata = TRTLLMGENMLADecodeMetadata(
self.workspace_buffer,
block_kv_indices,
)
@@ -279,52 +296,97 @@ def forward_decode(
# Build query tensor
if q_rope is not None:
- q_nope = q.view(-1, layer.tp_q_head_num, self.qk_nope_head_dim)
- q_rope = q_rope.view(-1, layer.tp_q_head_num, self.qk_rope_head_dim)
- query = torch.cat([q_nope, q_rope], dim=-1)
+ # q contains the NOPE part (v_head_dim)
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
+ q_rope = q_rope.view(
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
+ )
else:
- query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
+ reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
+ q_nope = reshaped_q[:, :, : layer.v_head_dim]
+ q_rope = reshaped_q[:, :, layer.v_head_dim :]
- # Scale factor for TRT-LLM MLA kernel.
- # The kernel computes softmax with scale: 1 / (sqrt(head_dim_qk) * scale)
- # where head_dim_qk = 576 (kv_lora_rank + qk_rope_head_dim).
- # To get the same result as FlashInfer (which uses layer.scaling = 1/sqrt(192)),
- # we need: 1 / (sqrt(576) * scale) = 1 / sqrt(192)
- # Therefore: scale = sqrt(576) / sqrt(192) = sqrt(3)
- scale = math.sqrt(self.kv_lora_rank + self.qk_rope_head_dim) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
+ # Concatenate to build the full query as expected by TRTLLM kernel
+ query = torch.cat([q_nope, q_rope], dim=-1)
+
+
+ if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
+ print(f"[TRTLLM-MLA]: model_runner.model_config.scaling.scaling is {self.model_runner.model_config.scaling}")
+
+ # Get the model scaling factor
+ # TRTLLM kernel applies the 1/sqrt(192) factor internally, so we
+ # should pass `1.0` here (see Flash-Infer equivalence tests).
+ sm_scale = 1
+ # (scale * ((512 + 64) ** 0.5)) / ((128 + 64) ** 0.5)
+ # ( sqrt(3)) / sqrt(192)
# KV cache tensor: reshape to (num_pages, page_size, dim)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
# Build KV cache slices expected by TRT-LLM: slice 0 → CKV+K, slice 1 → KPE.
pages = k_cache.view(-1, self.page_size, self.kv_cache_dim) # (P, blk, 576)
- # According to the FlashInfer test, both slices should contain the full 576-dim tensor.
- # Use torch.stack so each slice is an independent contiguous view.
- # NOTE: this duplicates the storage but matches the reference behaviour.
- kv_cache = torch.stack([pages, pages], dim=1) # (P, 2, blk, 576)
-
- # Metadata (prefer attribute on forward_batch for compatibility)
+ # Retrieve metadata so we can check block tables.
metadata = getattr(forward_batch, "decode_trtllm_mla_metadata", None)
if metadata is None:
metadata = self.forward_metadata
- # ---------- Debug output (enable with env var) ----------
- if os.getenv("SGLANG_DEBUG_TRTLLM_MLA", "0") == "1":
+ # ---------------------------------------------------------------------
+ # According to the FlashInfer test, both slices should contain the full 576-dim tensor.
+ # Use torch.stack so each slice is an independent contiguous view **after** we have
+ # patched the pages in-place.
+ kv_cache = torch.stack([pages, pages], dim=1) # (P, 2, blk, 576)
+
+ # Metadata already obtained above; no change needed.
+
+ # ------------- DEBUG KV CACHE CONSTRUCTION -------------
+ if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
+ print(
+ f"[TRTLLM-MLA] k_cache_raw: {k_cache.shape} {k_cache.dtype}\n"
+ f"[TRTLLM-MLA] pages : {pages.shape} {pages.dtype}\n"
+ f"[TRTLLM-MLA] kv_cache : {kv_cache.shape} {kv_cache.dtype}\n"
+ f"[TRTLLM-MLA] block_kv_indices: {metadata.block_kv_indices.shape} {metadata.block_kv_indices.dtype}\n"
+ f"[TRTLLM-MLA] workspace: {metadata.workspace.shape if metadata.workspace is not None else 'None'}\n"
+ f"[TRTLLM-MLA] max_seq_len: {metadata.block_kv_indices.shape[1] * self.page_size}\n"
+ f"[TRTLLM-MLA] k_cache[0,0,:3]: {k_cache[0,0,:3] if k_cache.numel() > 0 else 'empty'}\n"
+ f"[TRTLLM-MLA] pages[0,0,:3]: {pages[0,0,:3] if pages.numel() > 0 else 'empty'}\n"
+ f"[TRTLLM-MLA] kv_cache[0,0,0,:3]: {kv_cache[0,0,0,:3] if kv_cache.numel() > 0 else 'empty'}\n"
+ f"[TRTLLM-MLA] kv_cache[0,1,0,:3]: {kv_cache[0,1,0,:3] if kv_cache.numel() > 0 else 'empty'}\n"
+ f"[TRTLLM-MLA] block_kv_indices[0,:5]: {metadata.block_kv_indices[0,:5] if metadata.block_kv_indices.numel() > 0 else 'empty'}"
+ )
+
+ # ------------- DEBUG (align with FlashInfer-MLA) -------------
+ if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
+ q_nope_str = f"{q_nope.shape} {q_rope.dtype}" if q_nope is not None else "None"
+ q_rope_str = f"{q_rope.shape} {q_rope.dtype}" if q_rope is not None else "None"
+ q_nope_sample = f"{q_nope[0,:3,:3] if q_rope is not None and q_nope.numel() > 0 else 'empty'}"
+ q_rope_sample = f"{q_rope[0,:3,:3] if q_rope is not None and q_rope.numel() > 0 else 'empty'}"
+
print(
- f"[TRTLLM-MLA] Debug shapes before kernel call:\n"
- f" query: {query.shape} dtype={query.dtype}\n"
- f" kv_cache: {kv_cache.shape} dtype={kv_cache.dtype}\n"
- f" block_tables: {metadata.block_kv_indices.shape} dtype={metadata.block_kv_indices.dtype}\n"
- f" seq_lens: {forward_batch.seq_lens.shape} dtype={forward_batch.seq_lens.dtype}\n"
- f" page_size: {self.page_size}\n"
- f" max_seq_len: {metadata.block_kv_indices.shape[1] * self.page_size}\n"
- f" scale: {scale}\n"
- f" qk_nope_head_dim: {self.qk_nope_head_dim}\n"
- f" kv_lora_rank: {self.kv_lora_rank}\n"
- f" qk_rope_head_dim: {self.qk_rope_head_dim}"
+ f"[TRTLLM-MLA] q_nope : {q_nope_str}\n"
+ f"[TRTLLM-MLA] q_rope : {q_rope_str}\n"
+ f"[TRTLLM-MLA] query : {query.shape} {query.dtype}\n"
+ f"[TRTLLM-MLA] kv_cache: {kv_cache.shape} {kv_cache.dtype}\n"
+ f"[TRTLLM-MLA] scale : {sm_scale}\n"
+ f"[TRTLLM-MLA] cache_loc: {cache_loc.shape} {cache_loc.dtype}\n"
+ f"[TRTLLM-MLA] seq_lens: {forward_batch.seq_lens.shape} {forward_batch.seq_lens.dtype}\n"
+ f"[TRTLLM-MLA] batch_size: {forward_batch.batch_size}\n"
+ f"[TRTLLM-MLA] page_size: {self.page_size}\n"
+ f"[TRTLLM-MLA] qk_nope_head_dim: {self.qk_nope_head_dim}\n"
+ f"[TRTLLM-MLA] qk_rope_head_dim: {self.qk_rope_head_dim}\n"
+ f"[TRTLLM-MLA] kv_lora_rank: {self.kv_lora_rank}\n"
+ f"[TRTLLM-MLA] v_head_dim: {layer.v_head_dim}\n"
+ f"[TRTLLM-MLA] kv_cache_dim: {self.kv_cache_dim}\n"
+ f"[TRTLLM-MLA] tp_q_head_num: {layer.tp_q_head_num}\n"
+ f"[TRTLLM-MLA] tp_k_head_num: {layer.tp_k_head_num}\n"
+ f"[TRTLLM-MLA] layer_id: {layer.layer_id}\n"
+ f"[TRTLLM-MLA] q_nope[0,:3,:3]: {q_nope_sample}\n"
+ f"[TRTLLM-MLA] q_rope[0,:3,:3]: {q_rope_sample}\n"
+ f"[TRTLLM-MLA] query[0,:3,:3]: {query[0,:3,:3] if query.numel() > 0 else 'empty'}\n"
+ f"[TRTLLM-MLA] seq_lens values: {forward_batch.seq_lens[:min(5, len(forward_batch.seq_lens))]}\n"
+ f"[TRTLLM-MLA] cache_loc values: {cache_loc[:min(5, len(cache_loc))]}"
)
- # --------------------------------------------------------
+
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache,
@@ -335,18 +397,68 @@ def forward_decode(
block_tables=metadata.block_kv_indices,
seq_lens=forward_batch.seq_lens.to(torch.int32),
block_size=self.page_size,
- # Avoid .item() (host sync) during CUDA graph capture.
- # max_seq_len equals padded_blocks * page_size.
- max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
- scale=scale,
- bmm1_scale=1.0,
- bmm2_scale=1.0,
+ max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size), # max_seq_len equals padded_blocks * page_size.
+ q_scale=1.0,
+ k_scale=1.0,
+ v_scale=1.0,
+ sm_scale=sm_scale,
+ o_scale=1.0,
)
# TRTLLM kernel may return both V and ROPE dims (kv_lora_rank + qk_rope_head_dim).
# We only need the value projection part (v_head_dim).
raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
+ # ------------- DEBUG KERNEL OUTPUT -------------
+ if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
+ print(
+ f"[TRTLLM-MLA] raw_out : {raw_out.shape} {raw_out.dtype}\n"
+ f"[TRTLLM-MLA] raw_out_v : {raw_out_v.shape} {raw_out_v.dtype}\n"
+ f"[TRTLLM-MLA] raw_out[0,:3,:3]: {raw_out[0,:3,:3] if raw_out.numel() > 0 else 'empty'}\n"
+ f"[TRTLLM-MLA] raw_out_v[0,:3,:3]: {raw_out_v[0,:3,:3] if raw_out_v.numel() > 0 else 'empty'}\n"
+ f"[TRTLLM-MLA] raw_out stats: min={raw_out.min():.6f}, max={raw_out.max():.6f}, mean={raw_out.mean():.6f}\n"
+ f"[TRTLLM-MLA] raw_out_v stats: min={raw_out_v.min():.6f}, max={raw_out_v.max():.6f}, mean={raw_out_v.mean():.6f}"
+ )
+
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
if output.shape[0] > forward_batch.batch_size:
output = output[: forward_batch.batch_size]
- return output
\ No newline at end of file
+
+ # ------------- DEBUG FINAL OUTPUT -------------
+ if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
+ print(
+ f"[TRTLLM-MLA] output : {output.shape} {output.dtype}\n"
+ f"[TRTLLM-MLA] output[0,:10]: {output[0,:10] if output.numel() > 0 else 'empty'}\n"
+ f"[TRTLLM-MLA] output[0,10:20]: {output[0,10:20] if output.numel() > 10 else 'empty'}\n"
+ f"[TRTLLM-MLA] output[0,20:30]: {output[0,20:30] if output.numel() > 20 else 'empty'}\n"
+ f"[TRTLLM-MLA] output stats: min={output.min():.6f}, max={output.max():.6f}, mean={output.mean():.6f}\n"
+ f"[TRTLLM-MLA] output std: {output.std():.6f}\n"
+ f"[TRTLLM-MLA] output_reshaped: {output.shape[0]} -> {forward_batch.batch_size}\n"
+ f"[TRTLLM-MLA] ===== END TRTLLM-MLA DEBUG ====="
+ )
+
+ return output
+
+ @staticmethod
+ def _pad_invalid_blocks(block_kv_indices: torch.Tensor):
+ """Replace -1 paddings with the last valid page id for each row.
+
+ TRT-LLM treats a leading -1 as "empty sequence" and will skip the
+ whole sequence. For shorter sequences we therefore replicate the
+ last real page id into the padded region instead of leaving -1.
+ """
+ for row in range(block_kv_indices.size(0)):
+ row_view = block_kv_indices[row]
+ # Find last non-negative entry (every sequence has at least one)
+ valid_mask = row_view >= 0
+ if not valid_mask.any():
+ # Defensive – shouldn’t happen in decode mode
+ continue
+ last_valid = row_view[valid_mask][-1]
+ row_view[~valid_mask] = last_valid
+
+ # Extra visibility when debugging
+ if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
+ preview_rows = min(6, block_kv_indices.size(0))
+ print("[TRTLLM-MLA] block_kv_indices preview (after padding):")
+ for i in range(preview_rows):
+ print(f" seq {i}:", block_kv_indices[i].tolist())
\ No newline at end of file
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 0f53e52285e5..c14c3fbef48f 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -1443,11 +1443,11 @@ def _get_attention_backend_from_str(self, backend_str: str):
return CutlassMLABackend(self)
elif self.server_args.attention_backend == "trtllm_mla":
- from sglang.srt.layers.attention.trtllm_mla_backend import (
- TRTLLMMLABackend,
+ from sglang.srt.layers.attention.trtllm_gen_mla_backend import (
+ TRTLLMGENMLABackend,
)
- return TRTLLMMLABackend(self)
+ return TRTLLMGENMLABackend(self)
elif self.server_args.attention_backend == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend,
diff --git a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py
new file mode 100644
index 000000000000..acf752fd23ce
--- /dev/null
+++ b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py
@@ -0,0 +1,386 @@
+"""
+Clean test suite for TRTLLM MLA backend.
+
+This test file provides comprehensive testing for the TRTLLM MLA (Multi-Head Latent Attention) backend:
+
+1. test_basic_functionality: Basic smoke test with minimal setup
+2. test_decode_output_match: Compares TRTLLM MLA output against FlashInfer MLA reference
+ across different batch sizes and sequence lengths
+3. test_different_page_sizes: Tests consistency across different page sizes
+4. test_forward_decode_shape_sanity: Shape and sanity checks across various configurations
+
+The tests use unittest with subTest for parameterized testing, following the sglang test patterns.
+"""
+import unittest
+import torch
+import numpy as np
+from types import SimpleNamespace
+
+from sglang.srt.layers import dp_attention as _dp_attn
+
+# Patch DP-attention globals before importing backends
+_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test
+
+from sglang.srt.configs.model_config import AttentionArch
+from sglang.srt.layers.attention.trtllm_gen_mla_backend import TRTLLMGENMLABackend
+from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
+from sglang.srt.layers.radix_attention import RadixAttention
+from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
+from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
+from sglang.test.test_utils import CustomTestCase
+from sglang.srt.utils import is_flashinfer_available
+
+
+class MockModelRunner:
+ """Minimal fake ModelRunner for testing MLA backends."""
+
+ def __init__(self, page_size: int):
+ self.device = "cuda"
+ self.dtype = torch.bfloat16
+ self.kv_cache_dtype = torch.bfloat16
+ self.page_size = page_size
+
+ # Model-config stub with MLA attributes
+ self.model_config = type(
+ "ModelConfig",
+ (),
+ {
+ "context_len": 2048,
+ "attention_arch": AttentionArch.MLA,
+ "num_attention_heads": 128,
+ "kv_lora_rank": 512,
+ "qk_nope_head_dim": 128,
+ "qk_rope_head_dim": 64,
+ "v_head_dim": 512,
+ "scaling": 1.0 / ((128 + 64) ** 0.5),
+ "get_num_kv_heads": staticmethod(lambda _: 1),
+ },
+ )
+
+ # Req-to-token pool
+ max_bs = 64
+ max_ctx = self.model_config.context_len
+ req_to_token = torch.arange(max_bs * max_ctx, dtype=torch.int32, device=self.device).reshape(max_bs, max_ctx)
+ self.req_to_token_pool = type(
+ "TokenPool",
+ (),
+ {
+ "size": max_bs,
+ "req_to_token": req_to_token,
+ },
+ )
+
+ # KV-token pool (MLA)
+ self.token_to_kv_pool = MLATokenToKVPool(
+ size=max_bs * max_ctx,
+ page_size=page_size,
+ dtype=self.kv_cache_dtype,
+ kv_lora_rank=512,
+ qk_rope_head_dim=64,
+ layer_num=1,
+ device=self.device,
+ enable_memory_saver=False,
+ )
+
+
+def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
+ """Compare outputs with detailed analysis."""
+
+ # Basic checks
+ assert trtllm_out.shape == reference_out.shape, f"Shape mismatch: {trtllm_out.shape} vs {reference_out.shape}"
+ assert trtllm_out.dtype == reference_out.dtype, f"Dtype mismatch: {trtllm_out.dtype} vs {reference_out.dtype}"
+
+ # Check for NaN/Inf
+ assert not torch.isnan(trtllm_out).any(), "TRTLLM output contains NaN"
+ assert not torch.isnan(reference_out).any(), "Reference output contains NaN"
+ assert not torch.isinf(trtllm_out).any(), "TRTLLM output contains Inf"
+ assert not torch.isinf(reference_out).any(), "Reference output contains Inf"
+
+ # Element-wise differences
+ diff = (trtllm_out - reference_out).abs()
+ max_diff = diff.max().item()
+ mean_diff = diff.mean().item()
+
+ # Check numerical equivalence
+ all_close = torch.allclose(trtllm_out, reference_out, rtol=tolerance, atol=tolerance)
+
+ if not all_close:
+ print(f"Comparison failed: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, tolerance={tolerance}")
+ # Find top differences for debugging
+ flat_diff = diff.flatten()
+ top_diff_indices = torch.topk(flat_diff, k=min(5, flat_diff.numel())).indices
+ print("Top 5 differences:")
+ for i, idx in enumerate(top_diff_indices):
+ idx_tuple = np.unravel_index(idx.cpu().numpy(), trtllm_out.shape)
+ trt_val = trtllm_out[idx_tuple].item()
+ ref_val = reference_out[idx_tuple].item()
+ print(f" [{idx_tuple}]: TRTLLM={trt_val:.6f}, Reference={ref_val:.6f}, diff={abs(trt_val-ref_val):.6f}")
+
+ return all_close
+
+
+@unittest.skipIf(not torch.cuda.is_available() or not is_flashinfer_available(),
+ "CUDA + flashinfer required")
+class TestTRTLLMMLAClean(CustomTestCase):
+ """Test TRTLLM MLA backend against FlashInfer MLA backend (reference)."""
+
+ def setUp(self):
+ """Setup test fixtures."""
+ self.device = "cuda"
+ self.dtype = torch.bfloat16
+ self.page_size = 32
+
+ # Create model runner and backends
+ self.model_runner_trtllm = MockModelRunner(self.page_size)
+ self.model_runner_reference = MockModelRunner(self.page_size)
+
+ self.trtllm_backend = TRTLLMGENMLABackend(self.model_runner_trtllm)
+ self.reference_backend = FlashInferMLAAttnBackend(self.model_runner_reference)
+
+ # Create RadixAttention layer
+ self.layer = RadixAttention(
+ num_heads=128,
+ head_dim=512 + 64, # kv_lora_rank + qk_rope_head_dim
+ scaling=self.model_runner_trtllm.model_config.scaling,
+ num_kv_heads=1,
+ layer_id=0,
+ v_head_dim=512,
+ prefix="attn_mqa",
+ )
+
+ def _create_qkv_tensors(self, batch_size):
+ """Create Q, K, V tensors for testing."""
+ head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim
+ q = torch.randn((batch_size, 128, head_dim), dtype=self.dtype, device=self.device)
+ k = torch.randn((batch_size, 1, head_dim), dtype=self.dtype, device=self.device)
+ v = torch.randn((batch_size, 1, 512), dtype=self.dtype, device=self.device)
+ return q, k, v
+
+ def _create_forward_batch(self, batch_size, seq_lens, backend, model_runner):
+ """Create a forward batch for the given backend."""
+ fb = ForwardBatch(
+ batch_size=batch_size,
+ input_ids=torch.randint(0, 100, (batch_size, 1), device=self.device),
+ out_cache_loc=torch.arange(batch_size, device=self.device),
+ seq_lens_sum=int(seq_lens.sum().item()),
+ forward_mode=ForwardMode.DECODE,
+ req_pool_indices=torch.arange(batch_size, device=self.device),
+ seq_lens=seq_lens,
+ seq_lens_cpu=seq_lens.cpu(),
+ attn_backend=backend,
+ )
+ fb.req_to_token_pool = model_runner.req_to_token_pool
+ fb.token_to_kv_pool = model_runner.token_to_kv_pool
+ return fb
+
+ def _populate_kv_cache(self, batch_size, seq_lens, model_runners):
+ """Populate KV cache with identical data for both backends."""
+ torch.manual_seed(42) # Fixed seed for reproducible cache
+
+ for model_runner in model_runners:
+ torch.manual_seed(42) # Reset seed for each backend
+ for i in range(batch_size):
+ seq_len = int(seq_lens[i].item())
+ for token_idx in range(seq_len - 1):
+ # Create random K components for MLA
+ cache_k_nope = torch.randn((1, 128), dtype=self.dtype, device=self.device)
+ cache_k_rope = torch.randn((1, 64), dtype=self.dtype, device=self.device)
+
+ # Calculate cache location
+ cache_loc = model_runner.req_to_token_pool.req_to_token[i, token_idx]
+
+ # Save to KV cache
+ model_runner.token_to_kv_pool.set_mla_kv_buffer(
+ self.layer,
+ cache_loc.unsqueeze(0),
+ cache_k_nope.squeeze(0),
+ cache_k_rope.squeeze(0)
+ )
+
+ def test_decode_output_match(self):
+ """Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
+ # Test different batch sizes and sequence lengths
+ test_cases = [
+ (1, 64), (4, 64), (16, 64), (32, 64),
+ (1, 128), (4, 128), (16, 128), (32, 128),
+ (1, 256), (4, 256), (16, 256), (32, 256),
+ ]
+
+ for batch_size, max_seq_len in test_cases:
+ with self.subTest(batch_size=batch_size, max_seq_len=max_seq_len):
+ # Create identical sequence lengths for both backends
+ torch.manual_seed(42)
+ seq_lens = torch.randint(1, max_seq_len, (batch_size,), device=self.device)
+ seq_lens[0] = max_seq_len # Ensure at least one max length
+
+ # Create forward batches with identical inputs
+ fb_trtllm = self._create_forward_batch(
+ batch_size, seq_lens.clone(), self.trtllm_backend, self.model_runner_trtllm
+ )
+ fb_reference = self._create_forward_batch(
+ batch_size, seq_lens.clone(), self.reference_backend, self.model_runner_reference
+ )
+
+ # Initialize metadata for both backends
+ self.trtllm_backend.init_forward_metadata(fb_trtllm)
+ self.reference_backend.init_forward_metadata(fb_reference)
+
+ # Populate both KV caches identically
+ self._populate_kv_cache(batch_size, seq_lens, [self.model_runner_trtllm, self.model_runner_reference])
+
+ # Create Q, K, V tensors for current decode step
+ torch.manual_seed(123) # Fixed seed for Q, K, V
+ q, k, v = self._create_qkv_tensors(batch_size)
+
+ # Run forward decode on both backends
+ out_trtllm = self.trtllm_backend.forward_decode(q.clone(), k.clone(), v, self.layer, fb_trtllm)
+ out_reference = self.reference_backend.forward_decode(q.clone(), k.clone(), v.clone(), self.layer, fb_reference)
+
+ # Compare outputs
+ comparison_passed = compare_outputs(out_trtllm, out_reference, tolerance=1e-2)
+
+ self.assertTrue(comparison_passed,
+ f"TRTLLM and Reference outputs differ beyond tolerance. "
+ f"batch_size={batch_size}, max_seq_len={max_seq_len}, "
+ f"Max diff: {(out_trtllm - out_reference).abs().max().item()}"
+ )
+
+ def test_different_page_sizes(self):
+ """Test output consistency across different page sizes."""
+ page_sizes = [32, 64]
+ batch_size = 8
+ max_seq_len = 128
+
+ for page_size in page_sizes:
+ with self.subTest(page_size=page_size):
+ # Create model runner with specific page size
+ model_runner = MockModelRunner(page_size)
+ backend = TRTLLMGENMLABackend(model_runner)
+
+ # Create sequence lengths
+ torch.manual_seed(42)
+ seq_lens = torch.randint(1, max_seq_len, (batch_size,), device=self.device)
+ seq_lens[0] = max_seq_len
+
+ # Create forward batch
+ fb = self._create_forward_batch(batch_size, seq_lens, backend, model_runner)
+ backend.init_forward_metadata(fb)
+
+ # Populate KV cache
+ self._populate_kv_cache(batch_size, seq_lens, [model_runner])
+
+ # Create Q, K, V tensors
+ torch.manual_seed(123)
+ q, k, v = self._create_qkv_tensors(batch_size)
+
+ # Run forward decode
+ output = backend.forward_decode(q, k, v, self.layer, fb)
+
+ # Basic checks
+ expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim
+ self.assertEqual(output.shape, expected_shape, f"Output shape mismatch: {output.shape} vs {expected_shape}")
+ self.assertFalse(torch.isnan(output).any(), "Output contains NaN")
+ self.assertFalse(torch.isinf(output).any(), "Output contains Inf")
+
+ def test_basic_functionality(self):
+ """Test basic functionality with minimal setup."""
+ batch_size = 2
+ max_seq_len = 32
+
+ # Create sequence lengths
+ seq_lens = torch.tensor([max_seq_len, max_seq_len // 2], device=self.device)
+
+ # Create forward batch
+ fb = self._create_forward_batch(batch_size, seq_lens, self.trtllm_backend, self.model_runner_trtllm)
+ self.trtllm_backend.init_forward_metadata(fb)
+
+ # Populate KV cache
+ self._populate_kv_cache(batch_size, seq_lens, [self.model_runner_trtllm])
+
+ # Create Q, K, V tensors
+ q, k, v = self._create_qkv_tensors(batch_size)
+
+ # Run forward decode
+ output = self.trtllm_backend.forward_decode(q, k, v, self.layer, fb)
+
+ # Basic checks
+ expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim
+ self.assertEqual(output.shape, expected_shape)
+ self.assertEqual(output.dtype, self.dtype)
+ self.assertFalse(torch.isnan(output).any())
+ self.assertFalse(torch.isinf(output).any())
+
+ def test_forward_decode_shape_sanity(self):
+ """Smoke test decode across several page sizes and batch configurations."""
+ # Test configurations similar to the original test
+ test_configs = [
+ (16, 512, 32), # batch_size, seq_len, page_size
+ (16, 512, 64),
+ (8, 256, 32),
+ (4, 128, 32),
+ (1, 64, 32),
+ (32, 1024, 64),
+ ]
+
+ for batch_size, seq_len, page_size in test_configs:
+ with self.subTest(batch_size=batch_size, seq_len=seq_len, page_size=page_size):
+ # Create model runner with specific page size
+ model_runner = MockModelRunner(page_size)
+ backend = TRTLLMGENMLABackend(model_runner)
+
+ # Random seq lens (ensure one matches max)
+ torch.manual_seed(42)
+ seq_lens = torch.randint(1, seq_len, (batch_size,), device=self.device)
+ seq_lens[0] = seq_len
+
+ # Create forward batch
+ fb = ForwardBatch(
+ batch_size=batch_size,
+ input_ids=torch.randint(0, 100, (batch_size, 1), device=self.device),
+ out_cache_loc=torch.arange(batch_size, device=self.device),
+ seq_lens_sum=int(seq_lens.sum().item()),
+ forward_mode=ForwardMode.DECODE,
+ req_pool_indices=torch.arange(batch_size, device=self.device),
+ seq_lens=seq_lens,
+ seq_lens_cpu=seq_lens.cpu(),
+ attn_backend=backend,
+ )
+ fb.req_to_token_pool = model_runner.req_to_token_pool
+ fb.token_to_kv_pool = model_runner.token_to_kv_pool
+
+ backend.init_forward_metadata(fb)
+
+ # Create Q, K, V tensors
+ head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim
+ q = torch.randn((batch_size, 128, head_dim), dtype=self.dtype, device=self.device)
+ k = torch.randn((batch_size, 1, head_dim), dtype=self.dtype, device=self.device)
+ v = None # TRTLLM MLA decode kernel ignores v
+
+ # Create layer
+ layer = RadixAttention(
+ num_heads=128,
+ head_dim=512 + 64,
+ scaling=model_runner.model_config.scaling,
+ num_kv_heads=1,
+ layer_id=0,
+ v_head_dim=512,
+ prefix="attn_mqa",
+ )
+
+ # Run forward decode
+ output = backend.forward_decode(q, k, v, layer, fb)
+
+ # Shape and sanity checks
+ expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim
+ self.assertEqual(output.shape, expected_shape,
+ f"Output shape mismatch for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})")
+ self.assertEqual(output.dtype, self.dtype)
+ self.assertEqual(output.device.type, "cuda")
+ self.assertFalse(torch.isnan(output).any(),
+ f"Output contains NaN for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})")
+ self.assertFalse(torch.isinf(output).any(),
+ f"Output contains Inf for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})")
+
+
+if __name__ == "__main__":
+ unittest.main()
\ No newline at end of file
diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py
deleted file mode 100644
index 8d1b2e39c500..000000000000
--- a/python/sglang/test/attention/test_trtllm_mla_backend.py
+++ /dev/null
@@ -1,146 +0,0 @@
-import unittest
-import torch
-
-from sglang.srt.layers import dp_attention as _dp_attn
-
-# Patch DP-attention globals **before** importing the backend so that all
-# downstream `from … import get_attention_tp_size` statements receive the
-# patched version.
-_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test
-
-from sglang.srt.configs.model_config import AttentionArch
-from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
-from sglang.srt.layers.radix_attention import RadixAttention
-from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
-from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
-from sglang.test.test_utils import CustomTestCase
-from sglang.srt.utils import is_flashinfer_available
-
-
-class MockModelRunner:
- """Minimal fake `ModelRunner` for MLA backend unit tests."""
-
- def __init__(self, page_size: int):
- self.device = "cuda"
- self.dtype = torch.bfloat16
- self.kv_cache_dtype = torch.bfloat16
- self.page_size = page_size
-
- # Model-config stub – only the attributes accessed by the backend.
- self.model_config = type(
- "ModelConfig",
- (),
- {
- "context_len": 2048,
- "attention_arch": AttentionArch.MLA,
- "num_attention_heads": 128,
- "kv_lora_rank": 512,
- "qk_nope_head_dim": 128,
- "qk_rope_head_dim": 64,
- "v_head_dim": 512,
- "scaling": 1.0 / ((128 + 64) ** 0.5),
- "get_num_kv_heads": staticmethod(lambda _: 1),
- },
- )
-
- # Req-to-token pool (dummy)
- max_bs = 64
- max_ctx = self.model_config.context_len
- self.req_to_token_pool = type(
- "TokenPool",
- (),
- {
- "size": max_bs,
- "req_to_token": torch.zeros(max_bs, max_ctx, dtype=torch.int32, device=self.device),
- },
- )
-
- # KV-token pool
- self.token_to_kv_pool = MLATokenToKVPool(
- size=max_bs * max_ctx,
- page_size=page_size,
- dtype=self.kv_cache_dtype,
- kv_lora_rank=512,
- qk_rope_head_dim=64,
- layer_num=1,
- device=self.device,
- enable_memory_saver=False,
- )
-
-
-@unittest.skipIf(not torch.cuda.is_available() or not is_flashinfer_available(), "CUDA + flashinfer required")
-class TestTRTLLMMLABackend(CustomTestCase):
- """Structure mirrors `test_flashattn_backend.py` but focuses on MLA decode."""
-
- def setUp(self):
- self.batch_size = 16
- self.seq_len = 512
- self.page_sizes = [32, 64]
- self.device = "cuda"
- self.dtype = torch.bfloat16
-
- # ‑- helpers ---------------------------------------------------------
- def _init(self, page_size: int):
- self.model_runner = MockModelRunner(page_size)
- self.backend = TRTLLMMLABackend(self.model_runner)
- # Attach num_heads required by RadixAttention convenience
- self.model_runner.model_config.num_attention_heads = 128
-
- def _alloc_qkv(self):
- head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim
- q_shape = (self.batch_size, 128, head_dim)
- q = torch.randn(q_shape, dtype=self.dtype, device=self.device)
- k = torch.randn(self.batch_size, 1, head_dim, dtype=self.dtype, device=self.device)
- v = None # TRTLLM MLA decode kernel ignores v
- return q, k, v
-
- def _create_forward_batch(self, seq_lens: torch.Tensor):
- fb = ForwardBatch(
- batch_size=self.batch_size,
- input_ids=torch.randint(0, 100, (self.batch_size, 1), device=self.device),
- out_cache_loc=torch.arange(self.batch_size, device=self.device),
- seq_lens_sum=int(seq_lens.sum().item()),
- forward_mode=ForwardMode.DECODE,
- req_pool_indices=torch.arange(self.batch_size, device=self.device),
- seq_lens=seq_lens,
- seq_lens_cpu=seq_lens.cpu(),
- attn_backend=self.backend,
- )
- fb.req_to_token_pool = self.model_runner.req_to_token_pool
- fb.token_to_kv_pool = self.model_runner.token_to_kv_pool
- return fb
-
- # ‑- actual tests ----------------------------------------------------
- def test_forward_decode(self):
- """Smoke test decode across several page sizes."""
- for ps in self.page_sizes:
- self._init(ps)
-
- # Random seq lens (ensure one matches max)
- seq_lens = torch.randint(1, self.seq_len, (self.batch_size,), device=self.device)
- seq_lens[0] = self.seq_len
-
- forward_batch = self._create_forward_batch(seq_lens)
- self.backend.init_forward_metadata(forward_batch)
-
- q, k, v = self._alloc_qkv()
- layer = RadixAttention(
- num_heads=128,
- head_dim=512 + 64,
- scaling=self.model_runner.model_config.scaling,
- num_kv_heads=1,
- layer_id=0,
- v_head_dim=512,
- prefix="attn_mqa",
- )
- out = self.backend.forward_decode(q, k, v, layer, forward_batch)
-
- self.assertEqual(out.shape, (self.batch_size, 128 * 512))
- self.assertEqual(out.dtype, self.dtype)
- self.assertEqual(out.device.type, "cuda")
- self.assertFalse(torch.isnan(out).any())
- self.assertFalse(torch.isinf(out).any())
-
-
-if __name__ == "__main__":
- unittest.main()
\ No newline at end of file
diff --git a/python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py b/python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py
deleted file mode 100644
index db0affb7c048..000000000000
--- a/python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py
+++ /dev/null
@@ -1,218 +0,0 @@
-import unittest
-import torch
-import math
-
-from sglang.srt.layers import dp_attention as _dp_attn
-
-# Patch DP-attention globals before importing backends
-_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test
-
-from sglang.srt.configs.model_config import AttentionArch
-from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
-from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
-from sglang.srt.layers.radix_attention import RadixAttention
-from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
-from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
-from sglang.test.test_utils import CustomTestCase
-from sglang.srt.utils import is_flashinfer_available
-
-
-class MockModelRunner:
- """Minimal fake ModelRunner for comparing both MLA backends."""
-
- def __init__(self, page_size: int):
- self.device = "cuda"
- self.dtype = torch.bfloat16
- self.kv_cache_dtype = torch.bfloat16
- self.page_size = page_size
-
- # Model-config stub with MLA attributes
- self.model_config = type(
- "ModelConfig",
- (),
- {
- "context_len": 2048,
- "attention_arch": AttentionArch.MLA,
- "num_attention_heads": 128,
- "kv_lora_rank": 512,
- "qk_nope_head_dim": 128,
- "qk_rope_head_dim": 64,
- "v_head_dim": 512,
- "scaling": 1.0 / ((128 + 64) ** 0.5),
- "get_num_kv_heads": staticmethod(lambda _: 1),
- },
- )
-
- # Req-to-token pool
- max_bs = 64
- max_ctx = self.model_config.context_len
- self.req_to_token_pool = type(
- "TokenPool",
- (),
- {
- "size": max_bs,
- "req_to_token": torch.zeros(max_bs, max_ctx, dtype=torch.int32, device=self.device),
- },
- )
-
- # KV-token pool (MLA)
- self.token_to_kv_pool = MLATokenToKVPool(
- size=max_bs * max_ctx,
- page_size=page_size,
- dtype=self.kv_cache_dtype,
- kv_lora_rank=512,
- qk_rope_head_dim=64,
- layer_num=1,
- device=self.device,
- enable_memory_saver=False,
- )
-
-
-@unittest.skipIf(not torch.cuda.is_available() or not is_flashinfer_available(), "CUDA + flashinfer required")
-class TestTRTLLMvsFlashInferMLA(CustomTestCase):
- """Test numerical equivalence between TRTLLM and FlashInfer MLA backends."""
-
- def setUp(self):
- self.batch_size = 8
- self.seq_len = 256
- self.page_size = 32
- self.device = "cuda"
- self.dtype = torch.bfloat16
-
- # Create model runner
- self.model_runner = MockModelRunner(self.page_size)
-
- # Initialize both backends
- self.trtllm_backend = TRTLLMMLABackend(self.model_runner)
- self.flashinfer_backend = FlashInferMLAAttnBackend(self.model_runner)
-
- # Create RadixAttention layer for testing
- self.layer = RadixAttention(
- num_heads=128,
- head_dim=512 + 64, # kv_lora_rank + qk_rope_head_dim
- scaling=self.model_runner.model_config.scaling,
- num_kv_heads=1,
- layer_id=0,
- v_head_dim=512,
- prefix="attn_mqa",
- )
-
- def _create_qkv_tensors(self):
- """Create Q, K, V tensors for testing."""
- head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim
- q = torch.randn((self.batch_size, 128, head_dim), dtype=self.dtype, device=self.device)
- k = torch.randn((self.batch_size, 1, head_dim), dtype=self.dtype, device=self.device)
- # For FlashInfer MLA, if k is provided v must not be None.
- v = torch.randn((self.batch_size, 1, 512), dtype=self.dtype, device=self.device)
- return q, k, v
-
- def _create_forward_batch(self, backend):
- """Create a forward batch for the given backend."""
- # Random sequence lengths
- seq_lens = torch.randint(1, self.seq_len, (self.batch_size,), device=self.device)
- seq_lens[0] = self.seq_len # Ensure at least one max length
-
- fb = ForwardBatch(
- batch_size=self.batch_size,
- input_ids=torch.randint(0, 100, (self.batch_size, 1), device=self.device),
- out_cache_loc=torch.arange(self.batch_size, device=self.device),
- seq_lens_sum=int(seq_lens.sum().item()),
- forward_mode=ForwardMode.DECODE,
- req_pool_indices=torch.arange(self.batch_size, device=self.device),
- seq_lens=seq_lens,
- seq_lens_cpu=seq_lens.cpu(),
- attn_backend=backend,
- )
- fb.req_to_token_pool = self.model_runner.req_to_token_pool
- fb.token_to_kv_pool = self.model_runner.token_to_kv_pool
- return fb
-
- def test_decode_output_match(self):
- """Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
- # Create identical forward batches for both backends
- fb_trtllm = self._create_forward_batch(self.trtllm_backend)
- fb_flashinfer = self._create_forward_batch(self.flashinfer_backend)
-
- # Initialize metadata for both backends
- self.trtllm_backend.init_forward_metadata(fb_trtllm)
- self.flashinfer_backend.init_forward_metadata(fb_flashinfer)
-
- # Create Q, K, V tensors
- q, k, v = self._create_qkv_tensors()
-
- # Run forward decode on both backends
- out_trtllm = self.trtllm_backend.forward_decode(q.clone(), k.clone(), v, self.layer, fb_trtllm)
- out_flashinfer = self.flashinfer_backend.forward_decode(q.clone(), k.clone(), v.clone(), self.layer, fb_flashinfer)
-
- # Debug: print scale info
- print(f"\n[DEBUG] Scale analysis:")
- print(f" layer.scaling = {self.layer.scaling}")
- print(f" qk_nope_head_dim = {self.model_runner.model_config.qk_nope_head_dim}")
- print(f" qk_rope_head_dim = {self.model_runner.model_config.qk_rope_head_dim}")
- print(f" kv_lora_rank = {self.model_runner.model_config.kv_lora_rank}")
- print(f" Expected TRT scale factor = {math.sqrt(128 + 64) / math.sqrt(512 + 64)} = {math.sqrt(192) / math.sqrt(576)}")
- print(f" Output shapes: TRTLLM {out_trtllm.shape}, FlashInfer {out_flashinfer.shape}")
- print(f" Output means: TRTLLM {out_trtllm.mean().item():.6f}, FlashInfer {out_flashinfer.mean().item():.6f}")
- print(f" Output stds: TRTLLM {out_trtllm.std().item():.6f}, FlashInfer {out_flashinfer.std().item():.6f}")
- print(f" Max diff = {(out_trtllm - out_flashinfer).abs().max().item()}")
- print(f" Ratio of means = {out_trtllm.mean().item() / out_flashinfer.mean().item() if out_flashinfer.mean().item() != 0 else 'inf'}")
-
- # Additional debug
- print(f"\n[DEBUG] Scale computation:")
- print(f" layer.scaling = 1/sqrt(192) = {1/math.sqrt(192)}")
- print(f" TRT scale passed = layer.scaling * sqrt(192)/sqrt(576) = {self.layer.scaling * math.sqrt(192) / math.sqrt(576)}")
- print(f" TRT kernel will compute: 1 / (sqrt(576) * scale) = {1 / (math.sqrt(576) * self.layer.scaling * math.sqrt(192) / math.sqrt(576))}")
- print(f" Which equals: 1 / (layer.scaling * sqrt(192)) = {1 / (self.layer.scaling * math.sqrt(192))}")
- print(f" But FlashInfer uses: layer.scaling = {self.layer.scaling}")
- print(f" Ratio: {(1 / (self.layer.scaling * math.sqrt(192))) / self.layer.scaling} = sqrt(192) = {math.sqrt(192)}")
-
- # Check output shapes match
- self.assertEqual(out_trtllm.shape, out_flashinfer.shape,
- f"Output shapes differ: TRTLLM {out_trtllm.shape} vs FlashInfer {out_flashinfer.shape}")
-
- # Check output dtypes match
- self.assertEqual(out_trtllm.dtype, out_flashinfer.dtype)
-
- # Check numerical equivalence with tolerance
- # Note: Using higher tolerance due to potential numerical differences in implementations
- self.assertTrue(
- torch.allclose(out_trtllm, out_flashinfer, atol=1e-2, rtol=1e-2),
- f"TRTLLM and FlashInfer outputs differ beyond tolerance. "
- f"Max diff: {(out_trtllm - out_flashinfer).abs().max().item()}"
- )
-
- # Additional checks
- self.assertFalse(torch.isnan(out_trtllm).any(), "TRTLLM output contains NaN")
- self.assertFalse(torch.isnan(out_flashinfer).any(), "FlashInfer output contains NaN")
- self.assertFalse(torch.isinf(out_trtllm).any(), "TRTLLM output contains Inf")
- self.assertFalse(torch.isinf(out_flashinfer).any(), "FlashInfer output contains Inf")
-
- def test_decode_with_different_page_sizes(self):
- """Test output consistency across different page sizes."""
- page_sizes = [32, 64]
- outputs = []
-
- for ps in page_sizes:
- # Reinitialize with new page size
- self.model_runner = MockModelRunner(ps)
- self.trtllm_backend = TRTLLMMLABackend(self.model_runner)
-
- # Create batch and run decode
- fb = self._create_forward_batch(self.trtllm_backend)
- self.trtllm_backend.init_forward_metadata(fb)
-
- q, k, v = self._create_qkv_tensors()
- out = self.trtllm_backend.forward_decode(q, k, v, self.layer, fb)
- outputs.append(out)
-
- # Check that outputs are consistent across page sizes
- # Note: Different page sizes might lead to slightly different numerical results
- for i in range(1, len(outputs)):
- self.assertTrue(
- torch.allclose(outputs[0], outputs[i], atol=5e-2, rtol=5e-2),
- f"Output with page_size={page_sizes[0]} differs from page_size={page_sizes[i]}"
- )
-
-
-if __name__ == "__main__":
- unittest.main()
\ No newline at end of file
From 274fdbb6ff78ef994a4be078c0b8ce5f993b1c97 Mon Sep 17 00:00:00 2001
From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Date: Mon, 14 Jul 2025 20:01:04 -0700
Subject: [PATCH 05/27] refator code
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
---
.../attention/trtllm_gen_mla_backend.py | 550 ++++++++----------
.../sglang/srt/model_executor/model_runner.py | 13 +-
2 files changed, 240 insertions(+), 323 deletions(-)
diff --git a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
index cdb817984894..eaf1aa1c2eef 100644
--- a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
+++ b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
@@ -6,14 +6,13 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
+import math
import torch
import triton
-import math # Needed for scale correction
-import os
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
-from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
+from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton, NUM_PAGE_PER_BLOCK
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available
@@ -26,9 +25,16 @@
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
+# Constants
+TRTLLM_BLOCK_CONSTRAINT = 128 # TRT-LLM kernel constraint: block_num % (128 / block_size) == 0
+DEFAULT_WORKSPACE_SIZE_MB = 128
+DEFAULT_QUANTIZATION_SCALE = 1.0 # Default scale for FP8 quantization
+DEFAULT_SM_SCALE = 1.0 # Default softmax scale
+
@dataclass
class TRTLLMGENMLADecodeMetadata:
+ """Metadata for TRTLLM MLA decode operations."""
workspace: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None
@@ -47,52 +53,170 @@ def __init__(
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
)
+ # Cache model config for easier access
+ config = model_runner.model_config
+
# Model parameters
- self.num_q_heads = (
- model_runner.model_config.num_attention_heads // get_attention_tp_size()
- )
- self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
- get_attention_tp_size()
- )
- self.req_to_token = model_runner.req_to_token_pool.req_to_token
- self.num_local_heads = (
- model_runner.model_config.num_attention_heads // get_attention_tp_size()
- )
- self.forward_metadata: Union[TRTLLMGENMLADecodeMetadata] = None
- self.kv_lora_rank = model_runner.model_config.kv_lora_rank
- self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
- self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
- self.v_head_dim = model_runner.model_config.v_head_dim
- self.scaling = model_runner.model_config.scaling
+ self.num_q_heads = config.num_attention_heads // get_attention_tp_size()
+ self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size())
+ self.num_local_heads = config.num_attention_heads // get_attention_tp_size()
+
+ # MLA-specific dimensions
+ self.kv_lora_rank = config.kv_lora_rank
+ self.qk_nope_head_dim = config.qk_nope_head_dim
+ self.qk_rope_head_dim = config.qk_rope_head_dim
+ self.v_head_dim = config.v_head_dim
+ self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
+
+ # Runtime parameters
+ self.scaling = config.scaling
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
- self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
- self.page_size = model_runner.page_size # Use page size from model runner
+ self.page_size = model_runner.page_size
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
- # Allocate larger workspace for TRTLLM (128MB as in the test)
- self.workspace_size = 128 * 1024 * 1024
+ # Workspace allocation
+ self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
self.workspace_buffer = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
)
- # CUDA graph metadata storage
+ # CUDA graph state
self.decode_cuda_graph_metadata = {}
self.cuda_graph_kv_indices = None
+ self.forward_metadata: Union[TRTLLMGENMLADecodeMetadata] = None
def _calc_padded_blocks(self, max_seq_len: int) -> int:
- """Return number of blocks padded so that it satisfies TRTLLM constraint."""
+ """
+ Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
+
+ Args:
+ max_seq_len: Maximum sequence length in tokens
+
+ Returns:
+ Number of blocks padded to satisfy all constraints
+ """
blocks = triton.cdiv(max_seq_len, self.page_size)
-
- # The Triton helper that builds `block_kv_indices` emits `NUM_PAGE_PER_BLOCK`
- # (= 64) indices at a time, independent of `page_size`. To avoid it writing
- # past the end of the row we **must** make every row at least 64 long.
- # (Side-effect: max_seq_len is effectively rounded up to 64 × page_size = 2 K
- # tokens for page_size 32, which is still below the 2048-ctx used here.)
- min_blocks = 64
+
+ # TWO constraints require padding:
+ # 1. TRT-LLM kernel expects: block_num % (128 / block_size) == 0
+ # This is a hard requirement from the CUDA kernel implementation.
+ # 2. Our Triton helper `create_flashmla_kv_indices_triton` builds page tables
+ # in fixed bursts of 64 indices per outer loop iteration. Each burst writes
+ # unconditionally to positions [i*64, (i+1)*64) without per-element masking.
+ # If the row length isn't a multiple of 64, the final burst would overrun
+ # the allocated buffer and corrupt memory.
+ #
+ # We need to satisfy BOTH constraints, so we take the LCM of:
+ # - trtllm_constraint = 128 // page_size
+ # - triton_constraint = 64
+ trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
+ triton_constraint = NUM_PAGE_PER_BLOCK
+ min_blocks = math.lcm(trtllm_constraint, triton_constraint)
+
if blocks % min_blocks != 0:
blocks = triton.cdiv(blocks, min_blocks) * min_blocks
return blocks
+ def _get_quantization_scales(self) -> tuple[float, float, float, float, float]:
+ """
+ Get quantization scales for q, k, v, sm, and o tensors.
+
+ Returns:
+ Tuple of (q_scale, k_scale, v_scale, sm_scale, o_scale)
+ """
+ # TODO: Implement proper quantization scale inference based on model config
+ # For now, use default values for FP8
+ return (
+ DEFAULT_QUANTIZATION_SCALE, # q_scale
+ DEFAULT_QUANTIZATION_SCALE, # k_scale
+ DEFAULT_QUANTIZATION_SCALE, # v_scale
+ DEFAULT_SM_SCALE, # sm_scale
+ DEFAULT_QUANTIZATION_SCALE, # o_scale
+ )
+
+ def _prepare_kv_cache(self, layer: RadixAttention, forward_batch: ForwardBatch) -> torch.Tensor:
+ """
+ Prepare KV cache tensor in the format expected by TRT-LLM kernel.
+
+ Args:
+ layer: Attention layer
+ forward_batch: Forward batch info
+
+ Returns:
+ KV cache tensor shaped (num_pages, 2, page_size, kv_cache_dim)
+ """
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
+ pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
+ # TRT-LLM expects stacked format: slice 0 → CKV+K, slice 1 → KPE
+ return torch.stack([pages, pages], dim=1)
+
+ def _prepare_query_tensor(
+ self,
+ q: torch.Tensor,
+ q_rope: Optional[torch.Tensor],
+ layer: RadixAttention
+ ) -> torch.Tensor:
+ """
+ Prepare query tensor in the format expected by TRT-LLM kernel.
+
+ Args:
+ q: Query tensor
+ q_rope: Optional RoPE query tensor
+ layer: Attention layer
+
+ Returns:
+ Query tensor with concatenated NOPE and RoPE parts
+ """
+ if q_rope is not None:
+ # q contains the NOPE part (v_head_dim)
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
+ q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim)
+ else:
+ reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
+ q_nope = reshaped_q[:, :, : layer.v_head_dim]
+ q_rope = reshaped_q[:, :, layer.v_head_dim :]
+
+ return torch.cat([q_nope, q_rope], dim=-1)
+
+ def _create_block_kv_indices(
+ self,
+ batch_size: int,
+ max_blocks: int,
+ req_pool_indices: torch.Tensor,
+ seq_lens: torch.Tensor,
+ device: torch.device
+ ) -> torch.Tensor:
+ """
+ Create block KV indices tensor using Triton kernel.
+
+ Args:
+ batch_size: Batch size
+ max_blocks: Maximum number of blocks per sequence
+ req_pool_indices: Request pool indices
+ seq_lens: Sequence lengths
+ device: Target device
+
+ Returns:
+ Block KV indices tensor
+ """
+ block_kv_indices = torch.full(
+ (batch_size, max_blocks), -1, dtype=torch.int32, device=device
+ )
+
+ create_flashmla_kv_indices_triton[(batch_size,)](
+ self.req_to_token,
+ req_pool_indices,
+ seq_lens,
+ None,
+ block_kv_indices,
+ self.req_to_token.stride(0),
+ max_blocks,
+ self.page_size,
+ )
+
+ return block_kv_indices
+
def init_cuda_graph_state(
self,
max_bs: int,
@@ -100,14 +224,10 @@ def init_cuda_graph_state(
kv_indices_buf: Optional[torch.Tensor] = None,
):
"""Initialize CUDA graph state for TRTLLM MLA."""
- # Calculate padded block size that satisfies TRTLLM constraint
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
self.cuda_graph_kv_indices = torch.full(
- (max_bs, max_blocks_per_seq),
- -1,
- dtype=torch.int32,
- device=self.device,
+ (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
)
self.cuda_graph_workspace = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
@@ -124,46 +244,29 @@ def init_forward_metadata_capture_cuda_graph(
spec_info: Optional[SpecInfo],
):
"""Initialize metadata for CUDA graph capture."""
- if forward_mode.is_decode_or_idle():
- if spec_info is None:
- max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
-
- block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
- create_flashmla_kv_indices_triton[(bs,)](
- self.req_to_token,
- req_pool_indices,
- seq_lens,
- None,
- block_kv_indices,
- self.req_to_token.stride(0),
- max_seqlen_pad,
- self.page_size,
- )
- metadata = TRTLLMGENMLADecodeMetadata(
- self.cuda_graph_workspace,
- block_kv_indices,
- )
- self.decode_cuda_graph_metadata[bs] = metadata
- self.forward_metadata = metadata
- else:
- super().init_forward_metadata_capture_cuda_graph(
- bs,
- num_tokens,
- req_pool_indices,
- seq_lens,
- encoder_lens,
- forward_mode,
- spec_info,
- )
- else:
- super().init_forward_metadata_capture_cuda_graph(
- bs,
- num_tokens,
+ if forward_mode.is_decode_or_idle() and spec_info is None:
+ max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
+ block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
+
+ create_flashmla_kv_indices_triton[(bs,)](
+ self.req_to_token,
req_pool_indices,
seq_lens,
- encoder_lens,
- forward_mode,
- spec_info,
+ None,
+ block_kv_indices,
+ self.req_to_token.stride(0),
+ max_seqlen_pad,
+ self.page_size,
+ )
+
+ metadata = TRTLLMGENMLADecodeMetadata(
+ self.cuda_graph_workspace, block_kv_indices
+ )
+ self.decode_cuda_graph_metadata[bs] = metadata
+ self.forward_metadata = metadata
+ else:
+ super().init_forward_metadata_capture_cuda_graph(
+ bs, num_tokens, req_pool_indices, seq_lens, encoder_lens, forward_mode, spec_info
)
def init_forward_metadata_replay_cuda_graph(
@@ -178,93 +281,53 @@ def init_forward_metadata_replay_cuda_graph(
seq_lens_cpu: Optional[torch.Tensor],
):
"""Replay CUDA graph with new inputs."""
- if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
- print(f"[TRTLLM-MLA] init_forward_metadata_replay_cuda_graph: bs={bs}, forward_mode={forward_mode}, spec_info={spec_info is not None}")
-
- if forward_mode.is_decode_or_idle():
- if spec_info is None:
- # Reuse cached metadata
- metadata = self.decode_cuda_graph_metadata[bs]
-
- # Update block indices for new sequences
- create_flashmla_kv_indices_triton[(bs,)](
- self.req_to_token,
- req_pool_indices[:bs],
- seq_lens[:bs],
- None,
- metadata.block_kv_indices,
- self.req_to_token.stride(0),
- metadata.block_kv_indices.shape[1],
- self.page_size,
- )
- # Pad invalid blocks to keep TRTLLM happy
- # self._pad_invalid_blocks(metadata.block_kv_indices)
-
- self.forward_metadata = metadata
-
- else:
- # Speculative decoding: use parent class implementation
- super().init_forward_metadata_replay_cuda_graph(
- bs, req_pool_indices, seq_lens, seq_lens_sum,
- encoder_lens, forward_mode, spec_info, seq_lens_cpu
- )
+ if forward_mode.is_decode_or_idle() and spec_info is None:
+ metadata = self.decode_cuda_graph_metadata[bs]
+
+ # Update block indices for new sequences
+ create_flashmla_kv_indices_triton[(bs,)](
+ self.req_to_token,
+ req_pool_indices[:bs],
+ seq_lens[:bs],
+ None,
+ metadata.block_kv_indices,
+ self.req_to_token.stride(0),
+ metadata.block_kv_indices.shape[1],
+ self.page_size,
+ )
+
+ self.forward_metadata = metadata
else:
- # Prefill: use parent class implementation
super().init_forward_metadata_replay_cuda_graph(
bs, req_pool_indices, seq_lens, seq_lens_sum,
encoder_lens, forward_mode, spec_info, seq_lens_cpu
)
- def get_cuda_graph_seq_len_fill_value(self):
+ def get_cuda_graph_seq_len_fill_value(self) -> int:
"""Get the fill value for sequence lengths in CUDA graph."""
return 1
def init_forward_metadata(self, forward_batch: ForwardBatch):
- """Init the metadata for a forward pass."""
- bs = forward_batch.batch_size
- spec_info = forward_batch.spec_info
-
- if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
- print(f"[TRTLLM-MLA] init_forward_metadata: bs={bs}, forward_mode={forward_batch.forward_mode}, spec_info={spec_info is not None}")
- if forward_batch.forward_mode.is_decode_or_idle():
- if spec_info is None:
- # seq_lens_cpu may be None when cuda-graphs are disabled
- if getattr(forward_batch, "seq_lens_cpu", None) is not None:
- max_seq = forward_batch.seq_lens_cpu.max().item()
- else:
- max_seq = forward_batch.seq_lens.max().item()
-
- max_seqlen_pad = self._calc_padded_blocks(max_seq)
-
- block_kv_indices = torch.full(
- (bs, max_seqlen_pad),
- -1,
- dtype=torch.int32,
- device=forward_batch.seq_lens.device,
- )
-
- create_flashmla_kv_indices_triton[(bs,)](
- self.req_to_token,
- forward_batch.req_pool_indices,
- forward_batch.seq_lens,
- None,
- block_kv_indices,
- self.req_to_token.stride(0),
- max_seqlen_pad,
- self.page_size,
- )
+ """Initialize the metadata for a forward pass."""
+ if forward_batch.forward_mode.is_decode_or_idle() and forward_batch.spec_info is None:
+ bs = forward_batch.batch_size
+
+ # Get maximum sequence length
+ if getattr(forward_batch, "seq_lens_cpu", None) is not None:
+ max_seq = forward_batch.seq_lens_cpu.max().item()
+ else:
+ max_seq = forward_batch.seq_lens.max().item()
- # Ensure padded blocks are valid to avoid TRTLLM skipping shorter seqs
- # self._pad_invalid_blocks(block_kv_indices)
+ max_seqlen_pad = self._calc_padded_blocks(max_seq)
+ block_kv_indices = self._create_block_kv_indices(
+ bs, max_seqlen_pad, forward_batch.req_pool_indices,
+ forward_batch.seq_lens, forward_batch.seq_lens.device
+ )
- self.forward_metadata = TRTLLMGENMLADecodeMetadata(
- self.workspace_buffer,
- block_kv_indices,
- )
- # Expose to the ForwardBatch so that other components can access it
- forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
- else:
- super().init_forward_metadata(forward_batch)
+ self.forward_metadata = TRTLLMGENMLADecodeMetadata(
+ self.workspace_buffer, block_kv_indices
+ )
+ forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
else:
super().init_forward_metadata(forward_batch)
@@ -278,115 +341,27 @@ def forward_decode(
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
- ):
+ ) -> torch.Tensor:
"""Run forward for decode using TRTLLM MLA kernel."""
- cache_loc = forward_batch.out_cache_loc
-
# Save KV cache if requested
if k is not None and save_kv_cache:
+ cache_loc = forward_batch.out_cache_loc
if k_rope is not None:
- forward_batch.token_to_kv_pool.set_mla_kv_buffer(
- layer, cache_loc, k, k_rope
- )
- else:
- if v is not None:
- forward_batch.token_to_kv_pool.set_kv_buffer(
- layer, cache_loc, k, v
- )
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(layer, cache_loc, k, k_rope)
+ elif v is not None:
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
- # Build query tensor
- if q_rope is not None:
- # q contains the NOPE part (v_head_dim)
- q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
- q_rope = q_rope.view(
- -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
- )
- else:
- reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
- q_nope = reshaped_q[:, :, : layer.v_head_dim]
- q_rope = reshaped_q[:, :, layer.v_head_dim :]
-
- # Concatenate to build the full query as expected by TRTLLM kernel
- query = torch.cat([q_nope, q_rope], dim=-1)
-
-
- if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
- print(f"[TRTLLM-MLA]: model_runner.model_config.scaling.scaling is {self.model_runner.model_config.scaling}")
-
- # Get the model scaling factor
- # TRTLLM kernel applies the 1/sqrt(192) factor internally, so we
- # should pass `1.0` here (see Flash-Infer equivalence tests).
- sm_scale = 1
- # (scale * ((512 + 64) ** 0.5)) / ((128 + 64) ** 0.5)
- # ( sqrt(3)) / sqrt(192)
-
- # KV cache tensor: reshape to (num_pages, page_size, dim)
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
- # Build KV cache slices expected by TRT-LLM: slice 0 → CKV+K, slice 1 → KPE.
- pages = k_cache.view(-1, self.page_size, self.kv_cache_dim) # (P, blk, 576)
-
- # Retrieve metadata so we can check block tables.
- metadata = getattr(forward_batch, "decode_trtllm_mla_metadata", None)
- if metadata is None:
- metadata = self.forward_metadata
-
- # ---------------------------------------------------------------------
- # According to the FlashInfer test, both slices should contain the full 576-dim tensor.
- # Use torch.stack so each slice is an independent contiguous view **after** we have
- # patched the pages in-place.
- kv_cache = torch.stack([pages, pages], dim=1) # (P, 2, blk, 576)
-
- # Metadata already obtained above; no change needed.
-
- # ------------- DEBUG KV CACHE CONSTRUCTION -------------
- if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
- print(
- f"[TRTLLM-MLA] k_cache_raw: {k_cache.shape} {k_cache.dtype}\n"
- f"[TRTLLM-MLA] pages : {pages.shape} {pages.dtype}\n"
- f"[TRTLLM-MLA] kv_cache : {kv_cache.shape} {kv_cache.dtype}\n"
- f"[TRTLLM-MLA] block_kv_indices: {metadata.block_kv_indices.shape} {metadata.block_kv_indices.dtype}\n"
- f"[TRTLLM-MLA] workspace: {metadata.workspace.shape if metadata.workspace is not None else 'None'}\n"
- f"[TRTLLM-MLA] max_seq_len: {metadata.block_kv_indices.shape[1] * self.page_size}\n"
- f"[TRTLLM-MLA] k_cache[0,0,:3]: {k_cache[0,0,:3] if k_cache.numel() > 0 else 'empty'}\n"
- f"[TRTLLM-MLA] pages[0,0,:3]: {pages[0,0,:3] if pages.numel() > 0 else 'empty'}\n"
- f"[TRTLLM-MLA] kv_cache[0,0,0,:3]: {kv_cache[0,0,0,:3] if kv_cache.numel() > 0 else 'empty'}\n"
- f"[TRTLLM-MLA] kv_cache[0,1,0,:3]: {kv_cache[0,1,0,:3] if kv_cache.numel() > 0 else 'empty'}\n"
- f"[TRTLLM-MLA] block_kv_indices[0,:5]: {metadata.block_kv_indices[0,:5] if metadata.block_kv_indices.numel() > 0 else 'empty'}"
- )
-
- # ------------- DEBUG (align with FlashInfer-MLA) -------------
- if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
- q_nope_str = f"{q_nope.shape} {q_rope.dtype}" if q_nope is not None else "None"
- q_rope_str = f"{q_rope.shape} {q_rope.dtype}" if q_rope is not None else "None"
- q_nope_sample = f"{q_nope[0,:3,:3] if q_rope is not None and q_nope.numel() > 0 else 'empty'}"
- q_rope_sample = f"{q_rope[0,:3,:3] if q_rope is not None and q_rope.numel() > 0 else 'empty'}"
-
- print(
- f"[TRTLLM-MLA] q_nope : {q_nope_str}\n"
- f"[TRTLLM-MLA] q_rope : {q_rope_str}\n"
- f"[TRTLLM-MLA] query : {query.shape} {query.dtype}\n"
- f"[TRTLLM-MLA] kv_cache: {kv_cache.shape} {kv_cache.dtype}\n"
- f"[TRTLLM-MLA] scale : {sm_scale}\n"
- f"[TRTLLM-MLA] cache_loc: {cache_loc.shape} {cache_loc.dtype}\n"
- f"[TRTLLM-MLA] seq_lens: {forward_batch.seq_lens.shape} {forward_batch.seq_lens.dtype}\n"
- f"[TRTLLM-MLA] batch_size: {forward_batch.batch_size}\n"
- f"[TRTLLM-MLA] page_size: {self.page_size}\n"
- f"[TRTLLM-MLA] qk_nope_head_dim: {self.qk_nope_head_dim}\n"
- f"[TRTLLM-MLA] qk_rope_head_dim: {self.qk_rope_head_dim}\n"
- f"[TRTLLM-MLA] kv_lora_rank: {self.kv_lora_rank}\n"
- f"[TRTLLM-MLA] v_head_dim: {layer.v_head_dim}\n"
- f"[TRTLLM-MLA] kv_cache_dim: {self.kv_cache_dim}\n"
- f"[TRTLLM-MLA] tp_q_head_num: {layer.tp_q_head_num}\n"
- f"[TRTLLM-MLA] tp_k_head_num: {layer.tp_k_head_num}\n"
- f"[TRTLLM-MLA] layer_id: {layer.layer_id}\n"
- f"[TRTLLM-MLA] q_nope[0,:3,:3]: {q_nope_sample}\n"
- f"[TRTLLM-MLA] q_rope[0,:3,:3]: {q_rope_sample}\n"
- f"[TRTLLM-MLA] query[0,:3,:3]: {query[0,:3,:3] if query.numel() > 0 else 'empty'}\n"
- f"[TRTLLM-MLA] seq_lens values: {forward_batch.seq_lens[:min(5, len(forward_batch.seq_lens))]}\n"
- f"[TRTLLM-MLA] cache_loc values: {cache_loc[:min(5, len(cache_loc))]}"
- )
+ # Prepare tensors for TRT-LLM kernel
+ query = self._prepare_query_tensor(q, q_rope, layer)
+ kv_cache = self._prepare_kv_cache(layer, forward_batch)
+
+ # Get metadata
+ metadata = getattr(forward_batch, "decode_trtllm_mla_metadata", None) or self.forward_metadata
+
+ # Get quantization scales
+ q_scale, k_scale, v_scale, sm_scale, o_scale = self._get_quantization_scales()
-
+ # Call TRT-LLM kernel
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache,
@@ -397,68 +372,21 @@ def forward_decode(
block_tables=metadata.block_kv_indices,
seq_lens=forward_batch.seq_lens.to(torch.int32),
block_size=self.page_size,
- max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size), # max_seq_len equals padded_blocks * page_size.
- q_scale=1.0,
- k_scale=1.0,
- v_scale=1.0,
+ max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
+ q_scale=q_scale,
+ k_scale=k_scale,
+ v_scale=v_scale,
sm_scale=sm_scale,
- o_scale=1.0,
+ o_scale=o_scale,
)
- # TRTLLM kernel may return both V and ROPE dims (kv_lora_rank + qk_rope_head_dim).
- # We only need the value projection part (v_head_dim).
- raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
-
- # ------------- DEBUG KERNEL OUTPUT -------------
- if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
- print(
- f"[TRTLLM-MLA] raw_out : {raw_out.shape} {raw_out.dtype}\n"
- f"[TRTLLM-MLA] raw_out_v : {raw_out_v.shape} {raw_out_v.dtype}\n"
- f"[TRTLLM-MLA] raw_out[0,:3,:3]: {raw_out[0,:3,:3] if raw_out.numel() > 0 else 'empty'}\n"
- f"[TRTLLM-MLA] raw_out_v[0,:3,:3]: {raw_out_v[0,:3,:3] if raw_out_v.numel() > 0 else 'empty'}\n"
- f"[TRTLLM-MLA] raw_out stats: min={raw_out.min():.6f}, max={raw_out.max():.6f}, mean={raw_out.mean():.6f}\n"
- f"[TRTLLM-MLA] raw_out_v stats: min={raw_out_v.min():.6f}, max={raw_out_v.max():.6f}, mean={raw_out_v.mean():.6f}"
- )
+ # Extract value projection part and reshape
+ raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
+
+ # Truncate if needed
if output.shape[0] > forward_batch.batch_size:
output = output[: forward_batch.batch_size]
-
- # ------------- DEBUG FINAL OUTPUT -------------
- if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
- print(
- f"[TRTLLM-MLA] output : {output.shape} {output.dtype}\n"
- f"[TRTLLM-MLA] output[0,:10]: {output[0,:10] if output.numel() > 0 else 'empty'}\n"
- f"[TRTLLM-MLA] output[0,10:20]: {output[0,10:20] if output.numel() > 10 else 'empty'}\n"
- f"[TRTLLM-MLA] output[0,20:30]: {output[0,20:30] if output.numel() > 20 else 'empty'}\n"
- f"[TRTLLM-MLA] output stats: min={output.min():.6f}, max={output.max():.6f}, mean={output.mean():.6f}\n"
- f"[TRTLLM-MLA] output std: {output.std():.6f}\n"
- f"[TRTLLM-MLA] output_reshaped: {output.shape[0]} -> {forward_batch.batch_size}\n"
- f"[TRTLLM-MLA] ===== END TRTLLM-MLA DEBUG ====="
- )
-
- return output
- @staticmethod
- def _pad_invalid_blocks(block_kv_indices: torch.Tensor):
- """Replace -1 paddings with the last valid page id for each row.
-
- TRT-LLM treats a leading -1 as "empty sequence" and will skip the
- whole sequence. For shorter sequences we therefore replicate the
- last real page id into the padded region instead of leaving -1.
- """
- for row in range(block_kv_indices.size(0)):
- row_view = block_kv_indices[row]
- # Find last non-negative entry (every sequence has at least one)
- valid_mask = row_view >= 0
- if not valid_mask.any():
- # Defensive – shouldn’t happen in decode mode
- continue
- last_valid = row_view[valid_mask][-1]
- row_view[~valid_mask] = last_valid
-
- # Extra visibility when debugging
- if os.getenv("SGLANG_DEBUG_MLA", "0") == "1":
- preview_rows = min(6, block_kv_indices.size(0))
- print("[TRTLLM-MLA] block_kv_indices preview (after padding):")
- for i in range(preview_rows):
- print(f" seq {i}:", block_kv_indices[i].tolist())
\ No newline at end of file
+ return output
+
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index c14c3fbef48f..6f6bdb284331 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -396,18 +396,7 @@ def model_specific_adjustment(self):
):
server_args.attention_backend = "fa3"
elif is_sm100_supported():
- # On Blackwell, prefer TRTLLM MLA if available, otherwise flashinfer
- if is_flashinfer_available():
- try:
- import flashinfer
- if hasattr(flashinfer.decode, 'trtllm_batch_decode_with_kv_cache_mla'):
- server_args.attention_backend = "trtllm_mla"
- else:
- server_args.attention_backend = "flashinfer"
- except:
- server_args.attention_backend = "flashinfer"
- else:
- server_args.attention_backend = "flashinfer"
+ server_args.attention_backend = "trtllm_mla"
elif _is_hip:
head_num = self.model_config.get_num_kv_heads(self.tp_size)
# TODO current aiter only support head number 16 or 128 head number
From ba1220d8f0007b19ec9cb315fe2098957f042f65 Mon Sep 17 00:00:00 2001
From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Date: Mon, 14 Jul 2025 20:02:01 -0700
Subject: [PATCH 06/27] add utils.py modification
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
---
python/sglang/srt/layers/attention/utils.py | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py
index 71633d12dce5..35b5e2d25225 100644
--- a/python/sglang/srt/layers/attention/utils.py
+++ b/python/sglang/srt/layers/attention/utils.py
@@ -1,6 +1,10 @@
import triton
import triton.language as tl
+# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
+# Number of pages that the kernel writes per iteration.
+# Exposed here so other Python modules can import it instead of hard-coding 64.
+NUM_PAGE_PER_BLOCK = 64
@triton.jit
def create_flashinfer_kv_indices_triton(
@@ -53,6 +57,7 @@ def create_flashmla_kv_indices_triton(
PAGED_SIZE: tl.constexpr = 64,
):
BLOCK_SIZE: tl.constexpr = 4096
+ # Keep in sync with module-level NUM_PAGE_PER_BLOCK constant above
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
pid = tl.program_id(axis=0)
From f67fe412f26f5e232d4730a03af1f504c803500a Mon Sep 17 00:00:00 2001
From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Date: Mon, 14 Jul 2025 20:05:38 -0700
Subject: [PATCH 07/27] pre-commit
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
---
.../attention/trtllm_gen_mla_backend.py | 135 ++++++----
python/sglang/srt/layers/attention/utils.py | 3 +-
.../attention/test_trtllm_gen_mla_backend.py | 255 ++++++++++++------
3 files changed, 250 insertions(+), 143 deletions(-)
mode change 100644 => 100755 python/sglang/srt/layers/attention/utils.py
diff --git a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
index eaf1aa1c2eef..8243fc09bf8d 100644
--- a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
+++ b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
@@ -4,15 +4,18 @@
Support attention backend for TRTLLM-Gen MLA kernels from flashinfer.
"""
+import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
-import math
import torch
import triton
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
-from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton, NUM_PAGE_PER_BLOCK
+from sglang.srt.layers.attention.utils import (
+ NUM_PAGE_PER_BLOCK,
+ create_flashmla_kv_indices_triton,
+)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available
@@ -26,7 +29,9 @@
from sglang.srt.speculative.spec_info import SpecInfo
# Constants
-TRTLLM_BLOCK_CONSTRAINT = 128 # TRT-LLM kernel constraint: block_num % (128 / block_size) == 0
+TRTLLM_BLOCK_CONSTRAINT = (
+ 128 # TRT-LLM kernel constraint: block_num % (128 / block_size) == 0
+)
DEFAULT_WORKSPACE_SIZE_MB = 128
DEFAULT_QUANTIZATION_SCALE = 1.0 # Default scale for FP8 quantization
DEFAULT_SM_SCALE = 1.0 # Default softmax scale
@@ -35,6 +40,7 @@
@dataclass
class TRTLLMGENMLADecodeMetadata:
"""Metadata for TRTLLM MLA decode operations."""
+
workspace: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None
@@ -55,32 +61,32 @@ def __init__(
# Cache model config for easier access
config = model_runner.model_config
-
+
# Model parameters
self.num_q_heads = config.num_attention_heads // get_attention_tp_size()
self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size())
self.num_local_heads = config.num_attention_heads // get_attention_tp_size()
-
+
# MLA-specific dimensions
self.kv_lora_rank = config.kv_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
-
+
# Runtime parameters
self.scaling = config.scaling
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.page_size = model_runner.page_size
self.req_to_token = model_runner.req_to_token_pool.req_to_token
-
+
# Workspace allocation
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
self.workspace_buffer = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
)
-
+
# CUDA graph state
self.decode_cuda_graph_metadata = {}
self.cuda_graph_kv_indices = None
@@ -89,15 +95,15 @@ def __init__(
def _calc_padded_blocks(self, max_seq_len: int) -> int:
"""
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
-
+
Args:
max_seq_len: Maximum sequence length in tokens
-
+
Returns:
Number of blocks padded to satisfy all constraints
"""
blocks = triton.cdiv(max_seq_len, self.page_size)
-
+
# TWO constraints require padding:
# 1. TRT-LLM kernel expects: block_num % (128 / block_size) == 0
# This is a hard requirement from the CUDA kernel implementation.
@@ -106,14 +112,14 @@ def _calc_padded_blocks(self, max_seq_len: int) -> int:
# unconditionally to positions [i*64, (i+1)*64) without per-element masking.
# If the row length isn't a multiple of 64, the final burst would overrun
# the allocated buffer and corrupt memory.
- #
+ #
# We need to satisfy BOTH constraints, so we take the LCM of:
# - trtllm_constraint = 128 // page_size
# - triton_constraint = 64
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
triton_constraint = NUM_PAGE_PER_BLOCK
min_blocks = math.lcm(trtllm_constraint, triton_constraint)
-
+
if blocks % min_blocks != 0:
blocks = triton.cdiv(blocks, min_blocks) * min_blocks
return blocks
@@ -121,7 +127,7 @@ def _calc_padded_blocks(self, max_seq_len: int) -> int:
def _get_quantization_scales(self) -> tuple[float, float, float, float, float]:
"""
Get quantization scales for q, k, v, sm, and o tensors.
-
+
Returns:
Tuple of (q_scale, k_scale, v_scale, sm_scale, o_scale)
"""
@@ -129,20 +135,22 @@ def _get_quantization_scales(self) -> tuple[float, float, float, float, float]:
# For now, use default values for FP8
return (
DEFAULT_QUANTIZATION_SCALE, # q_scale
- DEFAULT_QUANTIZATION_SCALE, # k_scale
+ DEFAULT_QUANTIZATION_SCALE, # k_scale
DEFAULT_QUANTIZATION_SCALE, # v_scale
- DEFAULT_SM_SCALE, # sm_scale
+ DEFAULT_SM_SCALE, # sm_scale
DEFAULT_QUANTIZATION_SCALE, # o_scale
)
- def _prepare_kv_cache(self, layer: RadixAttention, forward_batch: ForwardBatch) -> torch.Tensor:
+ def _prepare_kv_cache(
+ self, layer: RadixAttention, forward_batch: ForwardBatch
+ ) -> torch.Tensor:
"""
Prepare KV cache tensor in the format expected by TRT-LLM kernel.
-
+
Args:
layer: Attention layer
forward_batch: Forward batch info
-
+
Returns:
KV cache tensor shaped (num_pages, 2, page_size, kv_cache_dim)
"""
@@ -152,26 +160,25 @@ def _prepare_kv_cache(self, layer: RadixAttention, forward_batch: ForwardBatch)
return torch.stack([pages, pages], dim=1)
def _prepare_query_tensor(
- self,
- q: torch.Tensor,
- q_rope: Optional[torch.Tensor],
- layer: RadixAttention
+ self, q: torch.Tensor, q_rope: Optional[torch.Tensor], layer: RadixAttention
) -> torch.Tensor:
"""
Prepare query tensor in the format expected by TRT-LLM kernel.
-
+
Args:
q: Query tensor
q_rope: Optional RoPE query tensor
layer: Attention layer
-
+
Returns:
Query tensor with concatenated NOPE and RoPE parts
"""
if q_rope is not None:
# q contains the NOPE part (v_head_dim)
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
- q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim)
+ q_rope = q_rope.view(
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
+ )
else:
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = reshaped_q[:, :, : layer.v_head_dim]
@@ -180,30 +187,30 @@ def _prepare_query_tensor(
return torch.cat([q_nope, q_rope], dim=-1)
def _create_block_kv_indices(
- self,
- batch_size: int,
- max_blocks: int,
+ self,
+ batch_size: int,
+ max_blocks: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
- device: torch.device
+ device: torch.device,
) -> torch.Tensor:
"""
Create block KV indices tensor using Triton kernel.
-
+
Args:
batch_size: Batch size
max_blocks: Maximum number of blocks per sequence
req_pool_indices: Request pool indices
seq_lens: Sequence lengths
device: Target device
-
+
Returns:
Block KV indices tensor
"""
block_kv_indices = torch.full(
(batch_size, max_blocks), -1, dtype=torch.int32, device=device
)
-
+
create_flashmla_kv_indices_triton[(batch_size,)](
self.req_to_token,
req_pool_indices,
@@ -214,7 +221,7 @@ def _create_block_kv_indices(
max_blocks,
self.page_size,
)
-
+
return block_kv_indices
def init_cuda_graph_state(
@@ -225,7 +232,7 @@ def init_cuda_graph_state(
):
"""Initialize CUDA graph state for TRTLLM MLA."""
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
-
+
self.cuda_graph_kv_indices = torch.full(
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
)
@@ -247,7 +254,7 @@ def init_forward_metadata_capture_cuda_graph(
if forward_mode.is_decode_or_idle() and spec_info is None:
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
-
+
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
@@ -258,7 +265,7 @@ def init_forward_metadata_capture_cuda_graph(
max_seqlen_pad,
self.page_size,
)
-
+
metadata = TRTLLMGENMLADecodeMetadata(
self.cuda_graph_workspace, block_kv_indices
)
@@ -266,7 +273,13 @@ def init_forward_metadata_capture_cuda_graph(
self.forward_metadata = metadata
else:
super().init_forward_metadata_capture_cuda_graph(
- bs, num_tokens, req_pool_indices, seq_lens, encoder_lens, forward_mode, spec_info
+ bs,
+ num_tokens,
+ req_pool_indices,
+ seq_lens,
+ encoder_lens,
+ forward_mode,
+ spec_info,
)
def init_forward_metadata_replay_cuda_graph(
@@ -283,7 +296,7 @@ def init_forward_metadata_replay_cuda_graph(
"""Replay CUDA graph with new inputs."""
if forward_mode.is_decode_or_idle() and spec_info is None:
metadata = self.decode_cuda_graph_metadata[bs]
-
+
# Update block indices for new sequences
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
@@ -295,12 +308,18 @@ def init_forward_metadata_replay_cuda_graph(
metadata.block_kv_indices.shape[1],
self.page_size,
)
-
+
self.forward_metadata = metadata
else:
super().init_forward_metadata_replay_cuda_graph(
- bs, req_pool_indices, seq_lens, seq_lens_sum,
- encoder_lens, forward_mode, spec_info, seq_lens_cpu
+ bs,
+ req_pool_indices,
+ seq_lens,
+ seq_lens_sum,
+ encoder_lens,
+ forward_mode,
+ spec_info,
+ seq_lens_cpu,
)
def get_cuda_graph_seq_len_fill_value(self) -> int:
@@ -309,9 +328,12 @@ def get_cuda_graph_seq_len_fill_value(self) -> int:
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize the metadata for a forward pass."""
- if forward_batch.forward_mode.is_decode_or_idle() and forward_batch.spec_info is None:
+ if (
+ forward_batch.forward_mode.is_decode_or_idle()
+ and forward_batch.spec_info is None
+ ):
bs = forward_batch.batch_size
-
+
# Get maximum sequence length
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
max_seq = forward_batch.seq_lens_cpu.max().item()
@@ -320,8 +342,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices(
- bs, max_seqlen_pad, forward_batch.req_pool_indices,
- forward_batch.seq_lens, forward_batch.seq_lens.device
+ bs,
+ max_seqlen_pad,
+ forward_batch.req_pool_indices,
+ forward_batch.seq_lens,
+ forward_batch.seq_lens.device,
)
self.forward_metadata = TRTLLMGENMLADecodeMetadata(
@@ -347,17 +372,22 @@ def forward_decode(
if k is not None and save_kv_cache:
cache_loc = forward_batch.out_cache_loc
if k_rope is not None:
- forward_batch.token_to_kv_pool.set_mla_kv_buffer(layer, cache_loc, k, k_rope)
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
+ layer, cache_loc, k, k_rope
+ )
elif v is not None:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
# Prepare tensors for TRT-LLM kernel
query = self._prepare_query_tensor(q, q_rope, layer)
kv_cache = self._prepare_kv_cache(layer, forward_batch)
-
+
# Get metadata
- metadata = getattr(forward_batch, "decode_trtllm_mla_metadata", None) or self.forward_metadata
-
+ metadata = (
+ getattr(forward_batch, "decode_trtllm_mla_metadata", None)
+ or self.forward_metadata
+ )
+
# Get quantization scales
q_scale, k_scale, v_scale, sm_scale, o_scale = self._get_quantization_scales()
@@ -383,10 +413,9 @@ def forward_decode(
# Extract value projection part and reshape
raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
-
+
# Truncate if needed
if output.shape[0] > forward_batch.batch_size:
output = output[: forward_batch.batch_size]
- return output
-
+ return output
diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py
old mode 100644
new mode 100755
index 35b5e2d25225..700248efdc91
--- a/python/sglang/srt/layers/attention/utils.py
+++ b/python/sglang/srt/layers/attention/utils.py
@@ -4,7 +4,8 @@
# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
# Number of pages that the kernel writes per iteration.
# Exposed here so other Python modules can import it instead of hard-coding 64.
-NUM_PAGE_PER_BLOCK = 64
+NUM_PAGE_PER_BLOCK = 64
+
@triton.jit
def create_flashinfer_kv_indices_triton(
diff --git a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py
index acf752fd23ce..b5a50cffb96e 100644
--- a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py
+++ b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py
@@ -11,24 +11,26 @@
The tests use unittest with subTest for parameterized testing, following the sglang test patterns.
"""
+
import unittest
-import torch
-import numpy as np
from types import SimpleNamespace
+import numpy as np
+import torch
+
from sglang.srt.layers import dp_attention as _dp_attn
# Patch DP-attention globals before importing backends
_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test
from sglang.srt.configs.model_config import AttentionArch
-from sglang.srt.layers.attention.trtllm_gen_mla_backend import TRTLLMGENMLABackend
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
+from sglang.srt.layers.attention.trtllm_gen_mla_backend import TRTLLMGENMLABackend
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
-from sglang.test.test_utils import CustomTestCase
from sglang.srt.utils import is_flashinfer_available
+from sglang.test.test_utils import CustomTestCase
class MockModelRunner:
@@ -60,7 +62,9 @@ def __init__(self, page_size: int):
# Req-to-token pool
max_bs = 64
max_ctx = self.model_config.context_len
- req_to_token = torch.arange(max_bs * max_ctx, dtype=torch.int32, device=self.device).reshape(max_bs, max_ctx)
+ req_to_token = torch.arange(
+ max_bs * max_ctx, dtype=torch.int32, device=self.device
+ ).reshape(max_bs, max_ctx)
self.req_to_token_pool = type(
"TokenPool",
(),
@@ -69,7 +73,7 @@ def __init__(self, page_size: int):
"req_to_token": req_to_token,
},
)
-
+
# KV-token pool (MLA)
self.token_to_kv_pool = MLATokenToKVPool(
size=max_bs * max_ctx,
@@ -85,27 +89,35 @@ def __init__(self, page_size: int):
def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
"""Compare outputs with detailed analysis."""
-
+
# Basic checks
- assert trtllm_out.shape == reference_out.shape, f"Shape mismatch: {trtllm_out.shape} vs {reference_out.shape}"
- assert trtllm_out.dtype == reference_out.dtype, f"Dtype mismatch: {trtllm_out.dtype} vs {reference_out.dtype}"
-
+ assert (
+ trtllm_out.shape == reference_out.shape
+ ), f"Shape mismatch: {trtllm_out.shape} vs {reference_out.shape}"
+ assert (
+ trtllm_out.dtype == reference_out.dtype
+ ), f"Dtype mismatch: {trtllm_out.dtype} vs {reference_out.dtype}"
+
# Check for NaN/Inf
assert not torch.isnan(trtllm_out).any(), "TRTLLM output contains NaN"
assert not torch.isnan(reference_out).any(), "Reference output contains NaN"
assert not torch.isinf(trtllm_out).any(), "TRTLLM output contains Inf"
assert not torch.isinf(reference_out).any(), "Reference output contains Inf"
-
+
# Element-wise differences
diff = (trtllm_out - reference_out).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
-
+
# Check numerical equivalence
- all_close = torch.allclose(trtllm_out, reference_out, rtol=tolerance, atol=tolerance)
-
+ all_close = torch.allclose(
+ trtllm_out, reference_out, rtol=tolerance, atol=tolerance
+ )
+
if not all_close:
- print(f"Comparison failed: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, tolerance={tolerance}")
+ print(
+ f"Comparison failed: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, tolerance={tolerance}"
+ )
# Find top differences for debugging
flat_diff = diff.flatten()
top_diff_indices = torch.topk(flat_diff, k=min(5, flat_diff.numel())).indices
@@ -114,13 +126,17 @@ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
idx_tuple = np.unravel_index(idx.cpu().numpy(), trtllm_out.shape)
trt_val = trtllm_out[idx_tuple].item()
ref_val = reference_out[idx_tuple].item()
- print(f" [{idx_tuple}]: TRTLLM={trt_val:.6f}, Reference={ref_val:.6f}, diff={abs(trt_val-ref_val):.6f}")
-
+ print(
+ f" [{idx_tuple}]: TRTLLM={trt_val:.6f}, Reference={ref_val:.6f}, diff={abs(trt_val-ref_val):.6f}"
+ )
+
return all_close
-@unittest.skipIf(not torch.cuda.is_available() or not is_flashinfer_available(),
- "CUDA + flashinfer required")
+@unittest.skipIf(
+ not torch.cuda.is_available() or not is_flashinfer_available(),
+ "CUDA + flashinfer required",
+)
class TestTRTLLMMLAClean(CustomTestCase):
"""Test TRTLLM MLA backend against FlashInfer MLA backend (reference)."""
@@ -129,14 +145,14 @@ def setUp(self):
self.device = "cuda"
self.dtype = torch.bfloat16
self.page_size = 32
-
+
# Create model runner and backends
self.model_runner_trtllm = MockModelRunner(self.page_size)
self.model_runner_reference = MockModelRunner(self.page_size)
-
+
self.trtllm_backend = TRTLLMGENMLABackend(self.model_runner_trtllm)
self.reference_backend = FlashInferMLAAttnBackend(self.model_runner_reference)
-
+
# Create RadixAttention layer
self.layer = RadixAttention(
num_heads=128,
@@ -151,7 +167,9 @@ def setUp(self):
def _create_qkv_tensors(self, batch_size):
"""Create Q, K, V tensors for testing."""
head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim
- q = torch.randn((batch_size, 128, head_dim), dtype=self.dtype, device=self.device)
+ q = torch.randn(
+ (batch_size, 128, head_dim), dtype=self.dtype, device=self.device
+ )
k = torch.randn((batch_size, 1, head_dim), dtype=self.dtype, device=self.device)
v = torch.randn((batch_size, 1, 512), dtype=self.dtype, device=self.device)
return q, k, v
@@ -176,73 +194,107 @@ def _create_forward_batch(self, batch_size, seq_lens, backend, model_runner):
def _populate_kv_cache(self, batch_size, seq_lens, model_runners):
"""Populate KV cache with identical data for both backends."""
torch.manual_seed(42) # Fixed seed for reproducible cache
-
+
for model_runner in model_runners:
torch.manual_seed(42) # Reset seed for each backend
for i in range(batch_size):
seq_len = int(seq_lens[i].item())
for token_idx in range(seq_len - 1):
# Create random K components for MLA
- cache_k_nope = torch.randn((1, 128), dtype=self.dtype, device=self.device)
- cache_k_rope = torch.randn((1, 64), dtype=self.dtype, device=self.device)
-
+ cache_k_nope = torch.randn(
+ (1, 128), dtype=self.dtype, device=self.device
+ )
+ cache_k_rope = torch.randn(
+ (1, 64), dtype=self.dtype, device=self.device
+ )
+
# Calculate cache location
- cache_loc = model_runner.req_to_token_pool.req_to_token[i, token_idx]
-
+ cache_loc = model_runner.req_to_token_pool.req_to_token[
+ i, token_idx
+ ]
+
# Save to KV cache
model_runner.token_to_kv_pool.set_mla_kv_buffer(
- self.layer,
- cache_loc.unsqueeze(0),
- cache_k_nope.squeeze(0),
- cache_k_rope.squeeze(0)
+ self.layer,
+ cache_loc.unsqueeze(0),
+ cache_k_nope.squeeze(0),
+ cache_k_rope.squeeze(0),
)
def test_decode_output_match(self):
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
# Test different batch sizes and sequence lengths
test_cases = [
- (1, 64), (4, 64), (16, 64), (32, 64),
- (1, 128), (4, 128), (16, 128), (32, 128),
- (1, 256), (4, 256), (16, 256), (32, 256),
+ (1, 64),
+ (4, 64),
+ (16, 64),
+ (32, 64),
+ (1, 128),
+ (4, 128),
+ (16, 128),
+ (32, 128),
+ (1, 256),
+ (4, 256),
+ (16, 256),
+ (32, 256),
]
-
+
for batch_size, max_seq_len in test_cases:
with self.subTest(batch_size=batch_size, max_seq_len=max_seq_len):
# Create identical sequence lengths for both backends
torch.manual_seed(42)
- seq_lens = torch.randint(1, max_seq_len, (batch_size,), device=self.device)
+ seq_lens = torch.randint(
+ 1, max_seq_len, (batch_size,), device=self.device
+ )
seq_lens[0] = max_seq_len # Ensure at least one max length
-
+
# Create forward batches with identical inputs
fb_trtllm = self._create_forward_batch(
- batch_size, seq_lens.clone(), self.trtllm_backend, self.model_runner_trtllm
+ batch_size,
+ seq_lens.clone(),
+ self.trtllm_backend,
+ self.model_runner_trtllm,
)
fb_reference = self._create_forward_batch(
- batch_size, seq_lens.clone(), self.reference_backend, self.model_runner_reference
+ batch_size,
+ seq_lens.clone(),
+ self.reference_backend,
+ self.model_runner_reference,
)
-
+
# Initialize metadata for both backends
self.trtllm_backend.init_forward_metadata(fb_trtllm)
self.reference_backend.init_forward_metadata(fb_reference)
-
+
# Populate both KV caches identically
- self._populate_kv_cache(batch_size, seq_lens, [self.model_runner_trtllm, self.model_runner_reference])
-
+ self._populate_kv_cache(
+ batch_size,
+ seq_lens,
+ [self.model_runner_trtllm, self.model_runner_reference],
+ )
+
# Create Q, K, V tensors for current decode step
torch.manual_seed(123) # Fixed seed for Q, K, V
q, k, v = self._create_qkv_tensors(batch_size)
-
+
# Run forward decode on both backends
- out_trtllm = self.trtllm_backend.forward_decode(q.clone(), k.clone(), v, self.layer, fb_trtllm)
- out_reference = self.reference_backend.forward_decode(q.clone(), k.clone(), v.clone(), self.layer, fb_reference)
-
+ out_trtllm = self.trtllm_backend.forward_decode(
+ q.clone(), k.clone(), v, self.layer, fb_trtllm
+ )
+ out_reference = self.reference_backend.forward_decode(
+ q.clone(), k.clone(), v.clone(), self.layer, fb_reference
+ )
+
# Compare outputs
- comparison_passed = compare_outputs(out_trtllm, out_reference, tolerance=1e-2)
-
- self.assertTrue(comparison_passed,
+ comparison_passed = compare_outputs(
+ out_trtllm, out_reference, tolerance=1e-2
+ )
+
+ self.assertTrue(
+ comparison_passed,
f"TRTLLM and Reference outputs differ beyond tolerance. "
f"batch_size={batch_size}, max_seq_len={max_seq_len}, "
- f"Max diff: {(out_trtllm - out_reference).abs().max().item()}"
+ f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
)
def test_different_page_sizes(self):
@@ -250,35 +302,43 @@ def test_different_page_sizes(self):
page_sizes = [32, 64]
batch_size = 8
max_seq_len = 128
-
+
for page_size in page_sizes:
with self.subTest(page_size=page_size):
# Create model runner with specific page size
model_runner = MockModelRunner(page_size)
backend = TRTLLMGENMLABackend(model_runner)
-
+
# Create sequence lengths
torch.manual_seed(42)
- seq_lens = torch.randint(1, max_seq_len, (batch_size,), device=self.device)
+ seq_lens = torch.randint(
+ 1, max_seq_len, (batch_size,), device=self.device
+ )
seq_lens[0] = max_seq_len
-
+
# Create forward batch
- fb = self._create_forward_batch(batch_size, seq_lens, backend, model_runner)
+ fb = self._create_forward_batch(
+ batch_size, seq_lens, backend, model_runner
+ )
backend.init_forward_metadata(fb)
-
+
# Populate KV cache
self._populate_kv_cache(batch_size, seq_lens, [model_runner])
-
+
# Create Q, K, V tensors
torch.manual_seed(123)
q, k, v = self._create_qkv_tensors(batch_size)
-
+
# Run forward decode
output = backend.forward_decode(q, k, v, self.layer, fb)
-
+
# Basic checks
expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim
- self.assertEqual(output.shape, expected_shape, f"Output shape mismatch: {output.shape} vs {expected_shape}")
+ self.assertEqual(
+ output.shape,
+ expected_shape,
+ f"Output shape mismatch: {output.shape} vs {expected_shape}",
+ )
self.assertFalse(torch.isnan(output).any(), "Output contains NaN")
self.assertFalse(torch.isinf(output).any(), "Output contains Inf")
@@ -286,23 +346,25 @@ def test_basic_functionality(self):
"""Test basic functionality with minimal setup."""
batch_size = 2
max_seq_len = 32
-
+
# Create sequence lengths
seq_lens = torch.tensor([max_seq_len, max_seq_len // 2], device=self.device)
-
+
# Create forward batch
- fb = self._create_forward_batch(batch_size, seq_lens, self.trtllm_backend, self.model_runner_trtllm)
+ fb = self._create_forward_batch(
+ batch_size, seq_lens, self.trtllm_backend, self.model_runner_trtllm
+ )
self.trtllm_backend.init_forward_metadata(fb)
-
+
# Populate KV cache
self._populate_kv_cache(batch_size, seq_lens, [self.model_runner_trtllm])
-
+
# Create Q, K, V tensors
q, k, v = self._create_qkv_tensors(batch_size)
-
+
# Run forward decode
output = self.trtllm_backend.forward_decode(q, k, v, self.layer, fb)
-
+
# Basic checks
expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim
self.assertEqual(output.shape, expected_shape)
@@ -321,22 +383,26 @@ def test_forward_decode_shape_sanity(self):
(1, 64, 32),
(32, 1024, 64),
]
-
+
for batch_size, seq_len, page_size in test_configs:
- with self.subTest(batch_size=batch_size, seq_len=seq_len, page_size=page_size):
+ with self.subTest(
+ batch_size=batch_size, seq_len=seq_len, page_size=page_size
+ ):
# Create model runner with specific page size
model_runner = MockModelRunner(page_size)
backend = TRTLLMGENMLABackend(model_runner)
-
+
# Random seq lens (ensure one matches max)
torch.manual_seed(42)
seq_lens = torch.randint(1, seq_len, (batch_size,), device=self.device)
seq_lens[0] = seq_len
-
+
# Create forward batch
fb = ForwardBatch(
batch_size=batch_size,
- input_ids=torch.randint(0, 100, (batch_size, 1), device=self.device),
+ input_ids=torch.randint(
+ 0, 100, (batch_size, 1), device=self.device
+ ),
out_cache_loc=torch.arange(batch_size, device=self.device),
seq_lens_sum=int(seq_lens.sum().item()),
forward_mode=ForwardMode.DECODE,
@@ -347,15 +413,19 @@ def test_forward_decode_shape_sanity(self):
)
fb.req_to_token_pool = model_runner.req_to_token_pool
fb.token_to_kv_pool = model_runner.token_to_kv_pool
-
+
backend.init_forward_metadata(fb)
-
+
# Create Q, K, V tensors
head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim
- q = torch.randn((batch_size, 128, head_dim), dtype=self.dtype, device=self.device)
- k = torch.randn((batch_size, 1, head_dim), dtype=self.dtype, device=self.device)
+ q = torch.randn(
+ (batch_size, 128, head_dim), dtype=self.dtype, device=self.device
+ )
+ k = torch.randn(
+ (batch_size, 1, head_dim), dtype=self.dtype, device=self.device
+ )
v = None # TRTLLM MLA decode kernel ignores v
-
+
# Create layer
layer = RadixAttention(
num_heads=128,
@@ -366,21 +436,28 @@ def test_forward_decode_shape_sanity(self):
v_head_dim=512,
prefix="attn_mqa",
)
-
+
# Run forward decode
output = backend.forward_decode(q, k, v, layer, fb)
-
+
# Shape and sanity checks
expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim
- self.assertEqual(output.shape, expected_shape,
- f"Output shape mismatch for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})")
+ self.assertEqual(
+ output.shape,
+ expected_shape,
+ f"Output shape mismatch for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})",
+ )
self.assertEqual(output.dtype, self.dtype)
self.assertEqual(output.device.type, "cuda")
- self.assertFalse(torch.isnan(output).any(),
- f"Output contains NaN for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})")
- self.assertFalse(torch.isinf(output).any(),
- f"Output contains Inf for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})")
+ self.assertFalse(
+ torch.isnan(output).any(),
+ f"Output contains NaN for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})",
+ )
+ self.assertFalse(
+ torch.isinf(output).any(),
+ f"Output contains Inf for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})",
+ )
if __name__ == "__main__":
- unittest.main()
\ No newline at end of file
+ unittest.main()
From f69b3e07396adaee6c44a54c0acd2274921ea3d9 Mon Sep 17 00:00:00 2001
From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Date: Mon, 21 Jul 2025 10:56:31 -0700
Subject: [PATCH 08/27] some neat picks
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
---
.../attention/trtllm_gen_mla_backend.py | 46 +-
python/sglang/srt/layers/attention/utils.py | 4 +-
.../attention/test_trtllm_gen_mla_backend.py | 840 +++++++++++++-----
3 files changed, 640 insertions(+), 250 deletions(-)
mode change 100644 => 100755 python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
diff --git a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
old mode 100644
new mode 100755
index 8243fc09bf8d..d80614ccd27a
--- a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
+++ b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
@@ -13,7 +13,7 @@
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.utils import (
- NUM_PAGE_PER_BLOCK,
+ TRITON_PAD_NUM_PAGE_PER_BLOCK,
create_flashmla_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
@@ -29,14 +29,17 @@
from sglang.srt.speculative.spec_info import SpecInfo
# Constants
-TRTLLM_BLOCK_CONSTRAINT = (
- 128 # TRT-LLM kernel constraint: block_num % (128 / block_size) == 0
-)
-DEFAULT_WORKSPACE_SIZE_MB = 128
-DEFAULT_QUANTIZATION_SCALE = 1.0 # Default scale for FP8 quantization
-DEFAULT_SM_SCALE = 1.0 # Default softmax scale
+DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
+
+# Block constraint from flashinfer requirements
+# See: https://github.com/flashinfer-ai/flashinfer/blob/fe29ed63cb923f25cae70ef83f3fd16139305b35/flashinfer/decode.py#L2057
+TRTLLM_BLOCK_CONSTRAINT = 128
+# Quantization scale (1.0 for fp8->bf16 and bf16->bf16 conversions)
+DEFAULT_QUANTIZATION_SCALE = 1.0
+# Softmax scale (1.0 since TRTLLM applies 1/sqrt(head_dim) internally)
+DEFAULT_SM_SCALE = 1.0
@dataclass
class TRTLLMGENMLADecodeMetadata:
"""Metadata for TRTLLM MLA decode operations."""
@@ -46,7 +49,7 @@ class TRTLLMGENMLADecodeMetadata:
class TRTLLMGENMLABackend(FlashInferMLAAttnBackend):
- """TRTLLM MLA attention kernels from flashinfer."""
+ """TRTLLM MLA attention kernel from flashinfer."""
def __init__(
self,
@@ -59,7 +62,6 @@ def __init__(
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
)
- # Cache model config for easier access
config = model_runner.model_config
# Model parameters
@@ -104,24 +106,15 @@ def _calc_padded_blocks(self, max_seq_len: int) -> int:
"""
blocks = triton.cdiv(max_seq_len, self.page_size)
- # TWO constraints require padding:
- # 1. TRT-LLM kernel expects: block_num % (128 / block_size) == 0
- # This is a hard requirement from the CUDA kernel implementation.
- # 2. Our Triton helper `create_flashmla_kv_indices_triton` builds page tables
- # in fixed bursts of 64 indices per outer loop iteration. Each burst writes
- # unconditionally to positions [i*64, (i+1)*64) without per-element masking.
- # If the row length isn't a multiple of 64, the final burst would overrun
- # the allocated buffer and corrupt memory.
- #
- # We need to satisfy BOTH constraints, so we take the LCM of:
- # - trtllm_constraint = 128 // page_size
- # - triton_constraint = 64
+ # Apply dual constraints (take LCM to satisfy both):
+ # 1. TRT-LLM: block_num % (128 / page_size) == 0
+ # Reference: https://github.com/NVIDIA/TensorRT-LLM/issues/XYZ # TODO: add actual link
+ # 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
- triton_constraint = NUM_PAGE_PER_BLOCK
- min_blocks = math.lcm(trtllm_constraint, triton_constraint)
-
- if blocks % min_blocks != 0:
- blocks = triton.cdiv(blocks, min_blocks) * min_blocks
+ constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
+
+ if blocks % constraint_lcm != 0:
+ blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
return blocks
def _get_quantization_scales(self) -> tuple[float, float, float, float, float]:
@@ -410,6 +403,7 @@ def forward_decode(
o_scale=o_scale,
)
+ #TODO: test?
# Extract value projection part and reshape
raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py
index 700248efdc91..0cfb6a359cf2 100755
--- a/python/sglang/srt/layers/attention/utils.py
+++ b/python/sglang/srt/layers/attention/utils.py
@@ -4,7 +4,7 @@
# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
# Number of pages that the kernel writes per iteration.
# Exposed here so other Python modules can import it instead of hard-coding 64.
-NUM_PAGE_PER_BLOCK = 64
+TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
@triton.jit
@@ -58,7 +58,7 @@ def create_flashmla_kv_indices_triton(
PAGED_SIZE: tl.constexpr = 64,
):
BLOCK_SIZE: tl.constexpr = 4096
- # Keep in sync with module-level NUM_PAGE_PER_BLOCK constant above
+ # Keep in sync with module-level TRITON_PAD_NUM_PAGE_PER_BLOCK constant above
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
pid = tl.program_id(axis=0)
diff --git a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py
index b5a50cffb96e..e2ea8d1da744 100644
--- a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py
+++ b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py
@@ -1,17 +1,3 @@
-"""
-Clean test suite for TRTLLM MLA backend.
-
-This test file provides comprehensive testing for the TRTLLM MLA (Multi-Head Latent Attention) backend:
-
-1. test_basic_functionality: Basic smoke test with minimal setup
-2. test_decode_output_match: Compares TRTLLM MLA output against FlashInfer MLA reference
- across different batch sizes and sequence lengths
-3. test_different_page_sizes: Tests consistency across different page sizes
-4. test_forward_decode_shape_sanity: Shape and sanity checks across various configurations
-
-The tests use unittest with subTest for parameterized testing, following the sglang test patterns.
-"""
-
import unittest
from types import SimpleNamespace
@@ -25,7 +11,7 @@
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
-from sglang.srt.layers.attention.trtllm_gen_mla_backend import TRTLLMGENMLABackend
+from sglang.srt.layers.attention.trtllm_gen_mla_backend import TRTLLMGENMLABackend, TRTLLMGENMLADecodeMetadata
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@@ -33,34 +19,166 @@
from sglang.test.test_utils import CustomTestCase
+# Global configuration for all tests
+DEFAULT_CONFIG = {
+ "device": "cuda",
+ "dtype": torch.bfloat16,
+ "kv_cache_dtype": torch.bfloat16,
+ "context_len": 2048,
+ "max_bs": 64,
+ "tolerance": 1e-2,
+ "seed_cache": 42,
+ "seed_qkv": 123,
+ # MLA model config (TRTLLM MLA has fixed constraints)
+ "num_attention_heads": 128,
+ "kv_lora_rank": 512,
+ "qk_nope_head_dim": 128,
+ "qk_rope_head_dim": 64,
+ "v_head_dim": 512,
+ "num_kv_heads": 1,
+ "layer_id": 0,
+}
+
+# Centralized test cases for different test scenarios
+TEST_CASES = {
+ "basic_functionality": [
+ {
+ "name": "single",
+ "batch_size": 1,
+ "max_seq_len": 32,
+ "page_size": 32,
+ "description": "Minimal smoke test",
+ },
+ {
+ "name": "batch",
+ "batch_size": 32,
+ "max_seq_len": 128,
+ "page_size": 32,
+ "description": "Medium-scale batch",
+ },
+ ],
+
+ "decode_output_match": [
+ {
+ "name": "single",
+ "batch_size": 1,
+ "max_seq_len": 64,
+ "page_size": 32,
+ "description": "Single vs reference",
+ },
+ {
+ "name": "batch",
+ "batch_size": 32,
+ "max_seq_len": 64,
+ "page_size": 32,
+ "description": "Batch vs reference",
+ }
+ ],
+
+ "page_size_consistency": [
+ # Only 32 and 64 supported for now https://github.com/flashinfer-ai/flashinfer/blob/fe29ed63cb923f25cae70ef83f3fd16139305b35/flashinfer/decode.py#L2115
+ # TODO: Test 16 and 128. Pending cubins
+ {
+ "name": "page_32",
+ "batch_size": 8,
+ "max_seq_len": 128,
+ "page_size": 32,
+ "description": "32-token pages",
+ },
+ {
+ "name": "page_64",
+ "batch_size": 8,
+ "max_seq_len": 128,
+ "page_size": 64,
+ "description": "64-token pages",
+ },
+ ],
+
+ "shape_sanity_tests": [
+ {
+ "name": "basic",
+ "batch_size": 1,
+ "max_seq_len": 128,
+ "page_size": 32,
+ "description": "Single sequence",
+ },
+ {
+ "name": "basic_different_pagesize",
+ "batch_size": 1,
+ "max_seq_len": 128,
+ "page_size": 64,
+ "description": "Different page size",
+ },
+ {
+ "name": "batch",
+ "batch_size": 8,
+ "max_seq_len": 128,
+ "page_size": 32,
+ "description": "Batch shapes",
+ },
+ ],
+
+ "metadata_tests": [
+ {
+ "name": "single_sequence",
+ "batch_size": 1,
+ "max_seq_len": 64,
+ "page_size": 32,
+ "description": "Single sequence metadata",
+ },
+ {
+ "name": "batch_mixed_lengths",
+ "batch_size": 8,
+ "max_seq_len": 128,
+ "page_size": 32,
+ "description": "Mixed sequence lengths",
+ },
+ {
+ "name": "large_batch",
+ "batch_size": 32,
+ "max_seq_len": 256,
+ "page_size": 64,
+ "description": "Large batch stress test",
+ },
+ {
+ "name": "edge_case_short",
+ "batch_size": 4,
+ "max_seq_len": 16,
+ "page_size": 32,
+ "description": "Sub-page sequences",
+ },
+ ],
+}
+
+
class MockModelRunner:
"""Minimal fake ModelRunner for testing MLA backends."""
- def __init__(self, page_size: int):
- self.device = "cuda"
- self.dtype = torch.bfloat16
- self.kv_cache_dtype = torch.bfloat16
- self.page_size = page_size
+ def __init__(self, config):
+ self.device = config["device"]
+ self.dtype = config["dtype"]
+ self.kv_cache_dtype = config["kv_cache_dtype"]
+ self.page_size = config["page_size"]
# Model-config stub with MLA attributes
self.model_config = type(
"ModelConfig",
(),
{
- "context_len": 2048,
+ "context_len": config["context_len"],
"attention_arch": AttentionArch.MLA,
- "num_attention_heads": 128,
- "kv_lora_rank": 512,
- "qk_nope_head_dim": 128,
- "qk_rope_head_dim": 64,
- "v_head_dim": 512,
- "scaling": 1.0 / ((128 + 64) ** 0.5),
- "get_num_kv_heads": staticmethod(lambda _: 1),
+ "num_attention_heads": config["num_attention_heads"],
+ "kv_lora_rank": config["kv_lora_rank"],
+ "qk_nope_head_dim": config["qk_nope_head_dim"],
+ "qk_rope_head_dim": config["qk_rope_head_dim"],
+ "v_head_dim": config["v_head_dim"],
+ "scaling": 1.0 / ((config["qk_nope_head_dim"] + config["qk_rope_head_dim"]) ** 0.5),
+ "get_num_kv_heads": staticmethod(lambda _: config["num_kv_heads"]),
},
)
# Req-to-token pool
- max_bs = 64
+ max_bs = config["max_bs"]
max_ctx = self.model_config.context_len
req_to_token = torch.arange(
max_bs * max_ctx, dtype=torch.int32, device=self.device
@@ -77,10 +195,10 @@ def __init__(self, page_size: int):
# KV-token pool (MLA)
self.token_to_kv_pool = MLATokenToKVPool(
size=max_bs * max_ctx,
- page_size=page_size,
+ page_size=config["page_size"],
dtype=self.kv_cache_dtype,
- kv_lora_rank=512,
- qk_rope_head_dim=64,
+ kv_lora_rank=config["kv_lora_rank"],
+ qk_rope_head_dim=config["qk_rope_head_dim"],
layer_num=1,
device=self.device,
enable_memory_saver=False,
@@ -89,7 +207,7 @@ def __init__(self, page_size: int):
def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
"""Compare outputs with detailed analysis."""
-
+
# Basic checks
assert (
trtllm_out.shape == reference_out.shape
@@ -137,52 +255,67 @@ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
not torch.cuda.is_available() or not is_flashinfer_available(),
"CUDA + flashinfer required",
)
-class TestTRTLLMMLAClean(CustomTestCase):
- """Test TRTLLM MLA backend against FlashInfer MLA backend (reference)."""
+class TestTRTLLMMLA(CustomTestCase):
+ """Test suite for TRTLLM MLA backend with centralized configuration."""
- def setUp(self):
- """Setup test fixtures."""
- self.device = "cuda"
- self.dtype = torch.bfloat16
- self.page_size = 32
+ def _merge_config(self, test_case):
+ """Merge test case with default configuration."""
+ config = DEFAULT_CONFIG.copy()
+ config.update(test_case)
+ return config
- # Create model runner and backends
- self.model_runner_trtllm = MockModelRunner(self.page_size)
- self.model_runner_reference = MockModelRunner(self.page_size)
+ def _create_model_components(self, config):
+ """Create model runners, backends, and layer for testing."""
+ # Create model runners
+ model_runner_trtllm = MockModelRunner(config)
+ model_runner_reference = MockModelRunner(config)
- self.trtllm_backend = TRTLLMGENMLABackend(self.model_runner_trtllm)
- self.reference_backend = FlashInferMLAAttnBackend(self.model_runner_reference)
+ # Create backends
+ trtllm_backend = TRTLLMGENMLABackend(model_runner_trtllm)
+ reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
# Create RadixAttention layer
- self.layer = RadixAttention(
- num_heads=128,
- head_dim=512 + 64, # kv_lora_rank + qk_rope_head_dim
- scaling=self.model_runner_trtllm.model_config.scaling,
- num_kv_heads=1,
- layer_id=0,
- v_head_dim=512,
+ layer = RadixAttention(
+ num_heads=config["num_attention_heads"],
+ head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"],
+ scaling=model_runner_trtllm.model_config.scaling,
+ num_kv_heads=config["num_kv_heads"],
+ layer_id=config["layer_id"],
+ v_head_dim=config["v_head_dim"],
prefix="attn_mqa",
)
- def _create_qkv_tensors(self, batch_size):
+ return model_runner_trtllm, model_runner_reference, trtllm_backend, reference_backend, layer
+
+ def _create_qkv_tensors(self, batch_size, config):
"""Create Q, K, V tensors for testing."""
- head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim
+ head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
+ device = config["device"]
+ dtype = config["dtype"]
+
q = torch.randn(
- (batch_size, 128, head_dim), dtype=self.dtype, device=self.device
+ (batch_size, config["num_attention_heads"], head_dim),
+ dtype=dtype, device=device
+ )
+ k = torch.randn(
+ (batch_size, config["num_kv_heads"], head_dim),
+ dtype=dtype, device=device
+ )
+ v = torch.randn(
+ (batch_size, config["num_kv_heads"], config["v_head_dim"]),
+ dtype=dtype, device=device
)
- k = torch.randn((batch_size, 1, head_dim), dtype=self.dtype, device=self.device)
- v = torch.randn((batch_size, 1, 512), dtype=self.dtype, device=self.device)
return q, k, v
- def _create_forward_batch(self, batch_size, seq_lens, backend, model_runner):
+ def _create_forward_batch(self, batch_size, seq_lens, backend, model_runner, config):
"""Create a forward batch for the given backend."""
fb = ForwardBatch(
batch_size=batch_size,
- input_ids=torch.randint(0, 100, (batch_size, 1), device=self.device),
- out_cache_loc=torch.arange(batch_size, device=self.device),
+ input_ids=torch.randint(0, 100, (batch_size, 1), device=config["device"]),
+ out_cache_loc=torch.arange(batch_size, device=config["device"]),
seq_lens_sum=int(seq_lens.sum().item()),
forward_mode=ForwardMode.DECODE,
- req_pool_indices=torch.arange(batch_size, device=self.device),
+ req_pool_indices=torch.arange(batch_size, device=config["device"]),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(),
attn_backend=backend,
@@ -191,273 +324,536 @@ def _create_forward_batch(self, batch_size, seq_lens, backend, model_runner):
fb.token_to_kv_pool = model_runner.token_to_kv_pool
return fb
- def _populate_kv_cache(self, batch_size, seq_lens, model_runners):
+ def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
"""Populate KV cache with identical data for both backends."""
- torch.manual_seed(42) # Fixed seed for reproducible cache
+ torch.manual_seed(config["seed_cache"]) # Fixed seed for reproducible cache
for model_runner in model_runners:
- torch.manual_seed(42) # Reset seed for each backend
+ torch.manual_seed(config["seed_cache"]) # Reset seed for each backend
for i in range(batch_size):
seq_len = int(seq_lens[i].item())
for token_idx in range(seq_len - 1):
# Create random K components for MLA
cache_k_nope = torch.randn(
- (1, 128), dtype=self.dtype, device=self.device
+ (1, config["qk_nope_head_dim"]),
+ dtype=config["dtype"], device=config["device"]
)
cache_k_rope = torch.randn(
- (1, 64), dtype=self.dtype, device=self.device
+ (1, config["qk_rope_head_dim"]),
+ dtype=config["dtype"], device=config["device"]
)
# Calculate cache location
- cache_loc = model_runner.req_to_token_pool.req_to_token[
- i, token_idx
- ]
+ cache_loc = model_runner.req_to_token_pool.req_to_token[i, token_idx]
# Save to KV cache
model_runner.token_to_kv_pool.set_mla_kv_buffer(
- self.layer,
+ layer,
cache_loc.unsqueeze(0),
cache_k_nope.squeeze(0),
cache_k_rope.squeeze(0),
)
+ def test_basic_functionality(self):
+ """Test basic functionality with minimal setup."""
+ print(f"\nRunning basic functionality tests...")
+
+ for test_case in TEST_CASES["basic_functionality"]:
+ with self.subTest(test_case=test_case["name"]):
+ print(f" Testing {test_case['name']}: {test_case['description']}")
+
+ config = self._merge_config(test_case)
+ batch_size = config["batch_size"]
+ max_seq_len = config["max_seq_len"]
+
+ # Create components
+ model_runner_trtllm, _, trtllm_backend, _, layer = self._create_model_components(config)
+
+ # Create sequence lengths - properly handle different batch sizes
+ if batch_size == 2:
+ seq_lens = torch.tensor(
+ [max_seq_len, max_seq_len // 2],
+ device=config["device"]
+ )
+ else:
+ # For larger batch sizes, create varied sequence lengths
+ torch.manual_seed(config["seed_cache"])
+ seq_lens = torch.randint(
+ max_seq_len // 2, max_seq_len + 1,
+ (batch_size,),
+ device=config["device"]
+ )
+ seq_lens[0] = max_seq_len # Ensure at least one max length
+
+ # Create forward batch
+ fb = self._create_forward_batch(
+ batch_size, seq_lens, trtllm_backend, model_runner_trtllm, config
+ )
+ trtllm_backend.init_forward_metadata(fb)
+
+ # Populate KV cache
+ self._populate_kv_cache(batch_size, seq_lens, [model_runner_trtllm], layer, config)
+
+ # Create Q, K, V tensors
+ torch.manual_seed(config["seed_qkv"])
+ q, k, v = self._create_qkv_tensors(batch_size, config)
+
+ # Run forward decode
+ output = trtllm_backend.forward_decode(q, k, v, layer, fb)
+
+ # Basic checks
+ expected_shape = (batch_size, config["num_attention_heads"] * config["v_head_dim"])
+ self.assertEqual(output.shape, expected_shape)
+ self.assertEqual(output.dtype, config["dtype"])
+ self.assertFalse(torch.isnan(output).any())
+ self.assertFalse(torch.isinf(output).any())
+
def test_decode_output_match(self):
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
- # Test different batch sizes and sequence lengths
- test_cases = [
- (1, 64),
- (4, 64),
- (16, 64),
- (32, 64),
- (1, 128),
- (4, 128),
- (16, 128),
- (32, 128),
- (1, 256),
- (4, 256),
- (16, 256),
- (32, 256),
- ]
+ print(f"\nRunning decode output matching tests...")
+
+ for test_case in TEST_CASES["decode_output_match"]:
+ with self.subTest(test_case=test_case["name"]):
+ print(f" Testing {test_case['name']}: {test_case['description']}")
+
+ config = self._merge_config(test_case)
+ batch_size = config["batch_size"]
+ max_seq_len = config["max_seq_len"]
+
+ # Create components
+ (model_runner_trtllm, model_runner_reference,
+ trtllm_backend, reference_backend, layer) = self._create_model_components(config)
- for batch_size, max_seq_len in test_cases:
- with self.subTest(batch_size=batch_size, max_seq_len=max_seq_len):
# Create identical sequence lengths for both backends
- torch.manual_seed(42)
+ torch.manual_seed(config["seed_cache"])
seq_lens = torch.randint(
- 1, max_seq_len, (batch_size,), device=self.device
+ 1, max_seq_len, (batch_size,), device=config["device"]
)
seq_lens[0] = max_seq_len # Ensure at least one max length
# Create forward batches with identical inputs
fb_trtllm = self._create_forward_batch(
- batch_size,
- seq_lens.clone(),
- self.trtllm_backend,
- self.model_runner_trtllm,
+ batch_size, seq_lens.clone(), trtllm_backend, model_runner_trtllm, config
)
fb_reference = self._create_forward_batch(
- batch_size,
- seq_lens.clone(),
- self.reference_backend,
- self.model_runner_reference,
+ batch_size, seq_lens.clone(), reference_backend, model_runner_reference, config
)
# Initialize metadata for both backends
- self.trtllm_backend.init_forward_metadata(fb_trtllm)
- self.reference_backend.init_forward_metadata(fb_reference)
+ trtllm_backend.init_forward_metadata(fb_trtllm)
+ reference_backend.init_forward_metadata(fb_reference)
# Populate both KV caches identically
self._populate_kv_cache(
- batch_size,
- seq_lens,
- [self.model_runner_trtllm, self.model_runner_reference],
+ batch_size, seq_lens, [model_runner_trtllm, model_runner_reference], layer, config
)
# Create Q, K, V tensors for current decode step
- torch.manual_seed(123) # Fixed seed for Q, K, V
- q, k, v = self._create_qkv_tensors(batch_size)
+ torch.manual_seed(config["seed_qkv"])
+ q, k, v = self._create_qkv_tensors(batch_size, config)
# Run forward decode on both backends
- out_trtllm = self.trtllm_backend.forward_decode(
- q.clone(), k.clone(), v, self.layer, fb_trtllm
+ out_trtllm = trtllm_backend.forward_decode(
+ q.clone(), k.clone(), v.clone(), layer, fb_trtllm
)
- out_reference = self.reference_backend.forward_decode(
- q.clone(), k.clone(), v.clone(), self.layer, fb_reference
+ out_reference = reference_backend.forward_decode(
+ q.clone(), k.clone(), v.clone(), layer, fb_reference
)
# Compare outputs
comparison_passed = compare_outputs(
- out_trtllm, out_reference, tolerance=1e-2
+ out_trtllm, out_reference, tolerance=config["tolerance"]
)
self.assertTrue(
comparison_passed,
f"TRTLLM and Reference outputs differ beyond tolerance. "
- f"batch_size={batch_size}, max_seq_len={max_seq_len}, "
+ f"Config: {test_case['name']}, "
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
)
- def test_different_page_sizes(self):
+ def test_page_size_consistency(self):
"""Test output consistency across different page sizes."""
- page_sizes = [32, 64]
- batch_size = 8
- max_seq_len = 128
-
- for page_size in page_sizes:
- with self.subTest(page_size=page_size):
- # Create model runner with specific page size
- model_runner = MockModelRunner(page_size)
- backend = TRTLLMGENMLABackend(model_runner)
+ print(f"\nRunning page size consistency tests...")
+
+ for test_case in TEST_CASES["page_size_consistency"]:
+ with self.subTest(test_case=test_case["name"]):
+ print(f" Testing {test_case['name']}: {test_case['description']}")
+
+ config = self._merge_config(test_case)
+ batch_size = config["batch_size"]
+ max_seq_len = config["max_seq_len"]
+
+ # Create components
+ model_runner, _, backend, _, layer = self._create_model_components(config)
# Create sequence lengths
- torch.manual_seed(42)
+ torch.manual_seed(config["seed_cache"])
seq_lens = torch.randint(
- 1, max_seq_len, (batch_size,), device=self.device
+ 1, max_seq_len, (batch_size,), device=config["device"]
)
seq_lens[0] = max_seq_len
# Create forward batch
fb = self._create_forward_batch(
- batch_size, seq_lens, backend, model_runner
+ batch_size, seq_lens, backend, model_runner, config
)
backend.init_forward_metadata(fb)
# Populate KV cache
- self._populate_kv_cache(batch_size, seq_lens, [model_runner])
+ self._populate_kv_cache(batch_size, seq_lens, [model_runner], layer, config)
# Create Q, K, V tensors
- torch.manual_seed(123)
- q, k, v = self._create_qkv_tensors(batch_size)
+ torch.manual_seed(config["seed_qkv"])
+ q, k, v = self._create_qkv_tensors(batch_size, config)
# Run forward decode
- output = backend.forward_decode(q, k, v, self.layer, fb)
+ output = backend.forward_decode(q, k, v, layer, fb)
- # Basic checks
- expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim
+ expected_shape = (batch_size, config["num_attention_heads"] * config["v_head_dim"])
self.assertEqual(
- output.shape,
- expected_shape,
- f"Output shape mismatch: {output.shape} vs {expected_shape}",
+ output.shape, expected_shape,
+ f"Output shape mismatch: {output.shape} vs {expected_shape}"
)
self.assertFalse(torch.isnan(output).any(), "Output contains NaN")
self.assertFalse(torch.isinf(output).any(), "Output contains Inf")
- def test_basic_functionality(self):
- """Test basic functionality with minimal setup."""
- batch_size = 2
- max_seq_len = 32
-
- # Create sequence lengths
- seq_lens = torch.tensor([max_seq_len, max_seq_len // 2], device=self.device)
-
- # Create forward batch
- fb = self._create_forward_batch(
- batch_size, seq_lens, self.trtllm_backend, self.model_runner_trtllm
- )
- self.trtllm_backend.init_forward_metadata(fb)
-
- # Populate KV cache
- self._populate_kv_cache(batch_size, seq_lens, [self.model_runner_trtllm])
-
- # Create Q, K, V tensors
- q, k, v = self._create_qkv_tensors(batch_size)
-
- # Run forward decode
- output = self.trtllm_backend.forward_decode(q, k, v, self.layer, fb)
-
- # Basic checks
- expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim
- self.assertEqual(output.shape, expected_shape)
- self.assertEqual(output.dtype, self.dtype)
- self.assertFalse(torch.isnan(output).any())
- self.assertFalse(torch.isinf(output).any())
-
- def test_forward_decode_shape_sanity(self):
- """Smoke test decode across several page sizes and batch configurations."""
- # Test configurations similar to the original test
- test_configs = [
- (16, 512, 32), # batch_size, seq_len, page_size
- (16, 512, 64),
- (8, 256, 32),
- (4, 128, 32),
- (1, 64, 32),
- (32, 1024, 64),
- ]
+ def test_shape_sanity(self):
+ """Smoke test decode across several configurations."""
+ print(f"\nRunning shape sanity tests...")
+
+ for test_case in TEST_CASES["shape_sanity_tests"]:
+ with self.subTest(test_case=test_case["name"]):
+ print(f" Testing {test_case['name']}: {test_case['description']}")
+
+ config = self._merge_config(test_case)
+ batch_size = config["batch_size"]
+ max_seq_len = config["max_seq_len"]
- for batch_size, seq_len, page_size in test_configs:
- with self.subTest(
- batch_size=batch_size, seq_len=seq_len, page_size=page_size
- ):
- # Create model runner with specific page size
- model_runner = MockModelRunner(page_size)
- backend = TRTLLMGENMLABackend(model_runner)
+ model_runner, _, backend, _, layer = self._create_model_components(config)
# Random seq lens (ensure one matches max)
- torch.manual_seed(42)
- seq_lens = torch.randint(1, seq_len, (batch_size,), device=self.device)
- seq_lens[0] = seq_len
+ torch.manual_seed(config["seed_cache"])
+ seq_lens = torch.randint(1, max_seq_len, (batch_size,), device=config["device"])
+ seq_lens[0] = max_seq_len
- # Create forward batch
- fb = ForwardBatch(
- batch_size=batch_size,
- input_ids=torch.randint(
- 0, 100, (batch_size, 1), device=self.device
- ),
- out_cache_loc=torch.arange(batch_size, device=self.device),
- seq_lens_sum=int(seq_lens.sum().item()),
- forward_mode=ForwardMode.DECODE,
- req_pool_indices=torch.arange(batch_size, device=self.device),
- seq_lens=seq_lens,
- seq_lens_cpu=seq_lens.cpu(),
- attn_backend=backend,
+ fb = self._create_forward_batch(
+ batch_size, seq_lens, backend, model_runner, config
)
- fb.req_to_token_pool = model_runner.req_to_token_pool
- fb.token_to_kv_pool = model_runner.token_to_kv_pool
-
backend.init_forward_metadata(fb)
# Create Q, K, V tensors
- head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim
+ torch.manual_seed(config["seed_qkv"])
+ head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
q = torch.randn(
- (batch_size, 128, head_dim), dtype=self.dtype, device=self.device
+ (batch_size, config["num_attention_heads"], head_dim),
+ dtype=config["dtype"], device=config["device"]
)
k = torch.randn(
- (batch_size, 1, head_dim), dtype=self.dtype, device=self.device
- )
- v = None # TRTLLM MLA decode kernel ignores v
-
- # Create layer
- layer = RadixAttention(
- num_heads=128,
- head_dim=512 + 64,
- scaling=model_runner.model_config.scaling,
- num_kv_heads=1,
- layer_id=0,
- v_head_dim=512,
- prefix="attn_mqa",
+ (batch_size, config["num_kv_heads"], head_dim),
+ dtype=config["dtype"], device=config["device"]
)
+ v = None
# Run forward decode
output = backend.forward_decode(q, k, v, layer, fb)
# Shape and sanity checks
- expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim
+ expected_shape = (batch_size, config["num_attention_heads"] * config["v_head_dim"])
self.assertEqual(
- output.shape,
- expected_shape,
- f"Output shape mismatch for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})",
+ output.shape, expected_shape,
+ f"Output shape mismatch for {test_case['name']}"
)
- self.assertEqual(output.dtype, self.dtype)
+ self.assertEqual(output.dtype, config["dtype"])
self.assertEqual(output.device.type, "cuda")
self.assertFalse(
torch.isnan(output).any(),
- f"Output contains NaN for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})",
+ f"Output contains NaN for {test_case['name']}"
)
self.assertFalse(
torch.isinf(output).any(),
- f"Output contains Inf for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})",
+ f"Output contains Inf for {test_case['name']}"
+ )
+
+ def test_metadata_initialization(self):
+ """Test TRTLLM MLA metadata initialization and structure."""
+ print(f"\nRunning metadata initialization tests...")
+
+ for test_case in TEST_CASES["metadata_tests"]:
+ with self.subTest(test_case=test_case["name"]):
+ print(f" Testing {test_case['name']}: {test_case['description']}")
+
+ config = self._merge_config(test_case)
+ batch_size = config["batch_size"]
+ max_seq_len = config["max_seq_len"]
+
+ # Create components
+ model_runner, _, backend, _, layer = self._create_model_components(config)
+
+ # Create varied sequence lengths
+ torch.manual_seed(config["seed_cache"])
+ if batch_size == 1:
+ seq_lens = torch.tensor([max_seq_len], device=config["device"])
+ else:
+ seq_lens = torch.randint(
+ max(1, max_seq_len // 4), max_seq_len + 1,
+ (batch_size,), device=config["device"]
+ )
+ seq_lens[0] = max_seq_len # Ensure at least one max length
+
+ # Create forward batch
+ fb = self._create_forward_batch(
+ batch_size, seq_lens, backend, model_runner, config
+ )
+
+ # Initialize metadata
+ backend.init_forward_metadata(fb)
+
+ # Verify metadata exists
+ self.assertIsNotNone(backend.forward_metadata)
+ self.assertIsInstance(
+ backend.forward_metadata,
+ TRTLLMGENMLADecodeMetadata
+ )
+
+ # Test metadata structure
+ metadata = backend.forward_metadata
+ self.assertIsNotNone(metadata.workspace, "Workspace should be allocated")
+ self.assertIsNotNone(metadata.block_kv_indices, "Block KV indices should be created")
+
+ # Test workspace properties
+ self.assertEqual(metadata.workspace.device.type, "cuda")
+ self.assertEqual(metadata.workspace.dtype, torch.int8)
+ self.assertGreater(metadata.workspace.numel(), 0, "Workspace should have non-zero size")
+
+ # Test block KV indices properties
+ self.assertEqual(metadata.block_kv_indices.device.type, "cuda")
+ self.assertEqual(metadata.block_kv_indices.dtype, torch.int32)
+ self.assertEqual(metadata.block_kv_indices.shape[0], batch_size)
+
+ # Verify block indices are valid (>= -1, since -1 is padding)
+ self.assertTrue(
+ (metadata.block_kv_indices >= -1).all(),
+ "All block indices should be >= -1 (with -1 as padding)"
+ )
+
+ def test_metadata_block_calculation(self):
+ """Test block count calculation logic."""
+ print(f"\nRunning metadata block calculation tests...")
+
+ test_scenarios = [
+ {"seq_len": 31, "page_size": 32, "expected_min_blocks": 1},
+ {"seq_len": 32, "page_size": 32, "expected_min_blocks": 1},
+ {"seq_len": 33, "page_size": 32, "expected_min_blocks": 2},
+ {"seq_len": 128, "page_size": 32, "expected_min_blocks": 4},
+ {"seq_len": 128, "page_size": 64, "expected_min_blocks": 2},
+ ]
+
+ for scenario in test_scenarios:
+ with self.subTest(scenario=scenario):
+ config = self._merge_config({
+ "batch_size": 1,
+ "max_seq_len": scenario["seq_len"],
+ "page_size": scenario["page_size"]
+ })
+
+ model_runner, _, backend, _, _ = self._create_model_components(config)
+
+ # Test internal block calculation
+ calculated_blocks = backend._calc_padded_blocks(scenario["seq_len"])
+
+ # Should be at least the minimum required
+ self.assertGreaterEqual(
+ calculated_blocks, scenario["expected_min_blocks"],
+ f"Calculated blocks ({calculated_blocks}) should be >= minimum required ({scenario['expected_min_blocks']})"
+ )
+
+ # Should satisfy page_size constraint
+ total_tokens = calculated_blocks * scenario["page_size"]
+ self.assertGreaterEqual(
+ total_tokens, scenario["seq_len"],
+ f"Total tokens ({total_tokens}) should cover sequence length ({scenario['seq_len']})"
+ )
+
+ # Should satisfy TRT-LLM constraint (128 / page_size)
+ trtllm_constraint = 128 // scenario["page_size"]
+ self.assertEqual(
+ calculated_blocks % trtllm_constraint, 0,
+ f"Block count should be multiple of TRT-LLM constraint ({trtllm_constraint})"
)
+ def test_metadata_kv_indices_correctness(self):
+ """Test KV indices creation and correctness."""
+ print(f"\nRunning KV indices correctness tests...")
+
+ for test_case in TEST_CASES["metadata_tests"][:2]: # Test subset for performance
+ with self.subTest(test_case=test_case["name"]):
+ print(f" Testing {test_case['name']}: {test_case['description']}")
+
+ config = self._merge_config(test_case)
+ batch_size = config["batch_size"]
+ max_seq_len = config["max_seq_len"]
+
+ model_runner, _, backend, _, layer = self._create_model_components(config)
+
+ # Create known sequence lengths
+ torch.manual_seed(config["seed_cache"])
+ if batch_size == 1:
+ seq_lens = torch.tensor([max_seq_len], device=config["device"])
+ else:
+ seq_lens = torch.randint(
+ max_seq_len // 2, max_seq_len + 1,
+ (batch_size,), device=config["device"]
+ )
+
+ fb = self._create_forward_batch(
+ batch_size, seq_lens, backend, model_runner, config
+ )
+
+ # Populate some KV cache to have valid indices
+ self._populate_kv_cache(batch_size, seq_lens, [model_runner], layer, config)
+
+ # Initialize metadata
+ backend.init_forward_metadata(fb)
+ metadata = backend.forward_metadata
+
+ # Verify KV indices structure
+ block_kv_indices = metadata.block_kv_indices
+
+ for i in range(batch_size):
+ seq_len = seq_lens[i].item()
+ expected_blocks = backend._calc_padded_blocks(seq_len)
+
+ # Count valid (non -1) indices for this sequence
+ valid_indices = (block_kv_indices[i] >= 0).sum().item()
+
+ # Should have at least enough blocks for the sequence
+ min_required_blocks = (seq_len + config["page_size"] - 1) // config["page_size"]
+ self.assertGreaterEqual(
+ valid_indices, min_required_blocks,
+ f"Sequence {i} should have at least {min_required_blocks} valid blocks, got {valid_indices}"
+ )
+
+ # Verify indices are within valid range
+ valid_block_indices = block_kv_indices[i][block_kv_indices[i] >= 0]
+ if len(valid_block_indices) > 0:
+ max_possible_blocks = model_runner.token_to_kv_pool.size // config["page_size"]
+ self.assertTrue(
+ (valid_block_indices < max_possible_blocks).all(),
+ f"All block indices should be < {max_possible_blocks}"
+ )
+
+ def test_metadata_cuda_graph_compatibility(self):
+ """Test metadata compatibility with CUDA graph capture/replay."""
+ print(f"\nRunning CUDA graph compatibility tests...")
+
+ config = self._merge_config({
+ "batch_size": 4,
+ "max_seq_len": 64,
+ "page_size": 32
+ })
+
+ model_runner, _, backend, _, layer = self._create_model_components(config)
+ batch_size = config["batch_size"]
+
+ # Initialize CUDA graph state
+ backend.init_cuda_graph_state(
+ max_bs=batch_size,
+ max_num_tokens=config["max_seq_len"] * batch_size
+ )
+
+ # Verify CUDA graph buffers are allocated
+ self.assertIsNotNone(backend.cuda_graph_kv_indices)
+ self.assertIsNotNone(backend.cuda_graph_workspace)
+
+ # Test capture metadata
+ seq_lens = torch.full((batch_size,), config["max_seq_len"], device=config["device"])
+ req_pool_indices = torch.arange(batch_size, device=config["device"])
+
+ backend.init_forward_metadata_capture_cuda_graph(
+ bs=batch_size,
+ num_tokens=batch_size,
+ req_pool_indices=req_pool_indices,
+ seq_lens=seq_lens,
+ encoder_lens=None,
+ forward_mode=ForwardMode.DECODE,
+ spec_info=None,
+ )
+
+ # Verify capture metadata
+ self.assertIn(batch_size, backend.decode_cuda_graph_metadata)
+ capture_metadata = backend.decode_cuda_graph_metadata[batch_size]
+
+ self.assertIsNotNone(capture_metadata.workspace)
+ self.assertIsNotNone(capture_metadata.block_kv_indices)
+
+ # Test replay with different sequence lengths
+ new_seq_lens = torch.randint(
+ config["max_seq_len"] // 2, config["max_seq_len"] + 1,
+ (batch_size,), device=config["device"]
+ )
+
+ backend.init_forward_metadata_replay_cuda_graph(
+ bs=batch_size,
+ req_pool_indices=req_pool_indices,
+ seq_lens=new_seq_lens,
+ seq_lens_sum=new_seq_lens.sum().item(),
+ encoder_lens=None,
+ forward_mode=ForwardMode.DECODE,
+ spec_info=None,
+ seq_lens_cpu=new_seq_lens.cpu(),
+ )
+
+ # Verify replay updated the metadata
+ replay_metadata = backend.forward_metadata
+ self.assertIsNotNone(replay_metadata)
+ self.assertEqual(replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr())
+
+ def test_metadata_consistency_across_calls(self):
+ """Test metadata consistency across multiple forward calls."""
+ print(f"\nRunning metadata consistency tests...")
+
+ config = self._merge_config({
+ "batch_size": 2,
+ "max_seq_len": 64,
+ "page_size": 32
+ })
+
+ model_runner, _, backend, _, layer = self._create_model_components(config)
+
+ # First call
+ seq_lens_1 = torch.tensor([32, 48], device=config["device"])
+ fb_1 = self._create_forward_batch(
+ config["batch_size"], seq_lens_1, backend, model_runner, config
+ )
+ backend.init_forward_metadata(fb_1)
+ metadata_1 = backend.forward_metadata
+
+ # Second call with same sequence lengths
+ seq_lens_2 = torch.tensor([32, 48], device=config["device"])
+ fb_2 = self._create_forward_batch(
+ config["batch_size"], seq_lens_2, backend, model_runner, config
+ )
+ backend.init_forward_metadata(fb_2)
+ metadata_2 = backend.forward_metadata
+
+ # Metadata structure should be consistent
+ self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape)
+ self.assertEqual(metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape)
+
+ # Third call with different sequence lengths
+ seq_lens_3 = torch.tensor([16, 64], device=config["device"])
+ fb_3 = self._create_forward_batch(
+ config["batch_size"], seq_lens_3, backend, model_runner, config
+ )
+ backend.init_forward_metadata(fb_3)
+ metadata_3 = backend.forward_metadata
+
+ # Should still have valid structure
+ self.assertIsNotNone(metadata_3.workspace)
+ self.assertIsNotNone(metadata_3.block_kv_indices)
+ self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
+
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
\ No newline at end of file
From 8a119cce2fa7aac62c8f5cc69bfe25299687bcf9 Mon Sep 17 00:00:00 2001
From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Date: Mon, 21 Jul 2025 14:32:27 -0700
Subject: [PATCH 09/27] neater interface
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
---
...n_mla_backend.py => trtllm_mla_backend.py} | 128 ++++++------------
.../sglang/srt/model_executor/model_runner.py | 8 +-
.../attention/test_trtllm_gen_mla_backend.py | 6 +-
3 files changed, 47 insertions(+), 95 deletions(-)
rename python/sglang/srt/layers/attention/{trtllm_gen_mla_backend.py => trtllm_mla_backend.py} (79%)
diff --git a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
similarity index 79%
rename from python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
rename to python/sglang/srt/layers/attention/trtllm_mla_backend.py
index d80614ccd27a..b0213b1b3fd3 100755
--- a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py
+++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
@@ -1,7 +1,7 @@
from __future__ import annotations
"""
-Support attention backend for TRTLLM-Gen MLA kernels from flashinfer.
+Support attention backend for TRTLLM MLA kernels from flashinfer.
"""
import math
@@ -35,20 +35,16 @@
# See: https://github.com/flashinfer-ai/flashinfer/blob/fe29ed63cb923f25cae70ef83f3fd16139305b35/flashinfer/decode.py#L2057
TRTLLM_BLOCK_CONSTRAINT = 128
-# Quantization scale (1.0 for fp8->bf16 and bf16->bf16 conversions)
-DEFAULT_QUANTIZATION_SCALE = 1.0
-# Softmax scale (1.0 since TRTLLM applies 1/sqrt(head_dim) internally)
-DEFAULT_SM_SCALE = 1.0
@dataclass
-class TRTLLMGENMLADecodeMetadata:
+class TRTLLMMLADecodeMetadata:
"""Metadata for TRTLLM MLA decode operations."""
workspace: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None
-class TRTLLMGENMLABackend(FlashInferMLAAttnBackend):
+class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"""TRTLLM MLA attention kernel from flashinfer."""
def __init__(
@@ -92,7 +88,7 @@ def __init__(
# CUDA graph state
self.decode_cuda_graph_metadata = {}
self.cuda_graph_kv_indices = None
- self.forward_metadata: Union[TRTLLMGENMLADecodeMetadata] = None
+ self.forward_metadata: Union[TRTLLMMLADecodeMetadata] = None
def _calc_padded_blocks(self, max_seq_len: int) -> int:
"""
@@ -117,68 +113,6 @@ def _calc_padded_blocks(self, max_seq_len: int) -> int:
blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
return blocks
- def _get_quantization_scales(self) -> tuple[float, float, float, float, float]:
- """
- Get quantization scales for q, k, v, sm, and o tensors.
-
- Returns:
- Tuple of (q_scale, k_scale, v_scale, sm_scale, o_scale)
- """
- # TODO: Implement proper quantization scale inference based on model config
- # For now, use default values for FP8
- return (
- DEFAULT_QUANTIZATION_SCALE, # q_scale
- DEFAULT_QUANTIZATION_SCALE, # k_scale
- DEFAULT_QUANTIZATION_SCALE, # v_scale
- DEFAULT_SM_SCALE, # sm_scale
- DEFAULT_QUANTIZATION_SCALE, # o_scale
- )
-
- def _prepare_kv_cache(
- self, layer: RadixAttention, forward_batch: ForwardBatch
- ) -> torch.Tensor:
- """
- Prepare KV cache tensor in the format expected by TRT-LLM kernel.
-
- Args:
- layer: Attention layer
- forward_batch: Forward batch info
-
- Returns:
- KV cache tensor shaped (num_pages, 2, page_size, kv_cache_dim)
- """
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
- pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
- # TRT-LLM expects stacked format: slice 0 → CKV+K, slice 1 → KPE
- return torch.stack([pages, pages], dim=1)
-
- def _prepare_query_tensor(
- self, q: torch.Tensor, q_rope: Optional[torch.Tensor], layer: RadixAttention
- ) -> torch.Tensor:
- """
- Prepare query tensor in the format expected by TRT-LLM kernel.
-
- Args:
- q: Query tensor
- q_rope: Optional RoPE query tensor
- layer: Attention layer
-
- Returns:
- Query tensor with concatenated NOPE and RoPE parts
- """
- if q_rope is not None:
- # q contains the NOPE part (v_head_dim)
- q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
- q_rope = q_rope.view(
- -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
- )
- else:
- reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
- q_nope = reshaped_q[:, :, : layer.v_head_dim]
- q_rope = reshaped_q[:, :, layer.v_head_dim :]
-
- return torch.cat([q_nope, q_rope], dim=-1)
-
def _create_block_kv_indices(
self,
batch_size: int,
@@ -259,7 +193,7 @@ def init_forward_metadata_capture_cuda_graph(
self.page_size,
)
- metadata = TRTLLMGENMLADecodeMetadata(
+ metadata = TRTLLMMLADecodeMetadata(
self.cuda_graph_workspace, block_kv_indices
)
self.decode_cuda_graph_metadata[bs] = metadata
@@ -342,7 +276,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
forward_batch.seq_lens.device,
)
- self.forward_metadata = TRTLLMGENMLADecodeMetadata(
+ self.forward_metadata = TRTLLMMLADecodeMetadata(
self.workspace_buffer, block_kv_indices
)
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
@@ -371,9 +305,27 @@ def forward_decode(
elif v is not None:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
- # Prepare tensors for TRT-LLM kernel
- query = self._prepare_query_tensor(q, q_rope, layer)
- kv_cache = self._prepare_kv_cache(layer, forward_batch)
+ # Prepare query tensor inline (avoid helper for shape tweaks)
+ if q_rope is not None:
+ # q contains NOPE part (v_head_dim)
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
+ q_rope_reshaped = q_rope.view(
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
+ )
+ query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
+ else:
+ # q already has both parts
+ query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
+
+ # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
+ if query.dim() == 3:
+ query = query.unsqueeze(1)
+
+ # Prepare KV cache inline (TRT-LLM expects stacked format)
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
+ pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
+ # TRT-LLM expects stacked format: slice 0 → CKV+K, slice 1 → KPE
+ kv_cache = torch.stack([pages, pages], dim=1)
# Get metadata
metadata = (
@@ -381,10 +333,16 @@ def forward_decode(
or self.forward_metadata
)
- # Get quantization scales
- q_scale, k_scale, v_scale, sm_scale, o_scale = self._get_quantization_scales()
+ # Scale computation for TRTLLM MLA kernel:
+ # - BMM1 scale = q_scale * k_scale * softmax_scale
+ # - For FP16 output in DeepSeek R1: q_scale = k_scale = 1.0, softmax_scale = 1/sqrt(head_dim)
+ # - This reduces to layer.scaling which is pre-computed as 1/sqrt(head_dim)
+ bmm1_scale = layer.scaling
+ bmm2_scale = 1.0
+ bmm1_scale_tensor = bmm2_scale_tensor = None
- # Call TRT-LLM kernel
+
+ # Call TRT-LLM kernel with proper scale configuration
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache,
@@ -396,20 +354,14 @@ def forward_decode(
seq_lens=forward_batch.seq_lens.to(torch.int32),
block_size=self.page_size,
max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
- q_scale=q_scale,
- k_scale=k_scale,
- v_scale=v_scale,
- sm_scale=sm_scale,
- o_scale=o_scale,
+ bmm1_scale=bmm1_scale,
+ bmm2_scale=bmm2_scale,
+ bmm1_scale_tensor=bmm1_scale_tensor,
+ bmm2_scale_tensor=bmm2_scale_tensor,
)
- #TODO: test?
# Extract value projection part and reshape
raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
- # Truncate if needed
- if output.shape[0] > forward_batch.batch_size:
- output = output[: forward_batch.batch_size]
-
return output
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 6f6bdb284331..189d66093561 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -396,7 +396,7 @@ def model_specific_adjustment(self):
):
server_args.attention_backend = "fa3"
elif is_sm100_supported():
- server_args.attention_backend = "trtllm_mla"
+ server_args.attention_backend = "flashinfer"
elif _is_hip:
head_num = self.model_config.get_num_kv_heads(self.tp_size)
# TODO current aiter only support head number 16 or 128 head number
@@ -1432,11 +1432,11 @@ def _get_attention_backend_from_str(self, backend_str: str):
return CutlassMLABackend(self)
elif self.server_args.attention_backend == "trtllm_mla":
- from sglang.srt.layers.attention.trtllm_gen_mla_backend import (
- TRTLLMGENMLABackend,
+ from python.sglang.srt.layers.attention.trtllm_mla_backend import (
+ TRTLLMMLABackend,
)
- return TRTLLMGENMLABackend(self)
+ return TRTLLMMLABackend(self)
elif self.server_args.attention_backend == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend,
diff --git a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py
index e2ea8d1da744..de4bf8a963d3 100644
--- a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py
+++ b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py
@@ -11,7 +11,7 @@
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
-from sglang.srt.layers.attention.trtllm_gen_mla_backend import TRTLLMGENMLABackend, TRTLLMGENMLADecodeMetadata
+from python.sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend, TRTLLMMLADecodeMetadata
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@@ -271,7 +271,7 @@ def _create_model_components(self, config):
model_runner_reference = MockModelRunner(config)
# Create backends
- trtllm_backend = TRTLLMGENMLABackend(model_runner_trtllm)
+ trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
# Create RadixAttention layer
@@ -613,7 +613,7 @@ def test_metadata_initialization(self):
self.assertIsNotNone(backend.forward_metadata)
self.assertIsInstance(
backend.forward_metadata,
- TRTLLMGENMLADecodeMetadata
+ TRTLLMMLADecodeMetadata
)
# Test metadata structure
From d772ccc2b675d9b64cf8c0e53e240cd5de45e119 Mon Sep 17 00:00:00 2001
From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Date: Mon, 21 Jul 2025 14:38:17 -0700
Subject: [PATCH 10/27] precommit+rename
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
---
.../layers/attention/trtllm_mla_backend.py | 5 +-
..._backend.py => test_trtllm_mla_backend.py} | 380 +++++++++++-------
2 files changed, 233 insertions(+), 152 deletions(-)
rename python/sglang/test/attention/{test_trtllm_gen_mla_backend.py => test_trtllm_mla_backend.py} (81%)
diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
index b0213b1b3fd3..5cbb84d08fba 100755
--- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py
+++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
@@ -102,13 +102,13 @@ def _calc_padded_blocks(self, max_seq_len: int) -> int:
"""
blocks = triton.cdiv(max_seq_len, self.page_size)
- # Apply dual constraints (take LCM to satisfy both):
+ # Apply dual constraints (take LCM to satisfy both):
# 1. TRT-LLM: block_num % (128 / page_size) == 0
# Reference: https://github.com/NVIDIA/TensorRT-LLM/issues/XYZ # TODO: add actual link
# 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
-
+
if blocks % constraint_lcm != 0:
blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
return blocks
@@ -341,7 +341,6 @@ def forward_decode(
bmm2_scale = 1.0
bmm1_scale_tensor = bmm2_scale_tensor = None
-
# Call TRT-LLM kernel with proper scale configuration
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
diff --git a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py
similarity index 81%
rename from python/sglang/test/attention/test_trtllm_gen_mla_backend.py
rename to python/sglang/test/attention/test_trtllm_mla_backend.py
index de4bf8a963d3..92c5045f93ad 100644
--- a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py
+++ b/python/sglang/test/attention/test_trtllm_mla_backend.py
@@ -9,16 +9,18 @@
# Patch DP-attention globals before importing backends
_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test
+from python.sglang.srt.layers.attention.trtllm_mla_backend import (
+ TRTLLMMLABackend,
+ TRTLLMMLADecodeMetadata,
+)
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
-from python.sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend, TRTLLMMLADecodeMetadata
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available
from sglang.test.test_utils import CustomTestCase
-
# Global configuration for all tests
DEFAULT_CONFIG = {
"device": "cuda",
@@ -57,7 +59,6 @@
"description": "Medium-scale batch",
},
],
-
"decode_output_match": [
{
"name": "single",
@@ -72,12 +73,11 @@
"max_seq_len": 64,
"page_size": 32,
"description": "Batch vs reference",
- }
+ },
],
-
"page_size_consistency": [
# Only 32 and 64 supported for now https://github.com/flashinfer-ai/flashinfer/blob/fe29ed63cb923f25cae70ef83f3fd16139305b35/flashinfer/decode.py#L2115
- # TODO: Test 16 and 128. Pending cubins
+ # TODO: Test 16 and 128. Pending cubins
{
"name": "page_32",
"batch_size": 8,
@@ -93,7 +93,6 @@
"description": "64-token pages",
},
],
-
"shape_sanity_tests": [
{
"name": "basic",
@@ -117,7 +116,6 @@
"description": "Batch shapes",
},
],
-
"metadata_tests": [
{
"name": "single_sequence",
@@ -172,7 +170,8 @@ def __init__(self, config):
"qk_nope_head_dim": config["qk_nope_head_dim"],
"qk_rope_head_dim": config["qk_rope_head_dim"],
"v_head_dim": config["v_head_dim"],
- "scaling": 1.0 / ((config["qk_nope_head_dim"] + config["qk_rope_head_dim"]) ** 0.5),
+ "scaling": 1.0
+ / ((config["qk_nope_head_dim"] + config["qk_rope_head_dim"]) ** 0.5),
"get_num_kv_heads": staticmethod(lambda _: config["num_kv_heads"]),
},
)
@@ -207,7 +206,7 @@ def __init__(self, config):
def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
"""Compare outputs with detailed analysis."""
-
+
# Basic checks
assert (
trtllm_out.shape == reference_out.shape
@@ -285,29 +284,38 @@ def _create_model_components(self, config):
prefix="attn_mqa",
)
- return model_runner_trtllm, model_runner_reference, trtllm_backend, reference_backend, layer
+ return (
+ model_runner_trtllm,
+ model_runner_reference,
+ trtllm_backend,
+ reference_backend,
+ layer,
+ )
def _create_qkv_tensors(self, batch_size, config):
"""Create Q, K, V tensors for testing."""
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
device = config["device"]
dtype = config["dtype"]
-
+
q = torch.randn(
- (batch_size, config["num_attention_heads"], head_dim),
- dtype=dtype, device=device
+ (batch_size, config["num_attention_heads"], head_dim),
+ dtype=dtype,
+ device=device,
)
k = torch.randn(
- (batch_size, config["num_kv_heads"], head_dim),
- dtype=dtype, device=device
+ (batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device
)
v = torch.randn(
- (batch_size, config["num_kv_heads"], config["v_head_dim"]),
- dtype=dtype, device=device
+ (batch_size, config["num_kv_heads"], config["v_head_dim"]),
+ dtype=dtype,
+ device=device,
)
return q, k, v
- def _create_forward_batch(self, batch_size, seq_lens, backend, model_runner, config):
+ def _create_forward_batch(
+ self, batch_size, seq_lens, backend, model_runner, config
+ ):
"""Create a forward batch for the given backend."""
fb = ForwardBatch(
batch_size=batch_size,
@@ -335,16 +343,20 @@ def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config)
for token_idx in range(seq_len - 1):
# Create random K components for MLA
cache_k_nope = torch.randn(
- (1, config["qk_nope_head_dim"]),
- dtype=config["dtype"], device=config["device"]
+ (1, config["qk_nope_head_dim"]),
+ dtype=config["dtype"],
+ device=config["device"],
)
cache_k_rope = torch.randn(
- (1, config["qk_rope_head_dim"]),
- dtype=config["dtype"], device=config["device"]
+ (1, config["qk_rope_head_dim"]),
+ dtype=config["dtype"],
+ device=config["device"],
)
# Calculate cache location
- cache_loc = model_runner.req_to_token_pool.req_to_token[i, token_idx]
+ cache_loc = model_runner.req_to_token_pool.req_to_token[
+ i, token_idx
+ ]
# Save to KV cache
model_runner.token_to_kv_pool.set_mla_kv_buffer(
@@ -357,31 +369,33 @@ def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config)
def test_basic_functionality(self):
"""Test basic functionality with minimal setup."""
print(f"\nRunning basic functionality tests...")
-
+
for test_case in TEST_CASES["basic_functionality"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
-
+
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
# Create components
- model_runner_trtllm, _, trtllm_backend, _, layer = self._create_model_components(config)
+ model_runner_trtllm, _, trtllm_backend, _, layer = (
+ self._create_model_components(config)
+ )
# Create sequence lengths - properly handle different batch sizes
if batch_size == 2:
seq_lens = torch.tensor(
- [max_seq_len, max_seq_len // 2],
- device=config["device"]
+ [max_seq_len, max_seq_len // 2], device=config["device"]
)
else:
# For larger batch sizes, create varied sequence lengths
torch.manual_seed(config["seed_cache"])
seq_lens = torch.randint(
- max_seq_len // 2, max_seq_len + 1,
- (batch_size,),
- device=config["device"]
+ max_seq_len // 2,
+ max_seq_len + 1,
+ (batch_size,),
+ device=config["device"],
)
seq_lens[0] = max_seq_len # Ensure at least one max length
@@ -392,7 +406,9 @@ def test_basic_functionality(self):
trtllm_backend.init_forward_metadata(fb)
# Populate KV cache
- self._populate_kv_cache(batch_size, seq_lens, [model_runner_trtllm], layer, config)
+ self._populate_kv_cache(
+ batch_size, seq_lens, [model_runner_trtllm], layer, config
+ )
# Create Q, K, V tensors
torch.manual_seed(config["seed_qkv"])
@@ -402,7 +418,10 @@ def test_basic_functionality(self):
output = trtllm_backend.forward_decode(q, k, v, layer, fb)
# Basic checks
- expected_shape = (batch_size, config["num_attention_heads"] * config["v_head_dim"])
+ expected_shape = (
+ batch_size,
+ config["num_attention_heads"] * config["v_head_dim"],
+ )
self.assertEqual(output.shape, expected_shape)
self.assertEqual(output.dtype, config["dtype"])
self.assertFalse(torch.isnan(output).any())
@@ -411,18 +430,23 @@ def test_basic_functionality(self):
def test_decode_output_match(self):
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
print(f"\nRunning decode output matching tests...")
-
+
for test_case in TEST_CASES["decode_output_match"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
-
+
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
# Create components
- (model_runner_trtllm, model_runner_reference,
- trtllm_backend, reference_backend, layer) = self._create_model_components(config)
+ (
+ model_runner_trtllm,
+ model_runner_reference,
+ trtllm_backend,
+ reference_backend,
+ layer,
+ ) = self._create_model_components(config)
# Create identical sequence lengths for both backends
torch.manual_seed(config["seed_cache"])
@@ -433,10 +457,18 @@ def test_decode_output_match(self):
# Create forward batches with identical inputs
fb_trtllm = self._create_forward_batch(
- batch_size, seq_lens.clone(), trtllm_backend, model_runner_trtllm, config
+ batch_size,
+ seq_lens.clone(),
+ trtllm_backend,
+ model_runner_trtllm,
+ config,
)
fb_reference = self._create_forward_batch(
- batch_size, seq_lens.clone(), reference_backend, model_runner_reference, config
+ batch_size,
+ seq_lens.clone(),
+ reference_backend,
+ model_runner_reference,
+ config,
)
# Initialize metadata for both backends
@@ -445,7 +477,11 @@ def test_decode_output_match(self):
# Populate both KV caches identically
self._populate_kv_cache(
- batch_size, seq_lens, [model_runner_trtllm, model_runner_reference], layer, config
+ batch_size,
+ seq_lens,
+ [model_runner_trtllm, model_runner_reference],
+ layer,
+ config,
)
# Create Q, K, V tensors for current decode step
@@ -475,17 +511,19 @@ def test_decode_output_match(self):
def test_page_size_consistency(self):
"""Test output consistency across different page sizes."""
print(f"\nRunning page size consistency tests...")
-
+
for test_case in TEST_CASES["page_size_consistency"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
-
+
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
# Create components
- model_runner, _, backend, _, layer = self._create_model_components(config)
+ model_runner, _, backend, _, layer = self._create_model_components(
+ config
+ )
# Create sequence lengths
torch.manual_seed(config["seed_cache"])
@@ -501,7 +539,9 @@ def test_page_size_consistency(self):
backend.init_forward_metadata(fb)
# Populate KV cache
- self._populate_kv_cache(batch_size, seq_lens, [model_runner], layer, config)
+ self._populate_kv_cache(
+ batch_size, seq_lens, [model_runner], layer, config
+ )
# Create Q, K, V tensors
torch.manual_seed(config["seed_qkv"])
@@ -510,10 +550,14 @@ def test_page_size_consistency(self):
# Run forward decode
output = backend.forward_decode(q, k, v, layer, fb)
- expected_shape = (batch_size, config["num_attention_heads"] * config["v_head_dim"])
+ expected_shape = (
+ batch_size,
+ config["num_attention_heads"] * config["v_head_dim"],
+ )
self.assertEqual(
- output.shape, expected_shape,
- f"Output shape mismatch: {output.shape} vs {expected_shape}"
+ output.shape,
+ expected_shape,
+ f"Output shape mismatch: {output.shape} vs {expected_shape}",
)
self.assertFalse(torch.isnan(output).any(), "Output contains NaN")
self.assertFalse(torch.isinf(output).any(), "Output contains Inf")
@@ -521,20 +565,24 @@ def test_page_size_consistency(self):
def test_shape_sanity(self):
"""Smoke test decode across several configurations."""
print(f"\nRunning shape sanity tests...")
-
+
for test_case in TEST_CASES["shape_sanity_tests"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
-
+
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
- model_runner, _, backend, _, layer = self._create_model_components(config)
+ model_runner, _, backend, _, layer = self._create_model_components(
+ config
+ )
# Random seq lens (ensure one matches max)
torch.manual_seed(config["seed_cache"])
- seq_lens = torch.randint(1, max_seq_len, (batch_size,), device=config["device"])
+ seq_lens = torch.randint(
+ 1, max_seq_len, (batch_size,), device=config["device"]
+ )
seq_lens[0] = max_seq_len
fb = self._create_forward_batch(
@@ -546,49 +594,57 @@ def test_shape_sanity(self):
torch.manual_seed(config["seed_qkv"])
head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
q = torch.randn(
- (batch_size, config["num_attention_heads"], head_dim),
- dtype=config["dtype"], device=config["device"]
+ (batch_size, config["num_attention_heads"], head_dim),
+ dtype=config["dtype"],
+ device=config["device"],
)
k = torch.randn(
- (batch_size, config["num_kv_heads"], head_dim),
- dtype=config["dtype"], device=config["device"]
+ (batch_size, config["num_kv_heads"], head_dim),
+ dtype=config["dtype"],
+ device=config["device"],
)
- v = None
+ v = None
# Run forward decode
output = backend.forward_decode(q, k, v, layer, fb)
# Shape and sanity checks
- expected_shape = (batch_size, config["num_attention_heads"] * config["v_head_dim"])
+ expected_shape = (
+ batch_size,
+ config["num_attention_heads"] * config["v_head_dim"],
+ )
self.assertEqual(
- output.shape, expected_shape,
- f"Output shape mismatch for {test_case['name']}"
+ output.shape,
+ expected_shape,
+ f"Output shape mismatch for {test_case['name']}",
)
self.assertEqual(output.dtype, config["dtype"])
self.assertEqual(output.device.type, "cuda")
self.assertFalse(
torch.isnan(output).any(),
- f"Output contains NaN for {test_case['name']}"
+ f"Output contains NaN for {test_case['name']}",
)
self.assertFalse(
torch.isinf(output).any(),
- f"Output contains Inf for {test_case['name']}"
+ f"Output contains Inf for {test_case['name']}",
)
def test_metadata_initialization(self):
"""Test TRTLLM MLA metadata initialization and structure."""
print(f"\nRunning metadata initialization tests...")
-
+
for test_case in TEST_CASES["metadata_tests"]:
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
-
+
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
# Create components
- model_runner, _, backend, _, layer = self._create_model_components(config)
+ model_runner, _, backend, _, layer = self._create_model_components(
+ config
+ )
# Create varied sequence lengths
torch.manual_seed(config["seed_cache"])
@@ -596,8 +652,10 @@ def test_metadata_initialization(self):
seq_lens = torch.tensor([max_seq_len], device=config["device"])
else:
seq_lens = torch.randint(
- max(1, max_seq_len // 4), max_seq_len + 1,
- (batch_size,), device=config["device"]
+ max(1, max_seq_len // 4),
+ max_seq_len + 1,
+ (batch_size,),
+ device=config["device"],
)
seq_lens[0] = max_seq_len # Ensure at least one max length
@@ -608,39 +666,42 @@ def test_metadata_initialization(self):
# Initialize metadata
backend.init_forward_metadata(fb)
-
+
# Verify metadata exists
self.assertIsNotNone(backend.forward_metadata)
- self.assertIsInstance(
- backend.forward_metadata,
- TRTLLMMLADecodeMetadata
- )
+ self.assertIsInstance(backend.forward_metadata, TRTLLMMLADecodeMetadata)
# Test metadata structure
metadata = backend.forward_metadata
- self.assertIsNotNone(metadata.workspace, "Workspace should be allocated")
- self.assertIsNotNone(metadata.block_kv_indices, "Block KV indices should be created")
-
+ self.assertIsNotNone(
+ metadata.workspace, "Workspace should be allocated"
+ )
+ self.assertIsNotNone(
+ metadata.block_kv_indices, "Block KV indices should be created"
+ )
+
# Test workspace properties
self.assertEqual(metadata.workspace.device.type, "cuda")
self.assertEqual(metadata.workspace.dtype, torch.int8)
- self.assertGreater(metadata.workspace.numel(), 0, "Workspace should have non-zero size")
-
+ self.assertGreater(
+ metadata.workspace.numel(), 0, "Workspace should have non-zero size"
+ )
+
# Test block KV indices properties
self.assertEqual(metadata.block_kv_indices.device.type, "cuda")
self.assertEqual(metadata.block_kv_indices.dtype, torch.int32)
self.assertEqual(metadata.block_kv_indices.shape[0], batch_size)
-
+
# Verify block indices are valid (>= -1, since -1 is padding)
self.assertTrue(
(metadata.block_kv_indices >= -1).all(),
- "All block indices should be >= -1 (with -1 as padding)"
+ "All block indices should be >= -1 (with -1 as padding)",
)
def test_metadata_block_calculation(self):
"""Test block count calculation logic."""
print(f"\nRunning metadata block calculation tests...")
-
+
test_scenarios = [
{"seq_len": 31, "page_size": 32, "expected_min_blocks": 1},
{"seq_len": 32, "page_size": 32, "expected_min_blocks": 1},
@@ -648,53 +709,62 @@ def test_metadata_block_calculation(self):
{"seq_len": 128, "page_size": 32, "expected_min_blocks": 4},
{"seq_len": 128, "page_size": 64, "expected_min_blocks": 2},
]
-
+
for scenario in test_scenarios:
with self.subTest(scenario=scenario):
- config = self._merge_config({
- "batch_size": 1,
- "max_seq_len": scenario["seq_len"],
- "page_size": scenario["page_size"]
- })
-
+ config = self._merge_config(
+ {
+ "batch_size": 1,
+ "max_seq_len": scenario["seq_len"],
+ "page_size": scenario["page_size"],
+ }
+ )
+
model_runner, _, backend, _, _ = self._create_model_components(config)
-
+
# Test internal block calculation
calculated_blocks = backend._calc_padded_blocks(scenario["seq_len"])
-
+
# Should be at least the minimum required
self.assertGreaterEqual(
- calculated_blocks, scenario["expected_min_blocks"],
- f"Calculated blocks ({calculated_blocks}) should be >= minimum required ({scenario['expected_min_blocks']})"
+ calculated_blocks,
+ scenario["expected_min_blocks"],
+ f"Calculated blocks ({calculated_blocks}) should be >= minimum required ({scenario['expected_min_blocks']})",
)
-
+
# Should satisfy page_size constraint
total_tokens = calculated_blocks * scenario["page_size"]
self.assertGreaterEqual(
- total_tokens, scenario["seq_len"],
- f"Total tokens ({total_tokens}) should cover sequence length ({scenario['seq_len']})"
+ total_tokens,
+ scenario["seq_len"],
+ f"Total tokens ({total_tokens}) should cover sequence length ({scenario['seq_len']})",
)
-
+
# Should satisfy TRT-LLM constraint (128 / page_size)
trtllm_constraint = 128 // scenario["page_size"]
self.assertEqual(
- calculated_blocks % trtllm_constraint, 0,
- f"Block count should be multiple of TRT-LLM constraint ({trtllm_constraint})"
+ calculated_blocks % trtllm_constraint,
+ 0,
+ f"Block count should be multiple of TRT-LLM constraint ({trtllm_constraint})",
)
def test_metadata_kv_indices_correctness(self):
"""Test KV indices creation and correctness."""
print(f"\nRunning KV indices correctness tests...")
-
- for test_case in TEST_CASES["metadata_tests"][:2]: # Test subset for performance
+
+ for test_case in TEST_CASES["metadata_tests"][
+ :2
+ ]: # Test subset for performance
with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}")
-
+
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
- model_runner, _, backend, _, layer = self._create_model_components(config)
+ model_runner, _, backend, _, layer = self._create_model_components(
+ config
+ )
# Create known sequence lengths
torch.manual_seed(config["seed_cache"])
@@ -702,8 +772,10 @@ def test_metadata_kv_indices_correctness(self):
seq_lens = torch.tensor([max_seq_len], device=config["device"])
else:
seq_lens = torch.randint(
- max_seq_len // 2, max_seq_len + 1,
- (batch_size,), device=config["device"]
+ max_seq_len // 2,
+ max_seq_len + 1,
+ (batch_size,),
+ device=config["device"],
)
fb = self._create_forward_batch(
@@ -711,7 +783,9 @@ def test_metadata_kv_indices_correctness(self):
)
# Populate some KV cache to have valid indices
- self._populate_kv_cache(batch_size, seq_lens, [model_runner], layer, config)
+ self._populate_kv_cache(
+ batch_size, seq_lens, [model_runner], layer, config
+ )
# Initialize metadata
backend.init_forward_metadata(fb)
@@ -719,57 +793,61 @@ def test_metadata_kv_indices_correctness(self):
# Verify KV indices structure
block_kv_indices = metadata.block_kv_indices
-
+
for i in range(batch_size):
seq_len = seq_lens[i].item()
expected_blocks = backend._calc_padded_blocks(seq_len)
-
+
# Count valid (non -1) indices for this sequence
valid_indices = (block_kv_indices[i] >= 0).sum().item()
-
+
# Should have at least enough blocks for the sequence
- min_required_blocks = (seq_len + config["page_size"] - 1) // config["page_size"]
+ min_required_blocks = (seq_len + config["page_size"] - 1) // config[
+ "page_size"
+ ]
self.assertGreaterEqual(
- valid_indices, min_required_blocks,
- f"Sequence {i} should have at least {min_required_blocks} valid blocks, got {valid_indices}"
+ valid_indices,
+ min_required_blocks,
+ f"Sequence {i} should have at least {min_required_blocks} valid blocks, got {valid_indices}",
)
-
+
# Verify indices are within valid range
valid_block_indices = block_kv_indices[i][block_kv_indices[i] >= 0]
if len(valid_block_indices) > 0:
- max_possible_blocks = model_runner.token_to_kv_pool.size // config["page_size"]
+ max_possible_blocks = (
+ model_runner.token_to_kv_pool.size // config["page_size"]
+ )
self.assertTrue(
(valid_block_indices < max_possible_blocks).all(),
- f"All block indices should be < {max_possible_blocks}"
+ f"All block indices should be < {max_possible_blocks}",
)
def test_metadata_cuda_graph_compatibility(self):
"""Test metadata compatibility with CUDA graph capture/replay."""
print(f"\nRunning CUDA graph compatibility tests...")
-
- config = self._merge_config({
- "batch_size": 4,
- "max_seq_len": 64,
- "page_size": 32
- })
-
+
+ config = self._merge_config(
+ {"batch_size": 4, "max_seq_len": 64, "page_size": 32}
+ )
+
model_runner, _, backend, _, layer = self._create_model_components(config)
batch_size = config["batch_size"]
-
+
# Initialize CUDA graph state
backend.init_cuda_graph_state(
- max_bs=batch_size,
- max_num_tokens=config["max_seq_len"] * batch_size
+ max_bs=batch_size, max_num_tokens=config["max_seq_len"] * batch_size
)
-
+
# Verify CUDA graph buffers are allocated
self.assertIsNotNone(backend.cuda_graph_kv_indices)
self.assertIsNotNone(backend.cuda_graph_workspace)
-
+
# Test capture metadata
- seq_lens = torch.full((batch_size,), config["max_seq_len"], device=config["device"])
+ seq_lens = torch.full(
+ (batch_size,), config["max_seq_len"], device=config["device"]
+ )
req_pool_indices = torch.arange(batch_size, device=config["device"])
-
+
backend.init_forward_metadata_capture_cuda_graph(
bs=batch_size,
num_tokens=batch_size,
@@ -779,20 +857,22 @@ def test_metadata_cuda_graph_compatibility(self):
forward_mode=ForwardMode.DECODE,
spec_info=None,
)
-
+
# Verify capture metadata
self.assertIn(batch_size, backend.decode_cuda_graph_metadata)
capture_metadata = backend.decode_cuda_graph_metadata[batch_size]
-
+
self.assertIsNotNone(capture_metadata.workspace)
self.assertIsNotNone(capture_metadata.block_kv_indices)
-
+
# Test replay with different sequence lengths
new_seq_lens = torch.randint(
- config["max_seq_len"] // 2, config["max_seq_len"] + 1,
- (batch_size,), device=config["device"]
+ config["max_seq_len"] // 2,
+ config["max_seq_len"] + 1,
+ (batch_size,),
+ device=config["device"],
)
-
+
backend.init_forward_metadata_replay_cuda_graph(
bs=batch_size,
req_pool_indices=req_pool_indices,
@@ -803,24 +883,24 @@ def test_metadata_cuda_graph_compatibility(self):
spec_info=None,
seq_lens_cpu=new_seq_lens.cpu(),
)
-
+
# Verify replay updated the metadata
replay_metadata = backend.forward_metadata
self.assertIsNotNone(replay_metadata)
- self.assertEqual(replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr())
+ self.assertEqual(
+ replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr()
+ )
def test_metadata_consistency_across_calls(self):
"""Test metadata consistency across multiple forward calls."""
print(f"\nRunning metadata consistency tests...")
-
- config = self._merge_config({
- "batch_size": 2,
- "max_seq_len": 64,
- "page_size": 32
- })
-
+
+ config = self._merge_config(
+ {"batch_size": 2, "max_seq_len": 64, "page_size": 32}
+ )
+
model_runner, _, backend, _, layer = self._create_model_components(config)
-
+
# First call
seq_lens_1 = torch.tensor([32, 48], device=config["device"])
fb_1 = self._create_forward_batch(
@@ -828,7 +908,7 @@ def test_metadata_consistency_across_calls(self):
)
backend.init_forward_metadata(fb_1)
metadata_1 = backend.forward_metadata
-
+
# Second call with same sequence lengths
seq_lens_2 = torch.tensor([32, 48], device=config["device"])
fb_2 = self._create_forward_batch(
@@ -836,11 +916,13 @@ def test_metadata_consistency_across_calls(self):
)
backend.init_forward_metadata(fb_2)
metadata_2 = backend.forward_metadata
-
+
# Metadata structure should be consistent
self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape)
- self.assertEqual(metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape)
-
+ self.assertEqual(
+ metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape
+ )
+
# Third call with different sequence lengths
seq_lens_3 = torch.tensor([16, 64], device=config["device"])
fb_3 = self._create_forward_batch(
@@ -848,7 +930,7 @@ def test_metadata_consistency_across_calls(self):
)
backend.init_forward_metadata(fb_3)
metadata_3 = backend.forward_metadata
-
+
# Should still have valid structure
self.assertIsNotNone(metadata_3.workspace)
self.assertIsNotNone(metadata_3.block_kv_indices)
@@ -856,4 +938,4 @@ def test_metadata_consistency_across_calls(self):
if __name__ == "__main__":
- unittest.main()
\ No newline at end of file
+ unittest.main()
From 00125a2fe243377d36a823e3bfd83802916ebd5f Mon Sep 17 00:00:00 2001
From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Date: Mon, 21 Jul 2025 18:41:39 -0700
Subject: [PATCH 11/27] updated docs
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
---
docs/backend/attention_backend.md | 9 +++++++++
docs/references/deepseek.md | 6 +++---
2 files changed, 12 insertions(+), 3 deletions(-)
diff --git a/docs/backend/attention_backend.md b/docs/backend/attention_backend.md
index caf23446f5a6..3c18e76bc69a 100644
--- a/docs/backend/attention_backend.md
+++ b/docs/backend/attention_backend.md
@@ -9,8 +9,12 @@
| **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ |
| **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ |
| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ |
+| **TRTLLM MLA** | ✅ | ❌* | ✅ | ✅ | ❌** |
| **Ascend** | ✅ | ❌ | ❌ | ❌ | ❌ |
+**Notes:**
+- \*\*TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend.
+
Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`.
This is because a page size of 16 can be converted to a page size of 1 in the kernel backend.
The "❌" and "✅" symbols in the table above under "Page Size > 1" indicate whether the kernel actually operates with a page size greater than 1, rather than treating a page size of 16 as a page size of 1.
@@ -48,6 +52,11 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code
```
+- TRTLLM MLA (Optimized for Blackwell Architecture, e.g., B200)
+```bash
+python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --trust-remote-code
+```
+
- Ascend
```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md
index 8b6d688d1507..117cbf65863c 100644
--- a/docs/references/deepseek.md
+++ b/docs/references/deepseek.md
@@ -90,7 +90,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
-- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), [FlashMLA](https://github.com/deepseek-ai/FlashMLA), [CutlassMLA](https://github.com/sgl-project/sglang/pull/5390), and [Triton](https://github.com/triton-lang/triton) backends. The default FA3 provides good performance across wide workloads.
+- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), [FlashMLA](https://github.com/deepseek-ai/FlashMLA), [CutlassMLA](https://github.com/sgl-project/sglang/pull/5390), **TRTLLM MLA** (optimized for Blackwell architecture), and [Triton](https://github.com/triton-lang/triton) backends. The default FA3 provides good performance across wide workloads.
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
@@ -104,7 +104,7 @@ Overall, with these optimizations, we have achieved up to **7x** acceleration in