Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 5 additions & 17 deletions python/sglang/srt/managers/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,16 +231,7 @@ def __init__(
self.mem_pool_host = mem_pool_host
self.write_policy = write_policy
self.page_size = page_size
# using kernel for small page KV cache transfer and DMA for large pages
if not io_backend:
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
self.io_backend = (
"direct"
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
else "kernel"
)
else:
self.io_backend = io_backend
self.io_backend = io_backend

self.enable_storage = False
# todo: move backend initialization to storage backend module
Expand Down Expand Up @@ -447,11 +438,8 @@ def write_thread_func_direct(self):
host_indices, device_indices = self.move_indices(
operation.host_indices, operation.device_indices
)
self.mem_pool_device.backup_to_host_all_layer(
self.mem_pool_host,
host_indices,
device_indices,
self.io_backend,
self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend
)
self.write_stream.synchronize()
self.mem_pool_host.complete_io(operation.host_indices)
Expand Down Expand Up @@ -491,8 +479,8 @@ def load_thread_func_layer_by_layer(self):
batch_operation.host_indices, batch_operation.device_indices
)
for i in range(self.mem_pool_host.layer_num):
self.mem_pool_device.load_from_host_per_layer(
self.mem_pool_host,
self.mem_pool_host.load_to_device_per_layer(
self.mem_pool_device,
host_indices,
device_indices,
i,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ def init_memory_pool_and_cache(self):
== "fa3" # hot fix for incompatibility
else server_args.hicache_io_backend
),
hicache_mem_layout=server_args.hicache_mem_layout,
hicache_storage_backend=server_args.hicache_storage_backend,
)
self.tp_worker.register_hicache_layer_transfer_counter(
Expand Down
21 changes: 19 additions & 2 deletions python/sglang/srt/mem_cache/hiradix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,33 @@ def __init__(
hicache_size: int,
hicache_write_policy: str,
hicache_io_backend: str,
hicache_mem_layout: str,
hicache_storage_backend: Optional[str] = None,
):

if hicache_io_backend == "direct":
if hicache_mem_layout == "page_first":
hicache_mem_layout = "layer_first"
logger.warning(
"Page first layout is not supported with direct IO backend, switching to layer first layout"
)

self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
self.kv_cache, hicache_ratio, hicache_size, page_size
self.kv_cache,
hicache_ratio,
hicache_size,
page_size,
hicache_mem_layout,
)
elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
self.kv_cache, hicache_ratio, hicache_size, page_size
self.kv_cache,
hicache_ratio,
hicache_size,
page_size,
hicache_mem_layout,
)
else:
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
Expand Down
133 changes: 15 additions & 118 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,17 @@

import numpy as np
import torch
import torch.distributed as dist
import triton
import triton.language as tl

from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2

logger = logging.getLogger(__name__)

GB = 1024 * 1024 * 1024
_is_cuda = is_cuda()
_is_npu = is_npu()
if not _is_npu:
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla


class ReqToTokenPool:
Expand Down Expand Up @@ -153,18 +149,6 @@ def set_kv_buffer(
) -> None:
raise NotImplementedError()

@abc.abstractmethod
def load_from_host_per_layer(
self, host_pool, host_indices, device_indices, layer_id, io_backend
):
raise NotImplementedError()

@abc.abstractmethod
def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
raise NotImplementedError()

def register_layer_transfer_counter(self, layer_transfer_counter):
self.layer_transfer_counter = layer_transfer_counter

Expand Down Expand Up @@ -253,12 +237,18 @@ def _create_buffers(self):
)
for _ in range(self.layer_num)
]
self.token_stride = self.head_num * self.head_dim
self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer],

self.k_data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer],
dtype=torch.uint64,
device=self.device,
)
self.v_data_ptrs = torch.tensor(
[x.data_ptr() for x in self.v_buffer],
dtype=torch.uint64,
device=self.device,
)
self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0)
self.data_strides = torch.tensor(
[
np.prod(x.shape[1:]) * x.dtype.itemsize
Expand Down Expand Up @@ -347,47 +337,6 @@ def load_cpu_copy(self, kv_cache_cpu, indices):
self.v_buffer[layer_id][chunk_indices] = v_chunk
torch.cuda.synchronize()

def load_from_host_per_layer(
self,
host_pool,
host_indices,
device_indices,
layer_id,
io_backend,
):
transfer_kv_per_layer(
src_k=host_pool.k_buffer[layer_id],
dst_k=self.k_buffer[layer_id],
src_v=host_pool.v_buffer[layer_id],
dst_v=self.v_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)

def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
if layer_id - self.start_layer >= len(host_pool.k_buffer):
raise ValueError(
f"Layer ID {layer_id} exceeds the number of layers in host pool."
)
transfer_kv_per_layer(
src_k=self.k_buffer[layer_id],
dst_k=host_pool.k_buffer[layer_id],
src_v=self.v_buffer[layer_id],
dst_v=host_pool.v_buffer[layer_id],
src_indices=device_indices,
dst_indices=host_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)

def _get_key_buffer(self, layer_id: int):
# for internal use of referencing
if self.store_dtype != self.dtype:
Expand Down Expand Up @@ -602,16 +551,6 @@ def set_kv_buffer(
layer_id_override=layer_id_pool,
)

def load_from_host_per_layer(
self, host_pool, host_indices, device_indices, layer_id, io_backend
):
raise NotImplementedError("HiCache not supported for SWAKVPool.")

def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
raise NotImplementedError("HiCache not supported for SWAKVPool.")


class AscendTokenToKVPool(MHATokenToKVPool):

Expand Down Expand Up @@ -823,7 +762,11 @@ def __init__(
for _ in range(layer_num)
]

self.token_stride = kv_lora_rank + qk_rope_head_dim
self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.kv_buffer],
dtype=torch.uint64,
device=self.device,
)
self.layer_transfer_counter = None

kv_size = self.get_kv_size_bytes()
Expand Down Expand Up @@ -909,38 +852,6 @@ def set_mla_kv_buffer(
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
)

def load_from_host_per_layer(
self, host_pool, host_indices, device_indices, layer_id, io_backend
):
transfer_kv_per_layer_mla(
src=host_pool.kv_buffer[layer_id],
dst=self.kv_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)

def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
raise ValueError(
f"Layer ID {layer_id} exceeds the number of layers in host pool."
)
transfer_kv_per_layer_mla(
src=self.kv_buffer[layer_id],
dst=host_pool.kv_buffer[layer_id],
src_indices=device_indices,
dst_indices=host_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)

def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
Expand Down Expand Up @@ -1131,20 +1042,6 @@ def set_kv_buffer(
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
self.label_buffer[layer_id - self.start_layer][loc] = cache_label

def load_from_host_per_layer(
self, host_pool, host_indices, device_indices, layer_id, io_backend
):
raise NotImplementedError(
"HiCache not supported for DoubleSparseTokenToKVPool."
)

def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
raise NotImplementedError(
"HiCache not supported for DoubleSparseTokenToKVPool."
)


@triton.jit
def copy_all_layer_kv_cache(
Expand Down
Loading
Loading