Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@
from __future__ import annotations

import logging
import os
from collections import deque
from dataclasses import dataclass
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np
import torch
from torch.distributed import ProcessGroup

Expand All @@ -47,12 +45,9 @@
prepare_abort,
)
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import (
KVCache,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import require_mlp_sync
Expand Down Expand Up @@ -141,7 +136,7 @@ class DecodePreallocQueue:
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: MetadataBuffers,
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional

import numpy as np
import torch

from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
Expand Down
9 changes: 5 additions & 4 deletions python/sglang/srt/managers/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
import math
import threading
from queue import Empty, Full, PriorityQueue, Queue
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional

import torch

from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool_host import HostKVCache

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -163,7 +164,7 @@ class HiCacheController:

def __init__(
self,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
mem_pool_host: HostKVCache,
page_size: int,
load_cache_event: threading.Event = None,
Expand Down
7 changes: 4 additions & 3 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@
)
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.metrics.collector import TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Expand Down Expand Up @@ -810,7 +811,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Request, memory pool, and cache
reqs: List[Req]
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
tree_cache: BasePrefixCache = None

# Batch configs
Expand Down Expand Up @@ -907,7 +908,7 @@ def init_new(
cls,
reqs: List[Req],
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
tree_cache: BasePrefixCache,
model_config: ModelConfig,
enable_overlap: bool,
Expand Down
8 changes: 5 additions & 3 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
from typing import Dict, List, Optional, Set, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union

import torch

from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode

if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator

# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
# Note that this only clips the estimation in the scheduler but does not change the stop
Expand Down Expand Up @@ -265,7 +267,7 @@ def __init__(
self,
page_size: int,
tree_cache: BasePrefixCache,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
running_batch: ScheduleBatch,
new_token_ratio: float,
rem_input_tokens: int,
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from collections import defaultdict, deque
from concurrent import futures
from dataclasses import dataclass
from http import HTTPStatus
from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
Expand All @@ -57,7 +58,7 @@ def __init__(
nccl_port: int,
is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
):
# Parse args
self.tp_size = server_args.tp_size
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

"""
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -17,13 +19,136 @@
Page-aligned memory pool.
"""

import abc
from typing import TYPE_CHECKING

import torch
import triton
import triton.language as tl

from sglang.srt.mem_cache.memory_pool import KVCache
from sglang.srt.utils import get_bool_env_var, next_power_of_2

if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache


class BaseTokenToKVPoolAllocator(abc.ABC):
@abc.abstractmethod
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
self._kvcache = kvcache

self.free_pages = None
self.is_not_in_free_group = True
self.free_group = []
self.clear()

def debug_print(self) -> str:
return ""

def available_size(self):
return len(self.free_pages) * self.page_size

def get_kvcache(self):
return self._kvcache

def restore_state(self, free_pages):
self.free_pages = free_pages

def backup_state(self):
return self.free_pages

def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []

def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.cat(self.free_group))

@abc.abstractmethod
def get_cpu_copy(self, indices):
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
raise NotImplementedError()

@abc.abstractmethod
def load_cpu_copy(self, kv_cache_cpu, indices):
# FIXME: reuse the load_cpu_copy after paged allocator is implemented
raise NotImplementedError()

@abc.abstractmethod
def clear(self):
raise NotImplementedError()

@abc.abstractmethod
def alloc(self, need_size: int):
raise NotImplementedError()

@abc.abstractmethod
def alloc_extend(self, *args, **kwargs):
raise NotImplementedError("alloc_extend is only for paged allocator")

@abc.abstractmethod
def alloc_decode(self, *args, **kwargs):
raise NotImplementedError("alloc_decode is only for paged allocator")

@abc.abstractmethod
def free(self, free_index: torch.Tensor):
raise NotImplementedError()


class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""An allocator managing the indices to kv cache data."""

def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
super().__init__(size, 1, dtype, device, kvcache)

def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
1, self.size + 1, dtype=torch.int64, device=self.device
)
self.is_not_in_free_group = True
self.free_group = []

def available_size(self):
# To avoid minor "len(free_pages) * 1" overhead
return len(self.free_pages)

def alloc(self, need_size: int):
if need_size > len(self.free_pages):
return None

select_index = self.free_pages[:need_size]
self.free_pages = self.free_pages[need_size:]
return select_index

def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return

if self.is_not_in_free_group:
self.free_pages = torch.cat((self.free_pages, free_index))
else:
self.free_group.append(free_index)

def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)

def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)


@triton.jit
def alloc_extend_kernel(
Expand Down Expand Up @@ -154,7 +279,7 @@ def alloc_decode_kernel(
tl.store(out_indices + pid, page * page_size)


class PagedTokenToKVPoolAllocator:
class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""
An allocator managing the indices to kv cache data.

Expand All @@ -172,27 +297,11 @@ def __init__(
device: str,
kvcache: KVCache,
):
self.size = size
self.dtype = dtype
self.device = device
self.page_size = page_size
super().__init__(size, page_size, dtype, device, kvcache)
self.num_pages = size // page_size

self.free_pages = None
self.is_not_in_free_group = True
self.free_group = []
self.clear()
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")

self._kvcache = kvcache
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)

def available_size(self):
return len(self.free_pages) * self.page_size

def get_kvcache(self):
return self._kvcache

def alloc(self, need_size: int):
# page-aligned allocation, returning contiguous indices of pages
if self.debug_mode:
Expand Down Expand Up @@ -298,21 +407,6 @@ def free(self, free_index: torch.Tensor):
if self.debug_mode:
assert len(torch.unique(self.free_pages)) == len(self.free_pages)

def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []

def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.cat(self.free_group))

def backup_state(self):
return self.free_pages

def restore_state(self, free_pages):
self.free_pages = free_pages

def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
Expand Down
7 changes: 4 additions & 3 deletions python/sglang/srt/mem_cache/chunk_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

"""Cache for chunked prefill, used when RadixCache is disabled."""

from typing import TYPE_CHECKING, Any, Callable, List, Tuple
from typing import TYPE_CHECKING, Any

import torch

from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool

if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
Expand All @@ -17,7 +18,7 @@ class ChunkCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int,
):
self.req_to_token_pool = req_to_token_pool
Expand Down
Loading
Loading