Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 67 additions & 7 deletions python/sglang/srt/mem_cache/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,24 @@ def __init__(
self._kvcache = kvcache

self.free_pages = None
self.release_pages = None
self.is_not_in_free_group = True
self.free_group = []

def debug_print(self) -> str:
return ""

def available_size(self):
return len(self.free_pages) * self.page_size
return (len(self.free_pages) + len(self.release_pages)) * self.page_size

def get_kvcache(self):
return self._kvcache

def restore_state(self, free_pages):
self.free_pages = free_pages
def restore_state(self, state):
self.free_pages, self.release_pages = state

def backup_state(self):
return self.free_pages
return (self.free_pages, self.release_pages)

def free_group_begin(self):
self.is_not_in_free_group = False
Expand All @@ -78,6 +79,14 @@ def free_group_end(self):
if self.free_group:
self.free(torch.cat(self.free_group))

def merge_and_sort_free(self):
if len(self.release_pages) > 0:
self.free_pages = torch.cat((self.free_pages, self.release_pages))
self.free_pages, _ = torch.sort(self.free_pages)
self.release_pages = torch.empty(
(0,), dtype=self.release_pages.dtype, device=self.device
)

def get_cpu_copy(self, *args, **kwargs):
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
raise NotImplementedError()
Expand Down Expand Up @@ -119,12 +128,15 @@ def clear(self):
)
self.is_not_in_free_group = True
self.free_group = []
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)

def available_size(self):
# To avoid minor "len(free_pages) * 1" overhead
return len(self.free_pages)
return len(self.free_pages) + len(self.release_pages)

def alloc(self, need_size: int):
if need_size > len(self.free_pages):
self.merge_and_sort_free()
if need_size > len(self.free_pages):
return None

Expand All @@ -137,7 +149,7 @@ def free(self, free_index: torch.Tensor):
return

if self.is_not_in_free_group:
self.free_pages = torch.cat((self.free_pages, free_index))
self.release_pages = torch.cat((self.release_pages, free_index))
else:
self.free_group.append(free_index)

Expand Down Expand Up @@ -421,6 +433,8 @@ def alloc(self, need_size: int):
), "The allocation size should be page-aligned"

num_pages = need_size // self.page_size
if num_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_pages > len(self.free_pages):
return None

Expand All @@ -446,6 +460,17 @@ def alloc_extend(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)

estimated_num_new_pages = (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code has been duplicated too many times. Please write a common subfunction for it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment, I'll fix it soon.

(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to reduce the sync by estimating with extend_num_tokens

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just changed this logic by using extend_num_tokens
#8794

)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()

bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device
Expand Down Expand Up @@ -483,6 +508,17 @@ def alloc_decode(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)

estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (seq_lens - 1 + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()

bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
alloc_decode_kernel[(bs,)](
Expand Down Expand Up @@ -511,7 +547,7 @@ def free(self, free_index: torch.Tensor):

if self.is_not_in_free_group:
free_page_indices = torch.unique(free_index // self.page_size)
self.free_pages = torch.cat((free_page_indices, self.free_pages))
self.release_pages = torch.cat((free_page_indices, self.release_pages))
else:
self.free_group.append(free_index)

Expand All @@ -525,6 +561,7 @@ def clear(self):
)
self.is_not_in_free_group = True
self.free_group = []
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)

def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)
Expand Down Expand Up @@ -633,6 +670,17 @@ def alloc_extend(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)

estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()

bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
Expand Down Expand Up @@ -668,6 +716,17 @@ def alloc_decode(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)

estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (seq_lens - 1 + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()

bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)

Expand All @@ -692,3 +751,4 @@ def alloc_decode(
def clear(self):
super().clear()
self.free_pages = self.free_pages.to(torch.int32)
self.release_pages = self.release_pages.to(torch.int32)
Loading