Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
119d60b
Add Flashinfer support for encoder-only model
ccs96307 May 12, 2025
f5794ce
reformat
ccs96307 May 12, 2025
91e8412
Merge branch 'main' into flashinfer-attn-support-encoder
ccs96307 May 12, 2025
c4e6f0e
Fix: determine it is not generate and not encoder-decoder architecture
ccs96307 May 12, 2025
d35a35a
Merge branch 'main' into flashinfer-attn-support-encoder
ccs96307 May 13, 2025
d86966e
Merge branch 'main' into flashinfer-attn-support-encoder
ccs96307 May 14, 2025
7bcec6a
Merge branch 'main' into flashinfer-attn-support-encoder
Fridge003 May 14, 2025
6945a23
Merge remote-tracking branch 'upstream/main' into flashinfer-attn-sup…
ccs96307 May 15, 2025
154232b
Delete padding workaround
ccs96307 May 15, 2025
212591f
Add description for limitation of flashinfer
ccs96307 May 15, 2025
75e35d1
Merge branch 'main' into flashinfer-attn-support-encoder
ccs96307 May 15, 2025
02165d3
Merge branch 'main' into flashinfer-attn-support-encoder
Fridge003 May 15, 2025
ff4595c
Merge branch 'main' into flashinfer-attn-support-encoder
Fridge003 May 17, 2025
4d6efb4
Merge branch 'main' into flashinfer-attn-support-encoder
ccs96307 May 18, 2025
60c2bc1
Merge branch 'main' into flashinfer-attn-support-encoder
Fridge003 May 19, 2025
d4ff9a8
Merge branch 'main' into flashinfer-attn-support-encoder
ccs96307 Jun 10, 2025
1c65c7c
Merge branch 'main' into flashinfer-attn-support-encoder
ccs96307 Jun 17, 2025
29d1be5
Merge branch 'main' into flashinfer-attn-support-encoder
ccs96307 Jun 18, 2025
33a8c83
Merge branch 'main' into flashinfer-attn-support-encoder
Fridge003 Jun 21, 2025
f52c693
Merge branch 'main' into flashinfer-attn-support-encoder
ccs96307 Jun 25, 2025
4a888d1
Merge branch 'main' into flashinfer-attn-support-encoder
ccs96307 Jul 3, 2025
b680370
Merge branch 'main' into flashinfer-attn-support-encoder
Fridge003 Jul 12, 2025
2c44ddc
Merge branch 'main' into flashinfer-attn-support-encoder
ccs96307 Jul 14, 2025
e8b01fe
Merge branch 'main' into flashinfer-attn-support-encoder
Fridge003 Jul 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
Expand Down Expand Up @@ -486,12 +487,20 @@ def forward_extend(
v_scale=layer.v_scale,
)
else:
causal = True
if layer.attn_type == AttentionType.ENCODER_ONLY:
save_kv_cache = False
causal = False

if self.forward_metadata.extend_no_prefix:
# NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
# The FlashInfer head_dim limitation itself is tracked here:
# https://github.com/flashinfer-ai/flashinfer/issues/1048
o = self.prefill_wrapper_ragged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
causal=causal,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
Expand Down
17 changes: 15 additions & 2 deletions test/srt/models/test_encoder_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@

MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)]

ATTENTION_BACKEND = ["torch_native", "triton"]
ATTENTION_BACKEND = ["torch_native", "triton", "flashinfer"]
BATCH_SIZE = [1, 2]
TORCH_DTYPES = [torch.float32]
TORCH_DTYPES = [torch.float32, torch.float16]
sgl_to_st_ratio = []


Expand Down Expand Up @@ -126,6 +126,19 @@ def test_prefill_logits(self):
for attention_backend in ATTENTION_BACKEND:
for batch_size in BATCH_SIZE:
for torch_dtype in TORCH_DTYPES:
# NOTE: FlashInfer currently has limitations with head_dim = 32 or
# other dimensions.
# The FlashInfer head_dim limitation itself is tracked here:
# https://github.com/flashinfer-ai/flashinfer/issues/1048
#
# Flashinfer does not support torch.float32 for dtype_q, so skip it
if attention_backend == "flashinfer":
if (
model == "BAAI/bge-small-en"
or torch_dtype == torch.float32
):
continue

self.assert_close_prefill_logits(
DEFAULT_PROMPTS,
model,
Expand Down
Loading