From f32d15795260522313126582e7d8046d2a562188 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Thu, 10 Jul 2025 20:42:50 +0000 Subject: [PATCH 01/27] trtllm gen mla initial commit Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- python/pyproject.toml | 1 + .../layers/attention/trtllm_mla_backend.py | 200 ++++++++++++++++ .../sglang/srt/model_executor/model_runner.py | 20 +- .../test/attention/test_trtllm_mla_backend.py | 224 ++++++++++++++++++ 4 files changed, 444 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/layers/attention/trtllm_mla_backend.py create mode 100644 python/sglang/test/attention/test_trtllm_mla_backend.py diff --git a/python/pyproject.toml b/python/pyproject.toml index d916fcb57e6c..c79681ae1c3d 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -75,6 +75,7 @@ blackwell = [ "tiktoken", ] + # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20250114, not from public vllm whl srt_hip = [ diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py new file mode 100644 index 000000000000..4534249d083e --- /dev/null +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +""" +Support attention backend for TRTLLM MLA kernels from flashinfer. +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union + +import torch +import triton + +from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend +from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.utils import is_flashinfer_available + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInfo + +if is_flashinfer_available(): + import flashinfer + + +# TRTLLM MLA supports variable page sizes + + +@dataclass +class TRTLLMMLADecodeMetadata: + workspace: Optional[torch.Tensor] = None + block_kv_indices: Optional[torch.Tensor] = None + + +class TRTLLMMLABackend(FlashInferMLAAttnBackend): + """TRTLLM MLA attention kernels from flashinfer.""" + + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + kv_last_page_len_buf: Optional[torch.Tensor] = None, + ): + super().__init__( + model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf + ) + + # Model parameters + self.num_q_heads = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.num_kv_heads = model_runner.model_config.get_num_kv_heads( + get_attention_tp_size() + ) + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.num_local_heads = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.forward_metadata: Union[TRTLLMMLADecodeMetadata] = None + self.kv_lora_rank = model_runner.model_config.kv_lora_rank + self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim + self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim + self.v_head_dim = model_runner.model_config.v_head_dim + self.scaling = model_runner.model_config.scaling + self.data_type = model_runner.kv_cache_dtype + self.q_data_type = model_runner.dtype + self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim + self.page_size = model_runner.page_size # Use page size from model runner + + # Validate dimensions for TRTLLM MLA (based on test requirements) + if self.qk_nope_head_dim != 128: + raise ValueError(f"TRTLLM MLA requires qk_nope_head_dim=128, got {self.qk_nope_head_dim}") + if self.kv_lora_rank != 512: + raise ValueError(f"TRTLLM MLA requires kv_lora_rank=512, got {self.kv_lora_rank}") + if self.qk_rope_head_dim != 64: + raise ValueError(f"TRTLLM MLA requires qk_rope_head_dim=64, got {self.qk_rope_head_dim}") + + # Allocate larger workspace for TRTLLM (128MB as in the test) + self.workspace_size = 128 * 1024 * 1024 + self.workspace_buffer = torch.empty( + self.workspace_size, dtype=torch.int8, device=self.device + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init the metadata for a forward pass.""" + bs = forward_batch.batch_size + spec_info = forward_batch.spec_info + + if forward_batch.forward_mode.is_decode_or_idle(): + if spec_info is None: + # Calculate max sequence length padded to page boundary + max_seqlen_pad = triton.cdiv( + forward_batch.seq_lens_cpu.max().item(), self.page_size + ) + + # Create block indices + block_kv_indices = torch.full( + (bs, max_seqlen_pad), + -1, + dtype=torch.int32, + device=forward_batch.seq_lens.device, + ) + + # Fill block indices using the existing triton kernel + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + None, + block_kv_indices, + self.req_to_token.stride(0), + max_seqlen_pad, + self.page_size, + ) + + forward_batch.decode_trtllm_mla_metadata = TRTLLMMLADecodeMetadata( + self.workspace_buffer, + block_kv_indices, + ) + self.forward_metadata = forward_batch.decode_trtllm_mla_metadata + else: + # Speculative decoding: use parent class implementation + super().init_forward_metadata(forward_batch) + else: + # Prefill: use parent class implementation + super().init_forward_metadata(forward_batch) + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + # For multi-head latent attention + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, + ): + """Run forward for decode using TRTLLM kernel.""" + cache_loc = forward_batch.out_cache_loc + + if k is not None: + assert v is not None + if save_kv_cache: + if k_rope is not None: + # MLA style KV cache storage + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + layer, + cache_loc, + k, + k_rope, + ) + else: + # Standard KV cache storage + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + cache_loc, + k, + v, + ) + + # Prepare query tensor - concatenate q_nope and q_rope + if q_rope is not None: + # q and q_rope are separate + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.qk_rope_head_dim) + query = torch.cat([q_nope, q_rope], dim=-1) + else: + # q already contains both nope and rope parts + query = q.view(-1, layer.tp_q_head_num, layer.head_dim) + + # Get KV cache + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + # Reshape KV cache for TRTLLM format + kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim) + + # Call TRTLLM MLA decode kernel + output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + query=query, + kv_cache=kv_cache, + workspace_buffer=forward_batch.decode_trtllm_mla_metadata.workspace, + qk_nope_head_dim=self.qk_nope_head_dim, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + block_tables=forward_batch.decode_trtllm_mla_metadata.block_kv_indices, + seq_lens=forward_batch.seq_lens.to(torch.int32), + block_size=self.page_size, + max_seq_len=forward_batch.seq_lens.max().item(), + scale=layer.scaling, + out=None, + bmm1_scale=1.0, # Only needed for FP8 + bmm2_scale=1.0, # Only needed for FP8 + ) + + return output.view(-1, layer.tp_q_head_num * layer.v_head_dim) \ No newline at end of file diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 13555adeb186..0f53e52285e5 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -396,7 +396,18 @@ def model_specific_adjustment(self): ): server_args.attention_backend = "fa3" elif is_sm100_supported(): - server_args.attention_backend = "flashinfer" + # On Blackwell, prefer TRTLLM MLA if available, otherwise flashinfer + if is_flashinfer_available(): + try: + import flashinfer + if hasattr(flashinfer.decode, 'trtllm_batch_decode_with_kv_cache_mla'): + server_args.attention_backend = "trtllm_mla" + else: + server_args.attention_backend = "flashinfer" + except: + server_args.attention_backend = "flashinfer" + else: + server_args.attention_backend = "flashinfer" elif _is_hip: head_num = self.model_config.get_num_kv_heads(self.tp_size) # TODO current aiter only support head number 16 or 128 head number @@ -422,6 +433,7 @@ def model_specific_adjustment(self): "triton", "flashmla", "cutlass_mla", + "trtllm_mla", "ascend", ]: logger.info( @@ -1430,6 +1442,12 @@ def _get_attention_backend_from_str(self, backend_str: str): ) return CutlassMLABackend(self) + elif self.server_args.attention_backend == "trtllm_mla": + from sglang.srt.layers.attention.trtllm_mla_backend import ( + TRTLLMMLABackend, + ) + + return TRTLLMMLABackend(self) elif self.server_args.attention_backend == "intel_amx": from sglang.srt.layers.attention.intel_amx_backend import ( IntelAMXAttnBackend, diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py new file mode 100644 index 000000000000..b2e05874d663 --- /dev/null +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -0,0 +1,224 @@ +import unittest + +import pytest +import torch + +from sglang.srt.configs.model_config import AttentionArch +from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.utils import is_flashinfer_available +from sglang.test.test_utils import CustomTestCase + + +class MockModelRunner: + def __init__( + self, + kv_lora_rank, + qk_rope_head_dim, + page_size=16, + ): + attention_arch = AttentionArch.MLA + self.device = "cuda" + self.dtype = torch.bfloat16 + self.kv_cache_dtype = torch.bfloat16 + context_len = 2048 + self.model_config = type( + "ModelConfig", + (), + { + "context_len": context_len, + "attention_arch": attention_arch, + "num_attention_heads": 128, + "kv_lora_rank": kv_lora_rank, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": qk_rope_head_dim, + "v_head_dim": kv_lora_rank, + "scaling": 1.0 / ((128 + 64) ** 0.5), + }, + ) + self.sliding_window_size = None + self.page_size = page_size + + batch_size = 256 + # Create a proper req_to_token_pool with the req_to_token attribute + self.req_to_token_pool = type( + "TokenPool", + (), + { + "size": batch_size, + "req_to_token": torch.zeros( + batch_size, context_len, dtype=torch.int32, device=self.device + ), + }, + ) + + max_total_num_tokens = batch_size * context_len + self.token_to_kv_pool = MLATokenToKVPool( + size=max_total_num_tokens, + page_size=self.page_size, + dtype=self.dtype, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + layer_num=1, # only consider layer=1 for unit test + device=self.device, + enable_memory_saver=False, + ) + + def get_num_kv_heads(self, tp_size): + """MLA uses single KV head.""" + return 1 + + +@pytest.mark.skipif( + not torch.cuda.is_available() or not is_flashinfer_available(), + reason="Test requires CUDA and flashinfer" +) +@pytest.mark.parametrize("batch_size", [16, 32, 64]) +@pytest.mark.parametrize("page_size", [16, 32, 64]) +@pytest.mark.parametrize("seq_len", [256, 512, 1024]) +class TestTRTLLMMLABackend(CustomTestCase): + def test_trtllm_decode_mla(self, batch_size, page_size, seq_len): + """Test TRTLLM MLA decode operation with various configurations.""" + # Check if PyTorch supports current GPU + if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + archs = torch.cuda.get_arch_list() + current_arch = f"sm_{capability[0]}{capability[1]}" + supported = any(current_arch in arch for arch in archs) + if not supported: + pytest.skip(f"PyTorch doesn't support {current_arch} - need nightly build") + + # DeepSeek MLA configuration + num_q_heads = 128 + kv_lora_rank = 512 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + device = "cuda" + dtype = torch.bfloat16 + + # Initialize model runner and backend + model_runner = MockModelRunner( + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + page_size=page_size, + ) + + # Check if flashinfer has TRTLLM MLA support + try: + import flashinfer + if not hasattr(flashinfer.decode, 'trtllm_batch_decode_with_kv_cache_mla'): + pytest.skip("flashinfer version does not have TRTLLM MLA support") + except ImportError: + pytest.skip("flashinfer not available") + + backend = TRTLLMMLABackend(model_runner) + + # Create attention layer + layer = RadixAttention( + num_heads=num_q_heads, + head_dim=kv_lora_rank + qk_rope_head_dim, + scaling=model_runner.model_config.scaling, + num_kv_heads=1, + layer_id=0, + v_head_dim=kv_lora_rank, + prefix="attn_mqa", + ) + + # Generate sequence lengths + seq_lens = torch.randint(1, seq_len, (batch_size,), device=device) + seq_lens[-1] = seq_len # Ensure at least one max length + max_seq_len = seq_lens.max().item() + + # Calculate blocks needed + blocks_per_seq = (seq_lens + page_size - 1) // page_size + total_blocks_needed = blocks_per_seq.sum().item() + + # Create req_to_token mapping + req_to_token = torch.zeros(batch_size, seq_len, dtype=torch.int32, device=device) + token_offset = 0 + for i in range(batch_size): + seq_len_i = seq_lens[i].item() + req_to_token[i, :seq_len_i] = torch.arange( + token_offset, token_offset + seq_len_i, device=device + ) + token_offset += seq_len_i + + model_runner.req_to_token_pool.req_to_token = req_to_token + + # Create forward batch for decode + forward_batch = ForwardBatch( + batch_size=batch_size, + input_ids=torch.randint(0, 100, (batch_size, 1), device=device), + out_cache_loc=torch.arange(batch_size, device=device), + seq_lens_sum=seq_lens.sum().item(), + forward_mode=ForwardMode.DECODE, + req_pool_indices=torch.arange(batch_size, device=device), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.cpu(), + attn_backend=backend, + ) + + # Add pools to forward batch + forward_batch.req_to_token_pool = model_runner.req_to_token_pool + forward_batch.token_to_kv_pool = model_runner.token_to_kv_pool + + # Fill KV cache with some data + cache_data = torch.randn( + seq_lens.sum().item(), + 1, # num_kv_heads + kv_lora_rank + qk_rope_head_dim, + dtype=dtype, + device=device, + ) + cache_indices = torch.arange(seq_lens.sum().item(), device=device) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_indices, cache_data, None + ) + + # Initialize metadata + backend.init_forward_metadata(forward_batch) + + # Create input tensors + q_shape = (batch_size, num_q_heads, kv_lora_rank + qk_rope_head_dim) + q = torch.randn(q_shape, dtype=dtype, device=device) + + # For MLA, k contains compressed KV, v is not used + k = torch.randn(batch_size, 1, kv_lora_rank + qk_rope_head_dim, dtype=dtype, device=device) + v = None + + # Run forward decode + output = backend.forward_decode(q, k, v, layer, forward_batch) + + # Verify output + expected_shape = (batch_size, num_q_heads * kv_lora_rank) + assert output.shape == expected_shape, f"Expected shape {expected_shape}, got {output.shape}" + assert output.dtype == dtype + assert output.device.type == "cuda" + assert not torch.isnan(output).any(), "Output contains NaN values" + assert not torch.isinf(output).any(), "Output contains Inf values" + + +# Simplified test for quick verification +@pytest.mark.skipif( + not torch.cuda.is_available() or not is_flashinfer_available(), + reason="Test requires CUDA and flashinfer" +) +def test_trtllm_mla_basic(): + """Basic test to verify TRTLLM MLA backend works.""" + # Check if flashinfer has TRTLLM MLA support + try: + import flashinfer + if not hasattr(flashinfer.decode, 'trtllm_batch_decode_with_kv_cache_mla'): + pytest.skip("flashinfer version does not have TRTLLM MLA support") + except ImportError: + pytest.skip("flashinfer not available") + + test = TestTRTLLMMLABackend() + test.test_trtllm_decode_mla(batch_size=32, page_size=32, seq_len=512) + print("TRTLLM MLA basic test passed!") + + +if __name__ == "__main__": + test_trtllm_mla_basic() \ No newline at end of file From 3bca375ceaabc939942d5f72788606a7a7a735fa Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Fri, 11 Jul 2025 09:51:15 -0700 Subject: [PATCH 02/27] Unittest passing Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- .../layers/attention/trtllm_mla_backend.py | 37 +-- python/sglang/srt/server_args.py | 1 + .../test/attention/test_trtllm_mla_backend.py | 280 +++++++----------- 3 files changed, 121 insertions(+), 197 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 4534249d083e..9fae41785c64 100644 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -143,19 +143,18 @@ def forward_decode( """Run forward for decode using TRTLLM kernel.""" cache_loc = forward_batch.out_cache_loc - if k is not None: - assert v is not None - if save_kv_cache: - if k_rope is not None: - # MLA style KV cache storage - forward_batch.token_to_kv_pool.set_mla_kv_buffer( - layer, - cache_loc, - k, - k_rope, - ) - else: - # Standard KV cache storage + if k is not None and save_kv_cache: + if k_rope is not None: + # MLA style KV cache storage + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + layer, + cache_loc, + k, + k_rope, + ) + else: + # Standard KV cache storage path. Skip if value tensor is absent (e.g., MLA decode tests). + if v is not None: forward_batch.token_to_kv_pool.set_kv_buffer( layer, cache_loc, @@ -176,11 +175,11 @@ def forward_decode( # Get KV cache k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - # Reshape KV cache for TRTLLM format - kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim) + # Reshape KV cache to 4-D (num_kv_heads, num_blocks, page_size, kv_dim) + kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(0) # 1 KV head # Call TRTLLM MLA decode kernel - output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, kv_cache=kv_cache, workspace_buffer=forward_batch.decode_trtllm_mla_metadata.workspace, @@ -196,5 +195,7 @@ def forward_decode( bmm1_scale=1.0, # Only needed for FP8 bmm2_scale=1.0, # Only needed for FP8 ) - - return output.view(-1, layer.tp_q_head_num * layer.v_head_dim) \ No newline at end of file + output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) + if output.shape[0] > forward_batch.batch_size: + output = output[: forward_batch.batch_size] + return output \ No newline at end of file diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index dc0c6cd1acab..bd5a8336aab2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1226,6 +1226,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "torch_native", "ascend", "triton", + "trtllm_mla", ], default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py index b2e05874d663..7f4f91d0e588 100644 --- a/python/sglang/test/attention/test_trtllm_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -1,224 +1,146 @@ import unittest - -import pytest import torch +from sglang.srt.layers import dp_attention as _dp_attn + +# Patch DP-attention globals **before** importing the backend so that all +# downstream `from … import get_attention_tp_size` statements receive the +# patched version. +_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test + from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.utils import is_flashinfer_available from sglang.test.test_utils import CustomTestCase +from sglang.srt.utils import is_flashinfer_available class MockModelRunner: - def __init__( - self, - kv_lora_rank, - qk_rope_head_dim, - page_size=16, - ): - attention_arch = AttentionArch.MLA + """Minimal fake `ModelRunner` for MLA backend unit tests.""" + + def __init__(self, page_size: int): self.device = "cuda" self.dtype = torch.bfloat16 self.kv_cache_dtype = torch.bfloat16 - context_len = 2048 + self.page_size = page_size + + # Model-config stub – only the attributes accessed by the backend. self.model_config = type( "ModelConfig", (), { - "context_len": context_len, - "attention_arch": attention_arch, + "context_len": 2048, + "attention_arch": AttentionArch.MLA, "num_attention_heads": 128, - "kv_lora_rank": kv_lora_rank, + "kv_lora_rank": 512, "qk_nope_head_dim": 128, - "qk_rope_head_dim": qk_rope_head_dim, - "v_head_dim": kv_lora_rank, + "qk_rope_head_dim": 64, + "v_head_dim": 512, "scaling": 1.0 / ((128 + 64) ** 0.5), + "get_num_kv_heads": staticmethod(lambda _: 1), }, ) - self.sliding_window_size = None - self.page_size = page_size - batch_size = 256 - # Create a proper req_to_token_pool with the req_to_token attribute + # Req-to-token pool (dummy) + max_bs = 64 + max_ctx = self.model_config.context_len self.req_to_token_pool = type( "TokenPool", (), { - "size": batch_size, - "req_to_token": torch.zeros( - batch_size, context_len, dtype=torch.int32, device=self.device - ), + "size": max_bs, + "req_to_token": torch.zeros(max_bs, max_ctx, dtype=torch.int32, device=self.device), }, ) - - max_total_num_tokens = batch_size * context_len + + # KV-token pool self.token_to_kv_pool = MLATokenToKVPool( - size=max_total_num_tokens, - page_size=self.page_size, - dtype=self.dtype, - kv_lora_rank=kv_lora_rank, - qk_rope_head_dim=qk_rope_head_dim, - layer_num=1, # only consider layer=1 for unit test + size=max_bs * max_ctx, + page_size=page_size, + dtype=self.kv_cache_dtype, + kv_lora_rank=512, + qk_rope_head_dim=64, + layer_num=1, device=self.device, enable_memory_saver=False, ) - def get_num_kv_heads(self, tp_size): - """MLA uses single KV head.""" - return 1 - -@pytest.mark.skipif( - not torch.cuda.is_available() or not is_flashinfer_available(), - reason="Test requires CUDA and flashinfer" -) -@pytest.mark.parametrize("batch_size", [16, 32, 64]) -@pytest.mark.parametrize("page_size", [16, 32, 64]) -@pytest.mark.parametrize("seq_len", [256, 512, 1024]) +@unittest.skipIf(not torch.cuda.is_available() or not is_flashinfer_available(), "CUDA + flashinfer required") class TestTRTLLMMLABackend(CustomTestCase): - def test_trtllm_decode_mla(self, batch_size, page_size, seq_len): - """Test TRTLLM MLA decode operation with various configurations.""" - # Check if PyTorch supports current GPU - if torch.cuda.is_available(): - capability = torch.cuda.get_device_capability() - archs = torch.cuda.get_arch_list() - current_arch = f"sm_{capability[0]}{capability[1]}" - supported = any(current_arch in arch for arch in archs) - if not supported: - pytest.skip(f"PyTorch doesn't support {current_arch} - need nightly build") - - # DeepSeek MLA configuration - num_q_heads = 128 - kv_lora_rank = 512 - qk_nope_head_dim = 128 - qk_rope_head_dim = 64 - device = "cuda" - dtype = torch.bfloat16 - - # Initialize model runner and backend - model_runner = MockModelRunner( - kv_lora_rank=kv_lora_rank, - qk_rope_head_dim=qk_rope_head_dim, - page_size=page_size, - ) - - # Check if flashinfer has TRTLLM MLA support - try: - import flashinfer - if not hasattr(flashinfer.decode, 'trtllm_batch_decode_with_kv_cache_mla'): - pytest.skip("flashinfer version does not have TRTLLM MLA support") - except ImportError: - pytest.skip("flashinfer not available") - - backend = TRTLLMMLABackend(model_runner) - - # Create attention layer - layer = RadixAttention( - num_heads=num_q_heads, - head_dim=kv_lora_rank + qk_rope_head_dim, - scaling=model_runner.model_config.scaling, - num_kv_heads=1, - layer_id=0, - v_head_dim=kv_lora_rank, - prefix="attn_mqa", - ) - - # Generate sequence lengths - seq_lens = torch.randint(1, seq_len, (batch_size,), device=device) - seq_lens[-1] = seq_len # Ensure at least one max length - max_seq_len = seq_lens.max().item() - - # Calculate blocks needed - blocks_per_seq = (seq_lens + page_size - 1) // page_size - total_blocks_needed = blocks_per_seq.sum().item() - - # Create req_to_token mapping - req_to_token = torch.zeros(batch_size, seq_len, dtype=torch.int32, device=device) - token_offset = 0 - for i in range(batch_size): - seq_len_i = seq_lens[i].item() - req_to_token[i, :seq_len_i] = torch.arange( - token_offset, token_offset + seq_len_i, device=device - ) - token_offset += seq_len_i - - model_runner.req_to_token_pool.req_to_token = req_to_token - - # Create forward batch for decode - forward_batch = ForwardBatch( - batch_size=batch_size, - input_ids=torch.randint(0, 100, (batch_size, 1), device=device), - out_cache_loc=torch.arange(batch_size, device=device), - seq_lens_sum=seq_lens.sum().item(), + """Structure mirrors `test_flashattn_backend.py` but focuses on MLA decode.""" + + def setUp(self): + self.batch_size = 16 + self.seq_len = 512 + self.page_sizes = [16, 32, 64] + self.device = "cuda" + self.dtype = torch.bfloat16 + + # ‑- helpers --------------------------------------------------------- + def _init(self, page_size: int): + self.model_runner = MockModelRunner(page_size) + self.backend = TRTLLMMLABackend(self.model_runner) + # Attach num_heads required by RadixAttention convenience + self.model_runner.model_config.num_attention_heads = 128 + + def _alloc_qkv(self): + head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim + q_shape = (self.batch_size, 128, head_dim) + q = torch.randn(q_shape, dtype=self.dtype, device=self.device) + k = torch.randn(self.batch_size, 1, head_dim, dtype=self.dtype, device=self.device) + v = None # TRTLLM MLA decode kernel ignores v + return q, k, v + + def _create_forward_batch(self, seq_lens: torch.Tensor): + fb = ForwardBatch( + batch_size=self.batch_size, + input_ids=torch.randint(0, 100, (self.batch_size, 1), device=self.device), + out_cache_loc=torch.arange(self.batch_size, device=self.device), + seq_lens_sum=int(seq_lens.sum().item()), forward_mode=ForwardMode.DECODE, - req_pool_indices=torch.arange(batch_size, device=device), + req_pool_indices=torch.arange(self.batch_size, device=self.device), seq_lens=seq_lens, seq_lens_cpu=seq_lens.cpu(), - attn_backend=backend, - ) - - # Add pools to forward batch - forward_batch.req_to_token_pool = model_runner.req_to_token_pool - forward_batch.token_to_kv_pool = model_runner.token_to_kv_pool - - # Fill KV cache with some data - cache_data = torch.randn( - seq_lens.sum().item(), - 1, # num_kv_heads - kv_lora_rank + qk_rope_head_dim, - dtype=dtype, - device=device, + attn_backend=self.backend, ) - cache_indices = torch.arange(seq_lens.sum().item(), device=device) - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_indices, cache_data, None - ) - - # Initialize metadata - backend.init_forward_metadata(forward_batch) - - # Create input tensors - q_shape = (batch_size, num_q_heads, kv_lora_rank + qk_rope_head_dim) - q = torch.randn(q_shape, dtype=dtype, device=device) - - # For MLA, k contains compressed KV, v is not used - k = torch.randn(batch_size, 1, kv_lora_rank + qk_rope_head_dim, dtype=dtype, device=device) - v = None - - # Run forward decode - output = backend.forward_decode(q, k, v, layer, forward_batch) - - # Verify output - expected_shape = (batch_size, num_q_heads * kv_lora_rank) - assert output.shape == expected_shape, f"Expected shape {expected_shape}, got {output.shape}" - assert output.dtype == dtype - assert output.device.type == "cuda" - assert not torch.isnan(output).any(), "Output contains NaN values" - assert not torch.isinf(output).any(), "Output contains Inf values" - - -# Simplified test for quick verification -@pytest.mark.skipif( - not torch.cuda.is_available() or not is_flashinfer_available(), - reason="Test requires CUDA and flashinfer" -) -def test_trtllm_mla_basic(): - """Basic test to verify TRTLLM MLA backend works.""" - # Check if flashinfer has TRTLLM MLA support - try: - import flashinfer - if not hasattr(flashinfer.decode, 'trtllm_batch_decode_with_kv_cache_mla'): - pytest.skip("flashinfer version does not have TRTLLM MLA support") - except ImportError: - pytest.skip("flashinfer not available") - - test = TestTRTLLMMLABackend() - test.test_trtllm_decode_mla(batch_size=32, page_size=32, seq_len=512) - print("TRTLLM MLA basic test passed!") + fb.req_to_token_pool = self.model_runner.req_to_token_pool + fb.token_to_kv_pool = self.model_runner.token_to_kv_pool + return fb + + # ‑- actual tests ---------------------------------------------------- + def test_forward_decode(self): + """Smoke test decode across several page sizes.""" + for ps in self.page_sizes: + self._init(ps) + + # Random seq lens (ensure one matches max) + seq_lens = torch.randint(1, self.seq_len, (self.batch_size,), device=self.device) + seq_lens[0] = self.seq_len + + forward_batch = self._create_forward_batch(seq_lens) + self.backend.init_forward_metadata(forward_batch) + + q, k, v = self._alloc_qkv() + layer = RadixAttention( + num_heads=128, + head_dim=512 + 64, + scaling=self.model_runner.model_config.scaling, + num_kv_heads=1, + layer_id=0, + v_head_dim=512, + prefix="attn_mqa", + ) + out = self.backend.forward_decode(q, k, v, layer, forward_batch) + + self.assertEqual(out.shape, (self.batch_size, 128 * 512)) + self.assertEqual(out.dtype, self.dtype) + self.assertEqual(out.device.type, "cuda") + self.assertFalse(torch.isnan(out).any()) + self.assertFalse(torch.isinf(out).any()) if __name__ == "__main__": - test_trtllm_mla_basic() \ No newline at end of file + unittest.main() \ No newline at end of file From 15f3b8079d5d75ac5e7d07c2ceb51c0035453169 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Sun, 13 Jul 2025 18:42:11 -0700 Subject: [PATCH 03/27] diff output vs flashinfer mla Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- .../layers/attention/trtllm_mla_backend.py | 265 ++++++++++++++---- .../test/attention/test_trtllm_mla_backend.py | 10 +- .../test_trtllm_vs_flashinfer_mla.py | 218 ++++++++++++++ 3 files changed, 431 insertions(+), 62 deletions(-) create mode 100644 python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 9fae41785c64..14d6bad3e271 100644 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -9,6 +9,8 @@ import torch import triton +import math # Needed for scale correction +import os from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton @@ -16,17 +18,14 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available +if is_flashinfer_available(): + import flashinfer + if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.spec_info import SpecInfo -if is_flashinfer_available(): - import flashinfer - - -# TRTLLM MLA supports variable page sizes - @dataclass class TRTLLMMLADecodeMetadata: @@ -70,41 +69,166 @@ def __init__( self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim self.page_size = model_runner.page_size # Use page size from model runner - # Validate dimensions for TRTLLM MLA (based on test requirements) - if self.qk_nope_head_dim != 128: - raise ValueError(f"TRTLLM MLA requires qk_nope_head_dim=128, got {self.qk_nope_head_dim}") - if self.kv_lora_rank != 512: - raise ValueError(f"TRTLLM MLA requires kv_lora_rank=512, got {self.kv_lora_rank}") - if self.qk_rope_head_dim != 64: - raise ValueError(f"TRTLLM MLA requires qk_rope_head_dim=64, got {self.qk_rope_head_dim}") - # Allocate larger workspace for TRTLLM (128MB as in the test) self.workspace_size = 128 * 1024 * 1024 self.workspace_buffer = torch.empty( self.workspace_size, dtype=torch.int8, device=self.device ) + + # CUDA graph metadata storage + self.decode_cuda_graph_metadata = {} + self.cuda_graph_kv_indices = None + + def _calc_padded_blocks(self, max_seq_len: int) -> int: + """Return number of blocks padded so that it satisfies TRTLLM constraint.""" + blocks = triton.cdiv(max_seq_len, self.page_size) + min_blocks = 128 // self.page_size # kernel requirement + if blocks % min_blocks != 0: + blocks = triton.cdiv(blocks, min_blocks) * min_blocks + return blocks + + def init_cuda_graph_state( + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, + ): + """Initialize CUDA graph state for TRTLLM MLA.""" + # Calculate padded block size that satisfies TRTLLM constraint + max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len) + + self.cuda_graph_kv_indices = torch.full( + (max_bs, max_blocks_per_seq), + -1, + dtype=torch.int32, + device=self.device, + ) + self.cuda_graph_workspace = torch.empty( + self.workspace_size, dtype=torch.int8, device=self.device + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + ): + """Initialize metadata for CUDA graph capture.""" + if forward_mode.is_decode_or_idle(): + if spec_info is None: + max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item()) + + block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad] + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + None, + block_kv_indices, + self.req_to_token.stride(0), + max_seqlen_pad, + self.page_size, + ) + metadata = TRTLLMMLADecodeMetadata( + self.cuda_graph_workspace, + block_kv_indices, + ) + self.decode_cuda_graph_metadata[bs] = metadata + self.forward_metadata = metadata + else: + super().init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, + ) + else: + super().init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + seq_lens_cpu: Optional[torch.Tensor], + ): + """Replay CUDA graph with new inputs.""" + if forward_mode.is_decode_or_idle(): + if spec_info is None: + # Reuse cached metadata + metadata = self.decode_cuda_graph_metadata[bs] + + # Update block indices for new sequences + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens[:bs], + None, + metadata.block_kv_indices, + self.req_to_token.stride(0), + metadata.block_kv_indices.shape[1], + self.page_size, + ) + + self.forward_metadata = metadata + else: + # Speculative decoding: use parent class implementation + super().init_forward_metadata_replay_cuda_graph( + bs, req_pool_indices, seq_lens, seq_lens_sum, + encoder_lens, forward_mode, spec_info, seq_lens_cpu + ) + else: + # Prefill: use parent class implementation + super().init_forward_metadata_replay_cuda_graph( + bs, req_pool_indices, seq_lens, seq_lens_sum, + encoder_lens, forward_mode, spec_info, seq_lens_cpu + ) + + def get_cuda_graph_seq_len_fill_value(self): + """Get the fill value for sequence lengths in CUDA graph.""" + return 1 def init_forward_metadata(self, forward_batch: ForwardBatch): """Init the metadata for a forward pass.""" bs = forward_batch.batch_size spec_info = forward_batch.spec_info - + if forward_batch.forward_mode.is_decode_or_idle(): if spec_info is None: - # Calculate max sequence length padded to page boundary - max_seqlen_pad = triton.cdiv( - forward_batch.seq_lens_cpu.max().item(), self.page_size - ) - - # Create block indices + # seq_lens_cpu may be None when cuda-graphs are disabled + if getattr(forward_batch, "seq_lens_cpu", None) is not None: + max_seq = forward_batch.seq_lens_cpu.max().item() + else: + max_seq = forward_batch.seq_lens.max().item() + + max_seqlen_pad = self._calc_padded_blocks(max_seq) + block_kv_indices = torch.full( (bs, max_seqlen_pad), -1, dtype=torch.int32, device=forward_batch.seq_lens.device, ) - - # Fill block indices using the existing triton kernel + create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, forward_batch.req_pool_indices, @@ -115,17 +239,16 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): max_seqlen_pad, self.page_size, ) - - forward_batch.decode_trtllm_mla_metadata = TRTLLMMLADecodeMetadata( + + self.forward_metadata = TRTLLMMLADecodeMetadata( self.workspace_buffer, block_kv_indices, ) - self.forward_metadata = forward_batch.decode_trtllm_mla_metadata + # Expose to the ForwardBatch so that other components can access it + forward_batch.decode_trtllm_mla_metadata = self.forward_metadata else: - # Speculative decoding: use parent class implementation super().init_forward_metadata(forward_batch) else: - # Prefill: use parent class implementation super().init_forward_metadata(forward_batch) def forward_decode( @@ -136,66 +259,94 @@ def forward_decode( layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, - # For multi-head latent attention q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, ): - """Run forward for decode using TRTLLM kernel.""" + """Run forward for decode using TRTLLM MLA kernel.""" cache_loc = forward_batch.out_cache_loc + # Save KV cache if requested if k is not None and save_kv_cache: if k_rope is not None: - # MLA style KV cache storage forward_batch.token_to_kv_pool.set_mla_kv_buffer( - layer, - cache_loc, - k, - k_rope, + layer, cache_loc, k, k_rope ) else: - # Standard KV cache storage path. Skip if value tensor is absent (e.g., MLA decode tests). if v is not None: forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - cache_loc, - k, - v, + layer, cache_loc, k, v ) - # Prepare query tensor - concatenate q_nope and q_rope + # Build query tensor if q_rope is not None: - # q and q_rope are separate - q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) - q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.qk_rope_head_dim) + q_nope = q.view(-1, layer.tp_q_head_num, self.qk_nope_head_dim) + q_rope = q_rope.view(-1, layer.tp_q_head_num, self.qk_rope_head_dim) query = torch.cat([q_nope, q_rope], dim=-1) else: - # q already contains both nope and rope parts query = q.view(-1, layer.tp_q_head_num, layer.head_dim) - # Get KV cache + # Scale factor for TRT-LLM MLA kernel. + # The kernel computes softmax with scale: 1 / (sqrt(head_dim_qk) * scale) + # where head_dim_qk = 576 (kv_lora_rank + qk_rope_head_dim). + # To get the same result as FlashInfer (which uses layer.scaling = 1/sqrt(192)), + # we need: 1 / (sqrt(576) * scale) = 1 / sqrt(192) + # Therefore: scale = sqrt(576) / sqrt(192) = sqrt(3) + scale = math.sqrt(self.kv_lora_rank + self.qk_rope_head_dim) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) + + # KV cache tensor: reshape to (num_pages, page_size, dim) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + # Build KV cache slices expected by TRT-LLM: slice 0 → CKV+K, slice 1 → KPE. + pages = k_cache.view(-1, self.page_size, self.kv_cache_dim) # (P, blk, 576) + + # According to the FlashInfer test, both slices should contain the full 576-dim tensor. + # Use torch.stack so each slice is an independent contiguous view. + # NOTE: this duplicates the storage but matches the reference behaviour. + kv_cache = torch.stack([pages, pages], dim=1) # (P, 2, blk, 576) - # Reshape KV cache to 4-D (num_kv_heads, num_blocks, page_size, kv_dim) - kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(0) # 1 KV head + # Metadata (prefer attribute on forward_batch for compatibility) + metadata = getattr(forward_batch, "decode_trtllm_mla_metadata", None) + if metadata is None: + metadata = self.forward_metadata + + # ---------- Debug output (enable with env var) ---------- + if os.getenv("SGLANG_DEBUG_TRTLLM_MLA", "0") == "1": + print( + f"[TRTLLM-MLA] Debug shapes before kernel call:\n" + f" query: {query.shape} dtype={query.dtype}\n" + f" kv_cache: {kv_cache.shape} dtype={kv_cache.dtype}\n" + f" block_tables: {metadata.block_kv_indices.shape} dtype={metadata.block_kv_indices.dtype}\n" + f" seq_lens: {forward_batch.seq_lens.shape} dtype={forward_batch.seq_lens.dtype}\n" + f" page_size: {self.page_size}\n" + f" max_seq_len: {metadata.block_kv_indices.shape[1] * self.page_size}\n" + f" scale: {scale}\n" + f" qk_nope_head_dim: {self.qk_nope_head_dim}\n" + f" kv_lora_rank: {self.kv_lora_rank}\n" + f" qk_rope_head_dim: {self.qk_rope_head_dim}" + ) + # -------------------------------------------------------- - # Call TRTLLM MLA decode kernel raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, kv_cache=kv_cache, - workspace_buffer=forward_batch.decode_trtllm_mla_metadata.workspace, + workspace_buffer=metadata.workspace, qk_nope_head_dim=self.qk_nope_head_dim, kv_lora_rank=self.kv_lora_rank, qk_rope_head_dim=self.qk_rope_head_dim, - block_tables=forward_batch.decode_trtllm_mla_metadata.block_kv_indices, + block_tables=metadata.block_kv_indices, seq_lens=forward_batch.seq_lens.to(torch.int32), block_size=self.page_size, - max_seq_len=forward_batch.seq_lens.max().item(), - scale=layer.scaling, - out=None, - bmm1_scale=1.0, # Only needed for FP8 - bmm2_scale=1.0, # Only needed for FP8 + # Avoid .item() (host sync) during CUDA graph capture. + # max_seq_len equals padded_blocks * page_size. + max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size), + scale=scale, + bmm1_scale=1.0, + bmm2_scale=1.0, ) - output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) + # TRTLLM kernel may return both V and ROPE dims (kv_lora_rank + qk_rope_head_dim). + # We only need the value projection part (v_head_dim). + raw_out_v = raw_out[..., : layer.v_head_dim].contiguous() + + output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim) if output.shape[0] > forward_batch.batch_size: output = output[: forward_batch.batch_size] return output \ No newline at end of file diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py index 7f4f91d0e588..8d1b2e39c500 100644 --- a/python/sglang/test/attention/test_trtllm_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -54,7 +54,7 @@ def __init__(self, page_size: int): "req_to_token": torch.zeros(max_bs, max_ctx, dtype=torch.int32, device=self.device), }, ) - + # KV-token pool self.token_to_kv_pool = MLATokenToKVPool( size=max_bs * max_ctx, @@ -75,10 +75,10 @@ class TestTRTLLMMLABackend(CustomTestCase): def setUp(self): self.batch_size = 16 self.seq_len = 512 - self.page_sizes = [16, 32, 64] + self.page_sizes = [32, 64] self.device = "cuda" self.dtype = torch.bfloat16 - + # ‑- helpers --------------------------------------------------------- def _init(self, page_size: int): self.model_runner = MockModelRunner(page_size) @@ -122,7 +122,7 @@ def test_forward_decode(self): forward_batch = self._create_forward_batch(seq_lens) self.backend.init_forward_metadata(forward_batch) - + q, k, v = self._alloc_qkv() layer = RadixAttention( num_heads=128, @@ -134,7 +134,7 @@ def test_forward_decode(self): prefix="attn_mqa", ) out = self.backend.forward_decode(q, k, v, layer, forward_batch) - + self.assertEqual(out.shape, (self.batch_size, 128 * 512)) self.assertEqual(out.dtype, self.dtype) self.assertEqual(out.device.type, "cuda") diff --git a/python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py b/python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py new file mode 100644 index 000000000000..db0affb7c048 --- /dev/null +++ b/python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py @@ -0,0 +1,218 @@ +import unittest +import torch +import math + +from sglang.srt.layers import dp_attention as _dp_attn + +# Patch DP-attention globals before importing backends +_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test + +from sglang.srt.configs.model_config import AttentionArch +from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend +from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.test.test_utils import CustomTestCase +from sglang.srt.utils import is_flashinfer_available + + +class MockModelRunner: + """Minimal fake ModelRunner for comparing both MLA backends.""" + + def __init__(self, page_size: int): + self.device = "cuda" + self.dtype = torch.bfloat16 + self.kv_cache_dtype = torch.bfloat16 + self.page_size = page_size + + # Model-config stub with MLA attributes + self.model_config = type( + "ModelConfig", + (), + { + "context_len": 2048, + "attention_arch": AttentionArch.MLA, + "num_attention_heads": 128, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 512, + "scaling": 1.0 / ((128 + 64) ** 0.5), + "get_num_kv_heads": staticmethod(lambda _: 1), + }, + ) + + # Req-to-token pool + max_bs = 64 + max_ctx = self.model_config.context_len + self.req_to_token_pool = type( + "TokenPool", + (), + { + "size": max_bs, + "req_to_token": torch.zeros(max_bs, max_ctx, dtype=torch.int32, device=self.device), + }, + ) + + # KV-token pool (MLA) + self.token_to_kv_pool = MLATokenToKVPool( + size=max_bs * max_ctx, + page_size=page_size, + dtype=self.kv_cache_dtype, + kv_lora_rank=512, + qk_rope_head_dim=64, + layer_num=1, + device=self.device, + enable_memory_saver=False, + ) + + +@unittest.skipIf(not torch.cuda.is_available() or not is_flashinfer_available(), "CUDA + flashinfer required") +class TestTRTLLMvsFlashInferMLA(CustomTestCase): + """Test numerical equivalence between TRTLLM and FlashInfer MLA backends.""" + + def setUp(self): + self.batch_size = 8 + self.seq_len = 256 + self.page_size = 32 + self.device = "cuda" + self.dtype = torch.bfloat16 + + # Create model runner + self.model_runner = MockModelRunner(self.page_size) + + # Initialize both backends + self.trtllm_backend = TRTLLMMLABackend(self.model_runner) + self.flashinfer_backend = FlashInferMLAAttnBackend(self.model_runner) + + # Create RadixAttention layer for testing + self.layer = RadixAttention( + num_heads=128, + head_dim=512 + 64, # kv_lora_rank + qk_rope_head_dim + scaling=self.model_runner.model_config.scaling, + num_kv_heads=1, + layer_id=0, + v_head_dim=512, + prefix="attn_mqa", + ) + + def _create_qkv_tensors(self): + """Create Q, K, V tensors for testing.""" + head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim + q = torch.randn((self.batch_size, 128, head_dim), dtype=self.dtype, device=self.device) + k = torch.randn((self.batch_size, 1, head_dim), dtype=self.dtype, device=self.device) + # For FlashInfer MLA, if k is provided v must not be None. + v = torch.randn((self.batch_size, 1, 512), dtype=self.dtype, device=self.device) + return q, k, v + + def _create_forward_batch(self, backend): + """Create a forward batch for the given backend.""" + # Random sequence lengths + seq_lens = torch.randint(1, self.seq_len, (self.batch_size,), device=self.device) + seq_lens[0] = self.seq_len # Ensure at least one max length + + fb = ForwardBatch( + batch_size=self.batch_size, + input_ids=torch.randint(0, 100, (self.batch_size, 1), device=self.device), + out_cache_loc=torch.arange(self.batch_size, device=self.device), + seq_lens_sum=int(seq_lens.sum().item()), + forward_mode=ForwardMode.DECODE, + req_pool_indices=torch.arange(self.batch_size, device=self.device), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.cpu(), + attn_backend=backend, + ) + fb.req_to_token_pool = self.model_runner.req_to_token_pool + fb.token_to_kv_pool = self.model_runner.token_to_kv_pool + return fb + + def test_decode_output_match(self): + """Test that TRTLLM and FlashInfer MLA backends produce matching outputs.""" + # Create identical forward batches for both backends + fb_trtllm = self._create_forward_batch(self.trtllm_backend) + fb_flashinfer = self._create_forward_batch(self.flashinfer_backend) + + # Initialize metadata for both backends + self.trtllm_backend.init_forward_metadata(fb_trtllm) + self.flashinfer_backend.init_forward_metadata(fb_flashinfer) + + # Create Q, K, V tensors + q, k, v = self._create_qkv_tensors() + + # Run forward decode on both backends + out_trtllm = self.trtllm_backend.forward_decode(q.clone(), k.clone(), v, self.layer, fb_trtllm) + out_flashinfer = self.flashinfer_backend.forward_decode(q.clone(), k.clone(), v.clone(), self.layer, fb_flashinfer) + + # Debug: print scale info + print(f"\n[DEBUG] Scale analysis:") + print(f" layer.scaling = {self.layer.scaling}") + print(f" qk_nope_head_dim = {self.model_runner.model_config.qk_nope_head_dim}") + print(f" qk_rope_head_dim = {self.model_runner.model_config.qk_rope_head_dim}") + print(f" kv_lora_rank = {self.model_runner.model_config.kv_lora_rank}") + print(f" Expected TRT scale factor = {math.sqrt(128 + 64) / math.sqrt(512 + 64)} = {math.sqrt(192) / math.sqrt(576)}") + print(f" Output shapes: TRTLLM {out_trtllm.shape}, FlashInfer {out_flashinfer.shape}") + print(f" Output means: TRTLLM {out_trtllm.mean().item():.6f}, FlashInfer {out_flashinfer.mean().item():.6f}") + print(f" Output stds: TRTLLM {out_trtllm.std().item():.6f}, FlashInfer {out_flashinfer.std().item():.6f}") + print(f" Max diff = {(out_trtllm - out_flashinfer).abs().max().item()}") + print(f" Ratio of means = {out_trtllm.mean().item() / out_flashinfer.mean().item() if out_flashinfer.mean().item() != 0 else 'inf'}") + + # Additional debug + print(f"\n[DEBUG] Scale computation:") + print(f" layer.scaling = 1/sqrt(192) = {1/math.sqrt(192)}") + print(f" TRT scale passed = layer.scaling * sqrt(192)/sqrt(576) = {self.layer.scaling * math.sqrt(192) / math.sqrt(576)}") + print(f" TRT kernel will compute: 1 / (sqrt(576) * scale) = {1 / (math.sqrt(576) * self.layer.scaling * math.sqrt(192) / math.sqrt(576))}") + print(f" Which equals: 1 / (layer.scaling * sqrt(192)) = {1 / (self.layer.scaling * math.sqrt(192))}") + print(f" But FlashInfer uses: layer.scaling = {self.layer.scaling}") + print(f" Ratio: {(1 / (self.layer.scaling * math.sqrt(192))) / self.layer.scaling} = sqrt(192) = {math.sqrt(192)}") + + # Check output shapes match + self.assertEqual(out_trtllm.shape, out_flashinfer.shape, + f"Output shapes differ: TRTLLM {out_trtllm.shape} vs FlashInfer {out_flashinfer.shape}") + + # Check output dtypes match + self.assertEqual(out_trtllm.dtype, out_flashinfer.dtype) + + # Check numerical equivalence with tolerance + # Note: Using higher tolerance due to potential numerical differences in implementations + self.assertTrue( + torch.allclose(out_trtllm, out_flashinfer, atol=1e-2, rtol=1e-2), + f"TRTLLM and FlashInfer outputs differ beyond tolerance. " + f"Max diff: {(out_trtllm - out_flashinfer).abs().max().item()}" + ) + + # Additional checks + self.assertFalse(torch.isnan(out_trtllm).any(), "TRTLLM output contains NaN") + self.assertFalse(torch.isnan(out_flashinfer).any(), "FlashInfer output contains NaN") + self.assertFalse(torch.isinf(out_trtllm).any(), "TRTLLM output contains Inf") + self.assertFalse(torch.isinf(out_flashinfer).any(), "FlashInfer output contains Inf") + + def test_decode_with_different_page_sizes(self): + """Test output consistency across different page sizes.""" + page_sizes = [32, 64] + outputs = [] + + for ps in page_sizes: + # Reinitialize with new page size + self.model_runner = MockModelRunner(ps) + self.trtllm_backend = TRTLLMMLABackend(self.model_runner) + + # Create batch and run decode + fb = self._create_forward_batch(self.trtllm_backend) + self.trtllm_backend.init_forward_metadata(fb) + + q, k, v = self._create_qkv_tensors() + out = self.trtllm_backend.forward_decode(q, k, v, self.layer, fb) + outputs.append(out) + + # Check that outputs are consistent across page sizes + # Note: Different page sizes might lead to slightly different numerical results + for i in range(1, len(outputs)): + self.assertTrue( + torch.allclose(outputs[0], outputs[i], atol=5e-2, rtol=5e-2), + f"Output with page_size={page_sizes[0]} differs from page_size={page_sizes[i]}" + ) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From cdd315e6a37ab7cc6a2e5823775168209a1b5e52 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Mon, 14 Jul 2025 18:06:43 -0700 Subject: [PATCH 04/27] trtllm mla kernel working Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- ...a_backend.py => trtllm_gen_mla_backend.py} | 204 ++++++--- .../sglang/srt/model_executor/model_runner.py | 6 +- .../attention/test_trtllm_gen_mla_backend.py | 386 ++++++++++++++++++ .../test/attention/test_trtllm_mla_backend.py | 146 ------- .../test_trtllm_vs_flashinfer_mla.py | 218 ---------- 5 files changed, 547 insertions(+), 413 deletions(-) rename python/sglang/srt/layers/attention/{trtllm_mla_backend.py => trtllm_gen_mla_backend.py} (56%) create mode 100644 python/sglang/test/attention/test_trtllm_gen_mla_backend.py delete mode 100644 python/sglang/test/attention/test_trtllm_mla_backend.py delete mode 100644 python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py similarity index 56% rename from python/sglang/srt/layers/attention/trtllm_mla_backend.py rename to python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py index 14d6bad3e271..cdb817984894 100644 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py @@ -1,7 +1,7 @@ from __future__ import annotations """ -Support attention backend for TRTLLM MLA kernels from flashinfer. +Support attention backend for TRTLLM-Gen MLA kernels from flashinfer. """ from dataclasses import dataclass @@ -28,12 +28,12 @@ @dataclass -class TRTLLMMLADecodeMetadata: +class TRTLLMGENMLADecodeMetadata: workspace: Optional[torch.Tensor] = None block_kv_indices: Optional[torch.Tensor] = None -class TRTLLMMLABackend(FlashInferMLAAttnBackend): +class TRTLLMGENMLABackend(FlashInferMLAAttnBackend): """TRTLLM MLA attention kernels from flashinfer.""" def __init__( @@ -58,7 +58,7 @@ def __init__( self.num_local_heads = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) - self.forward_metadata: Union[TRTLLMMLADecodeMetadata] = None + self.forward_metadata: Union[TRTLLMGENMLADecodeMetadata] = None self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim @@ -82,7 +82,13 @@ def __init__( def _calc_padded_blocks(self, max_seq_len: int) -> int: """Return number of blocks padded so that it satisfies TRTLLM constraint.""" blocks = triton.cdiv(max_seq_len, self.page_size) - min_blocks = 128 // self.page_size # kernel requirement + + # The Triton helper that builds `block_kv_indices` emits `NUM_PAGE_PER_BLOCK` + # (= 64) indices at a time, independent of `page_size`. To avoid it writing + # past the end of the row we **must** make every row at least 64 long. + # (Side-effect: max_seq_len is effectively rounded up to 64 × page_size = 2 K + # tokens for page_size 32, which is still below the 2048-ctx used here.) + min_blocks = 64 if blocks % min_blocks != 0: blocks = triton.cdiv(blocks, min_blocks) * min_blocks return blocks @@ -133,7 +139,7 @@ def init_forward_metadata_capture_cuda_graph( max_seqlen_pad, self.page_size, ) - metadata = TRTLLMMLADecodeMetadata( + metadata = TRTLLMGENMLADecodeMetadata( self.cuda_graph_workspace, block_kv_indices, ) @@ -172,6 +178,9 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_cpu: Optional[torch.Tensor], ): """Replay CUDA graph with new inputs.""" + if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": + print(f"[TRTLLM-MLA] init_forward_metadata_replay_cuda_graph: bs={bs}, forward_mode={forward_mode}, spec_info={spec_info is not None}") + if forward_mode.is_decode_or_idle(): if spec_info is None: # Reuse cached metadata @@ -188,8 +197,11 @@ def init_forward_metadata_replay_cuda_graph( metadata.block_kv_indices.shape[1], self.page_size, ) - + # Pad invalid blocks to keep TRTLLM happy + # self._pad_invalid_blocks(metadata.block_kv_indices) + self.forward_metadata = metadata + else: # Speculative decoding: use parent class implementation super().init_forward_metadata_replay_cuda_graph( @@ -212,6 +224,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): bs = forward_batch.batch_size spec_info = forward_batch.spec_info + if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": + print(f"[TRTLLM-MLA] init_forward_metadata: bs={bs}, forward_mode={forward_batch.forward_mode}, spec_info={spec_info is not None}") if forward_batch.forward_mode.is_decode_or_idle(): if spec_info is None: # seq_lens_cpu may be None when cuda-graphs are disabled @@ -240,7 +254,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.page_size, ) - self.forward_metadata = TRTLLMMLADecodeMetadata( + # Ensure padded blocks are valid to avoid TRTLLM skipping shorter seqs + # self._pad_invalid_blocks(block_kv_indices) + + self.forward_metadata = TRTLLMGENMLADecodeMetadata( self.workspace_buffer, block_kv_indices, ) @@ -279,52 +296,97 @@ def forward_decode( # Build query tensor if q_rope is not None: - q_nope = q.view(-1, layer.tp_q_head_num, self.qk_nope_head_dim) - q_rope = q_rope.view(-1, layer.tp_q_head_num, self.qk_rope_head_dim) - query = torch.cat([q_nope, q_rope], dim=-1) + # q contains the NOPE part (v_head_dim) + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) else: - query = q.view(-1, layer.tp_q_head_num, layer.head_dim) + reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = reshaped_q[:, :, : layer.v_head_dim] + q_rope = reshaped_q[:, :, layer.v_head_dim :] - # Scale factor for TRT-LLM MLA kernel. - # The kernel computes softmax with scale: 1 / (sqrt(head_dim_qk) * scale) - # where head_dim_qk = 576 (kv_lora_rank + qk_rope_head_dim). - # To get the same result as FlashInfer (which uses layer.scaling = 1/sqrt(192)), - # we need: 1 / (sqrt(576) * scale) = 1 / sqrt(192) - # Therefore: scale = sqrt(576) / sqrt(192) = sqrt(3) - scale = math.sqrt(self.kv_lora_rank + self.qk_rope_head_dim) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) + # Concatenate to build the full query as expected by TRTLLM kernel + query = torch.cat([q_nope, q_rope], dim=-1) + + + if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": + print(f"[TRTLLM-MLA]: model_runner.model_config.scaling.scaling is {self.model_runner.model_config.scaling}") + + # Get the model scaling factor + # TRTLLM kernel applies the 1/sqrt(192) factor internally, so we + # should pass `1.0` here (see Flash-Infer equivalence tests). + sm_scale = 1 + # (scale * ((512 + 64) ** 0.5)) / ((128 + 64) ** 0.5) + # ( sqrt(3)) / sqrt(192) # KV cache tensor: reshape to (num_pages, page_size, dim) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) # Build KV cache slices expected by TRT-LLM: slice 0 → CKV+K, slice 1 → KPE. pages = k_cache.view(-1, self.page_size, self.kv_cache_dim) # (P, blk, 576) - # According to the FlashInfer test, both slices should contain the full 576-dim tensor. - # Use torch.stack so each slice is an independent contiguous view. - # NOTE: this duplicates the storage but matches the reference behaviour. - kv_cache = torch.stack([pages, pages], dim=1) # (P, 2, blk, 576) - - # Metadata (prefer attribute on forward_batch for compatibility) + # Retrieve metadata so we can check block tables. metadata = getattr(forward_batch, "decode_trtllm_mla_metadata", None) if metadata is None: metadata = self.forward_metadata - # ---------- Debug output (enable with env var) ---------- - if os.getenv("SGLANG_DEBUG_TRTLLM_MLA", "0") == "1": + # --------------------------------------------------------------------- + # According to the FlashInfer test, both slices should contain the full 576-dim tensor. + # Use torch.stack so each slice is an independent contiguous view **after** we have + # patched the pages in-place. + kv_cache = torch.stack([pages, pages], dim=1) # (P, 2, blk, 576) + + # Metadata already obtained above; no change needed. + + # ------------- DEBUG KV CACHE CONSTRUCTION ------------- + if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": + print( + f"[TRTLLM-MLA] k_cache_raw: {k_cache.shape} {k_cache.dtype}\n" + f"[TRTLLM-MLA] pages : {pages.shape} {pages.dtype}\n" + f"[TRTLLM-MLA] kv_cache : {kv_cache.shape} {kv_cache.dtype}\n" + f"[TRTLLM-MLA] block_kv_indices: {metadata.block_kv_indices.shape} {metadata.block_kv_indices.dtype}\n" + f"[TRTLLM-MLA] workspace: {metadata.workspace.shape if metadata.workspace is not None else 'None'}\n" + f"[TRTLLM-MLA] max_seq_len: {metadata.block_kv_indices.shape[1] * self.page_size}\n" + f"[TRTLLM-MLA] k_cache[0,0,:3]: {k_cache[0,0,:3] if k_cache.numel() > 0 else 'empty'}\n" + f"[TRTLLM-MLA] pages[0,0,:3]: {pages[0,0,:3] if pages.numel() > 0 else 'empty'}\n" + f"[TRTLLM-MLA] kv_cache[0,0,0,:3]: {kv_cache[0,0,0,:3] if kv_cache.numel() > 0 else 'empty'}\n" + f"[TRTLLM-MLA] kv_cache[0,1,0,:3]: {kv_cache[0,1,0,:3] if kv_cache.numel() > 0 else 'empty'}\n" + f"[TRTLLM-MLA] block_kv_indices[0,:5]: {metadata.block_kv_indices[0,:5] if metadata.block_kv_indices.numel() > 0 else 'empty'}" + ) + + # ------------- DEBUG (align with FlashInfer-MLA) ------------- + if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": + q_nope_str = f"{q_nope.shape} {q_rope.dtype}" if q_nope is not None else "None" + q_rope_str = f"{q_rope.shape} {q_rope.dtype}" if q_rope is not None else "None" + q_nope_sample = f"{q_nope[0,:3,:3] if q_rope is not None and q_nope.numel() > 0 else 'empty'}" + q_rope_sample = f"{q_rope[0,:3,:3] if q_rope is not None and q_rope.numel() > 0 else 'empty'}" + print( - f"[TRTLLM-MLA] Debug shapes before kernel call:\n" - f" query: {query.shape} dtype={query.dtype}\n" - f" kv_cache: {kv_cache.shape} dtype={kv_cache.dtype}\n" - f" block_tables: {metadata.block_kv_indices.shape} dtype={metadata.block_kv_indices.dtype}\n" - f" seq_lens: {forward_batch.seq_lens.shape} dtype={forward_batch.seq_lens.dtype}\n" - f" page_size: {self.page_size}\n" - f" max_seq_len: {metadata.block_kv_indices.shape[1] * self.page_size}\n" - f" scale: {scale}\n" - f" qk_nope_head_dim: {self.qk_nope_head_dim}\n" - f" kv_lora_rank: {self.kv_lora_rank}\n" - f" qk_rope_head_dim: {self.qk_rope_head_dim}" + f"[TRTLLM-MLA] q_nope : {q_nope_str}\n" + f"[TRTLLM-MLA] q_rope : {q_rope_str}\n" + f"[TRTLLM-MLA] query : {query.shape} {query.dtype}\n" + f"[TRTLLM-MLA] kv_cache: {kv_cache.shape} {kv_cache.dtype}\n" + f"[TRTLLM-MLA] scale : {sm_scale}\n" + f"[TRTLLM-MLA] cache_loc: {cache_loc.shape} {cache_loc.dtype}\n" + f"[TRTLLM-MLA] seq_lens: {forward_batch.seq_lens.shape} {forward_batch.seq_lens.dtype}\n" + f"[TRTLLM-MLA] batch_size: {forward_batch.batch_size}\n" + f"[TRTLLM-MLA] page_size: {self.page_size}\n" + f"[TRTLLM-MLA] qk_nope_head_dim: {self.qk_nope_head_dim}\n" + f"[TRTLLM-MLA] qk_rope_head_dim: {self.qk_rope_head_dim}\n" + f"[TRTLLM-MLA] kv_lora_rank: {self.kv_lora_rank}\n" + f"[TRTLLM-MLA] v_head_dim: {layer.v_head_dim}\n" + f"[TRTLLM-MLA] kv_cache_dim: {self.kv_cache_dim}\n" + f"[TRTLLM-MLA] tp_q_head_num: {layer.tp_q_head_num}\n" + f"[TRTLLM-MLA] tp_k_head_num: {layer.tp_k_head_num}\n" + f"[TRTLLM-MLA] layer_id: {layer.layer_id}\n" + f"[TRTLLM-MLA] q_nope[0,:3,:3]: {q_nope_sample}\n" + f"[TRTLLM-MLA] q_rope[0,:3,:3]: {q_rope_sample}\n" + f"[TRTLLM-MLA] query[0,:3,:3]: {query[0,:3,:3] if query.numel() > 0 else 'empty'}\n" + f"[TRTLLM-MLA] seq_lens values: {forward_batch.seq_lens[:min(5, len(forward_batch.seq_lens))]}\n" + f"[TRTLLM-MLA] cache_loc values: {cache_loc[:min(5, len(cache_loc))]}" ) - # -------------------------------------------------------- + raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, kv_cache=kv_cache, @@ -335,18 +397,68 @@ def forward_decode( block_tables=metadata.block_kv_indices, seq_lens=forward_batch.seq_lens.to(torch.int32), block_size=self.page_size, - # Avoid .item() (host sync) during CUDA graph capture. - # max_seq_len equals padded_blocks * page_size. - max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size), - scale=scale, - bmm1_scale=1.0, - bmm2_scale=1.0, + max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size), # max_seq_len equals padded_blocks * page_size. + q_scale=1.0, + k_scale=1.0, + v_scale=1.0, + sm_scale=sm_scale, + o_scale=1.0, ) # TRTLLM kernel may return both V and ROPE dims (kv_lora_rank + qk_rope_head_dim). # We only need the value projection part (v_head_dim). raw_out_v = raw_out[..., : layer.v_head_dim].contiguous() + # ------------- DEBUG KERNEL OUTPUT ------------- + if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": + print( + f"[TRTLLM-MLA] raw_out : {raw_out.shape} {raw_out.dtype}\n" + f"[TRTLLM-MLA] raw_out_v : {raw_out_v.shape} {raw_out_v.dtype}\n" + f"[TRTLLM-MLA] raw_out[0,:3,:3]: {raw_out[0,:3,:3] if raw_out.numel() > 0 else 'empty'}\n" + f"[TRTLLM-MLA] raw_out_v[0,:3,:3]: {raw_out_v[0,:3,:3] if raw_out_v.numel() > 0 else 'empty'}\n" + f"[TRTLLM-MLA] raw_out stats: min={raw_out.min():.6f}, max={raw_out.max():.6f}, mean={raw_out.mean():.6f}\n" + f"[TRTLLM-MLA] raw_out_v stats: min={raw_out_v.min():.6f}, max={raw_out_v.max():.6f}, mean={raw_out_v.mean():.6f}" + ) + output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim) if output.shape[0] > forward_batch.batch_size: output = output[: forward_batch.batch_size] - return output \ No newline at end of file + + # ------------- DEBUG FINAL OUTPUT ------------- + if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": + print( + f"[TRTLLM-MLA] output : {output.shape} {output.dtype}\n" + f"[TRTLLM-MLA] output[0,:10]: {output[0,:10] if output.numel() > 0 else 'empty'}\n" + f"[TRTLLM-MLA] output[0,10:20]: {output[0,10:20] if output.numel() > 10 else 'empty'}\n" + f"[TRTLLM-MLA] output[0,20:30]: {output[0,20:30] if output.numel() > 20 else 'empty'}\n" + f"[TRTLLM-MLA] output stats: min={output.min():.6f}, max={output.max():.6f}, mean={output.mean():.6f}\n" + f"[TRTLLM-MLA] output std: {output.std():.6f}\n" + f"[TRTLLM-MLA] output_reshaped: {output.shape[0]} -> {forward_batch.batch_size}\n" + f"[TRTLLM-MLA] ===== END TRTLLM-MLA DEBUG =====" + ) + + return output + + @staticmethod + def _pad_invalid_blocks(block_kv_indices: torch.Tensor): + """Replace -1 paddings with the last valid page id for each row. + + TRT-LLM treats a leading -1 as "empty sequence" and will skip the + whole sequence. For shorter sequences we therefore replicate the + last real page id into the padded region instead of leaving -1. + """ + for row in range(block_kv_indices.size(0)): + row_view = block_kv_indices[row] + # Find last non-negative entry (every sequence has at least one) + valid_mask = row_view >= 0 + if not valid_mask.any(): + # Defensive – shouldn’t happen in decode mode + continue + last_valid = row_view[valid_mask][-1] + row_view[~valid_mask] = last_valid + + # Extra visibility when debugging + if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": + preview_rows = min(6, block_kv_indices.size(0)) + print("[TRTLLM-MLA] block_kv_indices preview (after padding):") + for i in range(preview_rows): + print(f" seq {i}:", block_kv_indices[i].tolist()) \ No newline at end of file diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0f53e52285e5..c14c3fbef48f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1443,11 +1443,11 @@ def _get_attention_backend_from_str(self, backend_str: str): return CutlassMLABackend(self) elif self.server_args.attention_backend == "trtllm_mla": - from sglang.srt.layers.attention.trtllm_mla_backend import ( - TRTLLMMLABackend, + from sglang.srt.layers.attention.trtllm_gen_mla_backend import ( + TRTLLMGENMLABackend, ) - return TRTLLMMLABackend(self) + return TRTLLMGENMLABackend(self) elif self.server_args.attention_backend == "intel_amx": from sglang.srt.layers.attention.intel_amx_backend import ( IntelAMXAttnBackend, diff --git a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py new file mode 100644 index 000000000000..acf752fd23ce --- /dev/null +++ b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py @@ -0,0 +1,386 @@ +""" +Clean test suite for TRTLLM MLA backend. + +This test file provides comprehensive testing for the TRTLLM MLA (Multi-Head Latent Attention) backend: + +1. test_basic_functionality: Basic smoke test with minimal setup +2. test_decode_output_match: Compares TRTLLM MLA output against FlashInfer MLA reference + across different batch sizes and sequence lengths +3. test_different_page_sizes: Tests consistency across different page sizes +4. test_forward_decode_shape_sanity: Shape and sanity checks across various configurations + +The tests use unittest with subTest for parameterized testing, following the sglang test patterns. +""" +import unittest +import torch +import numpy as np +from types import SimpleNamespace + +from sglang.srt.layers import dp_attention as _dp_attn + +# Patch DP-attention globals before importing backends +_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test + +from sglang.srt.configs.model_config import AttentionArch +from sglang.srt.layers.attention.trtllm_gen_mla_backend import TRTLLMGENMLABackend +from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.test.test_utils import CustomTestCase +from sglang.srt.utils import is_flashinfer_available + + +class MockModelRunner: + """Minimal fake ModelRunner for testing MLA backends.""" + + def __init__(self, page_size: int): + self.device = "cuda" + self.dtype = torch.bfloat16 + self.kv_cache_dtype = torch.bfloat16 + self.page_size = page_size + + # Model-config stub with MLA attributes + self.model_config = type( + "ModelConfig", + (), + { + "context_len": 2048, + "attention_arch": AttentionArch.MLA, + "num_attention_heads": 128, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 512, + "scaling": 1.0 / ((128 + 64) ** 0.5), + "get_num_kv_heads": staticmethod(lambda _: 1), + }, + ) + + # Req-to-token pool + max_bs = 64 + max_ctx = self.model_config.context_len + req_to_token = torch.arange(max_bs * max_ctx, dtype=torch.int32, device=self.device).reshape(max_bs, max_ctx) + self.req_to_token_pool = type( + "TokenPool", + (), + { + "size": max_bs, + "req_to_token": req_to_token, + }, + ) + + # KV-token pool (MLA) + self.token_to_kv_pool = MLATokenToKVPool( + size=max_bs * max_ctx, + page_size=page_size, + dtype=self.kv_cache_dtype, + kv_lora_rank=512, + qk_rope_head_dim=64, + layer_num=1, + device=self.device, + enable_memory_saver=False, + ) + + +def compare_outputs(trtllm_out, reference_out, tolerance=1e-2): + """Compare outputs with detailed analysis.""" + + # Basic checks + assert trtllm_out.shape == reference_out.shape, f"Shape mismatch: {trtllm_out.shape} vs {reference_out.shape}" + assert trtllm_out.dtype == reference_out.dtype, f"Dtype mismatch: {trtllm_out.dtype} vs {reference_out.dtype}" + + # Check for NaN/Inf + assert not torch.isnan(trtllm_out).any(), "TRTLLM output contains NaN" + assert not torch.isnan(reference_out).any(), "Reference output contains NaN" + assert not torch.isinf(trtllm_out).any(), "TRTLLM output contains Inf" + assert not torch.isinf(reference_out).any(), "Reference output contains Inf" + + # Element-wise differences + diff = (trtllm_out - reference_out).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + # Check numerical equivalence + all_close = torch.allclose(trtllm_out, reference_out, rtol=tolerance, atol=tolerance) + + if not all_close: + print(f"Comparison failed: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, tolerance={tolerance}") + # Find top differences for debugging + flat_diff = diff.flatten() + top_diff_indices = torch.topk(flat_diff, k=min(5, flat_diff.numel())).indices + print("Top 5 differences:") + for i, idx in enumerate(top_diff_indices): + idx_tuple = np.unravel_index(idx.cpu().numpy(), trtllm_out.shape) + trt_val = trtllm_out[idx_tuple].item() + ref_val = reference_out[idx_tuple].item() + print(f" [{idx_tuple}]: TRTLLM={trt_val:.6f}, Reference={ref_val:.6f}, diff={abs(trt_val-ref_val):.6f}") + + return all_close + + +@unittest.skipIf(not torch.cuda.is_available() or not is_flashinfer_available(), + "CUDA + flashinfer required") +class TestTRTLLMMLAClean(CustomTestCase): + """Test TRTLLM MLA backend against FlashInfer MLA backend (reference).""" + + def setUp(self): + """Setup test fixtures.""" + self.device = "cuda" + self.dtype = torch.bfloat16 + self.page_size = 32 + + # Create model runner and backends + self.model_runner_trtllm = MockModelRunner(self.page_size) + self.model_runner_reference = MockModelRunner(self.page_size) + + self.trtllm_backend = TRTLLMGENMLABackend(self.model_runner_trtllm) + self.reference_backend = FlashInferMLAAttnBackend(self.model_runner_reference) + + # Create RadixAttention layer + self.layer = RadixAttention( + num_heads=128, + head_dim=512 + 64, # kv_lora_rank + qk_rope_head_dim + scaling=self.model_runner_trtllm.model_config.scaling, + num_kv_heads=1, + layer_id=0, + v_head_dim=512, + prefix="attn_mqa", + ) + + def _create_qkv_tensors(self, batch_size): + """Create Q, K, V tensors for testing.""" + head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim + q = torch.randn((batch_size, 128, head_dim), dtype=self.dtype, device=self.device) + k = torch.randn((batch_size, 1, head_dim), dtype=self.dtype, device=self.device) + v = torch.randn((batch_size, 1, 512), dtype=self.dtype, device=self.device) + return q, k, v + + def _create_forward_batch(self, batch_size, seq_lens, backend, model_runner): + """Create a forward batch for the given backend.""" + fb = ForwardBatch( + batch_size=batch_size, + input_ids=torch.randint(0, 100, (batch_size, 1), device=self.device), + out_cache_loc=torch.arange(batch_size, device=self.device), + seq_lens_sum=int(seq_lens.sum().item()), + forward_mode=ForwardMode.DECODE, + req_pool_indices=torch.arange(batch_size, device=self.device), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.cpu(), + attn_backend=backend, + ) + fb.req_to_token_pool = model_runner.req_to_token_pool + fb.token_to_kv_pool = model_runner.token_to_kv_pool + return fb + + def _populate_kv_cache(self, batch_size, seq_lens, model_runners): + """Populate KV cache with identical data for both backends.""" + torch.manual_seed(42) # Fixed seed for reproducible cache + + for model_runner in model_runners: + torch.manual_seed(42) # Reset seed for each backend + for i in range(batch_size): + seq_len = int(seq_lens[i].item()) + for token_idx in range(seq_len - 1): + # Create random K components for MLA + cache_k_nope = torch.randn((1, 128), dtype=self.dtype, device=self.device) + cache_k_rope = torch.randn((1, 64), dtype=self.dtype, device=self.device) + + # Calculate cache location + cache_loc = model_runner.req_to_token_pool.req_to_token[i, token_idx] + + # Save to KV cache + model_runner.token_to_kv_pool.set_mla_kv_buffer( + self.layer, + cache_loc.unsqueeze(0), + cache_k_nope.squeeze(0), + cache_k_rope.squeeze(0) + ) + + def test_decode_output_match(self): + """Test that TRTLLM and FlashInfer MLA backends produce matching outputs.""" + # Test different batch sizes and sequence lengths + test_cases = [ + (1, 64), (4, 64), (16, 64), (32, 64), + (1, 128), (4, 128), (16, 128), (32, 128), + (1, 256), (4, 256), (16, 256), (32, 256), + ] + + for batch_size, max_seq_len in test_cases: + with self.subTest(batch_size=batch_size, max_seq_len=max_seq_len): + # Create identical sequence lengths for both backends + torch.manual_seed(42) + seq_lens = torch.randint(1, max_seq_len, (batch_size,), device=self.device) + seq_lens[0] = max_seq_len # Ensure at least one max length + + # Create forward batches with identical inputs + fb_trtllm = self._create_forward_batch( + batch_size, seq_lens.clone(), self.trtllm_backend, self.model_runner_trtllm + ) + fb_reference = self._create_forward_batch( + batch_size, seq_lens.clone(), self.reference_backend, self.model_runner_reference + ) + + # Initialize metadata for both backends + self.trtllm_backend.init_forward_metadata(fb_trtllm) + self.reference_backend.init_forward_metadata(fb_reference) + + # Populate both KV caches identically + self._populate_kv_cache(batch_size, seq_lens, [self.model_runner_trtllm, self.model_runner_reference]) + + # Create Q, K, V tensors for current decode step + torch.manual_seed(123) # Fixed seed for Q, K, V + q, k, v = self._create_qkv_tensors(batch_size) + + # Run forward decode on both backends + out_trtllm = self.trtllm_backend.forward_decode(q.clone(), k.clone(), v, self.layer, fb_trtllm) + out_reference = self.reference_backend.forward_decode(q.clone(), k.clone(), v.clone(), self.layer, fb_reference) + + # Compare outputs + comparison_passed = compare_outputs(out_trtllm, out_reference, tolerance=1e-2) + + self.assertTrue(comparison_passed, + f"TRTLLM and Reference outputs differ beyond tolerance. " + f"batch_size={batch_size}, max_seq_len={max_seq_len}, " + f"Max diff: {(out_trtllm - out_reference).abs().max().item()}" + ) + + def test_different_page_sizes(self): + """Test output consistency across different page sizes.""" + page_sizes = [32, 64] + batch_size = 8 + max_seq_len = 128 + + for page_size in page_sizes: + with self.subTest(page_size=page_size): + # Create model runner with specific page size + model_runner = MockModelRunner(page_size) + backend = TRTLLMGENMLABackend(model_runner) + + # Create sequence lengths + torch.manual_seed(42) + seq_lens = torch.randint(1, max_seq_len, (batch_size,), device=self.device) + seq_lens[0] = max_seq_len + + # Create forward batch + fb = self._create_forward_batch(batch_size, seq_lens, backend, model_runner) + backend.init_forward_metadata(fb) + + # Populate KV cache + self._populate_kv_cache(batch_size, seq_lens, [model_runner]) + + # Create Q, K, V tensors + torch.manual_seed(123) + q, k, v = self._create_qkv_tensors(batch_size) + + # Run forward decode + output = backend.forward_decode(q, k, v, self.layer, fb) + + # Basic checks + expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim + self.assertEqual(output.shape, expected_shape, f"Output shape mismatch: {output.shape} vs {expected_shape}") + self.assertFalse(torch.isnan(output).any(), "Output contains NaN") + self.assertFalse(torch.isinf(output).any(), "Output contains Inf") + + def test_basic_functionality(self): + """Test basic functionality with minimal setup.""" + batch_size = 2 + max_seq_len = 32 + + # Create sequence lengths + seq_lens = torch.tensor([max_seq_len, max_seq_len // 2], device=self.device) + + # Create forward batch + fb = self._create_forward_batch(batch_size, seq_lens, self.trtllm_backend, self.model_runner_trtllm) + self.trtllm_backend.init_forward_metadata(fb) + + # Populate KV cache + self._populate_kv_cache(batch_size, seq_lens, [self.model_runner_trtllm]) + + # Create Q, K, V tensors + q, k, v = self._create_qkv_tensors(batch_size) + + # Run forward decode + output = self.trtllm_backend.forward_decode(q, k, v, self.layer, fb) + + # Basic checks + expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim + self.assertEqual(output.shape, expected_shape) + self.assertEqual(output.dtype, self.dtype) + self.assertFalse(torch.isnan(output).any()) + self.assertFalse(torch.isinf(output).any()) + + def test_forward_decode_shape_sanity(self): + """Smoke test decode across several page sizes and batch configurations.""" + # Test configurations similar to the original test + test_configs = [ + (16, 512, 32), # batch_size, seq_len, page_size + (16, 512, 64), + (8, 256, 32), + (4, 128, 32), + (1, 64, 32), + (32, 1024, 64), + ] + + for batch_size, seq_len, page_size in test_configs: + with self.subTest(batch_size=batch_size, seq_len=seq_len, page_size=page_size): + # Create model runner with specific page size + model_runner = MockModelRunner(page_size) + backend = TRTLLMGENMLABackend(model_runner) + + # Random seq lens (ensure one matches max) + torch.manual_seed(42) + seq_lens = torch.randint(1, seq_len, (batch_size,), device=self.device) + seq_lens[0] = seq_len + + # Create forward batch + fb = ForwardBatch( + batch_size=batch_size, + input_ids=torch.randint(0, 100, (batch_size, 1), device=self.device), + out_cache_loc=torch.arange(batch_size, device=self.device), + seq_lens_sum=int(seq_lens.sum().item()), + forward_mode=ForwardMode.DECODE, + req_pool_indices=torch.arange(batch_size, device=self.device), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.cpu(), + attn_backend=backend, + ) + fb.req_to_token_pool = model_runner.req_to_token_pool + fb.token_to_kv_pool = model_runner.token_to_kv_pool + + backend.init_forward_metadata(fb) + + # Create Q, K, V tensors + head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim + q = torch.randn((batch_size, 128, head_dim), dtype=self.dtype, device=self.device) + k = torch.randn((batch_size, 1, head_dim), dtype=self.dtype, device=self.device) + v = None # TRTLLM MLA decode kernel ignores v + + # Create layer + layer = RadixAttention( + num_heads=128, + head_dim=512 + 64, + scaling=model_runner.model_config.scaling, + num_kv_heads=1, + layer_id=0, + v_head_dim=512, + prefix="attn_mqa", + ) + + # Run forward decode + output = backend.forward_decode(q, k, v, layer, fb) + + # Shape and sanity checks + expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim + self.assertEqual(output.shape, expected_shape, + f"Output shape mismatch for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})") + self.assertEqual(output.dtype, self.dtype) + self.assertEqual(output.device.type, "cuda") + self.assertFalse(torch.isnan(output).any(), + f"Output contains NaN for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})") + self.assertFalse(torch.isinf(output).any(), + f"Output contains Inf for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py deleted file mode 100644 index 8d1b2e39c500..000000000000 --- a/python/sglang/test/attention/test_trtllm_mla_backend.py +++ /dev/null @@ -1,146 +0,0 @@ -import unittest -import torch - -from sglang.srt.layers import dp_attention as _dp_attn - -# Patch DP-attention globals **before** importing the backend so that all -# downstream `from … import get_attention_tp_size` statements receive the -# patched version. -_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test - -from sglang.srt.configs.model_config import AttentionArch -from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend -from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.test.test_utils import CustomTestCase -from sglang.srt.utils import is_flashinfer_available - - -class MockModelRunner: - """Minimal fake `ModelRunner` for MLA backend unit tests.""" - - def __init__(self, page_size: int): - self.device = "cuda" - self.dtype = torch.bfloat16 - self.kv_cache_dtype = torch.bfloat16 - self.page_size = page_size - - # Model-config stub – only the attributes accessed by the backend. - self.model_config = type( - "ModelConfig", - (), - { - "context_len": 2048, - "attention_arch": AttentionArch.MLA, - "num_attention_heads": 128, - "kv_lora_rank": 512, - "qk_nope_head_dim": 128, - "qk_rope_head_dim": 64, - "v_head_dim": 512, - "scaling": 1.0 / ((128 + 64) ** 0.5), - "get_num_kv_heads": staticmethod(lambda _: 1), - }, - ) - - # Req-to-token pool (dummy) - max_bs = 64 - max_ctx = self.model_config.context_len - self.req_to_token_pool = type( - "TokenPool", - (), - { - "size": max_bs, - "req_to_token": torch.zeros(max_bs, max_ctx, dtype=torch.int32, device=self.device), - }, - ) - - # KV-token pool - self.token_to_kv_pool = MLATokenToKVPool( - size=max_bs * max_ctx, - page_size=page_size, - dtype=self.kv_cache_dtype, - kv_lora_rank=512, - qk_rope_head_dim=64, - layer_num=1, - device=self.device, - enable_memory_saver=False, - ) - - -@unittest.skipIf(not torch.cuda.is_available() or not is_flashinfer_available(), "CUDA + flashinfer required") -class TestTRTLLMMLABackend(CustomTestCase): - """Structure mirrors `test_flashattn_backend.py` but focuses on MLA decode.""" - - def setUp(self): - self.batch_size = 16 - self.seq_len = 512 - self.page_sizes = [32, 64] - self.device = "cuda" - self.dtype = torch.bfloat16 - - # ‑- helpers --------------------------------------------------------- - def _init(self, page_size: int): - self.model_runner = MockModelRunner(page_size) - self.backend = TRTLLMMLABackend(self.model_runner) - # Attach num_heads required by RadixAttention convenience - self.model_runner.model_config.num_attention_heads = 128 - - def _alloc_qkv(self): - head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim - q_shape = (self.batch_size, 128, head_dim) - q = torch.randn(q_shape, dtype=self.dtype, device=self.device) - k = torch.randn(self.batch_size, 1, head_dim, dtype=self.dtype, device=self.device) - v = None # TRTLLM MLA decode kernel ignores v - return q, k, v - - def _create_forward_batch(self, seq_lens: torch.Tensor): - fb = ForwardBatch( - batch_size=self.batch_size, - input_ids=torch.randint(0, 100, (self.batch_size, 1), device=self.device), - out_cache_loc=torch.arange(self.batch_size, device=self.device), - seq_lens_sum=int(seq_lens.sum().item()), - forward_mode=ForwardMode.DECODE, - req_pool_indices=torch.arange(self.batch_size, device=self.device), - seq_lens=seq_lens, - seq_lens_cpu=seq_lens.cpu(), - attn_backend=self.backend, - ) - fb.req_to_token_pool = self.model_runner.req_to_token_pool - fb.token_to_kv_pool = self.model_runner.token_to_kv_pool - return fb - - # ‑- actual tests ---------------------------------------------------- - def test_forward_decode(self): - """Smoke test decode across several page sizes.""" - for ps in self.page_sizes: - self._init(ps) - - # Random seq lens (ensure one matches max) - seq_lens = torch.randint(1, self.seq_len, (self.batch_size,), device=self.device) - seq_lens[0] = self.seq_len - - forward_batch = self._create_forward_batch(seq_lens) - self.backend.init_forward_metadata(forward_batch) - - q, k, v = self._alloc_qkv() - layer = RadixAttention( - num_heads=128, - head_dim=512 + 64, - scaling=self.model_runner.model_config.scaling, - num_kv_heads=1, - layer_id=0, - v_head_dim=512, - prefix="attn_mqa", - ) - out = self.backend.forward_decode(q, k, v, layer, forward_batch) - - self.assertEqual(out.shape, (self.batch_size, 128 * 512)) - self.assertEqual(out.dtype, self.dtype) - self.assertEqual(out.device.type, "cuda") - self.assertFalse(torch.isnan(out).any()) - self.assertFalse(torch.isinf(out).any()) - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py b/python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py deleted file mode 100644 index db0affb7c048..000000000000 --- a/python/sglang/test/attention/test_trtllm_vs_flashinfer_mla.py +++ /dev/null @@ -1,218 +0,0 @@ -import unittest -import torch -import math - -from sglang.srt.layers import dp_attention as _dp_attn - -# Patch DP-attention globals before importing backends -_dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test - -from sglang.srt.configs.model_config import AttentionArch -from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend -from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend -from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.test.test_utils import CustomTestCase -from sglang.srt.utils import is_flashinfer_available - - -class MockModelRunner: - """Minimal fake ModelRunner for comparing both MLA backends.""" - - def __init__(self, page_size: int): - self.device = "cuda" - self.dtype = torch.bfloat16 - self.kv_cache_dtype = torch.bfloat16 - self.page_size = page_size - - # Model-config stub with MLA attributes - self.model_config = type( - "ModelConfig", - (), - { - "context_len": 2048, - "attention_arch": AttentionArch.MLA, - "num_attention_heads": 128, - "kv_lora_rank": 512, - "qk_nope_head_dim": 128, - "qk_rope_head_dim": 64, - "v_head_dim": 512, - "scaling": 1.0 / ((128 + 64) ** 0.5), - "get_num_kv_heads": staticmethod(lambda _: 1), - }, - ) - - # Req-to-token pool - max_bs = 64 - max_ctx = self.model_config.context_len - self.req_to_token_pool = type( - "TokenPool", - (), - { - "size": max_bs, - "req_to_token": torch.zeros(max_bs, max_ctx, dtype=torch.int32, device=self.device), - }, - ) - - # KV-token pool (MLA) - self.token_to_kv_pool = MLATokenToKVPool( - size=max_bs * max_ctx, - page_size=page_size, - dtype=self.kv_cache_dtype, - kv_lora_rank=512, - qk_rope_head_dim=64, - layer_num=1, - device=self.device, - enable_memory_saver=False, - ) - - -@unittest.skipIf(not torch.cuda.is_available() or not is_flashinfer_available(), "CUDA + flashinfer required") -class TestTRTLLMvsFlashInferMLA(CustomTestCase): - """Test numerical equivalence between TRTLLM and FlashInfer MLA backends.""" - - def setUp(self): - self.batch_size = 8 - self.seq_len = 256 - self.page_size = 32 - self.device = "cuda" - self.dtype = torch.bfloat16 - - # Create model runner - self.model_runner = MockModelRunner(self.page_size) - - # Initialize both backends - self.trtllm_backend = TRTLLMMLABackend(self.model_runner) - self.flashinfer_backend = FlashInferMLAAttnBackend(self.model_runner) - - # Create RadixAttention layer for testing - self.layer = RadixAttention( - num_heads=128, - head_dim=512 + 64, # kv_lora_rank + qk_rope_head_dim - scaling=self.model_runner.model_config.scaling, - num_kv_heads=1, - layer_id=0, - v_head_dim=512, - prefix="attn_mqa", - ) - - def _create_qkv_tensors(self): - """Create Q, K, V tensors for testing.""" - head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim - q = torch.randn((self.batch_size, 128, head_dim), dtype=self.dtype, device=self.device) - k = torch.randn((self.batch_size, 1, head_dim), dtype=self.dtype, device=self.device) - # For FlashInfer MLA, if k is provided v must not be None. - v = torch.randn((self.batch_size, 1, 512), dtype=self.dtype, device=self.device) - return q, k, v - - def _create_forward_batch(self, backend): - """Create a forward batch for the given backend.""" - # Random sequence lengths - seq_lens = torch.randint(1, self.seq_len, (self.batch_size,), device=self.device) - seq_lens[0] = self.seq_len # Ensure at least one max length - - fb = ForwardBatch( - batch_size=self.batch_size, - input_ids=torch.randint(0, 100, (self.batch_size, 1), device=self.device), - out_cache_loc=torch.arange(self.batch_size, device=self.device), - seq_lens_sum=int(seq_lens.sum().item()), - forward_mode=ForwardMode.DECODE, - req_pool_indices=torch.arange(self.batch_size, device=self.device), - seq_lens=seq_lens, - seq_lens_cpu=seq_lens.cpu(), - attn_backend=backend, - ) - fb.req_to_token_pool = self.model_runner.req_to_token_pool - fb.token_to_kv_pool = self.model_runner.token_to_kv_pool - return fb - - def test_decode_output_match(self): - """Test that TRTLLM and FlashInfer MLA backends produce matching outputs.""" - # Create identical forward batches for both backends - fb_trtllm = self._create_forward_batch(self.trtllm_backend) - fb_flashinfer = self._create_forward_batch(self.flashinfer_backend) - - # Initialize metadata for both backends - self.trtllm_backend.init_forward_metadata(fb_trtllm) - self.flashinfer_backend.init_forward_metadata(fb_flashinfer) - - # Create Q, K, V tensors - q, k, v = self._create_qkv_tensors() - - # Run forward decode on both backends - out_trtllm = self.trtllm_backend.forward_decode(q.clone(), k.clone(), v, self.layer, fb_trtllm) - out_flashinfer = self.flashinfer_backend.forward_decode(q.clone(), k.clone(), v.clone(), self.layer, fb_flashinfer) - - # Debug: print scale info - print(f"\n[DEBUG] Scale analysis:") - print(f" layer.scaling = {self.layer.scaling}") - print(f" qk_nope_head_dim = {self.model_runner.model_config.qk_nope_head_dim}") - print(f" qk_rope_head_dim = {self.model_runner.model_config.qk_rope_head_dim}") - print(f" kv_lora_rank = {self.model_runner.model_config.kv_lora_rank}") - print(f" Expected TRT scale factor = {math.sqrt(128 + 64) / math.sqrt(512 + 64)} = {math.sqrt(192) / math.sqrt(576)}") - print(f" Output shapes: TRTLLM {out_trtllm.shape}, FlashInfer {out_flashinfer.shape}") - print(f" Output means: TRTLLM {out_trtllm.mean().item():.6f}, FlashInfer {out_flashinfer.mean().item():.6f}") - print(f" Output stds: TRTLLM {out_trtllm.std().item():.6f}, FlashInfer {out_flashinfer.std().item():.6f}") - print(f" Max diff = {(out_trtllm - out_flashinfer).abs().max().item()}") - print(f" Ratio of means = {out_trtllm.mean().item() / out_flashinfer.mean().item() if out_flashinfer.mean().item() != 0 else 'inf'}") - - # Additional debug - print(f"\n[DEBUG] Scale computation:") - print(f" layer.scaling = 1/sqrt(192) = {1/math.sqrt(192)}") - print(f" TRT scale passed = layer.scaling * sqrt(192)/sqrt(576) = {self.layer.scaling * math.sqrt(192) / math.sqrt(576)}") - print(f" TRT kernel will compute: 1 / (sqrt(576) * scale) = {1 / (math.sqrt(576) * self.layer.scaling * math.sqrt(192) / math.sqrt(576))}") - print(f" Which equals: 1 / (layer.scaling * sqrt(192)) = {1 / (self.layer.scaling * math.sqrt(192))}") - print(f" But FlashInfer uses: layer.scaling = {self.layer.scaling}") - print(f" Ratio: {(1 / (self.layer.scaling * math.sqrt(192))) / self.layer.scaling} = sqrt(192) = {math.sqrt(192)}") - - # Check output shapes match - self.assertEqual(out_trtllm.shape, out_flashinfer.shape, - f"Output shapes differ: TRTLLM {out_trtllm.shape} vs FlashInfer {out_flashinfer.shape}") - - # Check output dtypes match - self.assertEqual(out_trtllm.dtype, out_flashinfer.dtype) - - # Check numerical equivalence with tolerance - # Note: Using higher tolerance due to potential numerical differences in implementations - self.assertTrue( - torch.allclose(out_trtllm, out_flashinfer, atol=1e-2, rtol=1e-2), - f"TRTLLM and FlashInfer outputs differ beyond tolerance. " - f"Max diff: {(out_trtllm - out_flashinfer).abs().max().item()}" - ) - - # Additional checks - self.assertFalse(torch.isnan(out_trtllm).any(), "TRTLLM output contains NaN") - self.assertFalse(torch.isnan(out_flashinfer).any(), "FlashInfer output contains NaN") - self.assertFalse(torch.isinf(out_trtllm).any(), "TRTLLM output contains Inf") - self.assertFalse(torch.isinf(out_flashinfer).any(), "FlashInfer output contains Inf") - - def test_decode_with_different_page_sizes(self): - """Test output consistency across different page sizes.""" - page_sizes = [32, 64] - outputs = [] - - for ps in page_sizes: - # Reinitialize with new page size - self.model_runner = MockModelRunner(ps) - self.trtllm_backend = TRTLLMMLABackend(self.model_runner) - - # Create batch and run decode - fb = self._create_forward_batch(self.trtllm_backend) - self.trtllm_backend.init_forward_metadata(fb) - - q, k, v = self._create_qkv_tensors() - out = self.trtllm_backend.forward_decode(q, k, v, self.layer, fb) - outputs.append(out) - - # Check that outputs are consistent across page sizes - # Note: Different page sizes might lead to slightly different numerical results - for i in range(1, len(outputs)): - self.assertTrue( - torch.allclose(outputs[0], outputs[i], atol=5e-2, rtol=5e-2), - f"Output with page_size={page_sizes[0]} differs from page_size={page_sizes[i]}" - ) - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file From 274fdbb6ff78ef994a4be078c0b8ce5f993b1c97 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Mon, 14 Jul 2025 20:01:04 -0700 Subject: [PATCH 05/27] refator code Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- .../attention/trtllm_gen_mla_backend.py | 550 ++++++++---------- .../sglang/srt/model_executor/model_runner.py | 13 +- 2 files changed, 240 insertions(+), 323 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py index cdb817984894..eaf1aa1c2eef 100644 --- a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py @@ -6,14 +6,13 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union +import math import torch import triton -import math # Needed for scale correction -import os from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend -from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton +from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton, NUM_PAGE_PER_BLOCK from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available @@ -26,9 +25,16 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.spec_info import SpecInfo +# Constants +TRTLLM_BLOCK_CONSTRAINT = 128 # TRT-LLM kernel constraint: block_num % (128 / block_size) == 0 +DEFAULT_WORKSPACE_SIZE_MB = 128 +DEFAULT_QUANTIZATION_SCALE = 1.0 # Default scale for FP8 quantization +DEFAULT_SM_SCALE = 1.0 # Default softmax scale + @dataclass class TRTLLMGENMLADecodeMetadata: + """Metadata for TRTLLM MLA decode operations.""" workspace: Optional[torch.Tensor] = None block_kv_indices: Optional[torch.Tensor] = None @@ -47,52 +53,170 @@ def __init__( model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf ) + # Cache model config for easier access + config = model_runner.model_config + # Model parameters - self.num_q_heads = ( - model_runner.model_config.num_attention_heads // get_attention_tp_size() - ) - self.num_kv_heads = model_runner.model_config.get_num_kv_heads( - get_attention_tp_size() - ) - self.req_to_token = model_runner.req_to_token_pool.req_to_token - self.num_local_heads = ( - model_runner.model_config.num_attention_heads // get_attention_tp_size() - ) - self.forward_metadata: Union[TRTLLMGENMLADecodeMetadata] = None - self.kv_lora_rank = model_runner.model_config.kv_lora_rank - self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim - self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim - self.v_head_dim = model_runner.model_config.v_head_dim - self.scaling = model_runner.model_config.scaling + self.num_q_heads = config.num_attention_heads // get_attention_tp_size() + self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size()) + self.num_local_heads = config.num_attention_heads // get_attention_tp_size() + + # MLA-specific dimensions + self.kv_lora_rank = config.kv_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.v_head_dim = config.v_head_dim + self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim + + # Runtime parameters + self.scaling = config.scaling self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype - self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim - self.page_size = model_runner.page_size # Use page size from model runner + self.page_size = model_runner.page_size + self.req_to_token = model_runner.req_to_token_pool.req_to_token - # Allocate larger workspace for TRTLLM (128MB as in the test) - self.workspace_size = 128 * 1024 * 1024 + # Workspace allocation + self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024 self.workspace_buffer = torch.empty( self.workspace_size, dtype=torch.int8, device=self.device ) - # CUDA graph metadata storage + # CUDA graph state self.decode_cuda_graph_metadata = {} self.cuda_graph_kv_indices = None + self.forward_metadata: Union[TRTLLMGENMLADecodeMetadata] = None def _calc_padded_blocks(self, max_seq_len: int) -> int: - """Return number of blocks padded so that it satisfies TRTLLM constraint.""" + """ + Calculate padded block count that satisfies both TRT-LLM and Triton constraints. + + Args: + max_seq_len: Maximum sequence length in tokens + + Returns: + Number of blocks padded to satisfy all constraints + """ blocks = triton.cdiv(max_seq_len, self.page_size) - - # The Triton helper that builds `block_kv_indices` emits `NUM_PAGE_PER_BLOCK` - # (= 64) indices at a time, independent of `page_size`. To avoid it writing - # past the end of the row we **must** make every row at least 64 long. - # (Side-effect: max_seq_len is effectively rounded up to 64 × page_size = 2 K - # tokens for page_size 32, which is still below the 2048-ctx used here.) - min_blocks = 64 + + # TWO constraints require padding: + # 1. TRT-LLM kernel expects: block_num % (128 / block_size) == 0 + # This is a hard requirement from the CUDA kernel implementation. + # 2. Our Triton helper `create_flashmla_kv_indices_triton` builds page tables + # in fixed bursts of 64 indices per outer loop iteration. Each burst writes + # unconditionally to positions [i*64, (i+1)*64) without per-element masking. + # If the row length isn't a multiple of 64, the final burst would overrun + # the allocated buffer and corrupt memory. + # + # We need to satisfy BOTH constraints, so we take the LCM of: + # - trtllm_constraint = 128 // page_size + # - triton_constraint = 64 + trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size + triton_constraint = NUM_PAGE_PER_BLOCK + min_blocks = math.lcm(trtllm_constraint, triton_constraint) + if blocks % min_blocks != 0: blocks = triton.cdiv(blocks, min_blocks) * min_blocks return blocks + def _get_quantization_scales(self) -> tuple[float, float, float, float, float]: + """ + Get quantization scales for q, k, v, sm, and o tensors. + + Returns: + Tuple of (q_scale, k_scale, v_scale, sm_scale, o_scale) + """ + # TODO: Implement proper quantization scale inference based on model config + # For now, use default values for FP8 + return ( + DEFAULT_QUANTIZATION_SCALE, # q_scale + DEFAULT_QUANTIZATION_SCALE, # k_scale + DEFAULT_QUANTIZATION_SCALE, # v_scale + DEFAULT_SM_SCALE, # sm_scale + DEFAULT_QUANTIZATION_SCALE, # o_scale + ) + + def _prepare_kv_cache(self, layer: RadixAttention, forward_batch: ForwardBatch) -> torch.Tensor: + """ + Prepare KV cache tensor in the format expected by TRT-LLM kernel. + + Args: + layer: Attention layer + forward_batch: Forward batch info + + Returns: + KV cache tensor shaped (num_pages, 2, page_size, kv_cache_dim) + """ + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + pages = k_cache.view(-1, self.page_size, self.kv_cache_dim) + # TRT-LLM expects stacked format: slice 0 → CKV+K, slice 1 → KPE + return torch.stack([pages, pages], dim=1) + + def _prepare_query_tensor( + self, + q: torch.Tensor, + q_rope: Optional[torch.Tensor], + layer: RadixAttention + ) -> torch.Tensor: + """ + Prepare query tensor in the format expected by TRT-LLM kernel. + + Args: + q: Query tensor + q_rope: Optional RoPE query tensor + layer: Attention layer + + Returns: + Query tensor with concatenated NOPE and RoPE parts + """ + if q_rope is not None: + # q contains the NOPE part (v_head_dim) + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim) + else: + reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = reshaped_q[:, :, : layer.v_head_dim] + q_rope = reshaped_q[:, :, layer.v_head_dim :] + + return torch.cat([q_nope, q_rope], dim=-1) + + def _create_block_kv_indices( + self, + batch_size: int, + max_blocks: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + device: torch.device + ) -> torch.Tensor: + """ + Create block KV indices tensor using Triton kernel. + + Args: + batch_size: Batch size + max_blocks: Maximum number of blocks per sequence + req_pool_indices: Request pool indices + seq_lens: Sequence lengths + device: Target device + + Returns: + Block KV indices tensor + """ + block_kv_indices = torch.full( + (batch_size, max_blocks), -1, dtype=torch.int32, device=device + ) + + create_flashmla_kv_indices_triton[(batch_size,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + None, + block_kv_indices, + self.req_to_token.stride(0), + max_blocks, + self.page_size, + ) + + return block_kv_indices + def init_cuda_graph_state( self, max_bs: int, @@ -100,14 +224,10 @@ def init_cuda_graph_state( kv_indices_buf: Optional[torch.Tensor] = None, ): """Initialize CUDA graph state for TRTLLM MLA.""" - # Calculate padded block size that satisfies TRTLLM constraint max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len) self.cuda_graph_kv_indices = torch.full( - (max_bs, max_blocks_per_seq), - -1, - dtype=torch.int32, - device=self.device, + (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device ) self.cuda_graph_workspace = torch.empty( self.workspace_size, dtype=torch.int8, device=self.device @@ -124,46 +244,29 @@ def init_forward_metadata_capture_cuda_graph( spec_info: Optional[SpecInfo], ): """Initialize metadata for CUDA graph capture.""" - if forward_mode.is_decode_or_idle(): - if spec_info is None: - max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item()) - - block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad] - create_flashmla_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - seq_lens, - None, - block_kv_indices, - self.req_to_token.stride(0), - max_seqlen_pad, - self.page_size, - ) - metadata = TRTLLMGENMLADecodeMetadata( - self.cuda_graph_workspace, - block_kv_indices, - ) - self.decode_cuda_graph_metadata[bs] = metadata - self.forward_metadata = metadata - else: - super().init_forward_metadata_capture_cuda_graph( - bs, - num_tokens, - req_pool_indices, - seq_lens, - encoder_lens, - forward_mode, - spec_info, - ) - else: - super().init_forward_metadata_capture_cuda_graph( - bs, - num_tokens, + if forward_mode.is_decode_or_idle() and spec_info is None: + max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item()) + block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad] + + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, req_pool_indices, seq_lens, - encoder_lens, - forward_mode, - spec_info, + None, + block_kv_indices, + self.req_to_token.stride(0), + max_seqlen_pad, + self.page_size, + ) + + metadata = TRTLLMGENMLADecodeMetadata( + self.cuda_graph_workspace, block_kv_indices + ) + self.decode_cuda_graph_metadata[bs] = metadata + self.forward_metadata = metadata + else: + super().init_forward_metadata_capture_cuda_graph( + bs, num_tokens, req_pool_indices, seq_lens, encoder_lens, forward_mode, spec_info ) def init_forward_metadata_replay_cuda_graph( @@ -178,93 +281,53 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_cpu: Optional[torch.Tensor], ): """Replay CUDA graph with new inputs.""" - if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": - print(f"[TRTLLM-MLA] init_forward_metadata_replay_cuda_graph: bs={bs}, forward_mode={forward_mode}, spec_info={spec_info is not None}") - - if forward_mode.is_decode_or_idle(): - if spec_info is None: - # Reuse cached metadata - metadata = self.decode_cuda_graph_metadata[bs] - - # Update block indices for new sequences - create_flashmla_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices[:bs], - seq_lens[:bs], - None, - metadata.block_kv_indices, - self.req_to_token.stride(0), - metadata.block_kv_indices.shape[1], - self.page_size, - ) - # Pad invalid blocks to keep TRTLLM happy - # self._pad_invalid_blocks(metadata.block_kv_indices) - - self.forward_metadata = metadata - - else: - # Speculative decoding: use parent class implementation - super().init_forward_metadata_replay_cuda_graph( - bs, req_pool_indices, seq_lens, seq_lens_sum, - encoder_lens, forward_mode, spec_info, seq_lens_cpu - ) + if forward_mode.is_decode_or_idle() and spec_info is None: + metadata = self.decode_cuda_graph_metadata[bs] + + # Update block indices for new sequences + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens[:bs], + None, + metadata.block_kv_indices, + self.req_to_token.stride(0), + metadata.block_kv_indices.shape[1], + self.page_size, + ) + + self.forward_metadata = metadata else: - # Prefill: use parent class implementation super().init_forward_metadata_replay_cuda_graph( bs, req_pool_indices, seq_lens, seq_lens_sum, encoder_lens, forward_mode, spec_info, seq_lens_cpu ) - def get_cuda_graph_seq_len_fill_value(self): + def get_cuda_graph_seq_len_fill_value(self) -> int: """Get the fill value for sequence lengths in CUDA graph.""" return 1 def init_forward_metadata(self, forward_batch: ForwardBatch): - """Init the metadata for a forward pass.""" - bs = forward_batch.batch_size - spec_info = forward_batch.spec_info - - if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": - print(f"[TRTLLM-MLA] init_forward_metadata: bs={bs}, forward_mode={forward_batch.forward_mode}, spec_info={spec_info is not None}") - if forward_batch.forward_mode.is_decode_or_idle(): - if spec_info is None: - # seq_lens_cpu may be None when cuda-graphs are disabled - if getattr(forward_batch, "seq_lens_cpu", None) is not None: - max_seq = forward_batch.seq_lens_cpu.max().item() - else: - max_seq = forward_batch.seq_lens.max().item() - - max_seqlen_pad = self._calc_padded_blocks(max_seq) - - block_kv_indices = torch.full( - (bs, max_seqlen_pad), - -1, - dtype=torch.int32, - device=forward_batch.seq_lens.device, - ) - - create_flashmla_kv_indices_triton[(bs,)]( - self.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - None, - block_kv_indices, - self.req_to_token.stride(0), - max_seqlen_pad, - self.page_size, - ) + """Initialize the metadata for a forward pass.""" + if forward_batch.forward_mode.is_decode_or_idle() and forward_batch.spec_info is None: + bs = forward_batch.batch_size + + # Get maximum sequence length + if getattr(forward_batch, "seq_lens_cpu", None) is not None: + max_seq = forward_batch.seq_lens_cpu.max().item() + else: + max_seq = forward_batch.seq_lens.max().item() - # Ensure padded blocks are valid to avoid TRTLLM skipping shorter seqs - # self._pad_invalid_blocks(block_kv_indices) + max_seqlen_pad = self._calc_padded_blocks(max_seq) + block_kv_indices = self._create_block_kv_indices( + bs, max_seqlen_pad, forward_batch.req_pool_indices, + forward_batch.seq_lens, forward_batch.seq_lens.device + ) - self.forward_metadata = TRTLLMGENMLADecodeMetadata( - self.workspace_buffer, - block_kv_indices, - ) - # Expose to the ForwardBatch so that other components can access it - forward_batch.decode_trtllm_mla_metadata = self.forward_metadata - else: - super().init_forward_metadata(forward_batch) + self.forward_metadata = TRTLLMGENMLADecodeMetadata( + self.workspace_buffer, block_kv_indices + ) + forward_batch.decode_trtllm_mla_metadata = self.forward_metadata else: super().init_forward_metadata(forward_batch) @@ -278,115 +341,27 @@ def forward_decode( save_kv_cache: bool = True, q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, - ): + ) -> torch.Tensor: """Run forward for decode using TRTLLM MLA kernel.""" - cache_loc = forward_batch.out_cache_loc - # Save KV cache if requested if k is not None and save_kv_cache: + cache_loc = forward_batch.out_cache_loc if k_rope is not None: - forward_batch.token_to_kv_pool.set_mla_kv_buffer( - layer, cache_loc, k, k_rope - ) - else: - if v is not None: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v - ) + forward_batch.token_to_kv_pool.set_mla_kv_buffer(layer, cache_loc, k, k_rope) + elif v is not None: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) - # Build query tensor - if q_rope is not None: - # q contains the NOPE part (v_head_dim) - q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) - q_rope = q_rope.view( - -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim - ) - else: - reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) - q_nope = reshaped_q[:, :, : layer.v_head_dim] - q_rope = reshaped_q[:, :, layer.v_head_dim :] - - # Concatenate to build the full query as expected by TRTLLM kernel - query = torch.cat([q_nope, q_rope], dim=-1) - - - if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": - print(f"[TRTLLM-MLA]: model_runner.model_config.scaling.scaling is {self.model_runner.model_config.scaling}") - - # Get the model scaling factor - # TRTLLM kernel applies the 1/sqrt(192) factor internally, so we - # should pass `1.0` here (see Flash-Infer equivalence tests). - sm_scale = 1 - # (scale * ((512 + 64) ** 0.5)) / ((128 + 64) ** 0.5) - # ( sqrt(3)) / sqrt(192) - - # KV cache tensor: reshape to (num_pages, page_size, dim) - k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - # Build KV cache slices expected by TRT-LLM: slice 0 → CKV+K, slice 1 → KPE. - pages = k_cache.view(-1, self.page_size, self.kv_cache_dim) # (P, blk, 576) - - # Retrieve metadata so we can check block tables. - metadata = getattr(forward_batch, "decode_trtllm_mla_metadata", None) - if metadata is None: - metadata = self.forward_metadata - - # --------------------------------------------------------------------- - # According to the FlashInfer test, both slices should contain the full 576-dim tensor. - # Use torch.stack so each slice is an independent contiguous view **after** we have - # patched the pages in-place. - kv_cache = torch.stack([pages, pages], dim=1) # (P, 2, blk, 576) - - # Metadata already obtained above; no change needed. - - # ------------- DEBUG KV CACHE CONSTRUCTION ------------- - if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": - print( - f"[TRTLLM-MLA] k_cache_raw: {k_cache.shape} {k_cache.dtype}\n" - f"[TRTLLM-MLA] pages : {pages.shape} {pages.dtype}\n" - f"[TRTLLM-MLA] kv_cache : {kv_cache.shape} {kv_cache.dtype}\n" - f"[TRTLLM-MLA] block_kv_indices: {metadata.block_kv_indices.shape} {metadata.block_kv_indices.dtype}\n" - f"[TRTLLM-MLA] workspace: {metadata.workspace.shape if metadata.workspace is not None else 'None'}\n" - f"[TRTLLM-MLA] max_seq_len: {metadata.block_kv_indices.shape[1] * self.page_size}\n" - f"[TRTLLM-MLA] k_cache[0,0,:3]: {k_cache[0,0,:3] if k_cache.numel() > 0 else 'empty'}\n" - f"[TRTLLM-MLA] pages[0,0,:3]: {pages[0,0,:3] if pages.numel() > 0 else 'empty'}\n" - f"[TRTLLM-MLA] kv_cache[0,0,0,:3]: {kv_cache[0,0,0,:3] if kv_cache.numel() > 0 else 'empty'}\n" - f"[TRTLLM-MLA] kv_cache[0,1,0,:3]: {kv_cache[0,1,0,:3] if kv_cache.numel() > 0 else 'empty'}\n" - f"[TRTLLM-MLA] block_kv_indices[0,:5]: {metadata.block_kv_indices[0,:5] if metadata.block_kv_indices.numel() > 0 else 'empty'}" - ) - - # ------------- DEBUG (align with FlashInfer-MLA) ------------- - if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": - q_nope_str = f"{q_nope.shape} {q_rope.dtype}" if q_nope is not None else "None" - q_rope_str = f"{q_rope.shape} {q_rope.dtype}" if q_rope is not None else "None" - q_nope_sample = f"{q_nope[0,:3,:3] if q_rope is not None and q_nope.numel() > 0 else 'empty'}" - q_rope_sample = f"{q_rope[0,:3,:3] if q_rope is not None and q_rope.numel() > 0 else 'empty'}" - - print( - f"[TRTLLM-MLA] q_nope : {q_nope_str}\n" - f"[TRTLLM-MLA] q_rope : {q_rope_str}\n" - f"[TRTLLM-MLA] query : {query.shape} {query.dtype}\n" - f"[TRTLLM-MLA] kv_cache: {kv_cache.shape} {kv_cache.dtype}\n" - f"[TRTLLM-MLA] scale : {sm_scale}\n" - f"[TRTLLM-MLA] cache_loc: {cache_loc.shape} {cache_loc.dtype}\n" - f"[TRTLLM-MLA] seq_lens: {forward_batch.seq_lens.shape} {forward_batch.seq_lens.dtype}\n" - f"[TRTLLM-MLA] batch_size: {forward_batch.batch_size}\n" - f"[TRTLLM-MLA] page_size: {self.page_size}\n" - f"[TRTLLM-MLA] qk_nope_head_dim: {self.qk_nope_head_dim}\n" - f"[TRTLLM-MLA] qk_rope_head_dim: {self.qk_rope_head_dim}\n" - f"[TRTLLM-MLA] kv_lora_rank: {self.kv_lora_rank}\n" - f"[TRTLLM-MLA] v_head_dim: {layer.v_head_dim}\n" - f"[TRTLLM-MLA] kv_cache_dim: {self.kv_cache_dim}\n" - f"[TRTLLM-MLA] tp_q_head_num: {layer.tp_q_head_num}\n" - f"[TRTLLM-MLA] tp_k_head_num: {layer.tp_k_head_num}\n" - f"[TRTLLM-MLA] layer_id: {layer.layer_id}\n" - f"[TRTLLM-MLA] q_nope[0,:3,:3]: {q_nope_sample}\n" - f"[TRTLLM-MLA] q_rope[0,:3,:3]: {q_rope_sample}\n" - f"[TRTLLM-MLA] query[0,:3,:3]: {query[0,:3,:3] if query.numel() > 0 else 'empty'}\n" - f"[TRTLLM-MLA] seq_lens values: {forward_batch.seq_lens[:min(5, len(forward_batch.seq_lens))]}\n" - f"[TRTLLM-MLA] cache_loc values: {cache_loc[:min(5, len(cache_loc))]}" - ) + # Prepare tensors for TRT-LLM kernel + query = self._prepare_query_tensor(q, q_rope, layer) + kv_cache = self._prepare_kv_cache(layer, forward_batch) + + # Get metadata + metadata = getattr(forward_batch, "decode_trtllm_mla_metadata", None) or self.forward_metadata + + # Get quantization scales + q_scale, k_scale, v_scale, sm_scale, o_scale = self._get_quantization_scales() - + # Call TRT-LLM kernel raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, kv_cache=kv_cache, @@ -397,68 +372,21 @@ def forward_decode( block_tables=metadata.block_kv_indices, seq_lens=forward_batch.seq_lens.to(torch.int32), block_size=self.page_size, - max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size), # max_seq_len equals padded_blocks * page_size. - q_scale=1.0, - k_scale=1.0, - v_scale=1.0, + max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size), + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, sm_scale=sm_scale, - o_scale=1.0, + o_scale=o_scale, ) - # TRTLLM kernel may return both V and ROPE dims (kv_lora_rank + qk_rope_head_dim). - # We only need the value projection part (v_head_dim). - raw_out_v = raw_out[..., : layer.v_head_dim].contiguous() - - # ------------- DEBUG KERNEL OUTPUT ------------- - if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": - print( - f"[TRTLLM-MLA] raw_out : {raw_out.shape} {raw_out.dtype}\n" - f"[TRTLLM-MLA] raw_out_v : {raw_out_v.shape} {raw_out_v.dtype}\n" - f"[TRTLLM-MLA] raw_out[0,:3,:3]: {raw_out[0,:3,:3] if raw_out.numel() > 0 else 'empty'}\n" - f"[TRTLLM-MLA] raw_out_v[0,:3,:3]: {raw_out_v[0,:3,:3] if raw_out_v.numel() > 0 else 'empty'}\n" - f"[TRTLLM-MLA] raw_out stats: min={raw_out.min():.6f}, max={raw_out.max():.6f}, mean={raw_out.mean():.6f}\n" - f"[TRTLLM-MLA] raw_out_v stats: min={raw_out_v.min():.6f}, max={raw_out_v.max():.6f}, mean={raw_out_v.mean():.6f}" - ) + # Extract value projection part and reshape + raw_out_v = raw_out[..., : layer.v_head_dim].contiguous() output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + # Truncate if needed if output.shape[0] > forward_batch.batch_size: output = output[: forward_batch.batch_size] - - # ------------- DEBUG FINAL OUTPUT ------------- - if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": - print( - f"[TRTLLM-MLA] output : {output.shape} {output.dtype}\n" - f"[TRTLLM-MLA] output[0,:10]: {output[0,:10] if output.numel() > 0 else 'empty'}\n" - f"[TRTLLM-MLA] output[0,10:20]: {output[0,10:20] if output.numel() > 10 else 'empty'}\n" - f"[TRTLLM-MLA] output[0,20:30]: {output[0,20:30] if output.numel() > 20 else 'empty'}\n" - f"[TRTLLM-MLA] output stats: min={output.min():.6f}, max={output.max():.6f}, mean={output.mean():.6f}\n" - f"[TRTLLM-MLA] output std: {output.std():.6f}\n" - f"[TRTLLM-MLA] output_reshaped: {output.shape[0]} -> {forward_batch.batch_size}\n" - f"[TRTLLM-MLA] ===== END TRTLLM-MLA DEBUG =====" - ) - - return output - @staticmethod - def _pad_invalid_blocks(block_kv_indices: torch.Tensor): - """Replace -1 paddings with the last valid page id for each row. - - TRT-LLM treats a leading -1 as "empty sequence" and will skip the - whole sequence. For shorter sequences we therefore replicate the - last real page id into the padded region instead of leaving -1. - """ - for row in range(block_kv_indices.size(0)): - row_view = block_kv_indices[row] - # Find last non-negative entry (every sequence has at least one) - valid_mask = row_view >= 0 - if not valid_mask.any(): - # Defensive – shouldn’t happen in decode mode - continue - last_valid = row_view[valid_mask][-1] - row_view[~valid_mask] = last_valid - - # Extra visibility when debugging - if os.getenv("SGLANG_DEBUG_MLA", "0") == "1": - preview_rows = min(6, block_kv_indices.size(0)) - print("[TRTLLM-MLA] block_kv_indices preview (after padding):") - for i in range(preview_rows): - print(f" seq {i}:", block_kv_indices[i].tolist()) \ No newline at end of file + return output + diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c14c3fbef48f..6f6bdb284331 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -396,18 +396,7 @@ def model_specific_adjustment(self): ): server_args.attention_backend = "fa3" elif is_sm100_supported(): - # On Blackwell, prefer TRTLLM MLA if available, otherwise flashinfer - if is_flashinfer_available(): - try: - import flashinfer - if hasattr(flashinfer.decode, 'trtllm_batch_decode_with_kv_cache_mla'): - server_args.attention_backend = "trtllm_mla" - else: - server_args.attention_backend = "flashinfer" - except: - server_args.attention_backend = "flashinfer" - else: - server_args.attention_backend = "flashinfer" + server_args.attention_backend = "trtllm_mla" elif _is_hip: head_num = self.model_config.get_num_kv_heads(self.tp_size) # TODO current aiter only support head number 16 or 128 head number From ba1220d8f0007b19ec9cb315fe2098957f042f65 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Mon, 14 Jul 2025 20:02:01 -0700 Subject: [PATCH 06/27] add utils.py modification Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- python/sglang/srt/layers/attention/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py index 71633d12dce5..35b5e2d25225 100644 --- a/python/sglang/srt/layers/attention/utils.py +++ b/python/sglang/srt/layers/attention/utils.py @@ -1,6 +1,10 @@ import triton import triton.language as tl +# Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`. +# Number of pages that the kernel writes per iteration. +# Exposed here so other Python modules can import it instead of hard-coding 64. +NUM_PAGE_PER_BLOCK = 64 @triton.jit def create_flashinfer_kv_indices_triton( @@ -53,6 +57,7 @@ def create_flashmla_kv_indices_triton( PAGED_SIZE: tl.constexpr = 64, ): BLOCK_SIZE: tl.constexpr = 4096 + # Keep in sync with module-level NUM_PAGE_PER_BLOCK constant above NUM_PAGE_PER_BLOCK: tl.constexpr = 64 pid = tl.program_id(axis=0) From f67fe412f26f5e232d4730a03af1f504c803500a Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Mon, 14 Jul 2025 20:05:38 -0700 Subject: [PATCH 07/27] pre-commit Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- .../attention/trtllm_gen_mla_backend.py | 135 ++++++---- python/sglang/srt/layers/attention/utils.py | 3 +- .../attention/test_trtllm_gen_mla_backend.py | 255 ++++++++++++------ 3 files changed, 250 insertions(+), 143 deletions(-) mode change 100644 => 100755 python/sglang/srt/layers/attention/utils.py diff --git a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py index eaf1aa1c2eef..8243fc09bf8d 100644 --- a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py @@ -4,15 +4,18 @@ Support attention backend for TRTLLM-Gen MLA kernels from flashinfer. """ +import math from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union -import math import torch import triton from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend -from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton, NUM_PAGE_PER_BLOCK +from sglang.srt.layers.attention.utils import ( + NUM_PAGE_PER_BLOCK, + create_flashmla_kv_indices_triton, +) from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available @@ -26,7 +29,9 @@ from sglang.srt.speculative.spec_info import SpecInfo # Constants -TRTLLM_BLOCK_CONSTRAINT = 128 # TRT-LLM kernel constraint: block_num % (128 / block_size) == 0 +TRTLLM_BLOCK_CONSTRAINT = ( + 128 # TRT-LLM kernel constraint: block_num % (128 / block_size) == 0 +) DEFAULT_WORKSPACE_SIZE_MB = 128 DEFAULT_QUANTIZATION_SCALE = 1.0 # Default scale for FP8 quantization DEFAULT_SM_SCALE = 1.0 # Default softmax scale @@ -35,6 +40,7 @@ @dataclass class TRTLLMGENMLADecodeMetadata: """Metadata for TRTLLM MLA decode operations.""" + workspace: Optional[torch.Tensor] = None block_kv_indices: Optional[torch.Tensor] = None @@ -55,32 +61,32 @@ def __init__( # Cache model config for easier access config = model_runner.model_config - + # Model parameters self.num_q_heads = config.num_attention_heads // get_attention_tp_size() self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size()) self.num_local_heads = config.num_attention_heads // get_attention_tp_size() - + # MLA-specific dimensions self.kv_lora_rank = config.kv_lora_rank self.qk_nope_head_dim = config.qk_nope_head_dim self.qk_rope_head_dim = config.qk_rope_head_dim self.v_head_dim = config.v_head_dim self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim - + # Runtime parameters self.scaling = config.scaling self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype self.page_size = model_runner.page_size self.req_to_token = model_runner.req_to_token_pool.req_to_token - + # Workspace allocation self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024 self.workspace_buffer = torch.empty( self.workspace_size, dtype=torch.int8, device=self.device ) - + # CUDA graph state self.decode_cuda_graph_metadata = {} self.cuda_graph_kv_indices = None @@ -89,15 +95,15 @@ def __init__( def _calc_padded_blocks(self, max_seq_len: int) -> int: """ Calculate padded block count that satisfies both TRT-LLM and Triton constraints. - + Args: max_seq_len: Maximum sequence length in tokens - + Returns: Number of blocks padded to satisfy all constraints """ blocks = triton.cdiv(max_seq_len, self.page_size) - + # TWO constraints require padding: # 1. TRT-LLM kernel expects: block_num % (128 / block_size) == 0 # This is a hard requirement from the CUDA kernel implementation. @@ -106,14 +112,14 @@ def _calc_padded_blocks(self, max_seq_len: int) -> int: # unconditionally to positions [i*64, (i+1)*64) without per-element masking. # If the row length isn't a multiple of 64, the final burst would overrun # the allocated buffer and corrupt memory. - # + # # We need to satisfy BOTH constraints, so we take the LCM of: # - trtllm_constraint = 128 // page_size # - triton_constraint = 64 trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size triton_constraint = NUM_PAGE_PER_BLOCK min_blocks = math.lcm(trtllm_constraint, triton_constraint) - + if blocks % min_blocks != 0: blocks = triton.cdiv(blocks, min_blocks) * min_blocks return blocks @@ -121,7 +127,7 @@ def _calc_padded_blocks(self, max_seq_len: int) -> int: def _get_quantization_scales(self) -> tuple[float, float, float, float, float]: """ Get quantization scales for q, k, v, sm, and o tensors. - + Returns: Tuple of (q_scale, k_scale, v_scale, sm_scale, o_scale) """ @@ -129,20 +135,22 @@ def _get_quantization_scales(self) -> tuple[float, float, float, float, float]: # For now, use default values for FP8 return ( DEFAULT_QUANTIZATION_SCALE, # q_scale - DEFAULT_QUANTIZATION_SCALE, # k_scale + DEFAULT_QUANTIZATION_SCALE, # k_scale DEFAULT_QUANTIZATION_SCALE, # v_scale - DEFAULT_SM_SCALE, # sm_scale + DEFAULT_SM_SCALE, # sm_scale DEFAULT_QUANTIZATION_SCALE, # o_scale ) - def _prepare_kv_cache(self, layer: RadixAttention, forward_batch: ForwardBatch) -> torch.Tensor: + def _prepare_kv_cache( + self, layer: RadixAttention, forward_batch: ForwardBatch + ) -> torch.Tensor: """ Prepare KV cache tensor in the format expected by TRT-LLM kernel. - + Args: layer: Attention layer forward_batch: Forward batch info - + Returns: KV cache tensor shaped (num_pages, 2, page_size, kv_cache_dim) """ @@ -152,26 +160,25 @@ def _prepare_kv_cache(self, layer: RadixAttention, forward_batch: ForwardBatch) return torch.stack([pages, pages], dim=1) def _prepare_query_tensor( - self, - q: torch.Tensor, - q_rope: Optional[torch.Tensor], - layer: RadixAttention + self, q: torch.Tensor, q_rope: Optional[torch.Tensor], layer: RadixAttention ) -> torch.Tensor: """ Prepare query tensor in the format expected by TRT-LLM kernel. - + Args: q: Query tensor q_rope: Optional RoPE query tensor layer: Attention layer - + Returns: Query tensor with concatenated NOPE and RoPE parts """ if q_rope is not None: # q contains the NOPE part (v_head_dim) q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) - q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) else: reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) q_nope = reshaped_q[:, :, : layer.v_head_dim] @@ -180,30 +187,30 @@ def _prepare_query_tensor( return torch.cat([q_nope, q_rope], dim=-1) def _create_block_kv_indices( - self, - batch_size: int, - max_blocks: int, + self, + batch_size: int, + max_blocks: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - device: torch.device + device: torch.device, ) -> torch.Tensor: """ Create block KV indices tensor using Triton kernel. - + Args: batch_size: Batch size max_blocks: Maximum number of blocks per sequence req_pool_indices: Request pool indices seq_lens: Sequence lengths device: Target device - + Returns: Block KV indices tensor """ block_kv_indices = torch.full( (batch_size, max_blocks), -1, dtype=torch.int32, device=device ) - + create_flashmla_kv_indices_triton[(batch_size,)]( self.req_to_token, req_pool_indices, @@ -214,7 +221,7 @@ def _create_block_kv_indices( max_blocks, self.page_size, ) - + return block_kv_indices def init_cuda_graph_state( @@ -225,7 +232,7 @@ def init_cuda_graph_state( ): """Initialize CUDA graph state for TRTLLM MLA.""" max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len) - + self.cuda_graph_kv_indices = torch.full( (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device ) @@ -247,7 +254,7 @@ def init_forward_metadata_capture_cuda_graph( if forward_mode.is_decode_or_idle() and spec_info is None: max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item()) block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad] - + create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices, @@ -258,7 +265,7 @@ def init_forward_metadata_capture_cuda_graph( max_seqlen_pad, self.page_size, ) - + metadata = TRTLLMGENMLADecodeMetadata( self.cuda_graph_workspace, block_kv_indices ) @@ -266,7 +273,13 @@ def init_forward_metadata_capture_cuda_graph( self.forward_metadata = metadata else: super().init_forward_metadata_capture_cuda_graph( - bs, num_tokens, req_pool_indices, seq_lens, encoder_lens, forward_mode, spec_info + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, ) def init_forward_metadata_replay_cuda_graph( @@ -283,7 +296,7 @@ def init_forward_metadata_replay_cuda_graph( """Replay CUDA graph with new inputs.""" if forward_mode.is_decode_or_idle() and spec_info is None: metadata = self.decode_cuda_graph_metadata[bs] - + # Update block indices for new sequences create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, @@ -295,12 +308,18 @@ def init_forward_metadata_replay_cuda_graph( metadata.block_kv_indices.shape[1], self.page_size, ) - + self.forward_metadata = metadata else: super().init_forward_metadata_replay_cuda_graph( - bs, req_pool_indices, seq_lens, seq_lens_sum, - encoder_lens, forward_mode, spec_info, seq_lens_cpu + bs, + req_pool_indices, + seq_lens, + seq_lens_sum, + encoder_lens, + forward_mode, + spec_info, + seq_lens_cpu, ) def get_cuda_graph_seq_len_fill_value(self) -> int: @@ -309,9 +328,12 @@ def get_cuda_graph_seq_len_fill_value(self) -> int: def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize the metadata for a forward pass.""" - if forward_batch.forward_mode.is_decode_or_idle() and forward_batch.spec_info is None: + if ( + forward_batch.forward_mode.is_decode_or_idle() + and forward_batch.spec_info is None + ): bs = forward_batch.batch_size - + # Get maximum sequence length if getattr(forward_batch, "seq_lens_cpu", None) is not None: max_seq = forward_batch.seq_lens_cpu.max().item() @@ -320,8 +342,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): max_seqlen_pad = self._calc_padded_blocks(max_seq) block_kv_indices = self._create_block_kv_indices( - bs, max_seqlen_pad, forward_batch.req_pool_indices, - forward_batch.seq_lens, forward_batch.seq_lens.device + bs, + max_seqlen_pad, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens.device, ) self.forward_metadata = TRTLLMGENMLADecodeMetadata( @@ -347,17 +372,22 @@ def forward_decode( if k is not None and save_kv_cache: cache_loc = forward_batch.out_cache_loc if k_rope is not None: - forward_batch.token_to_kv_pool.set_mla_kv_buffer(layer, cache_loc, k, k_rope) + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + layer, cache_loc, k, k_rope + ) elif v is not None: forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) # Prepare tensors for TRT-LLM kernel query = self._prepare_query_tensor(q, q_rope, layer) kv_cache = self._prepare_kv_cache(layer, forward_batch) - + # Get metadata - metadata = getattr(forward_batch, "decode_trtllm_mla_metadata", None) or self.forward_metadata - + metadata = ( + getattr(forward_batch, "decode_trtllm_mla_metadata", None) + or self.forward_metadata + ) + # Get quantization scales q_scale, k_scale, v_scale, sm_scale, o_scale = self._get_quantization_scales() @@ -383,10 +413,9 @@ def forward_decode( # Extract value projection part and reshape raw_out_v = raw_out[..., : layer.v_head_dim].contiguous() output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim) - + # Truncate if needed if output.shape[0] > forward_batch.batch_size: output = output[: forward_batch.batch_size] - return output - + return output diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py old mode 100644 new mode 100755 index 35b5e2d25225..700248efdc91 --- a/python/sglang/srt/layers/attention/utils.py +++ b/python/sglang/srt/layers/attention/utils.py @@ -4,7 +4,8 @@ # Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`. # Number of pages that the kernel writes per iteration. # Exposed here so other Python modules can import it instead of hard-coding 64. -NUM_PAGE_PER_BLOCK = 64 +NUM_PAGE_PER_BLOCK = 64 + @triton.jit def create_flashinfer_kv_indices_triton( diff --git a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py index acf752fd23ce..b5a50cffb96e 100644 --- a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py @@ -11,24 +11,26 @@ The tests use unittest with subTest for parameterized testing, following the sglang test patterns. """ + import unittest -import torch -import numpy as np from types import SimpleNamespace +import numpy as np +import torch + from sglang.srt.layers import dp_attention as _dp_attn # Patch DP-attention globals before importing backends _dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test from sglang.srt.configs.model_config import AttentionArch -from sglang.srt.layers.attention.trtllm_gen_mla_backend import TRTLLMGENMLABackend from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend +from sglang.srt.layers.attention.trtllm_gen_mla_backend import TRTLLMGENMLABackend from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.test.test_utils import CustomTestCase from sglang.srt.utils import is_flashinfer_available +from sglang.test.test_utils import CustomTestCase class MockModelRunner: @@ -60,7 +62,9 @@ def __init__(self, page_size: int): # Req-to-token pool max_bs = 64 max_ctx = self.model_config.context_len - req_to_token = torch.arange(max_bs * max_ctx, dtype=torch.int32, device=self.device).reshape(max_bs, max_ctx) + req_to_token = torch.arange( + max_bs * max_ctx, dtype=torch.int32, device=self.device + ).reshape(max_bs, max_ctx) self.req_to_token_pool = type( "TokenPool", (), @@ -69,7 +73,7 @@ def __init__(self, page_size: int): "req_to_token": req_to_token, }, ) - + # KV-token pool (MLA) self.token_to_kv_pool = MLATokenToKVPool( size=max_bs * max_ctx, @@ -85,27 +89,35 @@ def __init__(self, page_size: int): def compare_outputs(trtllm_out, reference_out, tolerance=1e-2): """Compare outputs with detailed analysis.""" - + # Basic checks - assert trtllm_out.shape == reference_out.shape, f"Shape mismatch: {trtllm_out.shape} vs {reference_out.shape}" - assert trtllm_out.dtype == reference_out.dtype, f"Dtype mismatch: {trtllm_out.dtype} vs {reference_out.dtype}" - + assert ( + trtllm_out.shape == reference_out.shape + ), f"Shape mismatch: {trtllm_out.shape} vs {reference_out.shape}" + assert ( + trtllm_out.dtype == reference_out.dtype + ), f"Dtype mismatch: {trtllm_out.dtype} vs {reference_out.dtype}" + # Check for NaN/Inf assert not torch.isnan(trtllm_out).any(), "TRTLLM output contains NaN" assert not torch.isnan(reference_out).any(), "Reference output contains NaN" assert not torch.isinf(trtllm_out).any(), "TRTLLM output contains Inf" assert not torch.isinf(reference_out).any(), "Reference output contains Inf" - + # Element-wise differences diff = (trtllm_out - reference_out).abs() max_diff = diff.max().item() mean_diff = diff.mean().item() - + # Check numerical equivalence - all_close = torch.allclose(trtllm_out, reference_out, rtol=tolerance, atol=tolerance) - + all_close = torch.allclose( + trtllm_out, reference_out, rtol=tolerance, atol=tolerance + ) + if not all_close: - print(f"Comparison failed: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, tolerance={tolerance}") + print( + f"Comparison failed: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, tolerance={tolerance}" + ) # Find top differences for debugging flat_diff = diff.flatten() top_diff_indices = torch.topk(flat_diff, k=min(5, flat_diff.numel())).indices @@ -114,13 +126,17 @@ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2): idx_tuple = np.unravel_index(idx.cpu().numpy(), trtllm_out.shape) trt_val = trtllm_out[idx_tuple].item() ref_val = reference_out[idx_tuple].item() - print(f" [{idx_tuple}]: TRTLLM={trt_val:.6f}, Reference={ref_val:.6f}, diff={abs(trt_val-ref_val):.6f}") - + print( + f" [{idx_tuple}]: TRTLLM={trt_val:.6f}, Reference={ref_val:.6f}, diff={abs(trt_val-ref_val):.6f}" + ) + return all_close -@unittest.skipIf(not torch.cuda.is_available() or not is_flashinfer_available(), - "CUDA + flashinfer required") +@unittest.skipIf( + not torch.cuda.is_available() or not is_flashinfer_available(), + "CUDA + flashinfer required", +) class TestTRTLLMMLAClean(CustomTestCase): """Test TRTLLM MLA backend against FlashInfer MLA backend (reference).""" @@ -129,14 +145,14 @@ def setUp(self): self.device = "cuda" self.dtype = torch.bfloat16 self.page_size = 32 - + # Create model runner and backends self.model_runner_trtllm = MockModelRunner(self.page_size) self.model_runner_reference = MockModelRunner(self.page_size) - + self.trtllm_backend = TRTLLMGENMLABackend(self.model_runner_trtllm) self.reference_backend = FlashInferMLAAttnBackend(self.model_runner_reference) - + # Create RadixAttention layer self.layer = RadixAttention( num_heads=128, @@ -151,7 +167,9 @@ def setUp(self): def _create_qkv_tensors(self, batch_size): """Create Q, K, V tensors for testing.""" head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim - q = torch.randn((batch_size, 128, head_dim), dtype=self.dtype, device=self.device) + q = torch.randn( + (batch_size, 128, head_dim), dtype=self.dtype, device=self.device + ) k = torch.randn((batch_size, 1, head_dim), dtype=self.dtype, device=self.device) v = torch.randn((batch_size, 1, 512), dtype=self.dtype, device=self.device) return q, k, v @@ -176,73 +194,107 @@ def _create_forward_batch(self, batch_size, seq_lens, backend, model_runner): def _populate_kv_cache(self, batch_size, seq_lens, model_runners): """Populate KV cache with identical data for both backends.""" torch.manual_seed(42) # Fixed seed for reproducible cache - + for model_runner in model_runners: torch.manual_seed(42) # Reset seed for each backend for i in range(batch_size): seq_len = int(seq_lens[i].item()) for token_idx in range(seq_len - 1): # Create random K components for MLA - cache_k_nope = torch.randn((1, 128), dtype=self.dtype, device=self.device) - cache_k_rope = torch.randn((1, 64), dtype=self.dtype, device=self.device) - + cache_k_nope = torch.randn( + (1, 128), dtype=self.dtype, device=self.device + ) + cache_k_rope = torch.randn( + (1, 64), dtype=self.dtype, device=self.device + ) + # Calculate cache location - cache_loc = model_runner.req_to_token_pool.req_to_token[i, token_idx] - + cache_loc = model_runner.req_to_token_pool.req_to_token[ + i, token_idx + ] + # Save to KV cache model_runner.token_to_kv_pool.set_mla_kv_buffer( - self.layer, - cache_loc.unsqueeze(0), - cache_k_nope.squeeze(0), - cache_k_rope.squeeze(0) + self.layer, + cache_loc.unsqueeze(0), + cache_k_nope.squeeze(0), + cache_k_rope.squeeze(0), ) def test_decode_output_match(self): """Test that TRTLLM and FlashInfer MLA backends produce matching outputs.""" # Test different batch sizes and sequence lengths test_cases = [ - (1, 64), (4, 64), (16, 64), (32, 64), - (1, 128), (4, 128), (16, 128), (32, 128), - (1, 256), (4, 256), (16, 256), (32, 256), + (1, 64), + (4, 64), + (16, 64), + (32, 64), + (1, 128), + (4, 128), + (16, 128), + (32, 128), + (1, 256), + (4, 256), + (16, 256), + (32, 256), ] - + for batch_size, max_seq_len in test_cases: with self.subTest(batch_size=batch_size, max_seq_len=max_seq_len): # Create identical sequence lengths for both backends torch.manual_seed(42) - seq_lens = torch.randint(1, max_seq_len, (batch_size,), device=self.device) + seq_lens = torch.randint( + 1, max_seq_len, (batch_size,), device=self.device + ) seq_lens[0] = max_seq_len # Ensure at least one max length - + # Create forward batches with identical inputs fb_trtllm = self._create_forward_batch( - batch_size, seq_lens.clone(), self.trtllm_backend, self.model_runner_trtllm + batch_size, + seq_lens.clone(), + self.trtllm_backend, + self.model_runner_trtllm, ) fb_reference = self._create_forward_batch( - batch_size, seq_lens.clone(), self.reference_backend, self.model_runner_reference + batch_size, + seq_lens.clone(), + self.reference_backend, + self.model_runner_reference, ) - + # Initialize metadata for both backends self.trtllm_backend.init_forward_metadata(fb_trtllm) self.reference_backend.init_forward_metadata(fb_reference) - + # Populate both KV caches identically - self._populate_kv_cache(batch_size, seq_lens, [self.model_runner_trtllm, self.model_runner_reference]) - + self._populate_kv_cache( + batch_size, + seq_lens, + [self.model_runner_trtllm, self.model_runner_reference], + ) + # Create Q, K, V tensors for current decode step torch.manual_seed(123) # Fixed seed for Q, K, V q, k, v = self._create_qkv_tensors(batch_size) - + # Run forward decode on both backends - out_trtllm = self.trtllm_backend.forward_decode(q.clone(), k.clone(), v, self.layer, fb_trtllm) - out_reference = self.reference_backend.forward_decode(q.clone(), k.clone(), v.clone(), self.layer, fb_reference) - + out_trtllm = self.trtllm_backend.forward_decode( + q.clone(), k.clone(), v, self.layer, fb_trtllm + ) + out_reference = self.reference_backend.forward_decode( + q.clone(), k.clone(), v.clone(), self.layer, fb_reference + ) + # Compare outputs - comparison_passed = compare_outputs(out_trtllm, out_reference, tolerance=1e-2) - - self.assertTrue(comparison_passed, + comparison_passed = compare_outputs( + out_trtllm, out_reference, tolerance=1e-2 + ) + + self.assertTrue( + comparison_passed, f"TRTLLM and Reference outputs differ beyond tolerance. " f"batch_size={batch_size}, max_seq_len={max_seq_len}, " - f"Max diff: {(out_trtllm - out_reference).abs().max().item()}" + f"Max diff: {(out_trtllm - out_reference).abs().max().item()}", ) def test_different_page_sizes(self): @@ -250,35 +302,43 @@ def test_different_page_sizes(self): page_sizes = [32, 64] batch_size = 8 max_seq_len = 128 - + for page_size in page_sizes: with self.subTest(page_size=page_size): # Create model runner with specific page size model_runner = MockModelRunner(page_size) backend = TRTLLMGENMLABackend(model_runner) - + # Create sequence lengths torch.manual_seed(42) - seq_lens = torch.randint(1, max_seq_len, (batch_size,), device=self.device) + seq_lens = torch.randint( + 1, max_seq_len, (batch_size,), device=self.device + ) seq_lens[0] = max_seq_len - + # Create forward batch - fb = self._create_forward_batch(batch_size, seq_lens, backend, model_runner) + fb = self._create_forward_batch( + batch_size, seq_lens, backend, model_runner + ) backend.init_forward_metadata(fb) - + # Populate KV cache self._populate_kv_cache(batch_size, seq_lens, [model_runner]) - + # Create Q, K, V tensors torch.manual_seed(123) q, k, v = self._create_qkv_tensors(batch_size) - + # Run forward decode output = backend.forward_decode(q, k, v, self.layer, fb) - + # Basic checks expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim - self.assertEqual(output.shape, expected_shape, f"Output shape mismatch: {output.shape} vs {expected_shape}") + self.assertEqual( + output.shape, + expected_shape, + f"Output shape mismatch: {output.shape} vs {expected_shape}", + ) self.assertFalse(torch.isnan(output).any(), "Output contains NaN") self.assertFalse(torch.isinf(output).any(), "Output contains Inf") @@ -286,23 +346,25 @@ def test_basic_functionality(self): """Test basic functionality with minimal setup.""" batch_size = 2 max_seq_len = 32 - + # Create sequence lengths seq_lens = torch.tensor([max_seq_len, max_seq_len // 2], device=self.device) - + # Create forward batch - fb = self._create_forward_batch(batch_size, seq_lens, self.trtllm_backend, self.model_runner_trtllm) + fb = self._create_forward_batch( + batch_size, seq_lens, self.trtllm_backend, self.model_runner_trtllm + ) self.trtllm_backend.init_forward_metadata(fb) - + # Populate KV cache self._populate_kv_cache(batch_size, seq_lens, [self.model_runner_trtllm]) - + # Create Q, K, V tensors q, k, v = self._create_qkv_tensors(batch_size) - + # Run forward decode output = self.trtllm_backend.forward_decode(q, k, v, self.layer, fb) - + # Basic checks expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim self.assertEqual(output.shape, expected_shape) @@ -321,22 +383,26 @@ def test_forward_decode_shape_sanity(self): (1, 64, 32), (32, 1024, 64), ] - + for batch_size, seq_len, page_size in test_configs: - with self.subTest(batch_size=batch_size, seq_len=seq_len, page_size=page_size): + with self.subTest( + batch_size=batch_size, seq_len=seq_len, page_size=page_size + ): # Create model runner with specific page size model_runner = MockModelRunner(page_size) backend = TRTLLMGENMLABackend(model_runner) - + # Random seq lens (ensure one matches max) torch.manual_seed(42) seq_lens = torch.randint(1, seq_len, (batch_size,), device=self.device) seq_lens[0] = seq_len - + # Create forward batch fb = ForwardBatch( batch_size=batch_size, - input_ids=torch.randint(0, 100, (batch_size, 1), device=self.device), + input_ids=torch.randint( + 0, 100, (batch_size, 1), device=self.device + ), out_cache_loc=torch.arange(batch_size, device=self.device), seq_lens_sum=int(seq_lens.sum().item()), forward_mode=ForwardMode.DECODE, @@ -347,15 +413,19 @@ def test_forward_decode_shape_sanity(self): ) fb.req_to_token_pool = model_runner.req_to_token_pool fb.token_to_kv_pool = model_runner.token_to_kv_pool - + backend.init_forward_metadata(fb) - + # Create Q, K, V tensors head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim - q = torch.randn((batch_size, 128, head_dim), dtype=self.dtype, device=self.device) - k = torch.randn((batch_size, 1, head_dim), dtype=self.dtype, device=self.device) + q = torch.randn( + (batch_size, 128, head_dim), dtype=self.dtype, device=self.device + ) + k = torch.randn( + (batch_size, 1, head_dim), dtype=self.dtype, device=self.device + ) v = None # TRTLLM MLA decode kernel ignores v - + # Create layer layer = RadixAttention( num_heads=128, @@ -366,21 +436,28 @@ def test_forward_decode_shape_sanity(self): v_head_dim=512, prefix="attn_mqa", ) - + # Run forward decode output = backend.forward_decode(q, k, v, layer, fb) - + # Shape and sanity checks expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim - self.assertEqual(output.shape, expected_shape, - f"Output shape mismatch for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})") + self.assertEqual( + output.shape, + expected_shape, + f"Output shape mismatch for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})", + ) self.assertEqual(output.dtype, self.dtype) self.assertEqual(output.device.type, "cuda") - self.assertFalse(torch.isnan(output).any(), - f"Output contains NaN for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})") - self.assertFalse(torch.isinf(output).any(), - f"Output contains Inf for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})") + self.assertFalse( + torch.isnan(output).any(), + f"Output contains NaN for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})", + ) + self.assertFalse( + torch.isinf(output).any(), + f"Output contains Inf for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})", + ) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From f69b3e07396adaee6c44a54c0acd2274921ea3d9 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Mon, 21 Jul 2025 10:56:31 -0700 Subject: [PATCH 08/27] some neat picks Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- .../attention/trtllm_gen_mla_backend.py | 46 +- python/sglang/srt/layers/attention/utils.py | 4 +- .../attention/test_trtllm_gen_mla_backend.py | 840 +++++++++++++----- 3 files changed, 640 insertions(+), 250 deletions(-) mode change 100644 => 100755 python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py diff --git a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py old mode 100644 new mode 100755 index 8243fc09bf8d..d80614ccd27a --- a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py @@ -13,7 +13,7 @@ from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend from sglang.srt.layers.attention.utils import ( - NUM_PAGE_PER_BLOCK, + TRITON_PAD_NUM_PAGE_PER_BLOCK, create_flashmla_kv_indices_triton, ) from sglang.srt.layers.dp_attention import get_attention_tp_size @@ -29,14 +29,17 @@ from sglang.srt.speculative.spec_info import SpecInfo # Constants -TRTLLM_BLOCK_CONSTRAINT = ( - 128 # TRT-LLM kernel constraint: block_num % (128 / block_size) == 0 -) -DEFAULT_WORKSPACE_SIZE_MB = 128 -DEFAULT_QUANTIZATION_SCALE = 1.0 # Default scale for FP8 quantization -DEFAULT_SM_SCALE = 1.0 # Default softmax scale +DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB + +# Block constraint from flashinfer requirements +# See: https://github.com/flashinfer-ai/flashinfer/blob/fe29ed63cb923f25cae70ef83f3fd16139305b35/flashinfer/decode.py#L2057 +TRTLLM_BLOCK_CONSTRAINT = 128 +# Quantization scale (1.0 for fp8->bf16 and bf16->bf16 conversions) +DEFAULT_QUANTIZATION_SCALE = 1.0 +# Softmax scale (1.0 since TRTLLM applies 1/sqrt(head_dim) internally) +DEFAULT_SM_SCALE = 1.0 @dataclass class TRTLLMGENMLADecodeMetadata: """Metadata for TRTLLM MLA decode operations.""" @@ -46,7 +49,7 @@ class TRTLLMGENMLADecodeMetadata: class TRTLLMGENMLABackend(FlashInferMLAAttnBackend): - """TRTLLM MLA attention kernels from flashinfer.""" + """TRTLLM MLA attention kernel from flashinfer.""" def __init__( self, @@ -59,7 +62,6 @@ def __init__( model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf ) - # Cache model config for easier access config = model_runner.model_config # Model parameters @@ -104,24 +106,15 @@ def _calc_padded_blocks(self, max_seq_len: int) -> int: """ blocks = triton.cdiv(max_seq_len, self.page_size) - # TWO constraints require padding: - # 1. TRT-LLM kernel expects: block_num % (128 / block_size) == 0 - # This is a hard requirement from the CUDA kernel implementation. - # 2. Our Triton helper `create_flashmla_kv_indices_triton` builds page tables - # in fixed bursts of 64 indices per outer loop iteration. Each burst writes - # unconditionally to positions [i*64, (i+1)*64) without per-element masking. - # If the row length isn't a multiple of 64, the final burst would overrun - # the allocated buffer and corrupt memory. - # - # We need to satisfy BOTH constraints, so we take the LCM of: - # - trtllm_constraint = 128 // page_size - # - triton_constraint = 64 + # Apply dual constraints (take LCM to satisfy both): + # 1. TRT-LLM: block_num % (128 / page_size) == 0 + # Reference: https://github.com/NVIDIA/TensorRT-LLM/issues/XYZ # TODO: add actual link + # 2. Triton: page table builder uses 64-index bursts, needs multiple of 64 trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size - triton_constraint = NUM_PAGE_PER_BLOCK - min_blocks = math.lcm(trtllm_constraint, triton_constraint) - - if blocks % min_blocks != 0: - blocks = triton.cdiv(blocks, min_blocks) * min_blocks + constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK) + + if blocks % constraint_lcm != 0: + blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm return blocks def _get_quantization_scales(self) -> tuple[float, float, float, float, float]: @@ -410,6 +403,7 @@ def forward_decode( o_scale=o_scale, ) + #TODO: test? # Extract value projection part and reshape raw_out_v = raw_out[..., : layer.v_head_dim].contiguous() output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim) diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py index 700248efdc91..0cfb6a359cf2 100755 --- a/python/sglang/srt/layers/attention/utils.py +++ b/python/sglang/srt/layers/attention/utils.py @@ -4,7 +4,7 @@ # Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`. # Number of pages that the kernel writes per iteration. # Exposed here so other Python modules can import it instead of hard-coding 64. -NUM_PAGE_PER_BLOCK = 64 +TRITON_PAD_NUM_PAGE_PER_BLOCK = 64 @triton.jit @@ -58,7 +58,7 @@ def create_flashmla_kv_indices_triton( PAGED_SIZE: tl.constexpr = 64, ): BLOCK_SIZE: tl.constexpr = 4096 - # Keep in sync with module-level NUM_PAGE_PER_BLOCK constant above + # Keep in sync with module-level TRITON_PAD_NUM_PAGE_PER_BLOCK constant above NUM_PAGE_PER_BLOCK: tl.constexpr = 64 pid = tl.program_id(axis=0) diff --git a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py index b5a50cffb96e..e2ea8d1da744 100644 --- a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py @@ -1,17 +1,3 @@ -""" -Clean test suite for TRTLLM MLA backend. - -This test file provides comprehensive testing for the TRTLLM MLA (Multi-Head Latent Attention) backend: - -1. test_basic_functionality: Basic smoke test with minimal setup -2. test_decode_output_match: Compares TRTLLM MLA output against FlashInfer MLA reference - across different batch sizes and sequence lengths -3. test_different_page_sizes: Tests consistency across different page sizes -4. test_forward_decode_shape_sanity: Shape and sanity checks across various configurations - -The tests use unittest with subTest for parameterized testing, following the sglang test patterns. -""" - import unittest from types import SimpleNamespace @@ -25,7 +11,7 @@ from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend -from sglang.srt.layers.attention.trtllm_gen_mla_backend import TRTLLMGENMLABackend +from sglang.srt.layers.attention.trtllm_gen_mla_backend import TRTLLMGENMLABackend, TRTLLMGENMLADecodeMetadata from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -33,34 +19,166 @@ from sglang.test.test_utils import CustomTestCase +# Global configuration for all tests +DEFAULT_CONFIG = { + "device": "cuda", + "dtype": torch.bfloat16, + "kv_cache_dtype": torch.bfloat16, + "context_len": 2048, + "max_bs": 64, + "tolerance": 1e-2, + "seed_cache": 42, + "seed_qkv": 123, + # MLA model config (TRTLLM MLA has fixed constraints) + "num_attention_heads": 128, + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 512, + "num_kv_heads": 1, + "layer_id": 0, +} + +# Centralized test cases for different test scenarios +TEST_CASES = { + "basic_functionality": [ + { + "name": "single", + "batch_size": 1, + "max_seq_len": 32, + "page_size": 32, + "description": "Minimal smoke test", + }, + { + "name": "batch", + "batch_size": 32, + "max_seq_len": 128, + "page_size": 32, + "description": "Medium-scale batch", + }, + ], + + "decode_output_match": [ + { + "name": "single", + "batch_size": 1, + "max_seq_len": 64, + "page_size": 32, + "description": "Single vs reference", + }, + { + "name": "batch", + "batch_size": 32, + "max_seq_len": 64, + "page_size": 32, + "description": "Batch vs reference", + } + ], + + "page_size_consistency": [ + # Only 32 and 64 supported for now https://github.com/flashinfer-ai/flashinfer/blob/fe29ed63cb923f25cae70ef83f3fd16139305b35/flashinfer/decode.py#L2115 + # TODO: Test 16 and 128. Pending cubins + { + "name": "page_32", + "batch_size": 8, + "max_seq_len": 128, + "page_size": 32, + "description": "32-token pages", + }, + { + "name": "page_64", + "batch_size": 8, + "max_seq_len": 128, + "page_size": 64, + "description": "64-token pages", + }, + ], + + "shape_sanity_tests": [ + { + "name": "basic", + "batch_size": 1, + "max_seq_len": 128, + "page_size": 32, + "description": "Single sequence", + }, + { + "name": "basic_different_pagesize", + "batch_size": 1, + "max_seq_len": 128, + "page_size": 64, + "description": "Different page size", + }, + { + "name": "batch", + "batch_size": 8, + "max_seq_len": 128, + "page_size": 32, + "description": "Batch shapes", + }, + ], + + "metadata_tests": [ + { + "name": "single_sequence", + "batch_size": 1, + "max_seq_len": 64, + "page_size": 32, + "description": "Single sequence metadata", + }, + { + "name": "batch_mixed_lengths", + "batch_size": 8, + "max_seq_len": 128, + "page_size": 32, + "description": "Mixed sequence lengths", + }, + { + "name": "large_batch", + "batch_size": 32, + "max_seq_len": 256, + "page_size": 64, + "description": "Large batch stress test", + }, + { + "name": "edge_case_short", + "batch_size": 4, + "max_seq_len": 16, + "page_size": 32, + "description": "Sub-page sequences", + }, + ], +} + + class MockModelRunner: """Minimal fake ModelRunner for testing MLA backends.""" - def __init__(self, page_size: int): - self.device = "cuda" - self.dtype = torch.bfloat16 - self.kv_cache_dtype = torch.bfloat16 - self.page_size = page_size + def __init__(self, config): + self.device = config["device"] + self.dtype = config["dtype"] + self.kv_cache_dtype = config["kv_cache_dtype"] + self.page_size = config["page_size"] # Model-config stub with MLA attributes self.model_config = type( "ModelConfig", (), { - "context_len": 2048, + "context_len": config["context_len"], "attention_arch": AttentionArch.MLA, - "num_attention_heads": 128, - "kv_lora_rank": 512, - "qk_nope_head_dim": 128, - "qk_rope_head_dim": 64, - "v_head_dim": 512, - "scaling": 1.0 / ((128 + 64) ** 0.5), - "get_num_kv_heads": staticmethod(lambda _: 1), + "num_attention_heads": config["num_attention_heads"], + "kv_lora_rank": config["kv_lora_rank"], + "qk_nope_head_dim": config["qk_nope_head_dim"], + "qk_rope_head_dim": config["qk_rope_head_dim"], + "v_head_dim": config["v_head_dim"], + "scaling": 1.0 / ((config["qk_nope_head_dim"] + config["qk_rope_head_dim"]) ** 0.5), + "get_num_kv_heads": staticmethod(lambda _: config["num_kv_heads"]), }, ) # Req-to-token pool - max_bs = 64 + max_bs = config["max_bs"] max_ctx = self.model_config.context_len req_to_token = torch.arange( max_bs * max_ctx, dtype=torch.int32, device=self.device @@ -77,10 +195,10 @@ def __init__(self, page_size: int): # KV-token pool (MLA) self.token_to_kv_pool = MLATokenToKVPool( size=max_bs * max_ctx, - page_size=page_size, + page_size=config["page_size"], dtype=self.kv_cache_dtype, - kv_lora_rank=512, - qk_rope_head_dim=64, + kv_lora_rank=config["kv_lora_rank"], + qk_rope_head_dim=config["qk_rope_head_dim"], layer_num=1, device=self.device, enable_memory_saver=False, @@ -89,7 +207,7 @@ def __init__(self, page_size: int): def compare_outputs(trtllm_out, reference_out, tolerance=1e-2): """Compare outputs with detailed analysis.""" - + # Basic checks assert ( trtllm_out.shape == reference_out.shape @@ -137,52 +255,67 @@ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2): not torch.cuda.is_available() or not is_flashinfer_available(), "CUDA + flashinfer required", ) -class TestTRTLLMMLAClean(CustomTestCase): - """Test TRTLLM MLA backend against FlashInfer MLA backend (reference).""" +class TestTRTLLMMLA(CustomTestCase): + """Test suite for TRTLLM MLA backend with centralized configuration.""" - def setUp(self): - """Setup test fixtures.""" - self.device = "cuda" - self.dtype = torch.bfloat16 - self.page_size = 32 + def _merge_config(self, test_case): + """Merge test case with default configuration.""" + config = DEFAULT_CONFIG.copy() + config.update(test_case) + return config - # Create model runner and backends - self.model_runner_trtllm = MockModelRunner(self.page_size) - self.model_runner_reference = MockModelRunner(self.page_size) + def _create_model_components(self, config): + """Create model runners, backends, and layer for testing.""" + # Create model runners + model_runner_trtllm = MockModelRunner(config) + model_runner_reference = MockModelRunner(config) - self.trtllm_backend = TRTLLMGENMLABackend(self.model_runner_trtllm) - self.reference_backend = FlashInferMLAAttnBackend(self.model_runner_reference) + # Create backends + trtllm_backend = TRTLLMGENMLABackend(model_runner_trtllm) + reference_backend = FlashInferMLAAttnBackend(model_runner_reference) # Create RadixAttention layer - self.layer = RadixAttention( - num_heads=128, - head_dim=512 + 64, # kv_lora_rank + qk_rope_head_dim - scaling=self.model_runner_trtllm.model_config.scaling, - num_kv_heads=1, - layer_id=0, - v_head_dim=512, + layer = RadixAttention( + num_heads=config["num_attention_heads"], + head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"], + scaling=model_runner_trtllm.model_config.scaling, + num_kv_heads=config["num_kv_heads"], + layer_id=config["layer_id"], + v_head_dim=config["v_head_dim"], prefix="attn_mqa", ) - def _create_qkv_tensors(self, batch_size): + return model_runner_trtllm, model_runner_reference, trtllm_backend, reference_backend, layer + + def _create_qkv_tensors(self, batch_size, config): """Create Q, K, V tensors for testing.""" - head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim + head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"] + device = config["device"] + dtype = config["dtype"] + q = torch.randn( - (batch_size, 128, head_dim), dtype=self.dtype, device=self.device + (batch_size, config["num_attention_heads"], head_dim), + dtype=dtype, device=device + ) + k = torch.randn( + (batch_size, config["num_kv_heads"], head_dim), + dtype=dtype, device=device + ) + v = torch.randn( + (batch_size, config["num_kv_heads"], config["v_head_dim"]), + dtype=dtype, device=device ) - k = torch.randn((batch_size, 1, head_dim), dtype=self.dtype, device=self.device) - v = torch.randn((batch_size, 1, 512), dtype=self.dtype, device=self.device) return q, k, v - def _create_forward_batch(self, batch_size, seq_lens, backend, model_runner): + def _create_forward_batch(self, batch_size, seq_lens, backend, model_runner, config): """Create a forward batch for the given backend.""" fb = ForwardBatch( batch_size=batch_size, - input_ids=torch.randint(0, 100, (batch_size, 1), device=self.device), - out_cache_loc=torch.arange(batch_size, device=self.device), + input_ids=torch.randint(0, 100, (batch_size, 1), device=config["device"]), + out_cache_loc=torch.arange(batch_size, device=config["device"]), seq_lens_sum=int(seq_lens.sum().item()), forward_mode=ForwardMode.DECODE, - req_pool_indices=torch.arange(batch_size, device=self.device), + req_pool_indices=torch.arange(batch_size, device=config["device"]), seq_lens=seq_lens, seq_lens_cpu=seq_lens.cpu(), attn_backend=backend, @@ -191,273 +324,536 @@ def _create_forward_batch(self, batch_size, seq_lens, backend, model_runner): fb.token_to_kv_pool = model_runner.token_to_kv_pool return fb - def _populate_kv_cache(self, batch_size, seq_lens, model_runners): + def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config): """Populate KV cache with identical data for both backends.""" - torch.manual_seed(42) # Fixed seed for reproducible cache + torch.manual_seed(config["seed_cache"]) # Fixed seed for reproducible cache for model_runner in model_runners: - torch.manual_seed(42) # Reset seed for each backend + torch.manual_seed(config["seed_cache"]) # Reset seed for each backend for i in range(batch_size): seq_len = int(seq_lens[i].item()) for token_idx in range(seq_len - 1): # Create random K components for MLA cache_k_nope = torch.randn( - (1, 128), dtype=self.dtype, device=self.device + (1, config["qk_nope_head_dim"]), + dtype=config["dtype"], device=config["device"] ) cache_k_rope = torch.randn( - (1, 64), dtype=self.dtype, device=self.device + (1, config["qk_rope_head_dim"]), + dtype=config["dtype"], device=config["device"] ) # Calculate cache location - cache_loc = model_runner.req_to_token_pool.req_to_token[ - i, token_idx - ] + cache_loc = model_runner.req_to_token_pool.req_to_token[i, token_idx] # Save to KV cache model_runner.token_to_kv_pool.set_mla_kv_buffer( - self.layer, + layer, cache_loc.unsqueeze(0), cache_k_nope.squeeze(0), cache_k_rope.squeeze(0), ) + def test_basic_functionality(self): + """Test basic functionality with minimal setup.""" + print(f"\nRunning basic functionality tests...") + + for test_case in TEST_CASES["basic_functionality"]: + with self.subTest(test_case=test_case["name"]): + print(f" Testing {test_case['name']}: {test_case['description']}") + + config = self._merge_config(test_case) + batch_size = config["batch_size"] + max_seq_len = config["max_seq_len"] + + # Create components + model_runner_trtllm, _, trtllm_backend, _, layer = self._create_model_components(config) + + # Create sequence lengths - properly handle different batch sizes + if batch_size == 2: + seq_lens = torch.tensor( + [max_seq_len, max_seq_len // 2], + device=config["device"] + ) + else: + # For larger batch sizes, create varied sequence lengths + torch.manual_seed(config["seed_cache"]) + seq_lens = torch.randint( + max_seq_len // 2, max_seq_len + 1, + (batch_size,), + device=config["device"] + ) + seq_lens[0] = max_seq_len # Ensure at least one max length + + # Create forward batch + fb = self._create_forward_batch( + batch_size, seq_lens, trtllm_backend, model_runner_trtllm, config + ) + trtllm_backend.init_forward_metadata(fb) + + # Populate KV cache + self._populate_kv_cache(batch_size, seq_lens, [model_runner_trtllm], layer, config) + + # Create Q, K, V tensors + torch.manual_seed(config["seed_qkv"]) + q, k, v = self._create_qkv_tensors(batch_size, config) + + # Run forward decode + output = trtllm_backend.forward_decode(q, k, v, layer, fb) + + # Basic checks + expected_shape = (batch_size, config["num_attention_heads"] * config["v_head_dim"]) + self.assertEqual(output.shape, expected_shape) + self.assertEqual(output.dtype, config["dtype"]) + self.assertFalse(torch.isnan(output).any()) + self.assertFalse(torch.isinf(output).any()) + def test_decode_output_match(self): """Test that TRTLLM and FlashInfer MLA backends produce matching outputs.""" - # Test different batch sizes and sequence lengths - test_cases = [ - (1, 64), - (4, 64), - (16, 64), - (32, 64), - (1, 128), - (4, 128), - (16, 128), - (32, 128), - (1, 256), - (4, 256), - (16, 256), - (32, 256), - ] + print(f"\nRunning decode output matching tests...") + + for test_case in TEST_CASES["decode_output_match"]: + with self.subTest(test_case=test_case["name"]): + print(f" Testing {test_case['name']}: {test_case['description']}") + + config = self._merge_config(test_case) + batch_size = config["batch_size"] + max_seq_len = config["max_seq_len"] + + # Create components + (model_runner_trtllm, model_runner_reference, + trtllm_backend, reference_backend, layer) = self._create_model_components(config) - for batch_size, max_seq_len in test_cases: - with self.subTest(batch_size=batch_size, max_seq_len=max_seq_len): # Create identical sequence lengths for both backends - torch.manual_seed(42) + torch.manual_seed(config["seed_cache"]) seq_lens = torch.randint( - 1, max_seq_len, (batch_size,), device=self.device + 1, max_seq_len, (batch_size,), device=config["device"] ) seq_lens[0] = max_seq_len # Ensure at least one max length # Create forward batches with identical inputs fb_trtllm = self._create_forward_batch( - batch_size, - seq_lens.clone(), - self.trtllm_backend, - self.model_runner_trtllm, + batch_size, seq_lens.clone(), trtllm_backend, model_runner_trtllm, config ) fb_reference = self._create_forward_batch( - batch_size, - seq_lens.clone(), - self.reference_backend, - self.model_runner_reference, + batch_size, seq_lens.clone(), reference_backend, model_runner_reference, config ) # Initialize metadata for both backends - self.trtllm_backend.init_forward_metadata(fb_trtllm) - self.reference_backend.init_forward_metadata(fb_reference) + trtllm_backend.init_forward_metadata(fb_trtllm) + reference_backend.init_forward_metadata(fb_reference) # Populate both KV caches identically self._populate_kv_cache( - batch_size, - seq_lens, - [self.model_runner_trtllm, self.model_runner_reference], + batch_size, seq_lens, [model_runner_trtllm, model_runner_reference], layer, config ) # Create Q, K, V tensors for current decode step - torch.manual_seed(123) # Fixed seed for Q, K, V - q, k, v = self._create_qkv_tensors(batch_size) + torch.manual_seed(config["seed_qkv"]) + q, k, v = self._create_qkv_tensors(batch_size, config) # Run forward decode on both backends - out_trtllm = self.trtllm_backend.forward_decode( - q.clone(), k.clone(), v, self.layer, fb_trtllm + out_trtllm = trtllm_backend.forward_decode( + q.clone(), k.clone(), v.clone(), layer, fb_trtllm ) - out_reference = self.reference_backend.forward_decode( - q.clone(), k.clone(), v.clone(), self.layer, fb_reference + out_reference = reference_backend.forward_decode( + q.clone(), k.clone(), v.clone(), layer, fb_reference ) # Compare outputs comparison_passed = compare_outputs( - out_trtllm, out_reference, tolerance=1e-2 + out_trtllm, out_reference, tolerance=config["tolerance"] ) self.assertTrue( comparison_passed, f"TRTLLM and Reference outputs differ beyond tolerance. " - f"batch_size={batch_size}, max_seq_len={max_seq_len}, " + f"Config: {test_case['name']}, " f"Max diff: {(out_trtllm - out_reference).abs().max().item()}", ) - def test_different_page_sizes(self): + def test_page_size_consistency(self): """Test output consistency across different page sizes.""" - page_sizes = [32, 64] - batch_size = 8 - max_seq_len = 128 - - for page_size in page_sizes: - with self.subTest(page_size=page_size): - # Create model runner with specific page size - model_runner = MockModelRunner(page_size) - backend = TRTLLMGENMLABackend(model_runner) + print(f"\nRunning page size consistency tests...") + + for test_case in TEST_CASES["page_size_consistency"]: + with self.subTest(test_case=test_case["name"]): + print(f" Testing {test_case['name']}: {test_case['description']}") + + config = self._merge_config(test_case) + batch_size = config["batch_size"] + max_seq_len = config["max_seq_len"] + + # Create components + model_runner, _, backend, _, layer = self._create_model_components(config) # Create sequence lengths - torch.manual_seed(42) + torch.manual_seed(config["seed_cache"]) seq_lens = torch.randint( - 1, max_seq_len, (batch_size,), device=self.device + 1, max_seq_len, (batch_size,), device=config["device"] ) seq_lens[0] = max_seq_len # Create forward batch fb = self._create_forward_batch( - batch_size, seq_lens, backend, model_runner + batch_size, seq_lens, backend, model_runner, config ) backend.init_forward_metadata(fb) # Populate KV cache - self._populate_kv_cache(batch_size, seq_lens, [model_runner]) + self._populate_kv_cache(batch_size, seq_lens, [model_runner], layer, config) # Create Q, K, V tensors - torch.manual_seed(123) - q, k, v = self._create_qkv_tensors(batch_size) + torch.manual_seed(config["seed_qkv"]) + q, k, v = self._create_qkv_tensors(batch_size, config) # Run forward decode - output = backend.forward_decode(q, k, v, self.layer, fb) + output = backend.forward_decode(q, k, v, layer, fb) - # Basic checks - expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim + expected_shape = (batch_size, config["num_attention_heads"] * config["v_head_dim"]) self.assertEqual( - output.shape, - expected_shape, - f"Output shape mismatch: {output.shape} vs {expected_shape}", + output.shape, expected_shape, + f"Output shape mismatch: {output.shape} vs {expected_shape}" ) self.assertFalse(torch.isnan(output).any(), "Output contains NaN") self.assertFalse(torch.isinf(output).any(), "Output contains Inf") - def test_basic_functionality(self): - """Test basic functionality with minimal setup.""" - batch_size = 2 - max_seq_len = 32 - - # Create sequence lengths - seq_lens = torch.tensor([max_seq_len, max_seq_len // 2], device=self.device) - - # Create forward batch - fb = self._create_forward_batch( - batch_size, seq_lens, self.trtllm_backend, self.model_runner_trtllm - ) - self.trtllm_backend.init_forward_metadata(fb) - - # Populate KV cache - self._populate_kv_cache(batch_size, seq_lens, [self.model_runner_trtllm]) - - # Create Q, K, V tensors - q, k, v = self._create_qkv_tensors(batch_size) - - # Run forward decode - output = self.trtllm_backend.forward_decode(q, k, v, self.layer, fb) - - # Basic checks - expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim - self.assertEqual(output.shape, expected_shape) - self.assertEqual(output.dtype, self.dtype) - self.assertFalse(torch.isnan(output).any()) - self.assertFalse(torch.isinf(output).any()) - - def test_forward_decode_shape_sanity(self): - """Smoke test decode across several page sizes and batch configurations.""" - # Test configurations similar to the original test - test_configs = [ - (16, 512, 32), # batch_size, seq_len, page_size - (16, 512, 64), - (8, 256, 32), - (4, 128, 32), - (1, 64, 32), - (32, 1024, 64), - ] + def test_shape_sanity(self): + """Smoke test decode across several configurations.""" + print(f"\nRunning shape sanity tests...") + + for test_case in TEST_CASES["shape_sanity_tests"]: + with self.subTest(test_case=test_case["name"]): + print(f" Testing {test_case['name']}: {test_case['description']}") + + config = self._merge_config(test_case) + batch_size = config["batch_size"] + max_seq_len = config["max_seq_len"] - for batch_size, seq_len, page_size in test_configs: - with self.subTest( - batch_size=batch_size, seq_len=seq_len, page_size=page_size - ): - # Create model runner with specific page size - model_runner = MockModelRunner(page_size) - backend = TRTLLMGENMLABackend(model_runner) + model_runner, _, backend, _, layer = self._create_model_components(config) # Random seq lens (ensure one matches max) - torch.manual_seed(42) - seq_lens = torch.randint(1, seq_len, (batch_size,), device=self.device) - seq_lens[0] = seq_len + torch.manual_seed(config["seed_cache"]) + seq_lens = torch.randint(1, max_seq_len, (batch_size,), device=config["device"]) + seq_lens[0] = max_seq_len - # Create forward batch - fb = ForwardBatch( - batch_size=batch_size, - input_ids=torch.randint( - 0, 100, (batch_size, 1), device=self.device - ), - out_cache_loc=torch.arange(batch_size, device=self.device), - seq_lens_sum=int(seq_lens.sum().item()), - forward_mode=ForwardMode.DECODE, - req_pool_indices=torch.arange(batch_size, device=self.device), - seq_lens=seq_lens, - seq_lens_cpu=seq_lens.cpu(), - attn_backend=backend, + fb = self._create_forward_batch( + batch_size, seq_lens, backend, model_runner, config ) - fb.req_to_token_pool = model_runner.req_to_token_pool - fb.token_to_kv_pool = model_runner.token_to_kv_pool - backend.init_forward_metadata(fb) # Create Q, K, V tensors - head_dim = 512 + 64 # kv_lora_rank + qk_rope_head_dim + torch.manual_seed(config["seed_qkv"]) + head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"] q = torch.randn( - (batch_size, 128, head_dim), dtype=self.dtype, device=self.device + (batch_size, config["num_attention_heads"], head_dim), + dtype=config["dtype"], device=config["device"] ) k = torch.randn( - (batch_size, 1, head_dim), dtype=self.dtype, device=self.device - ) - v = None # TRTLLM MLA decode kernel ignores v - - # Create layer - layer = RadixAttention( - num_heads=128, - head_dim=512 + 64, - scaling=model_runner.model_config.scaling, - num_kv_heads=1, - layer_id=0, - v_head_dim=512, - prefix="attn_mqa", + (batch_size, config["num_kv_heads"], head_dim), + dtype=config["dtype"], device=config["device"] ) + v = None # Run forward decode output = backend.forward_decode(q, k, v, layer, fb) # Shape and sanity checks - expected_shape = (batch_size, 128 * 512) # num_heads * v_head_dim + expected_shape = (batch_size, config["num_attention_heads"] * config["v_head_dim"]) self.assertEqual( - output.shape, - expected_shape, - f"Output shape mismatch for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})", + output.shape, expected_shape, + f"Output shape mismatch for {test_case['name']}" ) - self.assertEqual(output.dtype, self.dtype) + self.assertEqual(output.dtype, config["dtype"]) self.assertEqual(output.device.type, "cuda") self.assertFalse( torch.isnan(output).any(), - f"Output contains NaN for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})", + f"Output contains NaN for {test_case['name']}" ) self.assertFalse( torch.isinf(output).any(), - f"Output contains Inf for config (bs={batch_size}, seq_len={seq_len}, page_size={page_size})", + f"Output contains Inf for {test_case['name']}" + ) + + def test_metadata_initialization(self): + """Test TRTLLM MLA metadata initialization and structure.""" + print(f"\nRunning metadata initialization tests...") + + for test_case in TEST_CASES["metadata_tests"]: + with self.subTest(test_case=test_case["name"]): + print(f" Testing {test_case['name']}: {test_case['description']}") + + config = self._merge_config(test_case) + batch_size = config["batch_size"] + max_seq_len = config["max_seq_len"] + + # Create components + model_runner, _, backend, _, layer = self._create_model_components(config) + + # Create varied sequence lengths + torch.manual_seed(config["seed_cache"]) + if batch_size == 1: + seq_lens = torch.tensor([max_seq_len], device=config["device"]) + else: + seq_lens = torch.randint( + max(1, max_seq_len // 4), max_seq_len + 1, + (batch_size,), device=config["device"] + ) + seq_lens[0] = max_seq_len # Ensure at least one max length + + # Create forward batch + fb = self._create_forward_batch( + batch_size, seq_lens, backend, model_runner, config + ) + + # Initialize metadata + backend.init_forward_metadata(fb) + + # Verify metadata exists + self.assertIsNotNone(backend.forward_metadata) + self.assertIsInstance( + backend.forward_metadata, + TRTLLMGENMLADecodeMetadata + ) + + # Test metadata structure + metadata = backend.forward_metadata + self.assertIsNotNone(metadata.workspace, "Workspace should be allocated") + self.assertIsNotNone(metadata.block_kv_indices, "Block KV indices should be created") + + # Test workspace properties + self.assertEqual(metadata.workspace.device.type, "cuda") + self.assertEqual(metadata.workspace.dtype, torch.int8) + self.assertGreater(metadata.workspace.numel(), 0, "Workspace should have non-zero size") + + # Test block KV indices properties + self.assertEqual(metadata.block_kv_indices.device.type, "cuda") + self.assertEqual(metadata.block_kv_indices.dtype, torch.int32) + self.assertEqual(metadata.block_kv_indices.shape[0], batch_size) + + # Verify block indices are valid (>= -1, since -1 is padding) + self.assertTrue( + (metadata.block_kv_indices >= -1).all(), + "All block indices should be >= -1 (with -1 as padding)" + ) + + def test_metadata_block_calculation(self): + """Test block count calculation logic.""" + print(f"\nRunning metadata block calculation tests...") + + test_scenarios = [ + {"seq_len": 31, "page_size": 32, "expected_min_blocks": 1}, + {"seq_len": 32, "page_size": 32, "expected_min_blocks": 1}, + {"seq_len": 33, "page_size": 32, "expected_min_blocks": 2}, + {"seq_len": 128, "page_size": 32, "expected_min_blocks": 4}, + {"seq_len": 128, "page_size": 64, "expected_min_blocks": 2}, + ] + + for scenario in test_scenarios: + with self.subTest(scenario=scenario): + config = self._merge_config({ + "batch_size": 1, + "max_seq_len": scenario["seq_len"], + "page_size": scenario["page_size"] + }) + + model_runner, _, backend, _, _ = self._create_model_components(config) + + # Test internal block calculation + calculated_blocks = backend._calc_padded_blocks(scenario["seq_len"]) + + # Should be at least the minimum required + self.assertGreaterEqual( + calculated_blocks, scenario["expected_min_blocks"], + f"Calculated blocks ({calculated_blocks}) should be >= minimum required ({scenario['expected_min_blocks']})" + ) + + # Should satisfy page_size constraint + total_tokens = calculated_blocks * scenario["page_size"] + self.assertGreaterEqual( + total_tokens, scenario["seq_len"], + f"Total tokens ({total_tokens}) should cover sequence length ({scenario['seq_len']})" + ) + + # Should satisfy TRT-LLM constraint (128 / page_size) + trtllm_constraint = 128 // scenario["page_size"] + self.assertEqual( + calculated_blocks % trtllm_constraint, 0, + f"Block count should be multiple of TRT-LLM constraint ({trtllm_constraint})" ) + def test_metadata_kv_indices_correctness(self): + """Test KV indices creation and correctness.""" + print(f"\nRunning KV indices correctness tests...") + + for test_case in TEST_CASES["metadata_tests"][:2]: # Test subset for performance + with self.subTest(test_case=test_case["name"]): + print(f" Testing {test_case['name']}: {test_case['description']}") + + config = self._merge_config(test_case) + batch_size = config["batch_size"] + max_seq_len = config["max_seq_len"] + + model_runner, _, backend, _, layer = self._create_model_components(config) + + # Create known sequence lengths + torch.manual_seed(config["seed_cache"]) + if batch_size == 1: + seq_lens = torch.tensor([max_seq_len], device=config["device"]) + else: + seq_lens = torch.randint( + max_seq_len // 2, max_seq_len + 1, + (batch_size,), device=config["device"] + ) + + fb = self._create_forward_batch( + batch_size, seq_lens, backend, model_runner, config + ) + + # Populate some KV cache to have valid indices + self._populate_kv_cache(batch_size, seq_lens, [model_runner], layer, config) + + # Initialize metadata + backend.init_forward_metadata(fb) + metadata = backend.forward_metadata + + # Verify KV indices structure + block_kv_indices = metadata.block_kv_indices + + for i in range(batch_size): + seq_len = seq_lens[i].item() + expected_blocks = backend._calc_padded_blocks(seq_len) + + # Count valid (non -1) indices for this sequence + valid_indices = (block_kv_indices[i] >= 0).sum().item() + + # Should have at least enough blocks for the sequence + min_required_blocks = (seq_len + config["page_size"] - 1) // config["page_size"] + self.assertGreaterEqual( + valid_indices, min_required_blocks, + f"Sequence {i} should have at least {min_required_blocks} valid blocks, got {valid_indices}" + ) + + # Verify indices are within valid range + valid_block_indices = block_kv_indices[i][block_kv_indices[i] >= 0] + if len(valid_block_indices) > 0: + max_possible_blocks = model_runner.token_to_kv_pool.size // config["page_size"] + self.assertTrue( + (valid_block_indices < max_possible_blocks).all(), + f"All block indices should be < {max_possible_blocks}" + ) + + def test_metadata_cuda_graph_compatibility(self): + """Test metadata compatibility with CUDA graph capture/replay.""" + print(f"\nRunning CUDA graph compatibility tests...") + + config = self._merge_config({ + "batch_size": 4, + "max_seq_len": 64, + "page_size": 32 + }) + + model_runner, _, backend, _, layer = self._create_model_components(config) + batch_size = config["batch_size"] + + # Initialize CUDA graph state + backend.init_cuda_graph_state( + max_bs=batch_size, + max_num_tokens=config["max_seq_len"] * batch_size + ) + + # Verify CUDA graph buffers are allocated + self.assertIsNotNone(backend.cuda_graph_kv_indices) + self.assertIsNotNone(backend.cuda_graph_workspace) + + # Test capture metadata + seq_lens = torch.full((batch_size,), config["max_seq_len"], device=config["device"]) + req_pool_indices = torch.arange(batch_size, device=config["device"]) + + backend.init_forward_metadata_capture_cuda_graph( + bs=batch_size, + num_tokens=batch_size, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=None, + ) + + # Verify capture metadata + self.assertIn(batch_size, backend.decode_cuda_graph_metadata) + capture_metadata = backend.decode_cuda_graph_metadata[batch_size] + + self.assertIsNotNone(capture_metadata.workspace) + self.assertIsNotNone(capture_metadata.block_kv_indices) + + # Test replay with different sequence lengths + new_seq_lens = torch.randint( + config["max_seq_len"] // 2, config["max_seq_len"] + 1, + (batch_size,), device=config["device"] + ) + + backend.init_forward_metadata_replay_cuda_graph( + bs=batch_size, + req_pool_indices=req_pool_indices, + seq_lens=new_seq_lens, + seq_lens_sum=new_seq_lens.sum().item(), + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=None, + seq_lens_cpu=new_seq_lens.cpu(), + ) + + # Verify replay updated the metadata + replay_metadata = backend.forward_metadata + self.assertIsNotNone(replay_metadata) + self.assertEqual(replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr()) + + def test_metadata_consistency_across_calls(self): + """Test metadata consistency across multiple forward calls.""" + print(f"\nRunning metadata consistency tests...") + + config = self._merge_config({ + "batch_size": 2, + "max_seq_len": 64, + "page_size": 32 + }) + + model_runner, _, backend, _, layer = self._create_model_components(config) + + # First call + seq_lens_1 = torch.tensor([32, 48], device=config["device"]) + fb_1 = self._create_forward_batch( + config["batch_size"], seq_lens_1, backend, model_runner, config + ) + backend.init_forward_metadata(fb_1) + metadata_1 = backend.forward_metadata + + # Second call with same sequence lengths + seq_lens_2 = torch.tensor([32, 48], device=config["device"]) + fb_2 = self._create_forward_batch( + config["batch_size"], seq_lens_2, backend, model_runner, config + ) + backend.init_forward_metadata(fb_2) + metadata_2 = backend.forward_metadata + + # Metadata structure should be consistent + self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape) + self.assertEqual(metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape) + + # Third call with different sequence lengths + seq_lens_3 = torch.tensor([16, 64], device=config["device"]) + fb_3 = self._create_forward_batch( + config["batch_size"], seq_lens_3, backend, model_runner, config + ) + backend.init_forward_metadata(fb_3) + metadata_3 = backend.forward_metadata + + # Should still have valid structure + self.assertIsNotNone(metadata_3.workspace) + self.assertIsNotNone(metadata_3.block_kv_indices) + self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"]) + if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file From 8a119cce2fa7aac62c8f5cc69bfe25299687bcf9 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Mon, 21 Jul 2025 14:32:27 -0700 Subject: [PATCH 09/27] neater interface Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- ...n_mla_backend.py => trtllm_mla_backend.py} | 128 ++++++------------ .../sglang/srt/model_executor/model_runner.py | 8 +- .../attention/test_trtllm_gen_mla_backend.py | 6 +- 3 files changed, 47 insertions(+), 95 deletions(-) rename python/sglang/srt/layers/attention/{trtllm_gen_mla_backend.py => trtllm_mla_backend.py} (79%) diff --git a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py similarity index 79% rename from python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py rename to python/sglang/srt/layers/attention/trtllm_mla_backend.py index d80614ccd27a..b0213b1b3fd3 100755 --- a/python/sglang/srt/layers/attention/trtllm_gen_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -1,7 +1,7 @@ from __future__ import annotations """ -Support attention backend for TRTLLM-Gen MLA kernels from flashinfer. +Support attention backend for TRTLLM MLA kernels from flashinfer. """ import math @@ -35,20 +35,16 @@ # See: https://github.com/flashinfer-ai/flashinfer/blob/fe29ed63cb923f25cae70ef83f3fd16139305b35/flashinfer/decode.py#L2057 TRTLLM_BLOCK_CONSTRAINT = 128 -# Quantization scale (1.0 for fp8->bf16 and bf16->bf16 conversions) -DEFAULT_QUANTIZATION_SCALE = 1.0 -# Softmax scale (1.0 since TRTLLM applies 1/sqrt(head_dim) internally) -DEFAULT_SM_SCALE = 1.0 @dataclass -class TRTLLMGENMLADecodeMetadata: +class TRTLLMMLADecodeMetadata: """Metadata for TRTLLM MLA decode operations.""" workspace: Optional[torch.Tensor] = None block_kv_indices: Optional[torch.Tensor] = None -class TRTLLMGENMLABackend(FlashInferMLAAttnBackend): +class TRTLLMMLABackend(FlashInferMLAAttnBackend): """TRTLLM MLA attention kernel from flashinfer.""" def __init__( @@ -92,7 +88,7 @@ def __init__( # CUDA graph state self.decode_cuda_graph_metadata = {} self.cuda_graph_kv_indices = None - self.forward_metadata: Union[TRTLLMGENMLADecodeMetadata] = None + self.forward_metadata: Union[TRTLLMMLADecodeMetadata] = None def _calc_padded_blocks(self, max_seq_len: int) -> int: """ @@ -117,68 +113,6 @@ def _calc_padded_blocks(self, max_seq_len: int) -> int: blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm return blocks - def _get_quantization_scales(self) -> tuple[float, float, float, float, float]: - """ - Get quantization scales for q, k, v, sm, and o tensors. - - Returns: - Tuple of (q_scale, k_scale, v_scale, sm_scale, o_scale) - """ - # TODO: Implement proper quantization scale inference based on model config - # For now, use default values for FP8 - return ( - DEFAULT_QUANTIZATION_SCALE, # q_scale - DEFAULT_QUANTIZATION_SCALE, # k_scale - DEFAULT_QUANTIZATION_SCALE, # v_scale - DEFAULT_SM_SCALE, # sm_scale - DEFAULT_QUANTIZATION_SCALE, # o_scale - ) - - def _prepare_kv_cache( - self, layer: RadixAttention, forward_batch: ForwardBatch - ) -> torch.Tensor: - """ - Prepare KV cache tensor in the format expected by TRT-LLM kernel. - - Args: - layer: Attention layer - forward_batch: Forward batch info - - Returns: - KV cache tensor shaped (num_pages, 2, page_size, kv_cache_dim) - """ - k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - pages = k_cache.view(-1, self.page_size, self.kv_cache_dim) - # TRT-LLM expects stacked format: slice 0 → CKV+K, slice 1 → KPE - return torch.stack([pages, pages], dim=1) - - def _prepare_query_tensor( - self, q: torch.Tensor, q_rope: Optional[torch.Tensor], layer: RadixAttention - ) -> torch.Tensor: - """ - Prepare query tensor in the format expected by TRT-LLM kernel. - - Args: - q: Query tensor - q_rope: Optional RoPE query tensor - layer: Attention layer - - Returns: - Query tensor with concatenated NOPE and RoPE parts - """ - if q_rope is not None: - # q contains the NOPE part (v_head_dim) - q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) - q_rope = q_rope.view( - -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim - ) - else: - reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) - q_nope = reshaped_q[:, :, : layer.v_head_dim] - q_rope = reshaped_q[:, :, layer.v_head_dim :] - - return torch.cat([q_nope, q_rope], dim=-1) - def _create_block_kv_indices( self, batch_size: int, @@ -259,7 +193,7 @@ def init_forward_metadata_capture_cuda_graph( self.page_size, ) - metadata = TRTLLMGENMLADecodeMetadata( + metadata = TRTLLMMLADecodeMetadata( self.cuda_graph_workspace, block_kv_indices ) self.decode_cuda_graph_metadata[bs] = metadata @@ -342,7 +276,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.seq_lens.device, ) - self.forward_metadata = TRTLLMGENMLADecodeMetadata( + self.forward_metadata = TRTLLMMLADecodeMetadata( self.workspace_buffer, block_kv_indices ) forward_batch.decode_trtllm_mla_metadata = self.forward_metadata @@ -371,9 +305,27 @@ def forward_decode( elif v is not None: forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) - # Prepare tensors for TRT-LLM kernel - query = self._prepare_query_tensor(q, q_rope, layer) - kv_cache = self._prepare_kv_cache(layer, forward_batch) + # Prepare query tensor inline (avoid helper for shape tweaks) + if q_rope is not None: + # q contains NOPE part (v_head_dim) + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope_reshaped = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) + query = torch.cat([q_nope, q_rope_reshaped], dim=-1) + else: + # q already has both parts + query = q.view(-1, layer.tp_q_head_num, layer.head_dim) + + # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1 + if query.dim() == 3: + query = query.unsqueeze(1) + + # Prepare KV cache inline (TRT-LLM expects stacked format) + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + pages = k_cache.view(-1, self.page_size, self.kv_cache_dim) + # TRT-LLM expects stacked format: slice 0 → CKV+K, slice 1 → KPE + kv_cache = torch.stack([pages, pages], dim=1) # Get metadata metadata = ( @@ -381,10 +333,16 @@ def forward_decode( or self.forward_metadata ) - # Get quantization scales - q_scale, k_scale, v_scale, sm_scale, o_scale = self._get_quantization_scales() + # Scale computation for TRTLLM MLA kernel: + # - BMM1 scale = q_scale * k_scale * softmax_scale + # - For FP16 output in DeepSeek R1: q_scale = k_scale = 1.0, softmax_scale = 1/sqrt(head_dim) + # - This reduces to layer.scaling which is pre-computed as 1/sqrt(head_dim) + bmm1_scale = layer.scaling + bmm2_scale = 1.0 + bmm1_scale_tensor = bmm2_scale_tensor = None - # Call TRT-LLM kernel + + # Call TRT-LLM kernel with proper scale configuration raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, kv_cache=kv_cache, @@ -396,20 +354,14 @@ def forward_decode( seq_lens=forward_batch.seq_lens.to(torch.int32), block_size=self.page_size, max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size), - q_scale=q_scale, - k_scale=k_scale, - v_scale=v_scale, - sm_scale=sm_scale, - o_scale=o_scale, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + bmm1_scale_tensor=bmm1_scale_tensor, + bmm2_scale_tensor=bmm2_scale_tensor, ) - #TODO: test? # Extract value projection part and reshape raw_out_v = raw_out[..., : layer.v_head_dim].contiguous() output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim) - # Truncate if needed - if output.shape[0] > forward_batch.batch_size: - output = output[: forward_batch.batch_size] - return output diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6f6bdb284331..189d66093561 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -396,7 +396,7 @@ def model_specific_adjustment(self): ): server_args.attention_backend = "fa3" elif is_sm100_supported(): - server_args.attention_backend = "trtllm_mla" + server_args.attention_backend = "flashinfer" elif _is_hip: head_num = self.model_config.get_num_kv_heads(self.tp_size) # TODO current aiter only support head number 16 or 128 head number @@ -1432,11 +1432,11 @@ def _get_attention_backend_from_str(self, backend_str: str): return CutlassMLABackend(self) elif self.server_args.attention_backend == "trtllm_mla": - from sglang.srt.layers.attention.trtllm_gen_mla_backend import ( - TRTLLMGENMLABackend, + from python.sglang.srt.layers.attention.trtllm_mla_backend import ( + TRTLLMMLABackend, ) - return TRTLLMGENMLABackend(self) + return TRTLLMMLABackend(self) elif self.server_args.attention_backend == "intel_amx": from sglang.srt.layers.attention.intel_amx_backend import ( IntelAMXAttnBackend, diff --git a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py index e2ea8d1da744..de4bf8a963d3 100644 --- a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_gen_mla_backend.py @@ -11,7 +11,7 @@ from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend -from sglang.srt.layers.attention.trtllm_gen_mla_backend import TRTLLMGENMLABackend, TRTLLMGENMLADecodeMetadata +from python.sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend, TRTLLMMLADecodeMetadata from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -271,7 +271,7 @@ def _create_model_components(self, config): model_runner_reference = MockModelRunner(config) # Create backends - trtllm_backend = TRTLLMGENMLABackend(model_runner_trtllm) + trtllm_backend = TRTLLMMLABackend(model_runner_trtllm) reference_backend = FlashInferMLAAttnBackend(model_runner_reference) # Create RadixAttention layer @@ -613,7 +613,7 @@ def test_metadata_initialization(self): self.assertIsNotNone(backend.forward_metadata) self.assertIsInstance( backend.forward_metadata, - TRTLLMGENMLADecodeMetadata + TRTLLMMLADecodeMetadata ) # Test metadata structure From d772ccc2b675d9b64cf8c0e53e240cd5de45e119 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Mon, 21 Jul 2025 14:38:17 -0700 Subject: [PATCH 10/27] precommit+rename Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- .../layers/attention/trtllm_mla_backend.py | 5 +- ..._backend.py => test_trtllm_mla_backend.py} | 380 +++++++++++------- 2 files changed, 233 insertions(+), 152 deletions(-) rename python/sglang/test/attention/{test_trtllm_gen_mla_backend.py => test_trtllm_mla_backend.py} (81%) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index b0213b1b3fd3..5cbb84d08fba 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -102,13 +102,13 @@ def _calc_padded_blocks(self, max_seq_len: int) -> int: """ blocks = triton.cdiv(max_seq_len, self.page_size) - # Apply dual constraints (take LCM to satisfy both): + # Apply dual constraints (take LCM to satisfy both): # 1. TRT-LLM: block_num % (128 / page_size) == 0 # Reference: https://github.com/NVIDIA/TensorRT-LLM/issues/XYZ # TODO: add actual link # 2. Triton: page table builder uses 64-index bursts, needs multiple of 64 trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK) - + if blocks % constraint_lcm != 0: blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm return blocks @@ -341,7 +341,6 @@ def forward_decode( bmm2_scale = 1.0 bmm1_scale_tensor = bmm2_scale_tensor = None - # Call TRT-LLM kernel with proper scale configuration raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, diff --git a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py similarity index 81% rename from python/sglang/test/attention/test_trtllm_gen_mla_backend.py rename to python/sglang/test/attention/test_trtllm_mla_backend.py index de4bf8a963d3..92c5045f93ad 100644 --- a/python/sglang/test/attention/test_trtllm_gen_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -9,16 +9,18 @@ # Patch DP-attention globals before importing backends _dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test +from python.sglang.srt.layers.attention.trtllm_mla_backend import ( + TRTLLMMLABackend, + TRTLLMMLADecodeMetadata, +) from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend -from python.sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend, TRTLLMMLADecodeMetadata from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available from sglang.test.test_utils import CustomTestCase - # Global configuration for all tests DEFAULT_CONFIG = { "device": "cuda", @@ -57,7 +59,6 @@ "description": "Medium-scale batch", }, ], - "decode_output_match": [ { "name": "single", @@ -72,12 +73,11 @@ "max_seq_len": 64, "page_size": 32, "description": "Batch vs reference", - } + }, ], - "page_size_consistency": [ # Only 32 and 64 supported for now https://github.com/flashinfer-ai/flashinfer/blob/fe29ed63cb923f25cae70ef83f3fd16139305b35/flashinfer/decode.py#L2115 - # TODO: Test 16 and 128. Pending cubins + # TODO: Test 16 and 128. Pending cubins { "name": "page_32", "batch_size": 8, @@ -93,7 +93,6 @@ "description": "64-token pages", }, ], - "shape_sanity_tests": [ { "name": "basic", @@ -117,7 +116,6 @@ "description": "Batch shapes", }, ], - "metadata_tests": [ { "name": "single_sequence", @@ -172,7 +170,8 @@ def __init__(self, config): "qk_nope_head_dim": config["qk_nope_head_dim"], "qk_rope_head_dim": config["qk_rope_head_dim"], "v_head_dim": config["v_head_dim"], - "scaling": 1.0 / ((config["qk_nope_head_dim"] + config["qk_rope_head_dim"]) ** 0.5), + "scaling": 1.0 + / ((config["qk_nope_head_dim"] + config["qk_rope_head_dim"]) ** 0.5), "get_num_kv_heads": staticmethod(lambda _: config["num_kv_heads"]), }, ) @@ -207,7 +206,7 @@ def __init__(self, config): def compare_outputs(trtllm_out, reference_out, tolerance=1e-2): """Compare outputs with detailed analysis.""" - + # Basic checks assert ( trtllm_out.shape == reference_out.shape @@ -285,29 +284,38 @@ def _create_model_components(self, config): prefix="attn_mqa", ) - return model_runner_trtllm, model_runner_reference, trtllm_backend, reference_backend, layer + return ( + model_runner_trtllm, + model_runner_reference, + trtllm_backend, + reference_backend, + layer, + ) def _create_qkv_tensors(self, batch_size, config): """Create Q, K, V tensors for testing.""" head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"] device = config["device"] dtype = config["dtype"] - + q = torch.randn( - (batch_size, config["num_attention_heads"], head_dim), - dtype=dtype, device=device + (batch_size, config["num_attention_heads"], head_dim), + dtype=dtype, + device=device, ) k = torch.randn( - (batch_size, config["num_kv_heads"], head_dim), - dtype=dtype, device=device + (batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device ) v = torch.randn( - (batch_size, config["num_kv_heads"], config["v_head_dim"]), - dtype=dtype, device=device + (batch_size, config["num_kv_heads"], config["v_head_dim"]), + dtype=dtype, + device=device, ) return q, k, v - def _create_forward_batch(self, batch_size, seq_lens, backend, model_runner, config): + def _create_forward_batch( + self, batch_size, seq_lens, backend, model_runner, config + ): """Create a forward batch for the given backend.""" fb = ForwardBatch( batch_size=batch_size, @@ -335,16 +343,20 @@ def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config) for token_idx in range(seq_len - 1): # Create random K components for MLA cache_k_nope = torch.randn( - (1, config["qk_nope_head_dim"]), - dtype=config["dtype"], device=config["device"] + (1, config["qk_nope_head_dim"]), + dtype=config["dtype"], + device=config["device"], ) cache_k_rope = torch.randn( - (1, config["qk_rope_head_dim"]), - dtype=config["dtype"], device=config["device"] + (1, config["qk_rope_head_dim"]), + dtype=config["dtype"], + device=config["device"], ) # Calculate cache location - cache_loc = model_runner.req_to_token_pool.req_to_token[i, token_idx] + cache_loc = model_runner.req_to_token_pool.req_to_token[ + i, token_idx + ] # Save to KV cache model_runner.token_to_kv_pool.set_mla_kv_buffer( @@ -357,31 +369,33 @@ def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config) def test_basic_functionality(self): """Test basic functionality with minimal setup.""" print(f"\nRunning basic functionality tests...") - + for test_case in TEST_CASES["basic_functionality"]: with self.subTest(test_case=test_case["name"]): print(f" Testing {test_case['name']}: {test_case['description']}") - + config = self._merge_config(test_case) batch_size = config["batch_size"] max_seq_len = config["max_seq_len"] # Create components - model_runner_trtllm, _, trtllm_backend, _, layer = self._create_model_components(config) + model_runner_trtllm, _, trtllm_backend, _, layer = ( + self._create_model_components(config) + ) # Create sequence lengths - properly handle different batch sizes if batch_size == 2: seq_lens = torch.tensor( - [max_seq_len, max_seq_len // 2], - device=config["device"] + [max_seq_len, max_seq_len // 2], device=config["device"] ) else: # For larger batch sizes, create varied sequence lengths torch.manual_seed(config["seed_cache"]) seq_lens = torch.randint( - max_seq_len // 2, max_seq_len + 1, - (batch_size,), - device=config["device"] + max_seq_len // 2, + max_seq_len + 1, + (batch_size,), + device=config["device"], ) seq_lens[0] = max_seq_len # Ensure at least one max length @@ -392,7 +406,9 @@ def test_basic_functionality(self): trtllm_backend.init_forward_metadata(fb) # Populate KV cache - self._populate_kv_cache(batch_size, seq_lens, [model_runner_trtllm], layer, config) + self._populate_kv_cache( + batch_size, seq_lens, [model_runner_trtllm], layer, config + ) # Create Q, K, V tensors torch.manual_seed(config["seed_qkv"]) @@ -402,7 +418,10 @@ def test_basic_functionality(self): output = trtllm_backend.forward_decode(q, k, v, layer, fb) # Basic checks - expected_shape = (batch_size, config["num_attention_heads"] * config["v_head_dim"]) + expected_shape = ( + batch_size, + config["num_attention_heads"] * config["v_head_dim"], + ) self.assertEqual(output.shape, expected_shape) self.assertEqual(output.dtype, config["dtype"]) self.assertFalse(torch.isnan(output).any()) @@ -411,18 +430,23 @@ def test_basic_functionality(self): def test_decode_output_match(self): """Test that TRTLLM and FlashInfer MLA backends produce matching outputs.""" print(f"\nRunning decode output matching tests...") - + for test_case in TEST_CASES["decode_output_match"]: with self.subTest(test_case=test_case["name"]): print(f" Testing {test_case['name']}: {test_case['description']}") - + config = self._merge_config(test_case) batch_size = config["batch_size"] max_seq_len = config["max_seq_len"] # Create components - (model_runner_trtllm, model_runner_reference, - trtllm_backend, reference_backend, layer) = self._create_model_components(config) + ( + model_runner_trtllm, + model_runner_reference, + trtllm_backend, + reference_backend, + layer, + ) = self._create_model_components(config) # Create identical sequence lengths for both backends torch.manual_seed(config["seed_cache"]) @@ -433,10 +457,18 @@ def test_decode_output_match(self): # Create forward batches with identical inputs fb_trtllm = self._create_forward_batch( - batch_size, seq_lens.clone(), trtllm_backend, model_runner_trtllm, config + batch_size, + seq_lens.clone(), + trtllm_backend, + model_runner_trtllm, + config, ) fb_reference = self._create_forward_batch( - batch_size, seq_lens.clone(), reference_backend, model_runner_reference, config + batch_size, + seq_lens.clone(), + reference_backend, + model_runner_reference, + config, ) # Initialize metadata for both backends @@ -445,7 +477,11 @@ def test_decode_output_match(self): # Populate both KV caches identically self._populate_kv_cache( - batch_size, seq_lens, [model_runner_trtllm, model_runner_reference], layer, config + batch_size, + seq_lens, + [model_runner_trtllm, model_runner_reference], + layer, + config, ) # Create Q, K, V tensors for current decode step @@ -475,17 +511,19 @@ def test_decode_output_match(self): def test_page_size_consistency(self): """Test output consistency across different page sizes.""" print(f"\nRunning page size consistency tests...") - + for test_case in TEST_CASES["page_size_consistency"]: with self.subTest(test_case=test_case["name"]): print(f" Testing {test_case['name']}: {test_case['description']}") - + config = self._merge_config(test_case) batch_size = config["batch_size"] max_seq_len = config["max_seq_len"] # Create components - model_runner, _, backend, _, layer = self._create_model_components(config) + model_runner, _, backend, _, layer = self._create_model_components( + config + ) # Create sequence lengths torch.manual_seed(config["seed_cache"]) @@ -501,7 +539,9 @@ def test_page_size_consistency(self): backend.init_forward_metadata(fb) # Populate KV cache - self._populate_kv_cache(batch_size, seq_lens, [model_runner], layer, config) + self._populate_kv_cache( + batch_size, seq_lens, [model_runner], layer, config + ) # Create Q, K, V tensors torch.manual_seed(config["seed_qkv"]) @@ -510,10 +550,14 @@ def test_page_size_consistency(self): # Run forward decode output = backend.forward_decode(q, k, v, layer, fb) - expected_shape = (batch_size, config["num_attention_heads"] * config["v_head_dim"]) + expected_shape = ( + batch_size, + config["num_attention_heads"] * config["v_head_dim"], + ) self.assertEqual( - output.shape, expected_shape, - f"Output shape mismatch: {output.shape} vs {expected_shape}" + output.shape, + expected_shape, + f"Output shape mismatch: {output.shape} vs {expected_shape}", ) self.assertFalse(torch.isnan(output).any(), "Output contains NaN") self.assertFalse(torch.isinf(output).any(), "Output contains Inf") @@ -521,20 +565,24 @@ def test_page_size_consistency(self): def test_shape_sanity(self): """Smoke test decode across several configurations.""" print(f"\nRunning shape sanity tests...") - + for test_case in TEST_CASES["shape_sanity_tests"]: with self.subTest(test_case=test_case["name"]): print(f" Testing {test_case['name']}: {test_case['description']}") - + config = self._merge_config(test_case) batch_size = config["batch_size"] max_seq_len = config["max_seq_len"] - model_runner, _, backend, _, layer = self._create_model_components(config) + model_runner, _, backend, _, layer = self._create_model_components( + config + ) # Random seq lens (ensure one matches max) torch.manual_seed(config["seed_cache"]) - seq_lens = torch.randint(1, max_seq_len, (batch_size,), device=config["device"]) + seq_lens = torch.randint( + 1, max_seq_len, (batch_size,), device=config["device"] + ) seq_lens[0] = max_seq_len fb = self._create_forward_batch( @@ -546,49 +594,57 @@ def test_shape_sanity(self): torch.manual_seed(config["seed_qkv"]) head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"] q = torch.randn( - (batch_size, config["num_attention_heads"], head_dim), - dtype=config["dtype"], device=config["device"] + (batch_size, config["num_attention_heads"], head_dim), + dtype=config["dtype"], + device=config["device"], ) k = torch.randn( - (batch_size, config["num_kv_heads"], head_dim), - dtype=config["dtype"], device=config["device"] + (batch_size, config["num_kv_heads"], head_dim), + dtype=config["dtype"], + device=config["device"], ) - v = None + v = None # Run forward decode output = backend.forward_decode(q, k, v, layer, fb) # Shape and sanity checks - expected_shape = (batch_size, config["num_attention_heads"] * config["v_head_dim"]) + expected_shape = ( + batch_size, + config["num_attention_heads"] * config["v_head_dim"], + ) self.assertEqual( - output.shape, expected_shape, - f"Output shape mismatch for {test_case['name']}" + output.shape, + expected_shape, + f"Output shape mismatch for {test_case['name']}", ) self.assertEqual(output.dtype, config["dtype"]) self.assertEqual(output.device.type, "cuda") self.assertFalse( torch.isnan(output).any(), - f"Output contains NaN for {test_case['name']}" + f"Output contains NaN for {test_case['name']}", ) self.assertFalse( torch.isinf(output).any(), - f"Output contains Inf for {test_case['name']}" + f"Output contains Inf for {test_case['name']}", ) def test_metadata_initialization(self): """Test TRTLLM MLA metadata initialization and structure.""" print(f"\nRunning metadata initialization tests...") - + for test_case in TEST_CASES["metadata_tests"]: with self.subTest(test_case=test_case["name"]): print(f" Testing {test_case['name']}: {test_case['description']}") - + config = self._merge_config(test_case) batch_size = config["batch_size"] max_seq_len = config["max_seq_len"] # Create components - model_runner, _, backend, _, layer = self._create_model_components(config) + model_runner, _, backend, _, layer = self._create_model_components( + config + ) # Create varied sequence lengths torch.manual_seed(config["seed_cache"]) @@ -596,8 +652,10 @@ def test_metadata_initialization(self): seq_lens = torch.tensor([max_seq_len], device=config["device"]) else: seq_lens = torch.randint( - max(1, max_seq_len // 4), max_seq_len + 1, - (batch_size,), device=config["device"] + max(1, max_seq_len // 4), + max_seq_len + 1, + (batch_size,), + device=config["device"], ) seq_lens[0] = max_seq_len # Ensure at least one max length @@ -608,39 +666,42 @@ def test_metadata_initialization(self): # Initialize metadata backend.init_forward_metadata(fb) - + # Verify metadata exists self.assertIsNotNone(backend.forward_metadata) - self.assertIsInstance( - backend.forward_metadata, - TRTLLMMLADecodeMetadata - ) + self.assertIsInstance(backend.forward_metadata, TRTLLMMLADecodeMetadata) # Test metadata structure metadata = backend.forward_metadata - self.assertIsNotNone(metadata.workspace, "Workspace should be allocated") - self.assertIsNotNone(metadata.block_kv_indices, "Block KV indices should be created") - + self.assertIsNotNone( + metadata.workspace, "Workspace should be allocated" + ) + self.assertIsNotNone( + metadata.block_kv_indices, "Block KV indices should be created" + ) + # Test workspace properties self.assertEqual(metadata.workspace.device.type, "cuda") self.assertEqual(metadata.workspace.dtype, torch.int8) - self.assertGreater(metadata.workspace.numel(), 0, "Workspace should have non-zero size") - + self.assertGreater( + metadata.workspace.numel(), 0, "Workspace should have non-zero size" + ) + # Test block KV indices properties self.assertEqual(metadata.block_kv_indices.device.type, "cuda") self.assertEqual(metadata.block_kv_indices.dtype, torch.int32) self.assertEqual(metadata.block_kv_indices.shape[0], batch_size) - + # Verify block indices are valid (>= -1, since -1 is padding) self.assertTrue( (metadata.block_kv_indices >= -1).all(), - "All block indices should be >= -1 (with -1 as padding)" + "All block indices should be >= -1 (with -1 as padding)", ) def test_metadata_block_calculation(self): """Test block count calculation logic.""" print(f"\nRunning metadata block calculation tests...") - + test_scenarios = [ {"seq_len": 31, "page_size": 32, "expected_min_blocks": 1}, {"seq_len": 32, "page_size": 32, "expected_min_blocks": 1}, @@ -648,53 +709,62 @@ def test_metadata_block_calculation(self): {"seq_len": 128, "page_size": 32, "expected_min_blocks": 4}, {"seq_len": 128, "page_size": 64, "expected_min_blocks": 2}, ] - + for scenario in test_scenarios: with self.subTest(scenario=scenario): - config = self._merge_config({ - "batch_size": 1, - "max_seq_len": scenario["seq_len"], - "page_size": scenario["page_size"] - }) - + config = self._merge_config( + { + "batch_size": 1, + "max_seq_len": scenario["seq_len"], + "page_size": scenario["page_size"], + } + ) + model_runner, _, backend, _, _ = self._create_model_components(config) - + # Test internal block calculation calculated_blocks = backend._calc_padded_blocks(scenario["seq_len"]) - + # Should be at least the minimum required self.assertGreaterEqual( - calculated_blocks, scenario["expected_min_blocks"], - f"Calculated blocks ({calculated_blocks}) should be >= minimum required ({scenario['expected_min_blocks']})" + calculated_blocks, + scenario["expected_min_blocks"], + f"Calculated blocks ({calculated_blocks}) should be >= minimum required ({scenario['expected_min_blocks']})", ) - + # Should satisfy page_size constraint total_tokens = calculated_blocks * scenario["page_size"] self.assertGreaterEqual( - total_tokens, scenario["seq_len"], - f"Total tokens ({total_tokens}) should cover sequence length ({scenario['seq_len']})" + total_tokens, + scenario["seq_len"], + f"Total tokens ({total_tokens}) should cover sequence length ({scenario['seq_len']})", ) - + # Should satisfy TRT-LLM constraint (128 / page_size) trtllm_constraint = 128 // scenario["page_size"] self.assertEqual( - calculated_blocks % trtllm_constraint, 0, - f"Block count should be multiple of TRT-LLM constraint ({trtllm_constraint})" + calculated_blocks % trtllm_constraint, + 0, + f"Block count should be multiple of TRT-LLM constraint ({trtllm_constraint})", ) def test_metadata_kv_indices_correctness(self): """Test KV indices creation and correctness.""" print(f"\nRunning KV indices correctness tests...") - - for test_case in TEST_CASES["metadata_tests"][:2]: # Test subset for performance + + for test_case in TEST_CASES["metadata_tests"][ + :2 + ]: # Test subset for performance with self.subTest(test_case=test_case["name"]): print(f" Testing {test_case['name']}: {test_case['description']}") - + config = self._merge_config(test_case) batch_size = config["batch_size"] max_seq_len = config["max_seq_len"] - model_runner, _, backend, _, layer = self._create_model_components(config) + model_runner, _, backend, _, layer = self._create_model_components( + config + ) # Create known sequence lengths torch.manual_seed(config["seed_cache"]) @@ -702,8 +772,10 @@ def test_metadata_kv_indices_correctness(self): seq_lens = torch.tensor([max_seq_len], device=config["device"]) else: seq_lens = torch.randint( - max_seq_len // 2, max_seq_len + 1, - (batch_size,), device=config["device"] + max_seq_len // 2, + max_seq_len + 1, + (batch_size,), + device=config["device"], ) fb = self._create_forward_batch( @@ -711,7 +783,9 @@ def test_metadata_kv_indices_correctness(self): ) # Populate some KV cache to have valid indices - self._populate_kv_cache(batch_size, seq_lens, [model_runner], layer, config) + self._populate_kv_cache( + batch_size, seq_lens, [model_runner], layer, config + ) # Initialize metadata backend.init_forward_metadata(fb) @@ -719,57 +793,61 @@ def test_metadata_kv_indices_correctness(self): # Verify KV indices structure block_kv_indices = metadata.block_kv_indices - + for i in range(batch_size): seq_len = seq_lens[i].item() expected_blocks = backend._calc_padded_blocks(seq_len) - + # Count valid (non -1) indices for this sequence valid_indices = (block_kv_indices[i] >= 0).sum().item() - + # Should have at least enough blocks for the sequence - min_required_blocks = (seq_len + config["page_size"] - 1) // config["page_size"] + min_required_blocks = (seq_len + config["page_size"] - 1) // config[ + "page_size" + ] self.assertGreaterEqual( - valid_indices, min_required_blocks, - f"Sequence {i} should have at least {min_required_blocks} valid blocks, got {valid_indices}" + valid_indices, + min_required_blocks, + f"Sequence {i} should have at least {min_required_blocks} valid blocks, got {valid_indices}", ) - + # Verify indices are within valid range valid_block_indices = block_kv_indices[i][block_kv_indices[i] >= 0] if len(valid_block_indices) > 0: - max_possible_blocks = model_runner.token_to_kv_pool.size // config["page_size"] + max_possible_blocks = ( + model_runner.token_to_kv_pool.size // config["page_size"] + ) self.assertTrue( (valid_block_indices < max_possible_blocks).all(), - f"All block indices should be < {max_possible_blocks}" + f"All block indices should be < {max_possible_blocks}", ) def test_metadata_cuda_graph_compatibility(self): """Test metadata compatibility with CUDA graph capture/replay.""" print(f"\nRunning CUDA graph compatibility tests...") - - config = self._merge_config({ - "batch_size": 4, - "max_seq_len": 64, - "page_size": 32 - }) - + + config = self._merge_config( + {"batch_size": 4, "max_seq_len": 64, "page_size": 32} + ) + model_runner, _, backend, _, layer = self._create_model_components(config) batch_size = config["batch_size"] - + # Initialize CUDA graph state backend.init_cuda_graph_state( - max_bs=batch_size, - max_num_tokens=config["max_seq_len"] * batch_size + max_bs=batch_size, max_num_tokens=config["max_seq_len"] * batch_size ) - + # Verify CUDA graph buffers are allocated self.assertIsNotNone(backend.cuda_graph_kv_indices) self.assertIsNotNone(backend.cuda_graph_workspace) - + # Test capture metadata - seq_lens = torch.full((batch_size,), config["max_seq_len"], device=config["device"]) + seq_lens = torch.full( + (batch_size,), config["max_seq_len"], device=config["device"] + ) req_pool_indices = torch.arange(batch_size, device=config["device"]) - + backend.init_forward_metadata_capture_cuda_graph( bs=batch_size, num_tokens=batch_size, @@ -779,20 +857,22 @@ def test_metadata_cuda_graph_compatibility(self): forward_mode=ForwardMode.DECODE, spec_info=None, ) - + # Verify capture metadata self.assertIn(batch_size, backend.decode_cuda_graph_metadata) capture_metadata = backend.decode_cuda_graph_metadata[batch_size] - + self.assertIsNotNone(capture_metadata.workspace) self.assertIsNotNone(capture_metadata.block_kv_indices) - + # Test replay with different sequence lengths new_seq_lens = torch.randint( - config["max_seq_len"] // 2, config["max_seq_len"] + 1, - (batch_size,), device=config["device"] + config["max_seq_len"] // 2, + config["max_seq_len"] + 1, + (batch_size,), + device=config["device"], ) - + backend.init_forward_metadata_replay_cuda_graph( bs=batch_size, req_pool_indices=req_pool_indices, @@ -803,24 +883,24 @@ def test_metadata_cuda_graph_compatibility(self): spec_info=None, seq_lens_cpu=new_seq_lens.cpu(), ) - + # Verify replay updated the metadata replay_metadata = backend.forward_metadata self.assertIsNotNone(replay_metadata) - self.assertEqual(replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr()) + self.assertEqual( + replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr() + ) def test_metadata_consistency_across_calls(self): """Test metadata consistency across multiple forward calls.""" print(f"\nRunning metadata consistency tests...") - - config = self._merge_config({ - "batch_size": 2, - "max_seq_len": 64, - "page_size": 32 - }) - + + config = self._merge_config( + {"batch_size": 2, "max_seq_len": 64, "page_size": 32} + ) + model_runner, _, backend, _, layer = self._create_model_components(config) - + # First call seq_lens_1 = torch.tensor([32, 48], device=config["device"]) fb_1 = self._create_forward_batch( @@ -828,7 +908,7 @@ def test_metadata_consistency_across_calls(self): ) backend.init_forward_metadata(fb_1) metadata_1 = backend.forward_metadata - + # Second call with same sequence lengths seq_lens_2 = torch.tensor([32, 48], device=config["device"]) fb_2 = self._create_forward_batch( @@ -836,11 +916,13 @@ def test_metadata_consistency_across_calls(self): ) backend.init_forward_metadata(fb_2) metadata_2 = backend.forward_metadata - + # Metadata structure should be consistent self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape) - self.assertEqual(metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape) - + self.assertEqual( + metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape + ) + # Third call with different sequence lengths seq_lens_3 = torch.tensor([16, 64], device=config["device"]) fb_3 = self._create_forward_batch( @@ -848,7 +930,7 @@ def test_metadata_consistency_across_calls(self): ) backend.init_forward_metadata(fb_3) metadata_3 = backend.forward_metadata - + # Should still have valid structure self.assertIsNotNone(metadata_3.workspace) self.assertIsNotNone(metadata_3.block_kv_indices) @@ -856,4 +938,4 @@ def test_metadata_consistency_across_calls(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 00125a2fe243377d36a823e3bfd83802916ebd5f Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Mon, 21 Jul 2025 18:41:39 -0700 Subject: [PATCH 11/27] updated docs Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- docs/backend/attention_backend.md | 9 +++++++++ docs/references/deepseek.md | 6 +++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/docs/backend/attention_backend.md b/docs/backend/attention_backend.md index caf23446f5a6..3c18e76bc69a 100644 --- a/docs/backend/attention_backend.md +++ b/docs/backend/attention_backend.md @@ -9,8 +9,12 @@ | **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ | | **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ | | **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ | +| **TRTLLM MLA** | ✅ | ❌* | ✅ | ✅ | ❌** | | **Ascend** | ✅ | ❌ | ❌ | ❌ | ❌ | +**Notes:** +- \*\*TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend. + Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`. This is because a page size of 16 can be converted to a page size of 1 in the kernel backend. The "❌" and "✅" symbols in the table above under "Page Size > 1" indicate whether the kernel actually operates with a page size greater than 1, rather than treating a page size of 16 as a page size of 1. @@ -48,6 +52,11 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code ``` +- TRTLLM MLA (Optimized for Blackwell Architecture, e.g., B200) +```bash +python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --trust-remote-code +``` + - Ascend ```bash python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 8b6d688d1507..117cbf65863c 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -90,7 +90,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. -- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), [FlashMLA](https://github.com/deepseek-ai/FlashMLA), [CutlassMLA](https://github.com/sgl-project/sglang/pull/5390), and [Triton](https://github.com/triton-lang/triton) backends. The default FA3 provides good performance across wide workloads. +- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), [FlashMLA](https://github.com/deepseek-ai/FlashMLA), [CutlassMLA](https://github.com/sgl-project/sglang/pull/5390), **TRTLLM MLA** (optimized for Blackwell architecture), and [Triton](https://github.com/triton-lang/triton) backends. The default FA3 provides good performance across wide workloads. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. @@ -104,7 +104,7 @@ Overall, with these optimizations, we have achieved up to **7x** acceleration in Multi-head Latent Attention for DeepSeek Series Models

-**Usage**: MLA optimization is enabled by default. +**Usage**: MLA optimization is enabled by default. For MLA models on Blackwell architecture (e.g., B100), the default backend is FlashInfer. To use the optimized TRTLLM MLA backend for decode operations, explicitly specify `--attention-backend trtllm_mla`. Note that TRTLLM MLA only optimizes decode operations - prefill operations (including multimodal inputs) will fall back to FlashInfer MLA. **Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details. @@ -161,7 +161,7 @@ Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculati python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 --trust-remote-code --tp 8 ``` - The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes. -- FlashAttention3 FlashMLA and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA backend is still under development. +- FlashAttention3, FlashMLA, and Triton backend fully supports MTP usage. TRTLLM MLA falls back to FlashInfer MLA for speculative decoding operations. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA backend is still under development. - To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)): - Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value. - Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it. From cd0d5662b299dbaaa668c5d8db2595b597918e38 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Tue, 22 Jul 2025 10:50:48 -0700 Subject: [PATCH 12/27] kv-cache fix Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- .../sglang/srt/layers/attention/trtllm_mla_backend.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 5cbb84d08fba..dbd172c23ac5 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -321,11 +321,11 @@ def forward_decode( if query.dim() == 3: query = query.unsqueeze(1) - # Prepare KV cache inline (TRT-LLM expects stacked format) + # Prepare KV cache inline k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) pages = k_cache.view(-1, self.page_size, self.kv_cache_dim) - # TRT-LLM expects stacked format: slice 0 → CKV+K, slice 1 → KPE - kv_cache = torch.stack([pages, pages], dim=1) + # TRT-LLM expects single KV data with extra dimension + kv_cache = pages.unsqueeze(1) # Get metadata metadata = ( @@ -339,7 +339,7 @@ def forward_decode( # - This reduces to layer.scaling which is pre-computed as 1/sqrt(head_dim) bmm1_scale = layer.scaling bmm2_scale = 1.0 - bmm1_scale_tensor = bmm2_scale_tensor = None + bmm1_scale_log2_tensor = bmm2_scale_tensor = None # Call TRT-LLM kernel with proper scale configuration raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( @@ -351,11 +351,10 @@ def forward_decode( qk_rope_head_dim=self.qk_rope_head_dim, block_tables=metadata.block_kv_indices, seq_lens=forward_batch.seq_lens.to(torch.int32), - block_size=self.page_size, max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size), bmm1_scale=bmm1_scale, bmm2_scale=bmm2_scale, - bmm1_scale_tensor=bmm1_scale_tensor, + bmm1_scale_log2_tensor=bmm1_scale_log2_tensor, bmm2_scale_tensor=bmm2_scale_tensor, ) From a06f25225545fa86677e46c7b04fcdce68eb2d8a Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Tue, 22 Jul 2025 19:44:03 -0700 Subject: [PATCH 13/27] remove query concat Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- .../layers/attention/trtllm_mla_backend.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index dbd172c23ac5..f05e8d412d6a 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -312,7 +312,11 @@ def forward_decode( q_rope_reshaped = q_rope.view( -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim ) - query = torch.cat([q_nope, q_rope_reshaped], dim=-1) + # Use a pre-allocated staging buffer to avoid per-step tensor + # allocation and the CatArrayBatchedCopy kernel. + query = q.new_empty(q_nope.shape[0], layer.tp_q_head_num, layer.head_dim) + query[..., : layer.v_head_dim].copy_(q_nope) + query[..., layer.v_head_dim :].copy_(q_rope_reshaped) else: # q already has both parts query = q.view(-1, layer.tp_q_head_num, layer.head_dim) @@ -335,11 +339,10 @@ def forward_decode( # Scale computation for TRTLLM MLA kernel: # - BMM1 scale = q_scale * k_scale * softmax_scale - # - For FP16 output in DeepSeek R1: q_scale = k_scale = 1.0, softmax_scale = 1/sqrt(head_dim) - # - This reduces to layer.scaling which is pre-computed as 1/sqrt(head_dim) - bmm1_scale = layer.scaling - bmm2_scale = 1.0 - bmm1_scale_log2_tensor = bmm2_scale_tensor = None + # - For FP16 path we keep q_scale = k_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling + # TODO: change once fp8 path is supported + q_scale = k_scale = 1.0 # for fp16 we keep 1 + bmm1_scale = q_scale * k_scale * layer.scaling # Call TRT-LLM kernel with proper scale configuration raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( @@ -352,10 +355,7 @@ def forward_decode( block_tables=metadata.block_kv_indices, seq_lens=forward_batch.seq_lens.to(torch.int32), max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size), - bmm1_scale=bmm1_scale, - bmm2_scale=bmm2_scale, - bmm1_scale_log2_tensor=bmm1_scale_log2_tensor, - bmm2_scale_tensor=bmm2_scale_tensor, + bmm1_scale=bmm1_scale ) # Extract value projection part and reshape From ab0df43684de8adbece35e03c17d35f60e5572b5 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Wed, 23 Jul 2025 14:35:28 -0700 Subject: [PATCH 14/27] server level changes Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- python/sglang/srt/layers/attention/trtllm_mla_backend.py | 8 ++------ python/sglang/srt/models/deepseek_v2.py | 1 + python/sglang/srt/server_args.py | 7 +++++++ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index f05e8d412d6a..fe20de95654d 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -305,18 +305,14 @@ def forward_decode( elif v is not None: forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) - # Prepare query tensor inline (avoid helper for shape tweaks) + # Prepare query tensor inline if q_rope is not None: # q contains NOPE part (v_head_dim) q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) q_rope_reshaped = q_rope.view( -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim ) - # Use a pre-allocated staging buffer to avoid per-step tensor - # allocation and the CatArrayBatchedCopy kernel. - query = q.new_empty(q_nope.shape[0], layer.tp_q_head_num, layer.head_dim) - query[..., : layer.v_head_dim].copy_(q_nope) - query[..., layer.v_head_dim :].copy_(q_rope_reshaped) + query = torch.cat([q_nope, q_rope_reshaped], dim=-1) else: # q already has both parts query = q.view(-1, layer.tp_q_head_num, layer.head_dim) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b5305f923fe4..6676f77667f7 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1258,6 +1258,7 @@ def forward_absorb_core( self.current_attention_backend == "fa3" or self.current_attention_backend == "flashinfer" or self.current_attention_backend == "cutlass_mla" + or self.current_attention_backend == "trtllm_mla" ): attn_output = self.attn_mqa( q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index bd5a8336aab2..1ee15efdb7de 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -409,6 +409,13 @@ def __post_init__(self): ) self.page_size = 128 + if self.attention_backend == "trtllm_mla": + if self.page_size not in [32, 64]: + logger.warning( + f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64." + ) + self.page_size = 64 + # Set page size if self.page_size is None: self.page_size = 1 From 523293d69f4031b98db2b22eee7c0633d141d82f Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Thu, 24 Jul 2025 13:56:18 -0700 Subject: [PATCH 15/27] update args and toml Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- .../layers/attention/trtllm_mla_backend.py | 34 ++++++++++++------- .../sglang/srt/model_executor/model_runner.py | 8 +++++ python/sglang/srt/server_args.py | 4 +++ 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index fe20de95654d..e340e44351c6 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -32,7 +32,11 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB # Block constraint from flashinfer requirements -# See: https://github.com/flashinfer-ai/flashinfer/blob/fe29ed63cb923f25cae70ef83f3fd16139305b35/flashinfer/decode.py#L2057 +# From flashinfer.decode._check_trtllm_gen_mla_shape: +# block_num % (128 / block_size) == 0 +# This imposes that the total number of blocks must be divisible by +# (128 / block_size). We capture the 128 constant here so we can +# compute the LCM with other padding constraints. TRTLLM_BLOCK_CONSTRAINT = 128 @@ -52,11 +56,9 @@ def __init__( model_runner: ModelRunner, skip_prefill: bool = False, kv_indptr_buf: Optional[torch.Tensor] = None, - kv_last_page_len_buf: Optional[torch.Tensor] = None, + q_indptr_decode_buf: Optional[torch.Tensor] = None, ): - super().__init__( - model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf - ) + super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf) config = model_runner.model_config @@ -88,7 +90,7 @@ def __init__( # CUDA graph state self.decode_cuda_graph_metadata = {} self.cuda_graph_kv_indices = None - self.forward_metadata: Union[TRTLLMMLADecodeMetadata] = None + self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None def _calc_padded_blocks(self, max_seq_len: int) -> int: """ @@ -104,7 +106,6 @@ def _calc_padded_blocks(self, max_seq_len: int) -> int: # Apply dual constraints (take LCM to satisfy both): # 1. TRT-LLM: block_num % (128 / page_size) == 0 - # Reference: https://github.com/NVIDIA/TensorRT-LLM/issues/XYZ # TODO: add actual link # 2. Triton: page table builder uses 64-index bursts, needs multiple of 64 trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK) @@ -281,6 +282,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) forward_batch.decode_trtllm_mla_metadata = self.forward_metadata else: + # For prefill or other modes, fallback to parent implementation. super().init_forward_metadata(forward_batch) def forward_decode( @@ -312,7 +314,7 @@ def forward_decode( q_rope_reshaped = q_rope.view( -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim ) - query = torch.cat([q_nope, q_rope_reshaped], dim=-1) + query = torch.cat([q_nope, q_rope_reshaped], dim=-1) else: # q already has both parts query = q.view(-1, layer.tp_q_head_num, layer.head_dim) @@ -335,12 +337,18 @@ def forward_decode( # Scale computation for TRTLLM MLA kernel: # - BMM1 scale = q_scale * k_scale * softmax_scale - # - For FP16 path we keep q_scale = k_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling - # TODO: change once fp8 path is supported - q_scale = k_scale = 1.0 # for fp16 we keep 1 - bmm1_scale = q_scale * k_scale * layer.scaling + # - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling + # - k_scale is read from model checkpoint if available + # TODO: Change once fp8 path is supported + q_scale = 1.0 # for fp16 we keep q_scale as 1.0 + + # silently pick k_scale (avoid per-call prints in CUDA graph) + k_scale = layer.k_scale_float if getattr(layer, "k_scale_float", None) is not None else 1.0 + + bmm1_scale = q_scale * k_scale * layer.scaling + - # Call TRT-LLM kernel with proper scale configuration + # Call TRT-LLM kernel raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, kv_cache=kv_cache, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 189d66093561..59551fa1e3ca 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1432,6 +1432,14 @@ def _get_attention_backend_from_str(self, backend_str: str): return CutlassMLABackend(self) elif self.server_args.attention_backend == "trtllm_mla": + if not self.use_mla_backend: + raise ValueError( + "trtllm_mla backend can only be used with MLA models." + ) + if not self.spec_algorithm.is_none(): + raise ValueError( + "trtllm_mla backend does not support speculative decoding yet." + ) from python.sglang.srt.layers.attention.trtllm_mla_backend import ( TRTLLMMLABackend, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1ee15efdb7de..a050e20401ac 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -415,6 +415,10 @@ def __post_init__(self): f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64." ) self.page_size = 64 + if self.speculative_algorithm is not None: + raise ValueError( + "trtllm_mla backend does not support speculative decoding yet." + ) # Set page size if self.page_size is None: From a3a8784f49eeb3e6b30bf98195c3ea82a795a4db Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Thu, 24 Jul 2025 14:00:12 -0700 Subject: [PATCH 16/27] remove check Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- python/sglang/srt/model_executor/model_runner.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 59551fa1e3ca..34330aeeaa06 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1436,10 +1436,6 @@ def _get_attention_backend_from_str(self, backend_str: str): raise ValueError( "trtllm_mla backend can only be used with MLA models." ) - if not self.spec_algorithm.is_none(): - raise ValueError( - "trtllm_mla backend does not support speculative decoding yet." - ) from python.sglang.srt.layers.attention.trtllm_mla_backend import ( TRTLLMMLABackend, ) From be55b7b60747f55f8019ea9e1b3396d903db6478 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Thu, 24 Jul 2025 14:19:57 -0700 Subject: [PATCH 17/27] lint Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- docs/backend/attention_backend.md | 2 +- python/sglang/srt/layers/attention/trtllm_mla_backend.py | 9 ++++++--- python/sglang/srt/model_executor/model_runner.py | 4 +--- 3 files changed, 8 insertions(+), 7 deletions(-) mode change 100644 => 100755 docs/backend/attention_backend.md mode change 100644 => 100755 python/sglang/srt/model_executor/model_runner.py diff --git a/docs/backend/attention_backend.md b/docs/backend/attention_backend.md old mode 100644 new mode 100755 index 3c18e76bc69a..3929591ea8f5 --- a/docs/backend/attention_backend.md +++ b/docs/backend/attention_backend.md @@ -12,7 +12,7 @@ | **TRTLLM MLA** | ✅ | ❌* | ✅ | ✅ | ❌** | | **Ascend** | ✅ | ❌ | ❌ | ❌ | ❌ | -**Notes:** +**Notes:** - \*\*TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend. Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`. diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index e340e44351c6..2a71c4a1fa00 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -343,11 +343,14 @@ def forward_decode( q_scale = 1.0 # for fp16 we keep q_scale as 1.0 # silently pick k_scale (avoid per-call prints in CUDA graph) - k_scale = layer.k_scale_float if getattr(layer, "k_scale_float", None) is not None else 1.0 + k_scale = ( + layer.k_scale_float + if getattr(layer, "k_scale_float", None) is not None + else 1.0 + ) bmm1_scale = q_scale * k_scale * layer.scaling - # Call TRT-LLM kernel raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, @@ -359,7 +362,7 @@ def forward_decode( block_tables=metadata.block_kv_indices, seq_lens=forward_batch.seq_lens.to(torch.int32), max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size), - bmm1_scale=bmm1_scale + bmm1_scale=bmm1_scale, ) # Extract value projection part and reshape diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py old mode 100644 new mode 100755 index 34330aeeaa06..08c0fbeb40e2 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1433,9 +1433,7 @@ def _get_attention_backend_from_str(self, backend_str: str): return CutlassMLABackend(self) elif self.server_args.attention_backend == "trtllm_mla": if not self.use_mla_backend: - raise ValueError( - "trtllm_mla backend can only be used with MLA models." - ) + raise ValueError("trtllm_mla backend can only be used with MLA models.") from python.sglang.srt.layers.attention.trtllm_mla_backend import ( TRTLLMMLABackend, ) From e01890aa9870b11d235301fc8fe32c773b619bb2 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Fri, 25 Jul 2025 09:10:56 -0700 Subject: [PATCH 18/27] bug fix Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- python/sglang/srt/layers/attention/trtllm_mla_backend.py | 4 +--- python/sglang/srt/model_executor/model_runner.py | 4 +--- python/sglang/test/attention/test_trtllm_mla_backend.py | 9 ++++----- 3 files changed, 6 insertions(+), 11 deletions(-) mode change 100644 => 100755 python/sglang/test/attention/test_trtllm_mla_backend.py diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 2a71c4a1fa00..dc41103c60ae 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -340,9 +340,7 @@ def forward_decode( # - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling # - k_scale is read from model checkpoint if available # TODO: Change once fp8 path is supported - q_scale = 1.0 # for fp16 we keep q_scale as 1.0 - - # silently pick k_scale (avoid per-call prints in CUDA graph) + q_scale = 1.0 k_scale = ( layer.k_scale_float if getattr(layer, "k_scale_float", None) is not None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 08c0fbeb40e2..4c1c9b362b99 100755 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1434,9 +1434,7 @@ def _get_attention_backend_from_str(self, backend_str: str): elif self.server_args.attention_backend == "trtllm_mla": if not self.use_mla_backend: raise ValueError("trtllm_mla backend can only be used with MLA models.") - from python.sglang.srt.layers.attention.trtllm_mla_backend import ( - TRTLLMMLABackend, - ) + from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend return TRTLLMMLABackend(self) elif self.server_args.attention_backend == "intel_amx": diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py old mode 100644 new mode 100755 index 92c5045f93ad..ca055ef3c77b --- a/python/sglang/test/attention/test_trtllm_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -9,12 +9,12 @@ # Patch DP-attention globals before importing backends _dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test -from python.sglang.srt.layers.attention.trtllm_mla_backend import ( +from sglang.srt.configs.model_config import AttentionArch +from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend +from sglang.srt.layers.attention.trtllm_mla_backend import ( TRTLLMMLABackend, TRTLLMMLADecodeMetadata, ) -from sglang.srt.configs.model_config import AttentionArch -from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -76,8 +76,7 @@ }, ], "page_size_consistency": [ - # Only 32 and 64 supported for now https://github.com/flashinfer-ai/flashinfer/blob/fe29ed63cb923f25cae70ef83f3fd16139305b35/flashinfer/decode.py#L2115 - # TODO: Test 16 and 128. Pending cubins + # Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel { "name": "page_32", "batch_size": 8, From ef24a0baf4604456fa139219a6fd66ab9d34b8ed Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Mon, 28 Jul 2025 13:26:28 -0700 Subject: [PATCH 19/27] fix conflict Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- python/sglang/srt/models/deepseek_v2.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 python/sglang/srt/models/deepseek_v2.py diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py old mode 100644 new mode 100755 From cc21f0bef9668ddd6c680d5fe588ca476e4072f2 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 28 Jul 2025 18:07:10 -0700 Subject: [PATCH 20/27] Update python/pyproject.toml --- python/pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index c79681ae1c3d..d916fcb57e6c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -75,7 +75,6 @@ blackwell = [ "tiktoken", ] - # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20250114, not from public vllm whl srt_hip = [ From 39cfd7dae8d55e50e2ab57695b76a98144f2aa8e Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Tue, 29 Jul 2025 10:11:55 -0700 Subject: [PATCH 21/27] some pr review fixes Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- .../layers/attention/trtllm_mla_backend.py | 126 +++++++++--------- python/sglang/srt/layers/attention/utils.py | 3 +- 2 files changed, 65 insertions(+), 64 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index dc41103c60ae..d3320144248c 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -147,6 +147,7 @@ def _create_block_kv_indices( block_kv_indices, self.req_to_token.stride(0), max_blocks, + TRITON_PAD_NUM_PAGE_PER_BLOCK, self.page_size, ) @@ -179,28 +180,9 @@ def init_forward_metadata_capture_cuda_graph( spec_info: Optional[SpecInfo], ): """Initialize metadata for CUDA graph capture.""" - if forward_mode.is_decode_or_idle() and spec_info is None: - max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item()) - block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad] - - create_flashmla_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - seq_lens, - None, - block_kv_indices, - self.req_to_token.stride(0), - max_seqlen_pad, - self.page_size, - ) - - metadata = TRTLLMMLADecodeMetadata( - self.cuda_graph_workspace, block_kv_indices - ) - self.decode_cuda_graph_metadata[bs] = metadata - self.forward_metadata = metadata - else: - super().init_forward_metadata_capture_cuda_graph( + # Delegate to parent for non-decode modes or when speculative execution is used. + if not (forward_mode.is_decode_or_idle() and spec_info is None): + return super().init_forward_metadata_capture_cuda_graph( bs, num_tokens, req_pool_indices, @@ -210,6 +192,26 @@ def init_forward_metadata_capture_cuda_graph( spec_info, ) + # Custom fast-path for decode/idle without speculative execution. + max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item()) + block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad] + + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + None, + block_kv_indices, + self.req_to_token.stride(0), + max_seqlen_pad, + TRITON_PAD_NUM_PAGE_PER_BLOCK, + self.page_size, + ) + + metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices) + self.decode_cuda_graph_metadata[bs] = metadata + self.forward_metadata = metadata + def init_forward_metadata_replay_cuda_graph( self, bs: int, @@ -222,24 +224,9 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_cpu: Optional[torch.Tensor], ): """Replay CUDA graph with new inputs.""" - if forward_mode.is_decode_or_idle() and spec_info is None: - metadata = self.decode_cuda_graph_metadata[bs] - - # Update block indices for new sequences - create_flashmla_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices[:bs], - seq_lens[:bs], - None, - metadata.block_kv_indices, - self.req_to_token.stride(0), - metadata.block_kv_indices.shape[1], - self.page_size, - ) - - self.forward_metadata = metadata - else: - super().init_forward_metadata_replay_cuda_graph( + # Delegate to parent for non-decode modes or when speculative execution is used. + if not (forward_mode.is_decode_or_idle() and spec_info is None): + return super().init_forward_metadata_replay_cuda_graph( bs, req_pool_indices, seq_lens, @@ -250,40 +237,55 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_cpu, ) + metadata = self.decode_cuda_graph_metadata[bs] + + # Update block indices for new sequences. + create_flashmla_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens[:bs], + None, + metadata.block_kv_indices, + self.req_to_token.stride(0), + metadata.block_kv_indices.shape[1], + TRITON_PAD_NUM_PAGE_PER_BLOCK, + self.page_size, + ) + def get_cuda_graph_seq_len_fill_value(self) -> int: """Get the fill value for sequence lengths in CUDA graph.""" return 1 def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize the metadata for a forward pass.""" - if ( + # Delegate to parent for non-decode modes or when speculative execution is used. + if not ( forward_batch.forward_mode.is_decode_or_idle() and forward_batch.spec_info is None ): - bs = forward_batch.batch_size - - # Get maximum sequence length - if getattr(forward_batch, "seq_lens_cpu", None) is not None: - max_seq = forward_batch.seq_lens_cpu.max().item() - else: - max_seq = forward_batch.seq_lens.max().item() + return super().init_forward_metadata(forward_batch) - max_seqlen_pad = self._calc_padded_blocks(max_seq) - block_kv_indices = self._create_block_kv_indices( - bs, - max_seqlen_pad, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.seq_lens.device, - ) + bs = forward_batch.batch_size - self.forward_metadata = TRTLLMMLADecodeMetadata( - self.workspace_buffer, block_kv_indices - ) - forward_batch.decode_trtllm_mla_metadata = self.forward_metadata + # Get maximum sequence length. + if getattr(forward_batch, "seq_lens_cpu", None) is not None: + max_seq = forward_batch.seq_lens_cpu.max().item() else: - # For prefill or other modes, fallback to parent implementation. - super().init_forward_metadata(forward_batch) + max_seq = forward_batch.seq_lens.max().item() + + max_seqlen_pad = self._calc_padded_blocks(max_seq) + block_kv_indices = self._create_block_kv_indices( + bs, + max_seqlen_pad, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens.device, + ) + + self.forward_metadata = TRTLLMMLADecodeMetadata( + self.workspace_buffer, block_kv_indices + ) + forward_batch.decode_trtllm_mla_metadata = self.forward_metadata def forward_decode( self, diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py index 0cfb6a359cf2..e8cd2e1580a1 100755 --- a/python/sglang/srt/layers/attention/utils.py +++ b/python/sglang/srt/layers/attention/utils.py @@ -55,11 +55,10 @@ def create_flashmla_kv_indices_triton( kv_indices_ptr, req_to_token_ptr_stride: tl.constexpr, kv_indices_ptr_stride: tl.constexpr, + NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK, PAGED_SIZE: tl.constexpr = 64, ): BLOCK_SIZE: tl.constexpr = 4096 - # Keep in sync with module-level TRITON_PAD_NUM_PAGE_PER_BLOCK constant above - NUM_PAGE_PER_BLOCK: tl.constexpr = 64 pid = tl.program_id(axis=0) # find the req pool idx, this is for batch to token From a39e81761062e1f2a67def9db4eafffe268dfa5a Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Tue, 29 Jul 2025 16:40:40 -0700 Subject: [PATCH 22/27] add todo comment Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- python/sglang/test/attention/test_trtllm_mla_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py index ca055ef3c77b..49f94bd8ad39 100755 --- a/python/sglang/test/attention/test_trtllm_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -1,5 +1,4 @@ import unittest -from types import SimpleNamespace import numpy as np import torch @@ -7,6 +6,7 @@ from sglang.srt.layers import dp_attention as _dp_attn # Patch DP-attention globals before importing backends +# TODO: change the interface of both trtllm_mla and flashinfer backends to take tp_size as an argument instead of patching _dp_attn.get_attention_tp_size = lambda: 1 # TP size = 1 for unit test from sglang.srt.configs.model_config import AttentionArch From aa9764f067e82abd6347b6f21fbb7c613eb5899e Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Wed, 30 Jul 2025 18:05:48 -0700 Subject: [PATCH 23/27] perm change Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- docs/backend/attention_backend.md | 0 python/sglang/srt/layers/attention/utils.py | 0 python/sglang/srt/model_executor/model_runner.py | 0 python/sglang/srt/models/deepseek_v2.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 docs/backend/attention_backend.md mode change 100755 => 100644 python/sglang/srt/layers/attention/utils.py mode change 100755 => 100644 python/sglang/srt/model_executor/model_runner.py mode change 100755 => 100644 python/sglang/srt/models/deepseek_v2.py diff --git a/docs/backend/attention_backend.md b/docs/backend/attention_backend.md old mode 100755 new mode 100644 diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py old mode 100755 new mode 100644 diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py old mode 100755 new mode 100644 diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py old mode 100755 new mode 100644 From fadc8a66e8d915f4a31fac14482d22c186713749 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Thu, 31 Jul 2025 13:43:06 -0700 Subject: [PATCH 24/27] some pr changes Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- docs/backend/attention_backend.md | 4 ++-- python/sglang/test/attention/test_trtllm_mla_backend.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/backend/attention_backend.md b/docs/backend/attention_backend.md index 3929591ea8f5..3dfe6cb3de5f 100644 --- a/docs/backend/attention_backend.md +++ b/docs/backend/attention_backend.md @@ -9,11 +9,11 @@ | **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ | | **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ | | **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ | -| **TRTLLM MLA** | ✅ | ❌* | ✅ | ✅ | ❌** | +| **TRTLLM MLA** | ✅ | ❌ | ✅ | ✅ | ❌ | | **Ascend** | ✅ | ❌ | ❌ | ❌ | ❌ | **Notes:** -- \*\*TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend. +- TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend. Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`. This is because a page size of 16 can be converted to a page size of 1 in the kernel backend. diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py index 49f94bd8ad39..2d1a3ce91c53 100755 --- a/python/sglang/test/attention/test_trtllm_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -1,5 +1,6 @@ import unittest +import math import numpy as np import torch @@ -11,6 +12,7 @@ from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend +from sglang.srt.layers.attention.flashinfer_mla_backend import TRITON_PAD_NUM_PAGE_PER_BLOCK from sglang.srt.layers.attention.trtllm_mla_backend import ( TRTLLMMLABackend, TRTLLMMLADecodeMetadata, @@ -739,12 +741,13 @@ def test_metadata_block_calculation(self): f"Total tokens ({total_tokens}) should cover sequence length ({scenario['seq_len']})", ) - # Should satisfy TRT-LLM constraint (128 / page_size) + # Should satisfy TRT-LLM and Triton constraints trtllm_constraint = 128 // scenario["page_size"] + constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK) self.assertEqual( - calculated_blocks % trtllm_constraint, + calculated_blocks % constraint_lcm, 0, - f"Block count should be multiple of TRT-LLM constraint ({trtllm_constraint})", + f"Block count should be multiple of LCM of constraints ({constraint_lcm})", ) def test_metadata_kv_indices_correctness(self): From da8ab52eebaae09e31fa1be50505498cf82cd349 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Thu, 31 Jul 2025 13:49:40 -0700 Subject: [PATCH 25/27] update doc Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- docs/references/deepseek.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 117cbf65863c..9f0eb6506263 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -161,7 +161,7 @@ Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculati python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 --trust-remote-code --tp 8 ``` - The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes. -- FlashAttention3, FlashMLA, and Triton backend fully supports MTP usage. TRTLLM MLA falls back to FlashInfer MLA for speculative decoding operations. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA backend is still under development. +- FlashAttention3, FlashMLA, and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA and TRTLLM MLA backend is still under development. - To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)): - Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value. - Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it. From f6bc5c3daf785e0844a2c7f8ffe825098faa710b Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Thu, 31 Jul 2025 14:03:03 -0700 Subject: [PATCH 26/27] add sm100 check Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- python/sglang/srt/server_args.py | 6 ++++++ python/sglang/test/attention/test_trtllm_mla_backend.py | 8 +++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ef2c1ad94b0e..c4a520f1ce4f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -24,6 +24,7 @@ from typing import List, Literal, Optional, Union from sglang.srt.hf_transformers_utils import check_gguf_file, get_config +from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.utils import ( @@ -403,6 +404,11 @@ def __post_init__(self): self.page_size = 128 if self.attention_backend == "trtllm_mla": + if not is_sm100_supported(): + raise ValueError( + "TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend." + ) + if self.page_size not in [32, 64]: logger.warning( f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64." diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py index 2d1a3ce91c53..be3ed08f40f3 100755 --- a/python/sglang/test/attention/test_trtllm_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -1,6 +1,6 @@ +import math import unittest -import math import numpy as np import torch @@ -12,11 +12,11 @@ from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend -from sglang.srt.layers.attention.flashinfer_mla_backend import TRITON_PAD_NUM_PAGE_PER_BLOCK from sglang.srt.layers.attention.trtllm_mla_backend import ( TRTLLMMLABackend, TRTLLMMLADecodeMetadata, ) +from sglang.srt.layers.attention.utils import TRITON_PAD_NUM_PAGE_PER_BLOCK from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -743,7 +743,9 @@ def test_metadata_block_calculation(self): # Should satisfy TRT-LLM and Triton constraints trtllm_constraint = 128 // scenario["page_size"] - constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK) + constraint_lcm = math.lcm( + trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK + ) self.assertEqual( calculated_blocks % constraint_lcm, 0, From 5679c0f383cf11aaa20f7967f73e4643dce09261 Mon Sep 17 00:00:00 2001 From: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Date: Thu, 31 Jul 2025 14:13:06 -0700 Subject: [PATCH 27/27] dito Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- docs/references/deepseek.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 9f0eb6506263..af5e38677318 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -104,7 +104,7 @@ Overall, with these optimizations, we have achieved up to **7x** acceleration in Multi-head Latent Attention for DeepSeek Series Models

-**Usage**: MLA optimization is enabled by default. For MLA models on Blackwell architecture (e.g., B100), the default backend is FlashInfer. To use the optimized TRTLLM MLA backend for decode operations, explicitly specify `--attention-backend trtllm_mla`. Note that TRTLLM MLA only optimizes decode operations - prefill operations (including multimodal inputs) will fall back to FlashInfer MLA. +**Usage**: MLA optimization is enabled by default. For MLA models on Blackwell architecture (e.g., B200), the default backend is FlashInfer. To use the optimized TRTLLM MLA backend for decode operations, explicitly specify `--attention-backend trtllm_mla`. Note that TRTLLM MLA only optimizes decode operations - prefill operations (including multimodal inputs) will fall back to FlashInfer MLA. **Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details. @@ -161,7 +161,7 @@ Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculati python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 --trust-remote-code --tp 8 ``` - The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes. -- FlashAttention3, FlashMLA, and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA and TRTLLM MLA backend is still under development. +- FlashAttention3, FlashMLA, and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA and TRTLLM MLA backends are still under development. - To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)): - Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value. - Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it.