Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import threading
from enum import Enum, auto
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -436,7 +436,7 @@ def __init__(
self,
rid: str,
origin_input_text: str,
origin_input_ids: Tuple[int],
origin_input_ids: List[int],
sampling_params: SamplingParams,
return_logprob: bool = False,
top_logprobs_num: int = 0,
Expand Down Expand Up @@ -467,7 +467,7 @@ def __init__(
# Each decode stage's output ids
self.output_ids = []
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
self.fill_ids = None
self.fill_ids = []
self.session_id = session_id
self.input_embeds = input_embeds

Expand Down Expand Up @@ -519,13 +519,14 @@ def __init__(

# Prefix info
# The indices to kv cache for the shared prefix.
self.prefix_indices = []
self.prefix_indices: torch.Tensor = []
# Number of tokens to run prefill.
self.extend_input_len = 0
# The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0
self.last_node = None
self.last_node_global = None
self.last_node: Any = None
self.last_host_node: Any = None
self.host_hit_length = 0

# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
Expand Down Expand Up @@ -644,21 +645,17 @@ def finished(self) -> bool:
def init_next_round_input(
self,
tree_cache: Optional[BasePrefixCache] = None,
enable_hierarchical_cache=False,
):
self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None:
# tree cache is None if the prefix is not computed with tree cache.
if enable_hierarchical_cache:
self.prefix_indices, self.last_node, self.last_node_global = (
tree_cache.match_prefix(
key=self.adjust_max_prefix_ids(), include_evicted=True
)
)
else:
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
(
self.prefix_indices,
self.last_node,
self.last_host_node,
self.host_hit_length,
) = tree_cache.match_prefix(
key=self.adjust_max_prefix_ids(),
)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)

def adjust_max_prefix_ids(self):
Expand Down
58 changes: 26 additions & 32 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
def calc_priority(self, waiting_queue: List[Req]) -> bool:
if self.policy == CacheAgnosticPolicy.FCFS:
# A shortcut for FCFS
return
return False

policy = self._determine_active_policy(waiting_queue)

Expand Down Expand Up @@ -134,7 +134,7 @@ def _validate_and_adjust_policy(
"""
try:
policy_enum = CacheAwarePolicy(policy)
if tree_cache.disable:
if getattr(tree_cache, "disable", True):
# If tree_cache is disabled, using CacheAgnosticPolicy policy
return CacheAgnosticPolicy.FCFS
return policy_enum
Expand All @@ -158,14 +158,9 @@ def _compute_prefix_matches(
prefix_ids = r.adjust_max_prefix_ids()

# NOTE: the prefix_indices must always be aligned with last_node
if self.enable_hierarchical_cache:
r.prefix_indices, r.last_node, r.last_node_global = (
self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True)
)
else:
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
)
r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids)
)

# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
Expand All @@ -175,7 +170,7 @@ def _compute_prefix_matches(
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
in_batch_matching_prefixes, _ = (
in_batch_matching_prefixes, _, _, _ = (
self.waiting_queue_radix_tree.match_prefix(
rid=r.rid, key=prefix_ids
)
Expand Down Expand Up @@ -268,6 +263,7 @@ class AddReqResult(Enum):
class PrefillAdder:
def __init__(
self,
page_size: int,
tree_cache: BasePrefixCache,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
running_batch: ScheduleBatch,
Expand All @@ -276,6 +272,7 @@ def __init__(
rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0,
):
self.page_size = page_size
self.tree_cache = tree_cache
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.running_batch = running_batch
Expand Down Expand Up @@ -442,46 +439,43 @@ def add_req_state(r, insert_sort=False):

return self.budget_state()

def add_one_req(
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
):
def add_one_req(self, req: Req, has_chunked_req: bool):
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
return self.add_one_req_ignore_eos(req, has_chunked_req)

total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
)
input_tokens = (
-(-req.extend_input_len // self.tree_cache.page_size)
* self.tree_cache.page_size
)

# adjusting the input_tokens based on host_hit_length and page_size
real_input_tokens = req.extend_input_len - req.host_hit_length
real_input_tokens = -(-real_input_tokens // self.page_size) * self.page_size
prefix_len = len(req.prefix_indices)

if total_tokens >= self.rem_total_tokens:
return AddReqResult.NO_TOKEN

if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
return AddReqResult.OTHER

with self._lock_node(req.last_node):
if total_tokens > self.rem_total_tokens:
# self.rem_total_tokens may decrease after the lock acquisition
if total_tokens >= self.rem_total_tokens:
return AddReqResult.NO_TOKEN

if (
enable_hierarchical_cache
and req.last_node_global is not None
and req.last_node_global.evicted
):
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
req.last_node_global, req.prefix_indices
if req.host_hit_length > 0:
new_indices, req.last_node = self.tree_cache.init_load_back(
req.last_host_node, req.host_hit_length
)
req.prefix_indices = torch.cat([req.prefix_indices, new_indices])
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
input_tokens = (
-(-req.extend_input_len // self.tree_cache.page_size)
* self.tree_cache.page_size
)
prefix_len = len(req.prefix_indices)

input_tokens = -(-req.extend_input_len // self.page_size) * self.page_size

if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
return AddReqResult.OTHER

if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill
self.can_run_list.append(req)
Expand All @@ -496,7 +490,7 @@ def add_one_req(
)
else:
# Make sure at least one page is available
trunc_len = self.rem_chunk_tokens - self.tree_cache.page_size + 1
trunc_len = self.rem_chunk_tokens - self.page_size + 1
if trunc_len <= 0:
return AddReqResult.OTHER

Expand Down
26 changes: 8 additions & 18 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,15 +1468,14 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
return None

if self.enable_hierarchical_cache:
# check for completion of hierarchical cache activities to release memory
self.tree_cache.writing_check()
self.tree_cache.loading_check()
self.tree_cache.check_hicache_events()

# Get priority queue
prefix_computed = self.policy.calc_priority(self.waiting_queue)
self.policy.calc_priority(self.waiting_queue)

# Prefill policy
adder = PrefillAdder(
self.page_size,
self.tree_cache,
self.token_to_kv_pool_allocator,
self.running_batch,
Expand Down Expand Up @@ -1518,19 +1517,8 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.running_batch.batch_is_full = True
break

# bypass prefix_computed if enable_hierarchical_cache
req.init_next_round_input(
(
None
if (prefix_computed and not self.enable_hierarchical_cache)
else self.tree_cache
),
self.enable_hierarchical_cache,
)

res = adder.add_one_req(
req, self.chunked_req, self.enable_hierarchical_cache
)
req.init_next_round_input(self.tree_cache)
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))

if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN:
Expand Down Expand Up @@ -1582,7 +1570,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
)
if self.enable_hierarchical_cache:
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache()
new_batch.hicache_consumer_index = (
self.tree_cache.ready_to_load_host_cache()
)

new_batch.prepare_for_extend()

Expand Down
60 changes: 52 additions & 8 deletions python/sglang/srt/mem_cache/base_prefix_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
from abc import ABC, abstractmethod
from typing import Any, List, Tuple
from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple

import torch

if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
else:
Req = Any # Placeholder for Req type when not type checking


class MatchResult(NamedTuple):
"""Result of a prefix match operation.

Attributes:
device_indices : Indices of the KV cache on the device matched by common prefix.
last_device_node: The last TreeNode on the device that was matched.
last_host_node : The last TreeNode on the host that was matched.
Note that if HiCache is not enabled,
this **must** be the same as `last_device_node`.
host_hit_length : Length of the KV cache hit on the host, if applicable.
0 if HiCache is not enabled.
"""

device_indices: torch.Tensor
last_device_node: Any
last_host_node: Any
host_hit_length: int = 0


class BasePrefixCache(ABC):
Expand All @@ -10,19 +36,15 @@ def reset(self):
pass

@abstractmethod
def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
pass

@abstractmethod
def insert(self, **kwargs):
def cache_finished_req(self, req: Req, **kwargs):
pass

@abstractmethod
def cache_finished_req(self, **kwargs):
pass

@abstractmethod
def cache_unfinished_req(self, **kwargs):
def cache_unfinished_req(self, req: Req, **kwargs):
pass

@abstractmethod
Expand All @@ -49,5 +71,27 @@ def total_size(self):
def pretty_print(self):
raise NotImplementedError()

def init_load_back(
self,
last_host_node: Any,
host_hit_length: int,
) -> Tuple[torch.Tensor, Any]:
"""
Preparing KV cache loading from host to device.
"""
raise NotImplementedError()

def ready_to_load_host_cache(self) -> Any:
"""
Notify the cache controller to start the KV cache loading
"""
raise NotImplementedError()

def check_hicache_events(self) -> Any:
"""
Check HiCache related activities to update radix tree and synchronize across TP workers if needed
"""
raise NotImplementedError()

def take_events(self):
return []
20 changes: 7 additions & 13 deletions python/sglang/srt/mem_cache/chunk_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,13 @@

import torch

from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator

if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req


class ChunkCacheEntry:
def __init__(self, rid: str, value: torch.Tensor):
self.rid = rid
self.value = value


class ChunkCache(BasePrefixCache):
def __init__(
self,
Expand All @@ -29,13 +23,16 @@ def __init__(
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = True

def reset(self):
pass

def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
return [], None
def match_prefix(self, **unused_kwargs) -> MatchResult:
return MatchResult(
device_indices=torch.empty((0,), dtype=torch.int64),
last_device_node=None,
last_host_node=None,
)

def cache_finished_req(self, req: Req):
kv_indices = self.req_to_token_pool.req_to_token[
Expand All @@ -54,9 +51,6 @@ def cache_unfinished_req(self, req: Req):
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req.prefix_indices = kv_indices

def insert(self):
raise NotImplementedError()

def evict(self, num_tokens: int):
pass

Expand Down
Loading
Loading