From 119d60b8be8035bd4aead676dc7b7f4c825eeec9 Mon Sep 17 00:00:00 2001 From: ccs96307 Date: Mon, 12 May 2025 05:56:37 +0000 Subject: [PATCH 1/5] Add Flashinfer support for encoder-only model --- .../layers/attention/flashinfer_backend.py | 63 +++++++++++++++++-- .../models/test_encoder_embedding_models.py | 11 +++- 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 1c254c4fa502..fbe771bd1bb6 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -25,6 +25,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.utils import is_flashinfer_available, next_power_of_2 @@ -64,6 +65,19 @@ class PrefillMetadata: global_workspace_buffer = None +# global_fake_head_dim is used as a workaround to bypass FlashInfer's current limitation +# which does not support head_dim=32 (or other unsupported dimensions). The actual +# QKV tensors have an effective head dimension of origin, with the remaining padded +# with zeros. +# +# Note: Be sure to set sm_scale = 1.0 / sqrt(actual_dim), based on the real head_dim, +# even though fake_head_dim is passed to the attention kernel. +# +# TODO: Once FlashInfer officially supports head_dim=32, this variable and the +# associated padding logic should be removed to eliminate the workaround. +global_fake_head_dim = 64 + + class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" @@ -438,12 +452,50 @@ def forward_extend( v_scale=layer.v_scale, ) else: - if self.forward_metadata.extend_no_prefix: + causal = True + if layer.attn_type == AttentionType.ENCODER_ONLY: + save_kv_cache = False + causal = False + + q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + k = k.view(-1, layer.tp_k_head_num, layer.head_dim) + v = v.view(-1, layer.tp_v_head_num, layer.head_dim) + + # Pad the `head_dim` if it less than `global_fake_head_dim` + original_head_dim = layer.head_dim + head_dim = max(global_fake_head_dim, original_head_dim) + + q_padded_shape = q.shape[:-1] + (head_dim,) + q_padded = torch.zeros(q_padded_shape, dtype=q.dtype, device=q.device) + q_padded[..., :original_head_dim] = q + q = q_padded + + k_padded_shape = k.shape[:-1] + (head_dim,) + k_padded = torch.zeros(k_padded_shape, dtype=k.dtype, device=k.device) + k_padded[..., :original_head_dim] = k + k = k_padded + + v_padded_shape = v.shape[:-1] + (head_dim,) + v_padded = torch.zeros(v_padded_shape, dtype=v.dtype, device=v.device) + v_padded[..., :original_head_dim] = v + v = v_padded + + o = self.prefill_wrapper_ragged.forward( + q, + k, + v, + causal=causal, + sm_scale=layer.scaling, + logits_soft_cap=logits_soft_cap, + ) + o = o[..., :original_head_dim].contiguous() + + elif self.forward_metadata.extend_no_prefix: o = self.prefill_wrapper_ragged.forward( q.view(-1, layer.tp_q_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim), v.view(-1, layer.tp_v_head_num, layer.head_dim), - causal=True, + causal=causal, sm_scale=layer.scaling, logits_soft_cap=logits_soft_cap, ) @@ -708,6 +760,9 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): get_attention_tp_size() ) self.head_dim = model_runner.model_config.head_dim + self.fake_head_dim = ( + None if self.head_dim >= global_fake_head_dim else global_fake_head_dim + ) self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype self.sliding_window_size = model_runner.sliding_window_size @@ -913,7 +968,7 @@ def call_begin_forward( qo_indptr, self.num_qo_heads, self.num_kv_heads, - self.head_dim, + self.head_dim if self.fake_head_dim is None else self.fake_head_dim, q_data_type=self.q_data_type, ) @@ -925,7 +980,7 @@ def call_begin_forward( self.kv_last_page_len[:bs], self.num_qo_heads, self.num_kv_heads, - self.head_dim, + self.fake_head_dim if self.fake_head_dim else self.head_dim, 1, q_data_type=self.q_data_type, kv_data_type=self.data_type, diff --git a/test/srt/models/test_encoder_embedding_models.py b/test/srt/models/test_encoder_embedding_models.py index 5202917c4b18..d51490c19094 100644 --- a/test/srt/models/test_encoder_embedding_models.py +++ b/test/srt/models/test_encoder_embedding_models.py @@ -27,9 +27,9 @@ MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)] -ATTENTION_BACKEND = ["torch_native", "triton"] +ATTENTION_BACKEND = ["torch_native", "triton", "flashinfer"] BATCH_SIZE = [1, 2] -TORCH_DTYPES = [torch.float32] +TORCH_DTYPES = [torch.float32, torch.float16] sgl_to_st_ratio = [] @@ -126,6 +126,13 @@ def test_prefill_logits(self): for attention_backend in ATTENTION_BACKEND: for batch_size in BATCH_SIZE: for torch_dtype in TORCH_DTYPES: + # Flashinfer is not support torch.float32 for dtype_q, so pass it + if ( + torch_dtype == torch.float32 + and attention_backend == "flashinfer" + ): + continue + self.assert_close_prefill_logits( DEFAULT_PROMPTS, model, From f5794cecb8994798035d7b467b3181a33e7e751a Mon Sep 17 00:00:00 2001 From: ccs96307 Date: Mon, 12 May 2025 10:29:25 +0000 Subject: [PATCH 2/5] reformat --- .../layers/attention/flashinfer_backend.py | 56 ++++++++++++++----- 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index fbe771bd1bb6..ede387d166c6 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -7,6 +7,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. """ +import math import os from dataclasses import dataclass from enum import Enum, auto @@ -463,22 +464,39 @@ def forward_extend( # Pad the `head_dim` if it less than `global_fake_head_dim` original_head_dim = layer.head_dim - head_dim = max(global_fake_head_dim, original_head_dim) + needs_padding = ( + hasattr(self.indices_updater_prefill, "fake_head_dim") + and self.indices_updater_prefill.fake_head_dim is not None + and original_head_dim < self.indices_updater_prefill.fake_head_dim + ) - q_padded_shape = q.shape[:-1] + (head_dim,) - q_padded = torch.zeros(q_padded_shape, dtype=q.dtype, device=q.device) - q_padded[..., :original_head_dim] = q - q = q_padded + if needs_padding: + q_padded_shape = q.shape[:-1] + ( + self.indices_updater_prefill.fake_head_dim, + ) + q_padded = torch.zeros( + q_padded_shape, dtype=q.dtype, device=q.device + ) + q_padded[..., :original_head_dim] = q + q = q_padded - k_padded_shape = k.shape[:-1] + (head_dim,) - k_padded = torch.zeros(k_padded_shape, dtype=k.dtype, device=k.device) - k_padded[..., :original_head_dim] = k - k = k_padded + k_padded_shape = k.shape[:-1] + ( + self.indices_updater_prefill.fake_head_dim, + ) + k_padded = torch.zeros( + k_padded_shape, dtype=k.dtype, device=k.device + ) + k_padded[..., :original_head_dim] = k + k = k_padded - v_padded_shape = v.shape[:-1] + (head_dim,) - v_padded = torch.zeros(v_padded_shape, dtype=v.dtype, device=v.device) - v_padded[..., :original_head_dim] = v - v = v_padded + v_padded_shape = v.shape[:-1] + ( + self.indices_updater_prefill.fake_head_dim, + ) + v_padded = torch.zeros( + v_padded_shape, dtype=v.dtype, device=v.device + ) + v_padded[..., :original_head_dim] = v + v = v_padded o = self.prefill_wrapper_ragged.forward( q, @@ -488,7 +506,11 @@ def forward_extend( sm_scale=layer.scaling, logits_soft_cap=logits_soft_cap, ) - o = o[..., :original_head_dim].contiguous() + + if needs_padding: + o = o[..., :original_head_dim] + + o = o.contiguous() elif self.forward_metadata.extend_no_prefix: o = self.prefill_wrapper_ragged.forward( @@ -961,6 +983,11 @@ def call_begin_forward( ) ) + # If self.fake_head_dim is not None + sm_scale_for_begin_forward = None + if self.fake_head_dim is not None: + sm_scale_for_begin_forward = 1.0 / math.sqrt(self.head_dim) + # extend part if use_ragged: wrapper_ragged.begin_forward( @@ -970,6 +997,7 @@ def call_begin_forward( self.num_kv_heads, self.head_dim if self.fake_head_dim is None else self.fake_head_dim, q_data_type=self.q_data_type, + sm_scale=sm_scale_for_begin_forward, ) # cached part From c4e6f0e6e0636e33568d484066335a8c796aad24 Mon Sep 17 00:00:00 2001 From: ccs96307 Date: Mon, 12 May 2025 12:54:08 +0000 Subject: [PATCH 3/5] Fix: determine it is not generate and not encoder-decoder architecture --- python/sglang/srt/layers/attention/flashinfer_backend.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index ede387d166c6..0cb83f47e353 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -783,7 +783,11 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): ) self.head_dim = model_runner.model_config.head_dim self.fake_head_dim = ( - None if self.head_dim >= global_fake_head_dim else global_fake_head_dim + global_fake_head_dim + if self.head_dim < global_fake_head_dim + and not model_runner.model_config.is_generation + and not model_runner.model_config.is_encoder_decoder + else None ) self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype From 154232b7bc4c6766c71190e267ce728787b3cc1f Mon Sep 17 00:00:00 2001 From: ccs96307 Date: Thu, 15 May 2025 06:35:42 +0000 Subject: [PATCH 4/5] Delete padding workaround --- .../layers/attention/flashinfer_backend.py | 90 ++----------------- 1 file changed, 6 insertions(+), 84 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 0cb83f47e353..e72aa1f449ca 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -7,7 +7,6 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. """ -import math import os from dataclasses import dataclass from enum import Enum, auto @@ -66,19 +65,6 @@ class PrefillMetadata: global_workspace_buffer = None -# global_fake_head_dim is used as a workaround to bypass FlashInfer's current limitation -# which does not support head_dim=32 (or other unsupported dimensions). The actual -# QKV tensors have an effective head dimension of origin, with the remaining padded -# with zeros. -# -# Note: Be sure to set sm_scale = 1.0 / sqrt(actual_dim), based on the real head_dim, -# even though fake_head_dim is passed to the attention kernel. -# -# TODO: Once FlashInfer officially supports head_dim=32, this variable and the -# associated padding logic should be removed to eliminate the workaround. -global_fake_head_dim = 64 - - class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" @@ -458,61 +444,10 @@ def forward_extend( save_kv_cache = False causal = False - q = q.view(-1, layer.tp_q_head_num, layer.head_dim) - k = k.view(-1, layer.tp_k_head_num, layer.head_dim) - v = v.view(-1, layer.tp_v_head_num, layer.head_dim) - - # Pad the `head_dim` if it less than `global_fake_head_dim` - original_head_dim = layer.head_dim - needs_padding = ( - hasattr(self.indices_updater_prefill, "fake_head_dim") - and self.indices_updater_prefill.fake_head_dim is not None - and original_head_dim < self.indices_updater_prefill.fake_head_dim - ) - - if needs_padding: - q_padded_shape = q.shape[:-1] + ( - self.indices_updater_prefill.fake_head_dim, - ) - q_padded = torch.zeros( - q_padded_shape, dtype=q.dtype, device=q.device - ) - q_padded[..., :original_head_dim] = q - q = q_padded - - k_padded_shape = k.shape[:-1] + ( - self.indices_updater_prefill.fake_head_dim, - ) - k_padded = torch.zeros( - k_padded_shape, dtype=k.dtype, device=k.device - ) - k_padded[..., :original_head_dim] = k - k = k_padded - - v_padded_shape = v.shape[:-1] + ( - self.indices_updater_prefill.fake_head_dim, - ) - v_padded = torch.zeros( - v_padded_shape, dtype=v.dtype, device=v.device - ) - v_padded[..., :original_head_dim] = v - v = v_padded - - o = self.prefill_wrapper_ragged.forward( - q, - k, - v, - causal=causal, - sm_scale=layer.scaling, - logits_soft_cap=logits_soft_cap, - ) - - if needs_padding: - o = o[..., :original_head_dim] - - o = o.contiguous() - - elif self.forward_metadata.extend_no_prefix: + if self.forward_metadata.extend_no_prefix: + # NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions + # The FlashInfer head_dim limitation itself is tracked here: + # https://github.com/flashinfer-ai/flashinfer/issues/1048 o = self.prefill_wrapper_ragged.forward( q.view(-1, layer.tp_q_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim), @@ -782,13 +717,6 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): get_attention_tp_size() ) self.head_dim = model_runner.model_config.head_dim - self.fake_head_dim = ( - global_fake_head_dim - if self.head_dim < global_fake_head_dim - and not model_runner.model_config.is_generation - and not model_runner.model_config.is_encoder_decoder - else None - ) self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype self.sliding_window_size = model_runner.sliding_window_size @@ -987,11 +915,6 @@ def call_begin_forward( ) ) - # If self.fake_head_dim is not None - sm_scale_for_begin_forward = None - if self.fake_head_dim is not None: - sm_scale_for_begin_forward = 1.0 / math.sqrt(self.head_dim) - # extend part if use_ragged: wrapper_ragged.begin_forward( @@ -999,9 +922,8 @@ def call_begin_forward( qo_indptr, self.num_qo_heads, self.num_kv_heads, - self.head_dim if self.fake_head_dim is None else self.fake_head_dim, + self.head_dim, q_data_type=self.q_data_type, - sm_scale=sm_scale_for_begin_forward, ) # cached part @@ -1012,7 +934,7 @@ def call_begin_forward( self.kv_last_page_len[:bs], self.num_qo_heads, self.num_kv_heads, - self.fake_head_dim if self.fake_head_dim else self.head_dim, + self.head_dim, 1, q_data_type=self.q_data_type, kv_data_type=self.data_type, From 212591fb20ca066b0d133831a5efbb912faa8cb5 Mon Sep 17 00:00:00 2001 From: ccs96307 Date: Thu, 15 May 2025 06:36:54 +0000 Subject: [PATCH 5/5] Add description for limitation of flashinfer --- .../models/test_encoder_embedding_models.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/test/srt/models/test_encoder_embedding_models.py b/test/srt/models/test_encoder_embedding_models.py index 1f7a06b9311c..dafaa72db595 100644 --- a/test/srt/models/test_encoder_embedding_models.py +++ b/test/srt/models/test_encoder_embedding_models.py @@ -126,12 +126,18 @@ def test_prefill_logits(self): for attention_backend in ATTENTION_BACKEND: for batch_size in BATCH_SIZE: for torch_dtype in TORCH_DTYPES: - # Flashinfer is not support torch.float32 for dtype_q, so pass it - if ( - torch_dtype == torch.float32 - and attention_backend == "flashinfer" - ): - continue + # NOTE: FlashInfer currently has limitations with head_dim = 32 or + # other dimensions. + # The FlashInfer head_dim limitation itself is tracked here: + # https://github.com/flashinfer-ai/flashinfer/issues/1048 + # + # Flashinfer does not support torch.float32 for dtype_q, so skip it + if attention_backend == "flashinfer": + if ( + model == "BAAI/bge-small-en" + or torch_dtype == torch.float32 + ): + continue self.assert_close_prefill_logits( DEFAULT_PROMPTS,