Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def finished(self) -> bool:
def init_next_round_input(
self,
tree_cache: Optional[BasePrefixCache] = None,
disable_inc_hit_count: Optional[bool] = False,
):
self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None:
Expand All @@ -642,6 +643,7 @@ def init_next_round_input(
self.host_hit_length,
) = tree_cache.match_prefix(
key=self.adjust_max_prefix_ids(),
disable_inc_hit_count=disable_inc_hit_count
)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ def _add_request_to_queue(self, req: Req):
self.disagg_decode_prealloc_queue.add(req)
else:
if self.enable_hicache_storage:
req.init_next_round_input(self.tree_cache)
req.init_next_round_input(self.tree_cache, disable_inc_hit_count=True)
last_hash = req.last_host_node.get_last_hash_value()
matched_len = len(req.prefix_indices) + req.host_hit_length
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
Expand Down
13 changes: 7 additions & 6 deletions python/sglang/srt/mem_cache/hiradix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def match_prefix(self, key: List[int], **kwargs):
page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len]

value, last_node = self._match_prefix_helper(self.root_node, key)
value, last_node = self._match_prefix_helper(self.root_node, key, **kwargs)
if value:
value = torch.cat(value)
else:
Expand Down Expand Up @@ -540,7 +540,8 @@ def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value)
node.children[child_key] = new_node
return matched_length

def _match_prefix_helper(self, node: TreeNode, key: List):
def _match_prefix_helper(self, node: TreeNode, key: List, **kwargs):
disable_inc_hit_count = kwargs.pop("disable_inc_hit_count", False)
node.last_access_time = time.monotonic()
child_key = self.get_child_key_fn(key)
value = []
Expand All @@ -551,13 +552,15 @@ def _match_prefix_helper(self, node: TreeNode, key: List):
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
self.inc_hit_count(new_node)
if not disable_inc_hit_count:
self.inc_hit_count(new_node)
if not new_node.evicted:
value.append(new_node.value)
node = new_node
break
else:
self.inc_hit_count(child)
if not disable_inc_hit_count:
self.inc_hit_count(child)
if not child.evicted:
value.append(child.value)
node = child
Expand Down Expand Up @@ -613,7 +616,6 @@ def _insert_helper(self, node: TreeNode, key: List, value):
self.token_to_kv_pool_host.update_synced(node.host_value)
self.evictable_size_ += len(node.value)
else:
self.inc_hit_count(node)
total_prefix_length += prefix_len
else:
# partial match, split the node
Expand All @@ -623,7 +625,6 @@ def _insert_helper(self, node: TreeNode, key: List, value):
self.token_to_kv_pool_host.update_synced(new_node.host_value)
self.evictable_size_ += len(new_node.value)
else:
self.inc_hit_count(new_node)
total_prefix_length += prefix_len
node = new_node

Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len]

value, last_node = self._match_prefix_helper(self.root_node, key)
value, last_node = self._match_prefix_helper(self.root_node, key, **kwargs)
if value:
value = torch.cat(value)
else:
Expand Down Expand Up @@ -269,7 +269,7 @@ def cache_unfinished_req(self, req: Req):
)

# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids, disable_inc_hit_count=True)
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
Expand Down Expand Up @@ -367,7 +367,7 @@ def _dfs_helper(node: TreeNode):

##### Internal Helper Functions #####

def _match_prefix_helper(self, node: TreeNode, key: List):
def _match_prefix_helper(self, node: TreeNode, key: List, **kwargs):
node.last_access_time = time.monotonic()

child_key = self.get_child_key_fn(key)
Expand Down