Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 52 additions & 6 deletions python/sglang/srt/managers/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __init__(
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
mem_pool_host: HostKVCache,
page_size: int,
tp_group: torch.distributed.ProcessGroup,
load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective",
io_backend: str = "",
Expand All @@ -244,11 +245,17 @@ def __init__(
self.enable_storage = False
# todo: move backend initialization to storage backend module
if storage_backend is not None:
# create a new communication group for synchronizing storage operations across TP workers
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
if self.tp_world_size > 1:
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")

if storage_backend == "file":
self.storage_backend = HiCacheFile()
self.enable_storage = True
# todo: threshold policy for prefetching
self.prefetch_threshold = prefetch_threshold
self.prefetch_threshold = min(prefetch_threshold, self.page_size)
else:
raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}"
Expand Down Expand Up @@ -568,13 +575,32 @@ def prefetch_thread_func(self):
else:
break

if self.tp_world_size > 1:
storage_hit_count_tensor = torch.tensor(
storage_hit_count, dtype=torch.int
)
torch.distributed.all_reduce(
storage_hit_count_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
storage_hit_count = storage_hit_count_tensor.item()

if storage_hit_count < self.prefetch_threshold:
# not to prefetch if not enough benefits
self.prefetch_revoke_queue.put(operation.request_id)
logger.debug(
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
)
else:
operation.hash_value = hash_value
operation.hash_value = hash_value[
: (storage_hit_count // self.page_size)
]
# free the pre-allocated memory for pages that are not hit
self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
operation.host_indices = operation.host_indices[:storage_hit_count]
logger.debug(
f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
)
self.prefetch_buffer.put(operation)

Expand Down Expand Up @@ -611,17 +637,37 @@ def backup_thread_func(self):
last_hash = get_hash_str(
tokens_to_backup[i : i + self.page_size], last_hash
)
# todo, handle failures in storage backend
self.storage_backend.set(
success = self.storage_backend.set(
last_hash,
self.mem_pool_host.get_flat_data_page(
operation.host_indices[i]
),
)
if not success:
logger.warning(f"Failed to write page {last_hash} to storage.")
break
operation.completed_tokens += self.page_size
operation.hash_value.append(last_hash)

self.ack_backup_queue.put((operation.id, operation.hash_value))
min_completed_tokens = operation.completed_tokens
if self.tp_world_size > 1:
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce(
completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
min_completed_tokens = completed_tokens_tensor.item()

self.ack_backup_queue.put(
(
operation.id,
operation.hash_value[:min_completed_tokens],
min_completed_tokens,
)
)

except Empty:
continue
15 changes: 14 additions & 1 deletion python/sglang/srt/mem_cache/hicache_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
logger = logging.getLogger(__name__)


from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)


def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
hasher = hashlib.sha256()

Expand Down Expand Up @@ -80,13 +86,17 @@ class HiCacheFile(HiCacheStorage):

def __init__(self, file_path: str = "/tmp/hicache"):
self.file_path = file_path
if not os.path.exists(self.file_path):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else ""
if not os.path.exists(self.file_path) and tp_rank == 0:
os.makedirs(self.file_path)
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")

def get(
self, key: str, target_location: Optional[torch.Tensor] = None
) -> torch.Tensor | None:
key += self.tp_suffix
tensor_path = os.path.join(self.file_path, f"{key}.bin")
try:
# todo: fixing the target_location logic to enable in-place loading
Expand All @@ -112,6 +122,7 @@ def batch_get(
]

def set(self, key: str, value: torch.Tensor) -> bool:
key += self.tp_suffix
tensor_path = os.path.join(self.file_path, f"{key}.bin")
if self.exists(key):
logger.debug(f"Key {key} already exists. Skipped.")
Expand All @@ -130,10 +141,12 @@ def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
return True

def exists(self, key: str) -> bool:
key += self.tp_suffix
tensor_path = os.path.join(self.file_path, f"{key}.bin")
return os.path.exists(tensor_path)

def delete(self, key: str) -> None:
key += self.tp_suffix
tensor_path = os.path.join(self.file_path, f"{key}.bin")
try:
os.remove(tensor_path)
Expand Down
46 changes: 30 additions & 16 deletions python/sglang/srt/mem_cache/hiradix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")

self.tp_group = tp_cache_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
self.enable_storage = hicache_storage_backend is not None
# todo: customizable storage prefetch threshold
self.prefetch_threshold = 256
Expand All @@ -59,6 +60,7 @@ def __init__(
token_to_kv_pool_allocator,
self.token_to_kv_pool_host,
page_size,
self.tp_group,
load_cache_event=self.load_cache_event,
write_policy=hicache_write_policy,
io_backend=hicache_io_backend,
Expand Down Expand Up @@ -153,7 +155,7 @@ def writing_check(self, write_back=False):
queue_size = torch.tensor(
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to radix cache
torch.distributed.all_reduce(
queue_size,
Expand Down Expand Up @@ -353,7 +355,7 @@ def check_revoked_prefetch(self):
queue_size = torch.tensor(
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce(
queue_size,
Expand All @@ -372,17 +374,23 @@ def check_backup_progress(self):
queue_size = torch.tensor(
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
for _ in range(queue_size.item()):
ack_id, hash_value = self.cache_controller.ack_backup_queue.get()
self.ongoing_backup[ack_id].hash_value = hash_value
self.ongoing_backup[ack_id].release_host()
ack_id, hash_value, completed_tokens = (
self.cache_controller.ack_backup_queue.get()
)
host_node = self.ongoing_backup[ack_id]
if completed_tokens < len(host_node.key):
# backup is only partially successful, split the node
new_node = self._split_node(host_node.key, host_node, completed_tokens)
new_node.hash_value = hash_value
host_node.release_host()
del self.ongoing_backup[ack_id]

def check_prefetch_progress(self, req_id: str):
Expand All @@ -400,15 +408,18 @@ def check_prefetch_progress(self, req_id: str):
)
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")

min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
min_completed_tokens = completed_tokens
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce(
min_completed_tokens,
completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
min_completed_tokens = min_completed_tokens.item()
min_completed_tokens = completed_tokens_tensor.item()
fetched_token_ids = token_ids[:min_completed_tokens]
written_indices = host_indices[:min_completed_tokens]
matched_length = self._insert_helper_host(
Expand Down Expand Up @@ -465,16 +476,19 @@ def prefetch_from_storage(
new_input_tokens: List[int],
last_hash: Optional[str] = None,
):
if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold:
# align the number of fetching tokens to the page size
prefetch_length = len(new_input_tokens) - (
len(new_input_tokens) % self.page_size
)
new_input_tokens = new_input_tokens[:prefetch_length]
if not self.enable_storage or prefetch_length < self.prefetch_threshold:
return

last_host_node.protect_host()
host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens))
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
if host_indices is None:
self.evict_host(len(new_input_tokens))
host_indices = self.cache_controller.mem_pool_host.alloc(
len(new_input_tokens)
)
self.evict_host(prefetch_length)
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
if host_indices is None:
last_host_node.release_host()
# no sufficient host memory to prefetch
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/mem_cache/memory_pool_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def available_size(self):

@synchronized()
def alloc(self, need_size: int) -> torch.Tensor:
assert (
need_size % self.page_size == 0
), "The requested size should be a multiple of the page size."
if need_size > self.available_size():
return None

Expand Down
Loading