Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions flexkv/integration/vllm/vllm_v1_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from flexkv.integration.stats import FlexKVStats
from flexkv.integration.utils import cdiv
from flexkv.integration.config import FlexKVConfig
from flexkv.transfer_manager import TransferManagerOnRemote

# vllm
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_tp_group

if TYPE_CHECKING:
from vllm.config import VllmConfig
Expand Down Expand Up @@ -517,8 +519,14 @@ def __init__(
flexkv_config: FlexKVConfig,
dp_client_id: int,
):
self.is_local_leader = get_tp_group().local_rank == 0
self.launch_remote_transfer_manager = get_tp_group().local_rank == 0 and \
get_tp_group().rank_in_group != 0
current_device_id = torch.cuda.current_device() + dp_client_id * flexkv_config.model_config.tp_size
self.flexkv_config = flexkv_config
if self.launch_remote_transfer_manager:
self.remote_transfer_manager_process = TransferManagerOnRemote.create_process()

logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.gpu_register_port}, \
dp_client_id: {dp_client_id}")
self.tp_client = KVTPClient(flexkv_config.gpu_register_port, dp_client_id, current_device_id)
Expand Down Expand Up @@ -554,6 +562,12 @@ def register_to_server(self, kv_caches: dict[str, torch.Tensor]):
self.tp_client.register_to_server(gpu_blocks, gpu_layout)
logger.info("Finish register kv_caches")

def __del__(self):
if hasattr(self, "remote_transfer_manager_process") and \
self.remote_transfer_manager_process is not None:
self.remote_transfer_manager_process.join()
self.remote_transfer_manager_process.close()
self.remote_transfer_manager_process = None

class FlexKVConnectorV1Impl:
def __init__(self, vllm_config: "VllmConfig", role: "KVConnectorRole"):
Expand Down
12 changes: 6 additions & 6 deletions flexkv/kvtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from flexkv.transfer_manager import TransferManagerHandle, TransferManagerOnRemote
from flexkv.common.request import KVResponseStatus, KVResponse
from flexkv.transfer_manager import get_master_host_and_ports_from_env
from flexkv.common.debug import flexkv_logger

class TaskStatus(Enum):
# slot mapping is not ready
Expand Down Expand Up @@ -93,22 +92,23 @@ def __init__(self,
self._check_config(model_config, cache_config)

self.is_multinode_tp = False
self.tp_size_per_node = model_config.tp_size
self.tp_size_per_node = min(model_config.tp_size, torch.cuda.device_count())

if self.model_config.tp_size > self.tp_size_per_node:
if self.model_config.tp_size != torch.cuda.device_count() * 2:
raise ValueError("Only support 2 nodes TP")
if self.model_config.dp_size != 1:
raise ValueError("Only support dp_size=1 for multi-node TP")
self.is_multinode_tp = True
self.tp_size_per_node = torch.cuda.device_count()

self.cache_engine = GlobalCacheEngine(cache_config, model_config)

model_config_for_transfer = copy.deepcopy(self.model_config)
if self.is_multinode_tp and not self.model_config.use_mla:
model_config_for_transfer.num_kv_heads = self.tp_size_per_node

if self.is_multinode_tp:
model_config_for_transfer.tp_size = self.tp_size_per_node
if not self.model_config.use_mla:
model_config_for_transfer.num_kv_heads = self.tp_size_per_node

combine_with_trtllm = os.getenv("FLEXKV_WITH_TRTLLM", "0") == "1"
if not combine_with_trtllm:
self.transfer_handles = [TransferManagerHandle(
Expand Down
49 changes: 25 additions & 24 deletions flexkv/transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from flexkv.transfer.transfer_engine import TransferEngine
from flexkv.server.utils import get_zmq_socket
from flexkv.server.request import RegisterTPClientRequest, Response
from flexkv.common.debug import flexkv_logger


class TransferManager:
Expand Down Expand Up @@ -60,7 +59,6 @@ def _handle_gpu_blocks_registration(self, req: RegisterTPClientRequest) -> None:
try:
self.all_gpu_blocks[device_id] = req.handles
self.all_gpu_layouts[device_id] = req.gpu_layout
flexkv_logger.info(f"GPU {device_id} registered successfully")
except Exception as e:
flexkv_logger.error(f"Failed to register GPU {device_id}: {e}")

Expand All @@ -81,6 +79,8 @@ def _register_gpu_blocks_via_socket(self) -> None:
if isinstance(req, RegisterTPClientRequest):
flexkv_logger.info(f"Received GPU blocks registration request: {type(req)}")
self._handle_gpu_blocks_registration(req)
flexkv_logger.info(f"GPU {req.device_id} registered successfully, \
waiting for {expected_gpus - len(self.all_gpu_blocks)} GPUs to register")
else:
flexkv_logger.error(f"Unrecognized RequestType in SchedulerServer: {type(req)}")

Expand Down Expand Up @@ -318,31 +318,32 @@ def shutdown(self) -> None:
def __del__(self) -> None:
if not self._shutdown_flag:
self.shutdown()

@classmethod
def create_process(cls, **kwargs: Any) -> Process:
import tempfile
import os

# Serialize the class and kwargs
cls_data = pickle.dumps(cls)
kwargs_data = pickle.dumps(kwargs)

# Create temporary files for serialized data
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.cls') as f:
f.write(cls_data)
cls_file = f.name

with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.kwargs') as f:
f.write(kwargs_data)
kwargs_file = f.name

# Prepare environment - remove MPI-related variables to avoid conflicts
env = os.environ.copy()
# CRITICAL: Remove CUDA_VISIBLE_DEVICES to allow access to all GPUs
# TransferManager needs to access all physical GPUs for IPC
if 'CUDA_VISIBLE_DEVICES' in env:
flexkv_logger.info(f"Removing CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']} for TransferManager subprocess")
flexkv_logger.info(f"Removing CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']} "
"for TransferManager subprocess")
env.pop('CUDA_VISIBLE_DEVICES', None)

# Create the subprocess script
Expand All @@ -351,26 +352,26 @@ def create_process(cls, **kwargs: Any) -> Process:
import sys
import pickle
import tempfile

# Immediately disable MPI to avoid conflicts
os.environ['MPI4PY_RC_INITIALIZE'] = 'false'

try:
# Load the class and kwargs
with open("{cls_file}", "rb") as f:
cls = pickle.load(f)

with open("{kwargs_file}", "rb") as f:
kwargs = pickle.load(f)

# Create and start TransferManager instance
instance = cls(**kwargs)
instance.start()

# Keep running until worker thread exits
if hasattr(instance, '_worker_thread') and instance._worker_thread is not None:
instance._worker_thread.join()

except Exception as e:
print(f"Error in TransferManager subprocess: {{e}}", file=sys.stderr)
sys.exit(1)
Expand All @@ -382,13 +383,13 @@ def create_process(cls, **kwargs: Any) -> Process:
except Exception:
pass
''').strip()

# Start the subprocess
process = subprocess.Popen([
sys.executable, '-c', transfer_manager_script
], env=env, stdout=None, stderr=None, text=True) # None = inherit parent's stdout/stderr
flexkv_logger.info(f"TransferManager subprocess started, PID: {process.pid}")

# Clean up temporary files after subprocess completes
def cleanup_files():
# Wait for subprocess to complete before cleaning up files
Expand All @@ -398,26 +399,26 @@ def cleanup_files():
os.unlink(kwargs_file)
except Exception:
pass

import threading
cleanup_thread = threading.Thread(target=cleanup_files, daemon=True)
cleanup_thread.start()

# Return a wrapper that mimics multiprocessing.Process interface
class SubprocessWrapper:
def __init__(self, popen_process):
self._popen = popen_process
self.pid = popen_process.pid

def is_alive(self):
return self._popen.poll() is None

def terminate(self):
self._popen.terminate()

def join(self, timeout=None):
return self._popen.wait(timeout)

def close(self):
# Close the subprocess pipes
if self._popen.stdout:
Expand All @@ -426,7 +427,7 @@ def close(self):
self._popen.stderr.close()
if self._popen.stdin:
self._popen.stdin.close()

return SubprocessWrapper(process)

class TransferManagerHandleBase(ABC):
Expand Down Expand Up @@ -500,7 +501,7 @@ def __init__(self,
def _start_process(self) -> None:
if self.process is not None and self.process.is_alive():
return

self.process = self.mp_ctx.Process(
target=self._process_worker,
args=(self.model_config,
Expand Down