Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 48 additions & 49 deletions python/sglang/srt/mem_cache/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ def __init__(
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
self._kvcache = kvcache
self.need_sort = need_sort

self.free_pages = None
self.release_pages = None
Expand Down Expand Up @@ -79,6 +81,9 @@ def free_group_end(self):
if self.free_group:
self.free(torch.cat(self.free_group))

def estimated_num_new_pages(self, bs, extend_num_tokens):
return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)

def merge_and_sort_free(self):
if len(self.release_pages) > 0:
self.free_pages = torch.cat((self.free_pages, self.release_pages))
Expand Down Expand Up @@ -117,8 +122,15 @@ def free(self, free_index: torch.Tensor):
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""An allocator managing the indices to kv cache data."""

def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
super().__init__(size, 1, dtype, device, kvcache)
def __init__(
self,
size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
super().__init__(size, 1, dtype, device, kvcache, need_sort)
self.clear()

def clear(self):
Expand All @@ -135,7 +147,7 @@ def available_size(self):
return len(self.free_pages) + len(self.release_pages)

def alloc(self, need_size: int):
if need_size > len(self.free_pages):
if self.need_sort and need_size > len(self.free_pages):
self.merge_and_sort_free()
if need_size > len(self.free_pages):
return None
Expand All @@ -149,7 +161,10 @@ def free(self, free_index: torch.Tensor):
return

if self.is_not_in_free_group:
self.release_pages = torch.cat((self.release_pages, free_index))
if self.need_sort:
self.release_pages = torch.cat((self.release_pages, free_index))
else:
self.free_pages = torch.cat((self.free_pages, free_index))
else:
self.free_group.append(free_index)

Expand All @@ -170,8 +185,9 @@ def __init__(
dtype: torch.dtype,
device: str,
kvcache: SWAKVPool,
need_sort: bool,
):
super().__init__(size, 1, dtype, device, kvcache)
super().__init__(size, 1, dtype, device, kvcache, need_sort)
assert isinstance(kvcache, SWAKVPool)
self._size_full = size
self._size_swa = size_swa
Expand All @@ -180,12 +196,14 @@ def __init__(
dtype,
device,
kvcache.full_kv_pool,
need_sort,
)
self.swa_attn_allocator = TokenToKVPoolAllocator(
size_swa,
dtype,
device,
kvcache.swa_kv_pool,
need_sort,
)
self.full_to_swa_index_mapping = torch.empty(
size + size_swa + 1,
Expand Down Expand Up @@ -418,8 +436,9 @@ def __init__(
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
super().__init__(size, page_size, dtype, device, kvcache)
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
Expand All @@ -433,7 +452,7 @@ def alloc(self, need_size: int):
), "The allocation size should be page-aligned"

num_pages = need_size // self.page_size
if num_pages > len(self.free_pages):
if self.need_sort and num_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_pages > len(self.free_pages):
return None
Expand All @@ -460,18 +479,12 @@ def alloc_extend(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)

estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
bs = len(prefix_lens)
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
self.free_pages
):
self.merge_and_sort_free()

bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
Expand Down Expand Up @@ -508,18 +521,12 @@ def alloc_decode(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)

estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (seq_lens - 1 + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
bs = len(seq_lens)
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
self.free_pages
):
self.merge_and_sort_free()

bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
alloc_decode_kernel[(bs,)](
seq_lens,
Expand Down Expand Up @@ -547,7 +554,10 @@ def free(self, free_index: torch.Tensor):

if self.is_not_in_free_group:
free_page_indices = torch.unique(free_index // self.page_size)
self.release_pages = torch.cat((free_page_indices, self.release_pages))
if self.need_sort:
self.release_pages = torch.cat((free_page_indices, self.release_pages))
else:
self.free_pages = torch.cat((free_page_indices, self.free_pages))
else:
self.free_group.append(free_index)

Expand Down Expand Up @@ -654,8 +664,9 @@ def __init__(
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
super().__init__(size, page_size, dtype, device, kvcache)
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)

def alloc_extend(
Expand All @@ -670,18 +681,12 @@ def alloc_extend(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)

estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
bs = len(prefix_lens)
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
self.free_pages
):
self.merge_and_sort_free()

bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
)
Expand Down Expand Up @@ -716,18 +721,12 @@ def alloc_decode(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)

estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (seq_lens - 1 + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
bs = len(seq_lens)
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
self.free_pages
):
self.merge_and_sort_free()

bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)

self.ret_values = alloc_decode_kernel_ascend(
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,13 +1272,17 @@ def init_memory_pool(
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
)
else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
)
else:
if _is_npu:
Expand All @@ -1288,6 +1292,8 @@ def init_memory_pool(
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
)
else:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
Expand All @@ -1296,6 +1302,8 @@ def init_memory_pool(
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
)
else:
assert self.is_draft_worker
Expand Down
Loading