Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,6 @@ def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int)
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
]
else:
raise ValueError(
"get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
)
swa_attention_layer_ids = None
full_attention_layer_ids = None
return swa_attention_layer_ids, full_attention_layer_ids
10 changes: 9 additions & 1 deletion python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,15 @@ def _allocatable_tokens(
else 0
)

allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max(
if self.scheduler.model_config.is_hybrid:
available_size = min(
self.token_to_kv_pool_allocator.full_available_size(),
self.token_to_kv_pool_allocator.swa_available_size(),
)
else:
available_size = self.token_to_kv_pool_allocator.available_size()

allocatable_tokens = available_size - max(
# preserve some space for future decode
self.num_reserved_decode_tokens
* (
Expand Down
30 changes: 30 additions & 0 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
Expand Down Expand Up @@ -589,6 +590,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator

# Dispatch the update function
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
Expand Down Expand Up @@ -655,6 +657,10 @@ def update_sliding_window(
paged_kernel_lens_sum_tmp = seq_lens_sum
kv_start_idx_tmp = None

use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)

self.call_begin_forward(
decode_wrappers[wrapper_id],
req_pool_indices,
Expand All @@ -663,6 +669,7 @@ def update_sliding_window(
self.kv_indptr[wrapper_id],
kv_start_idx_tmp,
spec_info,
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
)

def update_cross_attention(
Expand Down Expand Up @@ -704,6 +711,7 @@ def call_begin_forward(
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
use_sliding_window_kv_pool: bool = False,
):
if spec_info is None:
bs = len(req_pool_indices)
Expand Down Expand Up @@ -731,6 +739,14 @@ def call_begin_forward(
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1

if use_sliding_window_kv_pool:
kv_last_index = kv_indptr[-1]
kv_indices[:kv_last_index] = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
kv_indices[:kv_last_index]
)
)

wrapper.begin_forward(
kv_indptr,
kv_indices,
Expand Down Expand Up @@ -765,6 +781,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged

# Dispatch the update function
Expand Down Expand Up @@ -848,6 +865,9 @@ def update_sliding_window(
paged_kernel_lens_sum = seq_lens_sum

kv_start_idx = seq_lens - paged_kernel_lens
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)

self.call_begin_forward(
self.prefill_wrapper_ragged,
Expand All @@ -862,6 +882,7 @@ def update_sliding_window(
self.qo_indptr[wrapper_id],
use_ragged,
spec_info,
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
)

def update_cross_attention(
Expand Down Expand Up @@ -916,6 +937,7 @@ def call_begin_forward(
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
use_sliding_window_kv_pool: bool = False,
):
bs = len(seq_lens)
if spec_info is None:
Expand Down Expand Up @@ -964,6 +986,14 @@ def call_begin_forward(
q_data_type=self.q_data_type,
)

if use_sliding_window_kv_pool:
kv_last_index = kv_indptr[-1]
kv_indices[:kv_last_index] = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
kv_indices[:kv_last_index]
)
)

# cached part
wrapper_paged.begin_forward(
qo_indptr,
Expand Down
Loading
Loading