diff --git a/flexkv/common/memory_handle.py b/flexkv/common/memory_handle.py index d6d43b7c6c..03beabe3c5 100644 --- a/flexkv/common/memory_handle.py +++ b/flexkv/common/memory_handle.py @@ -13,16 +13,17 @@ class cudaIpcMemHandle_t(ctypes.Structure): - _fields_ = [('reserved', ctypes.c_byte * 64)] + _fields_ = [("reserved", ctypes.c_byte * 64)] + # Load CUDA runtime library try: - cudart = ctypes.CDLL('libcudart.so') -except OSError: + cudart = ctypes.CDLL("libcudart.so") +except: try: - cudart = ctypes.CDLL('libcudart.so.12') - except OSError: - cudart = ctypes.CDLL('libcudart.so.11') + cudart = ctypes.CDLL("libcudart.so.12") + except: + cudart = ctypes.CDLL("libcudart.so.11") # CUDA IPC handle size (64 bytes on Linux) CUDA_IPC_HANDLE_SIZE = 64 @@ -31,6 +32,7 @@ class cudaIpcMemHandle_t(ctypes.Structure): cudaSuccess = 0 cudaErrorInvalidValue = 11 + @dataclass class TensorSharedHandle: rebuild_func: Optional[Callable] @@ -42,6 +44,7 @@ class TensorSharedHandle: tensor_shape: Optional[Tuple[int, ...]] = None tensor_dtype: Optional[torch.dtype] = None tensor_numel: Optional[int] = None + offset: int = 0 def __init__( self, @@ -53,15 +56,13 @@ def __init__( tensor_dtype: Optional[ Union[torch.dtype, str] ] = None, # only used when data is bytes + offset: int = 0, # offset in bytes from base pointer (for memory pool allocations) ): """ Now we support three ways to construct TensorSharedHandle: - If data is a tensor that is managed by torch, we will use the reduce_tensor method - to export the TensorSharedHandle. - If data is a tensor that is allocated by cudamalloc, we will use the cudaIpcGetMemHandle - method to export the TensorSharedHandle. - If data is bytes-like, it means the memory has already been shared by CUDA IPC, - we will skip the export process to construct the TensorSharedHandle. + If data is a tensor that is managed by torch, we will use the reduce_tensor method to export the TensorSharedHandle. + If data is a tensor that is allocated by cudamalloc, we will use the cudaIpcGetMemHandle method to export the TensorSharedHandle. + If data is bytes-like, it means the memory has already been shared by CUDA IPC, we will skip the export process to construct the TensorSharedHandle. """ self.use_direct_ipc = False @@ -76,7 +77,7 @@ def __init__( elif isinstance(data, bytes): self._init_from_ipc_handle( - bytes(data), device_id, tensor_shape, tensor_dtype + bytes(data), device_id, tensor_shape, tensor_dtype, offset=offset ) return else: @@ -108,6 +109,9 @@ def _init_from_tensor( tmp_list = list(self.rebuild_args) tmp_list[6] = device_id self.rebuild_args = tuple(tmp_list) + flexkv_logger.debug( + f"Tensor exported via PyTorch CUDA IPC for device {self.device}" + ) return except RuntimeError as e: flexkv_logger.warning(f"PyTorch CUDA IPC export failed: {e}") @@ -125,8 +129,12 @@ def _init_from_tensor( ) self.rebuild_func = None self.rebuild_args = None + self.offset = 0 ## only used when constructing from direct ipc handle + flexkv_logger.info( + f"Tensor exported via direct CUDA IPC: tensor.device={tensor.device}, passed device_id={device_id}, final self.device={self.device}" + ) except Exception as e: - raise RuntimeError(f"Both PyTorch and direct CUDA IPC export failed: {e}") from e + raise RuntimeError(f"Both PyTorch and direct CUDA IPC export failed: {e}") def _init_from_ipc_handle( self, @@ -134,6 +142,7 @@ def _init_from_ipc_handle( device_id: int, tensor_shape: Optional[Tuple[int, ...]], tensor_dtype: Optional[Union[torch.dtype, str]], + offset: int = 0, ) -> None: if ipc_handle is None: raise ValueError("ipc_handle is required when constructing from external handle") @@ -147,7 +156,7 @@ def _init_from_ipc_handle( resolved_shape = tuple(int(dim) for dim in tensor_shape) resolved_dtype = self._ensure_torch_dtype(tensor_dtype) - self.use_direct_ipc = True + self.use_direct_ipc = True # must set to true when constructing from direct ipc handle self.ipc_handle = bytes(ipc_handle) self.tensor_shape = resolved_shape self.tensor_dtype = resolved_dtype @@ -158,6 +167,13 @@ def _init_from_ipc_handle( self.device = torch.device(f"cuda:{device_id}") self.rebuild_func = None self.rebuild_args = None + self.offset = offset + + + flexkv_logger.info( + f"TensorSharedHandle constructed from external IPC handle {self.ipc_handle.hex()} on device {self.device} \ + with shape {self.tensor_shape} and dtype {self.tensor_dtype}, ptr offset={offset}" + ) @staticmethod def _ensure_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: @@ -188,7 +204,7 @@ def _ensure_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: def get_tensor(self) -> torch.Tensor: if self.use_direct_ipc: return self._import_cuda_ipc_handle( - self.ipc_handle, self.tensor_shape, self.tensor_dtype, self.device + self.ipc_handle, self.tensor_shape, self.tensor_dtype, self.device, offset=self.offset ) else: return self._import_tensor_handle( @@ -204,119 +220,6 @@ def _export_tensor_handle( rebuild_func, rebuild_args = reductions.reduce_tensor(tensor) return rebuild_func, rebuild_args, device - @staticmethod - def _export_cuda_ipc_handle(tensor: torch.Tensor) -> bytes: - """ - 直接使用 CUDA IPC API 导出 tensor 的 IPC handle - """ - # Get device pointer - data_ptr = tensor.data_ptr() - device = tensor.device - - flexkv_logger.debug(f"Exporting CUDA IPC handle: device={device}, data_ptr={hex(data_ptr)}") - - # Ensure we're on the correct device - torch.cuda.set_device(device) - - # Create IPC handle buffer - # ipc_handle = ctypes.create_string_buffer(CUDA_IPC_HANDLE_SIZE) - ipc_handle = cudaIpcMemHandle_t() - - # Call cudaIpcGetMemHandle - result = cudart.cudaIpcGetMemHandle( - ctypes.byref(ipc_handle), - ctypes.c_void_p(data_ptr) - ) - - if result != cudaSuccess: - error_msg = f"cudaIpcGetMemHandle failed with error code {result} for device {device}, ptr={hex(data_ptr)}" - flexkv_logger.error(error_msg) - raise RuntimeError(error_msg) - - # Return handle as bytes - # handle_bytes = bytes(ipc_handle.raw) - handle_bytes = ctypes.string_at(ctypes.byref(ipc_handle), 64) - flexkv_logger.debug(f"IPC handle exported successfully, first 16 bytes: {handle_bytes.hex()}") - return handle_bytes - - @staticmethod - def _import_cuda_ipc_handle(ipc_handle: bytes, shape: Tuple[int, ...], - dtype: torch.dtype, device: torch.device) -> torch.Tensor: - """ - 直接使用 CUDA IPC API 从 handle 导入 tensor - """ - flexkv_logger.debug(f"Attempting to import CUDA IPC handle for device {device}") - - # Ensure CUDA is initialized in this process - if not torch.cuda.is_initialized(): - flexkv_logger.info("Initializing CUDA in subprocess") - torch.cuda.init() - - # Set device and create a dummy tensor to ensure context is created - device_id = device.index if device.index is not None else 0 - torch.cuda.set_device(device_id) - - # Force CUDA context creation - _ = torch.zeros(1, device=device) - - # Create IPC handle buffer - ipc_handle_buf = ctypes.create_string_buffer(ipc_handle, CUDA_IPC_HANDLE_SIZE) - - # 重建 IPC handle - handle = cudaIpcMemHandle_t() - ctypes.memmove(ctypes.byref(handle), ipc_handle, 64) - - # Open IPC memory handle - dev_ptr = ctypes.c_void_p() - result = cudart.cudaIpcOpenMemHandle( - ctypes.byref(dev_ptr), - handle, - ctypes.c_int(1) # cudaIpcMemLazyEnablePeerAccess = 1 - ) - flexkv_logger.debug(f"import CUDA IPC handle: device={device}, dev_ptr={hex(dev_ptr.value)}") - if result != cudaSuccess: - error_msg = f"cudaIpcOpenMemHandle failed with error code {result} for device {device_id}" - flexkv_logger.error(error_msg) - # flexkv_logger.error(f"IPC handle bytes (first 16): {ipc_handle[:16].hex()}") - flexkv_logger.error(f"IPC handle bytes (first 16): {ipc_handle.hex()}") - flexkv_logger.error(f"Current CUDA device: {torch.cuda.current_device()}") - flexkv_logger.error(f"Target device: {device_id}") - raise RuntimeError(error_msg) - - # Create tensor from pointer - numel = 1 - for dim in shape: - numel *= dim - - class CudaArrayInterface: - def __init__(self, data_ptr, shape, dtype, strides=None): - self.__cuda_array_interface__ = { - "data": (data_ptr, False), # (data_ptr, read_only) - "shape": tuple(shape), - "typestr": { - torch.float32: " torch.Tensor: + """ + Helper function to create a PyTorch tensor from a CUDA memory pointer. + + This function handles the special case of bfloat16 by using uint16 as an intermediate + type, since PyTorch's __cuda_array_interface__ doesn't support " bytes: @@ -384,12 +360,18 @@ def _import_cuda_ipc_handle( shape: Tuple[int, ...], dtype: torch.dtype, device: torch.device, + offset: int = 0, ) -> torch.Tensor: """ Using CUDA IPC API to import the tensor from the IPC handle + + Args: + ipc_handle: CUDA IPC memory handle (bytes) + shape: Tensor shape + dtype: Tensor dtype + device: Target CUDA device + offset: Offset in bytes from the base pointer (for memory pool allocations) """ - flexkv_logger.debug(f"Attempting to import CUDA IPC handle for device {device}") - # Ensure CUDA is initialized in this process if not torch.cuda.is_initialized(): flexkv_logger.info("Initializing CUDA in subprocess") @@ -401,9 +383,6 @@ def _import_cuda_ipc_handle( # Force CUDA context creation _ = torch.zeros(1, device=device) - flexkv_logger.debug( - f"CUDA context created for device {device_id}, current_device={torch.cuda.current_device()}" - ) # Create IPC handle buffer ipc_handle_buf = ctypes.create_string_buffer(ipc_handle, CUDA_IPC_HANDLE_SIZE) @@ -412,57 +391,38 @@ def _import_cuda_ipc_handle( handle = cudaIpcMemHandle_t() ctypes.memmove(ctypes.byref(handle), ipc_handle, 64) - # Open IPC memory handle - dev_ptr = ctypes.c_void_p() + # Open IPC memory handle to get base pointer + base_ptr = ctypes.c_void_p() result = cudart.cudaIpcOpenMemHandle( - ctypes.byref(dev_ptr), + ctypes.byref(base_ptr), handle, ctypes.c_int(1), # cudaIpcMemLazyEnablePeerAccess = 1 ) - flexkv_logger.debug( - f"import CUDA IPC handle: device={device}, dev_ptr={hex(dev_ptr.value)}" - ) + # Print GPU memory address for comparison with C++ side + if result != cudaSuccess: error_msg = f"cudaIpcOpenMemHandle failed with error code {result} for device {device_id}" flexkv_logger.error(error_msg) - # flexkv_logger.error(f"IPC handle bytes (first 16): {ipc_handle[:16].hex()}") - flexkv_logger.error(f"IPC handle bytes (first 16): {ipc_handle.hex()}") + flexkv_logger.error(f"IPC handle bytes (full): {ipc_handle.hex()}") flexkv_logger.error(f"Current CUDA device: {torch.cuda.current_device()}") flexkv_logger.error(f"Target device: {device_id}") raise RuntimeError(error_msg) - # Create tensor from pointer - numel = 1 - for dim in shape: - numel *= dim + # Calculate the actual data pointer: base_ptr + offset + data_ptr = base_ptr.value + offset + if offset > 0: + data_ptr_hex = hex(data_ptr) + base_ptr_hex = hex(base_ptr.value) + flexkv_logger.info( + f"_import_cuda_ipc_handle: Opened IPC handle: device={device}, base_gpu_ptr={base_ptr_hex}, offset={offset}, actual data_ptr={data_ptr_hex}" + ) + + # Create tensor from pointer using helper function + tensor = TensorSharedHandle._create_tensor_from_cuda_ptr( + data_ptr, shape, dtype, device + ) - class CudaArrayInterface: - def __init__(self, data_ptr, shape, dtype, strides=None): - self.__cuda_array_interface__ = { - "data": (data_ptr, False), # (data_ptr, read_only) - "shape": tuple(shape), - "typestr": { - torch.float32: "