Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/trtllm_adaption/launch.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
mkdir -p logs
TIMESTAMP=$(date +%Y.%m.%d-%H:%M:%S)
MODEL_PATH=${1:-YOUR_MODEL_PATH}

BATCH_SIZE=4
Expand All @@ -20,4 +18,4 @@ trtllm-serve serve $MODEL_PATH \
--max_seq_len $MAX_SEQ_LEN \
--max_num_tokens $MAX_NUM_TOKENS \
--max_batch_size $BATCH_SIZE \
--extra_llm_api_options extra-llm-api-config.yml 2>&1 | tee logs/$TIMESTAMP.log
--extra_llm_api_options extra-llm-api-config.yml
54 changes: 54 additions & 0 deletions examples/trtllm_adaption/multi_node_launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
BATCH_SIZE=4
TP_SIZE=16
EP_SIZE=$TP_SIZE
MAX_SEQ_LEN=155648
MAX_NUM_TOKENS=16384
# MAX_SEQ_LEN=8192
# MAX_NUM_TOKENS=8192
HOSTFILE=YOUR_HOSTFILE
MODEL_PATH=${1:-YOUR_MODEL_PATH}

export FLEXKV_CONFIG_PATH=$(realpath "./flexkv_config.json")
export TENSORRT_LLM_USE_FLEXKV=1
export FLEXKV_MASTER_HOST="172.16.0.30"
export FLEXKV_MASTER_PORTS="5556,5557,5558"
export FLEXKV_TRT_SUBPROCESS_HOST="172.16.0.30"
export FLEXKV_TRT_SUBPROCESS_PORTS="6667,6668,6669"
export TLLM_LOG_FIRST_RANK_ONLY=0

mpirun -np 16 \
--hostfile $HOSTFILE \
-mca plm_rsh_args "-p 9898" \
-mca btl tcp,self \
-mca btl_tcp_if_include eth0 \
-x CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
-x GLOO_SOCKET_IFNAME=eth0 \
-x NCCL_DEBUG=INFO \
-x NCCL_IBEXT_DISABLE=0 \
-x NCCL_IB_GID_INDEX=3 \
-x NCCL_IB_DISABLE=0 \
-x NCCL_NET_GDR_LEVEL=2 \
-x NCCL_IB_QPS_PER_CONNECTION=4 \
-x NCCL_IB_TC=160 \
-x NCCL_IB_TIMEOUT=22 \
-x NCCL_SOCKET_IFNAME=eth0 \
-x OMPI_MCA_btl=tcp,self \
-x OMPI_MCA_btl_tcp_if_include=eth0 \
-x FLEXKV_CONFIG_PATH \
-x TENSORRT_LLM_USE_FLEXKV \
-x FLEXKV_MASTER_HOST \
-x FLEXKV_MASTER_PORTS \
-x TLLM_LOG_FIRST_RANK_ONLY \
-x FLEXKV_TRT_SUBPROCESS_HOST \
-x FLEXKV_TRT_SUBPROCESS_PORTS \
--allow-run-as-root \
trtllm-llmapi-launch trtllm-serve $MODEL_PATH \
--host 0.0.0.0 \
--port 6000 \
--backend pytorch \
--tp_size $TP_SIZE \
--ep_size $EP_SIZE \
--max_seq_len $MAX_SEQ_LEN \
--max_num_tokens $MAX_NUM_TOKENS \
--max_batch_size $BATCH_SIZE \
--extra_llm_api_options extra-llm-api-config.yml
18 changes: 10 additions & 8 deletions flexkv/integration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from flexkv.common.debug import flexkv_logger
from flexkv.common.config import *
from transformers import AutoConfig as HFAutoConfig

if TYPE_CHECKING:
from vllm.v1.kv_cache_interface import KVCacheConfig, FullAttentionSpec
Expand Down Expand Up @@ -119,9 +118,6 @@ def post_init_from_sglang_config(
def post_init_from_trt_config(
self,
config,
tp_size: int,
dp_size: int,
dp_rank: int,
):
self.cache_config.tokens_per_block = config.tokens_per_block
# Convert dtype string to torch.dtype
Expand All @@ -141,13 +137,19 @@ def post_init_from_trt_config(
self.model_config.dtype = dtype_map.get(dtype_str, torch.bfloat16)
else:
self.model_config.dtype = dtype_str

self.model_config.tp_size = tp_size
self.model_config.dp_size = dp_size
self.model_config.dp_rank = dp_rank

# Set model config (parallel configs part)
if config.mapping.enable_attention_dp:
self.model_config.tp_size = 1
self.model_config.dp_size = config.mapping.tp_size
else:
self.model_config.tp_size = config.mapping.tp_size
self.model_config.dp_size = 1

# self.model_config (model configs part)
try:
model_path = getattr(config, 'hf_model_dir', None)
from transformers import AutoConfig as HFAutoConfig
hf_config = HFAutoConfig.from_pretrained(
str(model_path),
trust_remote_code=True
Expand Down
103 changes: 89 additions & 14 deletions flexkv/integration/tensorrt_llm/trtllm_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np
import torch

import traceback
from flexkv.kvmanager import KVManager
from flexkv.server.client import KVTPClient
from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType
Expand All @@ -18,8 +18,9 @@
from flexkv.integration.config import FlexKVConfig
from flexkv.integration.tensorrt_llm.meta import(
FlexKVResponse, FlexKVTask, FlexKVGetTask, FlexKVPutTask, FlexKVConnectorMetadata)
from flexkv.transfer_manager import TransferManagerOnRemote

from flexkv.integration.tensorrt_llm.utils import RequestWrapper, get_dp_tp_info
from flexkv.integration.tensorrt_llm.utils import RequestWrapper
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
from tensorrt_llm.bindings.executor import ExecutorConfig
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
Expand All @@ -30,9 +31,9 @@ class FlexKVSchedulerConnector(KvCacheConnectorScheduler):
def __init__(self, config: ExecutorConfig):
# Set environment variable for FlexKV with TensorRT-LLM,this must before KVManager initialization
os.environ['FLEXKV_WITH_TRTLLM'] = '1'
tp_size, dp_size, dp_rank = get_dp_tp_info(config)
flexkv_config = FlexKVConfig.from_env()
flexkv_config.post_init_from_trt_config(config, tp_size, dp_size, dp_rank)
flexkv_config = FlexKVConfig.from_env()
flexkv_config.post_init_from_trt_config(config)
self.node_rank, self.tp_rank, self.dp_rank = get_rank_info_from_trt_config(config)

flexkv_logger.info(f"Start init FlexKVSchedulerConnector with {flexkv_config}")
self.server_recv_port = flexkv_config.server_recv_port
Expand All @@ -44,7 +45,7 @@ def __init__(self, config: ExecutorConfig):
self.flexkv_manager = KVManager(model_config=self.model_config,
cache_config=self.cache_config,
server_recv_port=flexkv_config.server_recv_port,
dp_client_id=dp_rank)
dp_client_id=self.dp_rank)
self.flexkv_manager.start()
# self.dp_client = KVDPClient(self.server_recv_port, self.model_config)

Expand All @@ -61,8 +62,6 @@ def __init__(self, config: ExecutorConfig):

flexkv_logger.info("Finish init FlexKVSchedulerConnector")



def is_ready(
self,
) -> bool:
Expand Down Expand Up @@ -449,17 +448,47 @@ def dp_client_id(self) -> int:

class FlexKVWorkerConnector(KvCacheConnectorWorker):
def __init__(self, config: ExecutorConfig):
tp_size, dp_size, dp_rank = get_dp_tp_info(config)
flexkv_config = FlexKVConfig.from_env()
flexkv_config.post_init_from_trt_config(config, tp_size, dp_size, dp_rank)
dp_client_id = dp_rank
self.node_rank, self.tp_rank, self.dp_rank = get_rank_info_from_trt_config(config)
flexkv_config.post_init_from_trt_config(config)
dp_client_id = self.dp_rank

current_device_id = torch.cuda.current_device() + dp_client_id * flexkv_config.model_config.tp_size
self.flexkv_config = flexkv_config
flexkv_logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.gpu_register_port}, \
dp_client_id: {dp_client_id}")

# For multi-node TP on remote node (node B), worker0 needs to create TransferManagerOnRemote process
self.remote_process = None
if self._need_to_create_remote_process():
flexkv_logger.info("Multi-node TP detected on remote node, worker0 creating TransferManagerOnRemote process")
self.remote_process = TransferManagerOnRemote.create_process()
flexkv_logger.info(f"TransferManagerOnRemote process created, PID: {self.remote_process.pid}")

flexkv_logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.gpu_register_port}, dp_client_id: {dp_client_id}")
self.tp_client = KVTPClient(flexkv_config.gpu_register_port, dp_client_id, current_device_id)
flexkv_logger.info("Finish init FlexKVWorkerConnector")

def _need_to_create_remote_process(self) -> bool:
"""Check if need to create TransferManagerOnRemote process.

Returns True when all of the following conditions are met:
- Multi-node TP is detected (tp_size > gpus_per_node)
- Current node is not master node (node_rank > 0)
- Current worker is worker0 in TP group (tp_rank == 0)

Returns:
bool: True if need to create TransferManagerOnRemote process, False otherwise.
"""
try:
is_master_node = self.node_rank == 0
is_first_worker = self.tp_rank % 8 == 0
is_multinode_tp = self.flexkv_config.model_config.tp_size > torch.cuda.device_count()
flexkv_logger.info(f"{is_master_node=}, {is_first_worker=}, {is_multinode_tp=}")

return is_multinode_tp and not is_master_node and is_first_worker
except Exception as e:
flexkv_logger.error(f"Failed to get node info from flexkv_config: {e}")
flexkv_logger.error(traceback.format_exc())
raise e

def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
# vllm kv_caches: dict[str, torch.Tensor]
Expand Down Expand Up @@ -543,4 +572,50 @@ def get_finished(

finished_sending = self._metadata.finished_sending
finished_recving = self._metadata.finished_recving
return finished_sending, finished_recving
return finished_sending, finished_recving

def shutdown(self) -> None:
"""Shutdown the worker connector and cleanup resources.

For node B worker0, this will shutdown the TransferManagerOnRemote process.
"""
if hasattr(self, 'remote_process') and self.remote_process is not None:
flexkv_logger.info("Shutting down TransferManagerOnRemote process on node B worker0")
try:
# Terminate the process
if hasattr(self.remote_process, '_popen'):
self.remote_process._popen.terminate()
self.remote_process._popen.wait(timeout=5.0)
if self.remote_process._popen.poll() is None:
flexkv_logger.warning("TransferManagerOnRemote process did not terminate, killing it")
self.remote_process._popen.kill()
self.remote_process._popen.wait()
else:
# Fallback for subprocess.Popen
self.remote_process.terminate()
self.remote_process.wait(timeout=5.0)
if self.remote_process.poll() is None:
flexkv_logger.warning("TransferManagerOnRemote process did not terminate, killing it")
self.remote_process.kill()
self.remote_process.wait()
flexkv_logger.info("TransferManagerOnRemote process shutdown complete")
except Exception as e:
flexkv_logger.error(f"Error shutting down TransferManagerOnRemote process: {e}")
finally:
self.remote_process = None

def __del__(self):
"""Cleanup on deletion."""
if hasattr(self, 'remote_process') and self.remote_process is not None:
self.shutdown()

def get_rank_info_from_trt_config(config: ExecutorConfig):
if config.mapping.enable_attention_dp:
node_rank = config.mapping.node_rank
tp_rank = config.mapping.tp_rank
dp_rank = config.mapping.rank
else:
node_rank = config.mapping.node_rank
tp_rank = config.mapping.tp_rank
dp_rank = 0
return node_rank, tp_rank, dp_rank
16 changes: 1 addition & 15 deletions flexkv/integration/tensorrt_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,4 @@ def __repr__(self):
return f"RequestWrapper(req_id={self.req_id}, " \
f"num_prompt_tokens={self.num_prompt_tokens}, " \
f"num_tokens={len(self.all_token_ids)}, " \
f"num_new_matched_tokens={self.num_new_matched_tokens})"

def get_dp_tp_info(config: ExecutorConfig):
mapping = config.mapping

if mapping.enable_attention_dp:
# trt 也不支持同时开 tp+dp
tp_size = 1
dp_size = mapping.tp_size
dp_rank = mapping.rank
else:
tp_size = mapping.tp_size
dp_size = 1
dp_rank = 0
return tp_size, dp_size, dp_rank
f"num_new_matched_tokens={self.num_new_matched_tokens})"
9 changes: 6 additions & 3 deletions flexkv/kvtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from flexkv.cache.cache_engine import GlobalCacheEngine
from flexkv.transfer_manager import TransferManagerHandle, TransferManagerOnRemote
from flexkv.common.request import KVResponseStatus, KVResponse
from flexkv.transfer_manager import get_master_host_and_ports_from_env
from flexkv.transfer_manager import (
get_master_host_and_ports_from_env,
get_trtllm_subprocess_host_and_ports_from_env
)

class TaskStatus(Enum):
# slot mapping is not ready
Expand Down Expand Up @@ -120,8 +123,8 @@ def __init__(self,
# When using FlexKV with TensorRT-LLM, we use remote mode to transfer data
# to avoid the way we launch subprocess in FlexKV
# conflict with TensorRT-LLM's MPI initialization
self.remote_process = TransferManagerOnRemote.create_process()
master_host, master_ports = get_master_host_and_ports_from_env()
master_host, master_ports = get_trtllm_subprocess_host_and_ports_from_env()
self.remote_process = TransferManagerOnRemote.create_process(mode="TrtllmSubprocess")
self.transfer_handles = [
TransferManagerHandle(
model_config_for_transfer,
Expand Down
Loading