Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
a1604af
prevent automatically initializing MPI
zhuofan1123 Oct 28, 2025
c38022a
disable auto-mpi-init
linhu-nv Oct 28, 2025
853d0da
Init support for TensorRT-LLM
axxx03 Oct 23, 2025
d22401c
add scripts
axxx03 Oct 23, 2025
d494491
fix import and interface
axxx03 Oct 23, 2025
4525da9
support the trtllm gpu layout and improve register api of trt_adapter
Luis-xu Oct 24, 2025
b711c2c
modify log
axxx03 Oct 27, 2025
8dff541
modify scripts
axxx03 Oct 28, 2025
dcd542a
use remote transfermanager
zhuofan1123 Oct 28, 2025
f9a5409
some fix by hulin
axxx03 Oct 29, 2025
38e8466
using subprocess instead of multiprocessing
axxx03 Oct 29, 2025
ae3d272
fix dead lock
axxx03 Oct 29, 2025
f4e0fd6
fix some bugs about gpu_register_port
myNameAnn Oct 29, 2025
7a68a87
fix tensor export
myNameAnn Oct 29, 2025
2f5e667
fix head_size
axxx03 Oct 29, 2025
8ce6c48
fix num_kv_heads for deepseek
axxx03 Oct 29, 2025
33426ca
fix ipc open error
peaceforeverCN Oct 30, 2025
186e590
fix head_size calculationg error
Luis-xu Oct 30, 2025
5c09e7e
fix interface
axxx03 Oct 30, 2025
62748eb
fix get num_matched_tokens from trtllm
Luis-xu Oct 30, 2025
435ebed
fix head_size calculationg error
Luis-xu Oct 30, 2025
e0ff9d9
fix interface
axxx03 Oct 30, 2025
fb0e158
fix short len
axxx03 Oct 30, 2025
972af0d
remove code
axxx03 Oct 31, 2025
62ae825
add patch file
axxx03 Oct 31, 2025
5371df9
modify scripts
axxx03 Oct 31, 2025
d4a98c3
tensorRT LLM will wait until kvmanager isready
axxx03 Oct 31, 2025
a8bda16
[bugfix] fix token alignment issue in tensorrt-llm by rounding down t…
peaceforeverCN Oct 31, 2025
330cd05
trivial
axxx03 Nov 3, 2025
868995a
support flexkv + cuda graph using flexkv
axxx03 Nov 3, 2025
0c9fda4
modify patch
axxx03 Nov 4, 2025
c7a56a9
modify scripts
axxx03 Nov 4, 2025
b8d7d69
[bugfix] fix some bug
peaceforeverCN Nov 5, 2025
535db85
fix redix_tree
axxx03 Nov 7, 2025
94bab66
modify scripts
axxx03 Nov 7, 2025
83e93e7
add debug log
axxx03 Nov 17, 2025
10b9efa
modify scripts
axxx03 Nov 18, 2025
77b1d1e
fix rebase error
axxx03 Nov 18, 2025
18b1823
fix radix tree
axxx03 Nov 19, 2025
e7ecf89
fix scripts
axxx03 Nov 19, 2025
49c652e
use new config
axxx03 Nov 19, 2025
0257108
rename
axxx03 Nov 19, 2025
004ea7a
fix script
axxx03 Nov 19, 2025
52ca066
add branch for calculation of aligned_length
axxx03 Nov 19, 2025
94376fd
add branch for remote_process
axxx03 Nov 19, 2025
76399f4
take another way to determine branch
axxx03 Nov 19, 2025
6bdad9c
fux scripts
axxx03 Nov 19, 2025
d7cde03
remove useless env and config
axxx03 Nov 19, 2025
1713150
remove useless commit
axxx03 Nov 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,6 @@ cover/

# VSCode
.vscode/

# logs
*.log
39 changes: 38 additions & 1 deletion csrc/radix_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,38 @@

namespace flexkv {

// Helper function matching Python's get_hash with boundary check and _has_hashes branch
// Returns std::nullopt if block_id is out of bounds (like Python returning None)
// If has_hashes is true, reads from block_hashes_ptr; otherwise computes from token_ids
static std::optional<HashType> get_hash_safe(
int64_t *block_hashes_ptr,
int64_t *token_ids_ptr, // Can be nullptr if has_hashes is true
int block_id,
int num_blocks,
bool has_hashes,
int tokens_per_block) {
if (block_id >= num_blocks) {
return std::nullopt; // Out of bounds, return None (similar to Python)
}

if (has_hashes) {
// Read from pre-computed block_hashes (matching Python: if self._has_hashes)
return HashType(block_hashes_ptr[block_id]);
} else {
// Compute hash from token_ids (matching Python: hash_array(self.token_ids[...]))
if (token_ids_ptr == nullptr) {
// Cannot compute without token_ids, return nullopt
return std::nullopt;
}
// Compute hash for tokens up to (block_id+1)*tokens_per_block
// Matching Python: hash_array(self.token_ids[:(block_id+1)*self.tokens_per_block])
Hasher hasher;
hasher.reset(); // Reset hasher (matching Python's _HASHER.reset())
hasher.update(token_ids_ptr, (block_id + 1) * tokens_per_block * sizeof(int64_t));
return hasher.digest();
}
}

CRadixNode::CRadixNode(CRadixTreeIndex *index, bool ready, int lock_cnt) {
assert(index != nullptr);

Expand Down Expand Up @@ -230,6 +262,11 @@ CRadixTreeIndex::match_prefix(torch::Tensor &block_hashes, int num_blocks,
auto physical_blocks = new std::vector<int64_t>();
auto block_hashes_ptr = block_hashes.data_ptr<int64_t>();
HashType child_hash;

// In C++ version, block_hashes is always pre-computed (has_hashes = true)
// token_ids_ptr is nullptr since we don't have token_ids in this function signature
bool has_hashes = true;
int64_t *token_ids_ptr = nullptr;

while (prefix_blocks_num < num_blocks) {
if (update_cache_info) {
Expand Down Expand Up @@ -289,4 +326,4 @@ CRadixTreeIndex::match_prefix(torch::Tensor &block_hashes, int num_blocks,
last_ready_node, current_node, physical_blocks);
}

} // namespace flexkv
} // namespace flexkv
20 changes: 20 additions & 0 deletions examples/trtllm_adaption/extra-llm-api-config-cg.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
cuda_graph_config:
enable_padding: true
batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
enable_chunked_prefill: true
kv_cache_config:
enable_partial_reuse: false
free_gpu_memory_fraction: 0.75
kv_connector_config:
connector_module: "flexkv.integration.tensorrt_llm.trtllm_adapter"
connector_scheduler_class: "FlexKVSchedulerConnector"
connector_worker_class: "FlexKVWorkerConnector"
speculative_config:
decoding_type: MTP
num_nextn_predict_layers: 3
6 changes: 6 additions & 0 deletions examples/trtllm_adaption/flexkv_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"cpu_cache_gb": 32,
"ssd_cache_gb": 1024,
"ssd_cache_dir": "/data/flexkv_ssd/",
"enable_gds": false
}
23 changes: 23 additions & 0 deletions examples/trtllm_adaption/launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
mkdir -p logs
TIMESTAMP=$(date +%Y.%m.%d-%H:%M:%S)
MODEL_PATH=${1:-YOUR_MODEL_PATH}

BATCH_SIZE=4
TP_SIZE=8
EP_SIZE=$TP_SIZE
MAX_SEQ_LEN=155648
MAX_NUM_TOKENS=16384

export FLEXKV_CONFIG_PATH="./flexkv_config.json"
export TENSORRT_LLM_USE_FLEXKV=1

trtllm-serve serve $MODEL_PATH \
--host 0.0.0.0 \
--port 6000 \
--backend pytorch \
--tp_size $TP_SIZE \
--ep_size $EP_SIZE \
--max_seq_len $MAX_SEQ_LEN \
--max_num_tokens $MAX_NUM_TOKENS \
--max_batch_size $BATCH_SIZE \
--extra_llm_api_options extra-llm-api-config-cg.yml 2>&1 | tee logs/$TIMESTAMP.log
9 changes: 7 additions & 2 deletions flexkv/cache/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import List, Tuple, Optional, Dict, Callable
from dataclasses import dataclass, field

import os
import numpy as np
import nvtx
import torch
Expand Down Expand Up @@ -319,8 +320,12 @@ def get(self,
raise NotImplementedError(f"Layerwise transfer is not supported yet, "
f"layer_num: {layer_num}, layer_granularity: {layer_granularity}")

# ignore the last incomplete block
aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block
if not os.getenv("FLEXKV_WITH_TRTLLM", "0") == "1":
aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block
else:
# When using FlexKV with TensorRT-LLM, we ignore the last incomplete block.
aligned_length = ((token_ids.shape[0] - 1) // self.tokens_per_block) * self.tokens_per_block

aligned_token_ids = token_ids[:aligned_length]
token_mask = token_mask[:aligned_length]

Expand Down
52 changes: 46 additions & 6 deletions flexkv/common/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
import os
import sys
import time
import inspect
from functools import wraps
from typing import Optional, Callable, Any


FLEXKV_LOGGING_PREFIX = os.getenv("FLEXKV_LOGGING_PREFIX", "FLEXKV")
_FORMAT = (f"[{FLEXKV_LOGGING_PREFIX}] %(levelname)s %(asctime)s.%(msecs)03d "
" %(message)s")
"[%(filename)s:%(lineno)d] %(message)s")
_DATE_FORMAT = "%m-%d %H:%M:%S"

class FlexkvLogger:
def __init__(self, debug_level: str = "INFO"):
self.enabled = False
self.logger = logging.getLogger("FLEXKV")

self.logger.propagate = False

has_console_handler = any(
isinstance(handler, logging.StreamHandler)
for handler in self.logger.handlers
Expand Down Expand Up @@ -44,25 +47,62 @@ def set_level(self, level: str) -> None:
self.logger.setLevel(log_level)
self.enabled = log_level != (logging.CRITICAL + 1)

def _get_caller_info(self):
frame = inspect.currentframe()
try:
for _ in range(2):
frame = frame.f_back
if frame is None:
break

if frame is not None:
filename = os.path.basename(frame.f_code.co_filename)
lineno = frame.f_lineno
return filename, lineno
finally:
del frame

return "unknown", 0

def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
if self.enabled:
self.logger.debug(msg, *args, **kwargs)
filename, lineno = self._get_caller_info()
record = self.logger.makeRecord(
self.logger.name, logging.DEBUG, filename, lineno, msg, args, None
)
self.logger.handle(record)

def info(self, msg: str, *args: Any, **kwargs: Any) -> None:
if self.enabled:
self.logger.info(msg, *args, **kwargs)
filename, lineno = self._get_caller_info()
record = self.logger.makeRecord(
self.logger.name, logging.INFO, filename, lineno, msg, args, None
)
self.logger.handle(record)

def warning(self, msg: str, *args: Any, **kwargs: Any) -> None:
if self.enabled:
self.logger.warning(msg, *args, **kwargs)
filename, lineno = self._get_caller_info()
record = self.logger.makeRecord(
self.logger.name, logging.WARNING, filename, lineno, msg, args, None
)
self.logger.handle(record)

def error(self, msg: str, *args: Any, **kwargs: Any) -> None:
if self.enabled:
self.logger.error(msg, *args, **kwargs)
filename, lineno = self._get_caller_info()
record = self.logger.makeRecord(
self.logger.name, logging.ERROR, filename, lineno, msg, args, None
)
self.logger.handle(record)

def critical(self, msg: str, *args: Any, **kwargs: Any) -> None:
if self.enabled:
self.logger.critical(msg, *args, **kwargs)
filename, lineno = self._get_caller_info()
record = self.logger.makeRecord(
self.logger.name, logging.CRITICAL, filename, lineno, msg, args, None
)
self.logger.handle(record)


flexkv_logger = FlexkvLogger(os.getenv("FLEXKV_LOG_LEVEL", "INFO"))
Expand Down
134 changes: 134 additions & 0 deletions flexkv/common/memory_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,26 @@
from flexkv.common.debug import flexkv_logger


class cudaIpcMemHandle_t(ctypes.Structure):
_fields_ = [('reserved', ctypes.c_byte * 64)]

# Load CUDA runtime library
try:
cudart = ctypes.CDLL('libcudart.so')
except:
try:
cudart = ctypes.CDLL('libcudart.so.12')
except:
cudart = ctypes.CDLL('libcudart.so.11')

# CUDA IPC handle size (64 bytes on Linux)
CUDA_IPC_HANDLE_SIZE = 64

# CUDA error codes
cudaSuccess = 0
cudaErrorInvalidValue = 11


class cudaIpcMemHandle_t(ctypes.Structure):
_fields_ = [("reserved", ctypes.c_byte * 64)]

Expand Down Expand Up @@ -211,6 +231,120 @@ def _export_tensor_handle(
device = tensor.device
rebuild_func, rebuild_args = reductions.reduce_tensor(tensor)
return rebuild_func, rebuild_args, device

@staticmethod
def _export_cuda_ipc_handle(tensor: torch.Tensor) -> bytes:
"""
直接使用 CUDA IPC API 导出 tensor 的 IPC handle
"""
# Get device pointer
data_ptr = tensor.data_ptr()
device = tensor.device

flexkv_logger.debug(f"Exporting CUDA IPC handle: device={device}, data_ptr={hex(data_ptr)}")

# Ensure we're on the correct device
torch.cuda.set_device(device)

# Create IPC handle buffer
# ipc_handle = ctypes.create_string_buffer(CUDA_IPC_HANDLE_SIZE)
ipc_handle = cudaIpcMemHandle_t()

# Call cudaIpcGetMemHandle
result = cudart.cudaIpcGetMemHandle(
ctypes.byref(ipc_handle),
ctypes.c_void_p(data_ptr)
)

if result != cudaSuccess:
error_msg = f"cudaIpcGetMemHandle failed with error code {result} for device {device}, ptr={hex(data_ptr)}"
flexkv_logger.error(error_msg)
raise RuntimeError(error_msg)

# Return handle as bytes
# handle_bytes = bytes(ipc_handle.raw)
handle_bytes = ctypes.string_at(ctypes.byref(ipc_handle), 64)
flexkv_logger.debug(f"IPC handle exported successfully, first 16 bytes: {handle_bytes.hex()}")
return handle_bytes

@staticmethod
def _import_cuda_ipc_handle(ipc_handle: bytes, shape: Tuple[int, ...],
dtype: torch.dtype, device: torch.device) -> torch.Tensor:
"""
直接使用 CUDA IPC API 从 handle 导入 tensor
"""
flexkv_logger.debug(f"Attempting to import CUDA IPC handle for device {device}")

# Ensure CUDA is initialized in this process
if not torch.cuda.is_initialized():
flexkv_logger.info("Initializing CUDA in subprocess")
torch.cuda.init()

# Set device and create a dummy tensor to ensure context is created
device_id = device.index if device.index is not None else 0
torch.cuda.set_device(device_id)

# Force CUDA context creation
_ = torch.zeros(1, device=device)
flexkv_logger.debug(f"CUDA context created for device {device_id}, current_device={torch.cuda.current_device()}")

# Create IPC handle buffer
ipc_handle_buf = ctypes.create_string_buffer(ipc_handle, CUDA_IPC_HANDLE_SIZE)

# 重建 IPC handle
handle = cudaIpcMemHandle_t()
ctypes.memmove(ctypes.byref(handle), ipc_handle, 64)

# Open IPC memory handle
dev_ptr = ctypes.c_void_p()
result = cudart.cudaIpcOpenMemHandle(
ctypes.byref(dev_ptr),
handle,
ctypes.c_int(1) # cudaIpcMemLazyEnablePeerAccess = 1
)
flexkv_logger.debug(f"import CUDA IPC handle: device={device}, dev_ptr={hex(dev_ptr.value)}")
if result != cudaSuccess:
error_msg = f"cudaIpcOpenMemHandle failed with error code {result} for device {device_id}"
flexkv_logger.error(error_msg)
# flexkv_logger.error(f"IPC handle bytes (first 16): {ipc_handle[:16].hex()}")
flexkv_logger.error(f"IPC handle bytes (first 16): {ipc_handle.hex()}")
flexkv_logger.error(f"Current CUDA device: {torch.cuda.current_device()}")
flexkv_logger.error(f"Target device: {device_id}")
raise RuntimeError(error_msg)

# Create tensor from pointer
numel = 1
for dim in shape:
numel *= dim

class CudaArrayInterface:
def __init__(self, data_ptr, shape, dtype, strides=None):
self.__cuda_array_interface__ = {
"data": (data_ptr, False), # (data_ptr, read_only)
"shape": tuple(shape),
"typestr": {
torch.float32: "<f4",
torch.float64: "<f8",
torch.float16: "<f2",
torch.bfloat16: "<e", # brain float16
torch.int32: "<i4",
torch.int64: "<i8",
torch.int16: "<i2",
torch.uint8: "|u1",
torch.int8: "|i1",
torch.bool: "|b1",
}[dtype],
"version": 3,
"strides": strides, # None for C-contiguous
"descr": [("", "")]
}

# 使用方式:
cuda_interface = CudaArrayInterface(dev_ptr.value, shape, dtype)
tensor = torch.as_tensor(cuda_interface, device=device)

flexkv_logger.debug(f"Imported tensor with shape {shape} from CUDA IPC handle")
return tensor

## Import tensor handle
@staticmethod
Expand Down
Loading