Skip to content

Commit 1757460

Browse files
committed
[KVCache][Feature] implement storage prefetch ZMQ pipeline in Scheduler and Worker
Storage prefetch (Storage → CPU) previously had no runtime execution path in Scheduler/Worker: the Scheduler prepared host-block metadata but the actual data transfer was never triggered. Workers also had no mechanism to receive prefetch commands or report completion, leaving LOADING_FROM_STORAGE blocks permanently stuck and never promoted to HOST. - Add `_prefetch_node_map: Dict[int, BlockNode]` to track in-flight blocks by host_block_id for O(1) status lookup. - `prepare_prefetch_metadata`: register returned nodes into `_prefetch_node_map`. - New `update_storage_blocks_to_host(host_block_ids)`: transition LOADING_FROM_STORAGE → HOST after all TP workers confirm transfer done. - New `abort_prefetch_blocks(host_block_ids)`: remove nodes from RadixTree and release host pool blocks on transfer failure. - Add per-worker ZMQ PUSH/PULL servers (`_prefetch_cmd_servers`, `_prefetch_done_servers`), one pair per TP worker, keyed by local_rank. - `_init_prefetch_zmq_servers()`: initialize servers at startup when storage backend is configured. - `_prefetch_storage_cache()`: after inserting host blocks, serialize `StorageMetadata` and broadcast to all TP workers via ZMQ PUSH; then poll PULL done sockets until all workers reply, call `update_storage_blocks_to_host` on success or `abort_prefetch_blocks` on failure. - Add `receive_pyobj_once(block=False)`: non-blocking (or blocking) receive helper returning `(error, data)` tuple; used by Scheduler to poll done messages and by Worker in the prefetch loop. - Add `init_prefetch_zmq_clients()`: connect ZMQ PULL/PUSH clients to Scheduler servers for this worker's local_rank; start daemon `_prefetch_loop` thread. - `_prefetch_loop()`: background thread receiving `StorageMetadata` commands, calling `cache_controller.prefetch_from_storage`, waiting for `AsyncTaskHandler.wait`, and replying with ok/error status. - Add `TestUpdateStorageBlocksToHost` with 6 test cases covering: status transition, multi-block, unknown id, empty list, wrong status, and initial-empty-map assertions. No additional build steps. Enable storage prefetch via existing config: ```bash python -m fastdeploy.entrypoints.openai.api_server \ --kvcache-storage-backend <backend> \ --enable-prefix-caching \ ... ```
1 parent ed97063 commit 1757460

5 files changed

Lines changed: 412 additions & 10 deletions

File tree

fastdeploy/cache_manager/v1/cache_manager.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def __init__(
107107
self._pending_backup: List[Tuple[List[BlockNode], List[int]]] = []
108108
self._pending_block_ids: List[int] = []
109109

110+
# Mapping from host_block_id -> BlockNode for LOADING_FROM_STORAGE blocks,
111+
# used to quickly update status to HOST once prefetch completes.
112+
self._prefetch_node_map: Dict[int, BlockNode] = {}
113+
110114
# Storage scheduler (create using factory method if backend is configured)
111115
self._storage_scheduler = create_storage_scheduler(self.cache_config)
112116

@@ -1004,11 +1008,79 @@ def prepare_prefetch_metadata(
10041008
if wasted_block_ids:
10051009
self._host_pool.release(wasted_block_ids)
10061010

1011+
# Register nodes in prefetch_node_map for fast status update on done
1012+
for node in prefetch_nodes:
1013+
self._prefetch_node_map[node.block_id] = node
1014+
10071015
return prefetch_nodes
10081016
except Exception as e:
10091017
logger.error(f"prepare_prefetch_metadata error: {e}, {str(traceback.format_exc())}")
10101018
return []
10111019

1020+
def update_storage_blocks_to_host(self, host_block_ids: List[int]) -> None:
1021+
"""
1022+
Mark storage-prefetched blocks as HOST after data transfer completes.
1023+
1024+
Called by Scheduler when all TP workers report prefetch done for a batch
1025+
of blocks. Transitions block status LOADING_FROM_STORAGE → HOST so that
1026+
these blocks become eligible for swap-in scheduling.
1027+
1028+
Args:
1029+
host_block_ids: List of host block IDs that finished loading.
1030+
"""
1031+
if not host_block_ids:
1032+
return
1033+
try:
1034+
with self._lock:
1035+
updated = 0
1036+
for block_id in host_block_ids:
1037+
node = self._prefetch_node_map.pop(block_id, None)
1038+
if node is None:
1039+
logger.warning(
1040+
f"[StoragePrefetch] update_storage_blocks_to_host: "
1041+
f"block_id={block_id} not found in prefetch_node_map"
1042+
)
1043+
continue
1044+
if node.cache_status == CacheStatus.LOADING_FROM_STORAGE:
1045+
node.cache_status = CacheStatus.HOST
1046+
updated += 1
1047+
else:
1048+
logger.warning(
1049+
f"[StoragePrefetch] update_storage_blocks_to_host: "
1050+
f"block_id={block_id} unexpected status={node.cache_status}"
1051+
)
1052+
logger.info(
1053+
f"[StoragePrefetch] update_storage_blocks_to_host: "
1054+
f"requested={len(host_block_ids)}, updated={updated}"
1055+
)
1056+
except Exception as e:
1057+
logger.error(f"update_storage_blocks_to_host error: {e}, {str(traceback.format_exc())}")
1058+
1059+
def abort_prefetch_blocks(self, host_block_ids: List[int]) -> None:
1060+
"""
1061+
Abort in-flight prefetch blocks on failure.
1062+
1063+
Removes nodes from the prefetch_node_map, deletes them from the RadixTree,
1064+
and releases their host pool blocks. Called when the storage→CPU transfer
1065+
fails so that LOADING_FROM_STORAGE blocks do not leak.
1066+
1067+
Args:
1068+
host_block_ids: List of host block IDs whose prefetch should be aborted.
1069+
"""
1070+
if not host_block_ids:
1071+
return
1072+
try:
1073+
with self._lock:
1074+
for block_id in host_block_ids:
1075+
node = self._prefetch_node_map.pop(block_id, None)
1076+
if node is None:
1077+
continue
1078+
self._radix_tree._remove_node_from_tree(node)
1079+
self._host_pool.release(host_block_ids)
1080+
logger.warning(f"[StoragePrefetch] abort_prefetch_blocks: released {len(host_block_ids)} host blocks")
1081+
except Exception as e:
1082+
logger.error(f"abort_prefetch_blocks error: {e}, {str(traceback.format_exc())}")
1083+
10121084
# ============ Reset Methods ============
10131085

10141086
def reset_cache(self) -> bool:

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 134 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,18 @@
2222
from collections.abc import Iterable
2323
from concurrent.futures import ThreadPoolExecutor
2424
from dataclasses import dataclass, field
25-
from typing import List, Union
25+
from typing import Dict, List, Set, Union
2626

2727
import numpy as np
2828
import paddle
29+
import zmq
2930

3031
from fastdeploy import envs
3132
from fastdeploy.cache_manager.multimodal_cache_manager import (
3233
EncoderCacheManager,
3334
ProcessorCacheManager,
3435
)
35-
from fastdeploy.cache_manager.v1.metadata import CacheSwapMetadata
36+
from fastdeploy.cache_manager.v1.metadata import CacheSwapMetadata, StorageMetadata
3637
from fastdeploy.engine.request import (
3738
BatchRequest,
3839
ImagePosition,
@@ -44,6 +45,7 @@
4445
from fastdeploy.engine.resource_manager import ResourceManager
4546
from fastdeploy.input.utils import IDS_TYPE_FLAG
4647
from fastdeploy.inter_communicator import IPCSignal
48+
from fastdeploy.inter_communicator.zmq_server import ZmqIpcServer
4749
from fastdeploy.metrics.metrics import main_process_metrics
4850
from fastdeploy.multimodal.hasher import MultimodalHasher
4951
from fastdeploy.platforms import current_platform
@@ -252,6 +254,16 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l
252254
# Scheduler-side requests that have not been moved into resource manager waiting queue yet.
253255
self.scheduler_unhandled_request_num = 0
254256

257+
# ---- Storage Prefetch ZMQ channels (Scheduler side) ----
258+
# Initialized only when storage backend is configured.
259+
# One PUSH cmd socket + one PULL done socket per worker local_rank.
260+
# local_rank = dp_rank * tp_size + tp_rank
261+
self._prefetch_cmd_servers: Dict[int, ZmqIpcServer] = {}
262+
self._prefetch_done_servers: Dict[int, ZmqIpcServer] = {}
263+
264+
if self.config.cache_config.kvcache_storage_backend and self.enable_cache_manager_v1:
265+
self._init_prefetch_zmq_servers()
266+
255267
def allocated_slots(self, request: Request):
256268
return len(request.block_tables) * self.config.cache_config.block_size
257269

@@ -1264,6 +1276,29 @@ def apply_async_preprocess(self, request: Request) -> None:
12641276
self.async_preprocess_pool.submit(self._prefetch_storage_cache, request)
12651277
)
12661278

1279+
def _init_prefetch_zmq_servers(self) -> None:
1280+
"""
1281+
Initialize per-worker-rank ZMQ PUSH/PULL sockets for storage prefetch.
1282+
1283+
Called once during __init__ when storage backend is enabled.
1284+
Creates:
1285+
- prefetch_cmd_server[local_rank]: PUSH → Worker (send StorageMetadata)
1286+
- prefetch_done_server[local_rank]: PULL ← Worker (receive done notification)
1287+
1288+
local_rank = dp_rank * tp_size + tp_rank, covers all workers in this DP group.
1289+
"""
1290+
tp_size = self.config.parallel_config.tensor_parallel_size
1291+
dp_rank = self.config.parallel_config.local_data_parallel_id
1292+
port = self.config.parallel_config.local_engine_worker_queue_port
1293+
1294+
for tp_rank in range(tp_size):
1295+
local_rank = dp_rank * tp_size + tp_rank
1296+
cmd_name = f"prefetch_cmd_rank{local_rank}_{port}"
1297+
done_name = f"prefetch_done_rank{local_rank}_{port}"
1298+
self._prefetch_cmd_servers[local_rank] = ZmqIpcServer(cmd_name, zmq.PUSH)
1299+
self._prefetch_done_servers[local_rank] = ZmqIpcServer(done_name, zmq.PULL)
1300+
llm_logger.info(f"[StoragePrefetch] init ZMQ servers: cmd={cmd_name}, done={done_name}")
1301+
12671302
def _prefetch_storage_cache(self, request: Request) -> None:
12681303
"""
12691304
Asynchronously prefetch KV cache blocks from storage to host memory.
@@ -1274,29 +1309,118 @@ def _prefetch_storage_cache(self, request: Request) -> None:
12741309
2. Allocate host blocks for them.
12751310
3. Insert those blocks into the RadixTree with LOADING_FROM_STORAGE status.
12761311
1277-
The actual data transfer (storage → host memory) is handled by the Worker
1278-
via cache_controller.prefetch_from_storage once the batch is dispatched.
1312+
Then immediately sends a StorageMetadata message to all TP Workers via ZMQ,
1313+
so Workers can start the actual storage→CPU transfer independently of forward.
12791314
12801315
Args:
12811316
request: The request to prefetch cache for.
12821317
"""
1318+
host_block_ids: List[int] = []
12831319
try:
12841320
if not self.cache_manager.enable_prefix_caching:
12851321
return
12861322
llm_logger.debug(f"[StoragePrefetch] start async prefetch for request_id={request.request_id}")
12871323
self.cache_manager.match_prefix(request, skip_storage=False)
12881324
match_result = request.match_result
1289-
if match_result is not None:
1290-
request.match_result = None
1325+
request.match_result = None
1326+
if match_result is None or match_result.matched_storage_nums == 0:
1327+
return
12911328

1292-
llm_logger.info(
1293-
f"[StoragePrefetch] request_id={request.request_id} "
1294-
f"storage_matched={match_result.matched_storage_nums} blocks"
1329+
# Collect host_block_ids and hash_values from matched storage nodes
1330+
storage_nodes = match_result.storage_nodes
1331+
host_block_ids = [node.block_id for node in storage_nodes]
1332+
hash_values = [node.hash_value for node in storage_nodes]
1333+
1334+
llm_logger.info(
1335+
f"[StoragePrefetch] request_id={request.request_id} "
1336+
f"storage_matched={match_result.matched_storage_nums} blocks, "
1337+
f"host_block_ids={host_block_ids}"
1338+
)
1339+
1340+
if not self._prefetch_cmd_servers:
1341+
return
1342+
1343+
metadata = StorageMetadata(
1344+
hash_values=hash_values,
1345+
block_ids=host_block_ids,
1346+
direction="load",
1347+
)
1348+
1349+
# Build the payload with request_id for done matching
1350+
payload = {
1351+
"request_id": request.request_id,
1352+
"metadata": metadata,
1353+
}
1354+
1355+
# Send to all TP workers in this DP group
1356+
for local_rank, cmd_server in self._prefetch_cmd_servers.items():
1357+
try:
1358+
cmd_server.send_pyobj(payload)
1359+
except Exception as e:
1360+
llm_logger.error(f"[StoragePrefetch] failed to send cmd to rank={local_rank}: {e}")
1361+
1362+
# Block in this thread until all TP workers report done.
1363+
# This mirrors _download_features: the future is considered complete only
1364+
# when the actual storage→CPU transfer has finished on every worker.
1365+
expected_count = len(self._prefetch_cmd_servers)
1366+
done_ranks: Set[int] = set()
1367+
failed_ranks: Set[int] = set()
1368+
poll_interval = 0.001 # 1ms
1369+
1370+
while len(done_ranks) + len(failed_ranks) < expected_count:
1371+
for local_rank, done_server in self._prefetch_done_servers.items():
1372+
if local_rank in done_ranks or local_rank in failed_ranks:
1373+
continue
1374+
err, msg = done_server.receive_pyobj_once(block=False)
1375+
if err is not None:
1376+
llm_logger.warning(
1377+
f"[StoragePrefetch] done_server rank={local_rank} socket error: {err}, "
1378+
f"request_id={request.request_id}"
1379+
)
1380+
failed_ranks.add(local_rank)
1381+
continue
1382+
if msg is None:
1383+
continue
1384+
recv_req_id = msg.get("request_id", "")
1385+
if recv_req_id != request.request_id:
1386+
# Message for a different request; skip and let that request's
1387+
# thread poll its own done message. This should not normally happen
1388+
# since each worker sends done to the same socket, but guard anyway.
1389+
llm_logger.warning(
1390+
f"[StoragePrefetch] rank={local_rank} received done for unexpected "
1391+
f"request_id={recv_req_id}, expected={request.request_id}, skipping"
1392+
)
1393+
continue
1394+
if msg.get("status") != "ok":
1395+
llm_logger.warning(
1396+
f"[StoragePrefetch] rank={local_rank} worker reported prefetch failure for "
1397+
f"request_id={request.request_id}: {msg.get('error')}"
1398+
)
1399+
failed_ranks.add(local_rank)
1400+
continue
1401+
done_ranks.add(local_rank)
1402+
1403+
if len(done_ranks) + len(failed_ranks) < expected_count:
1404+
time.sleep(poll_interval)
1405+
1406+
if failed_ranks:
1407+
llm_logger.warning(
1408+
f"[StoragePrefetch] request_id={request.request_id} prefetch failed on "
1409+
f"ranks={failed_ranks}, aborting {len(host_block_ids)} host blocks"
12951410
)
1296-
# TODO: check if any of the block is still LOADING_FROM_STORAGE, if so, request.async_process_futures.append(self._prefetch_storage_cache)
1411+
self.cache_manager.abort_prefetch_blocks(host_block_ids)
1412+
return
1413+
1414+
# All workers done successfully: update CacheManager block status to HOST
1415+
self.cache_manager.update_storage_blocks_to_host(host_block_ids)
1416+
llm_logger.info(
1417+
f"[StoragePrefetch] request_id={request.request_id} all {expected_count} TP workers done, "
1418+
f"updated {len(host_block_ids)} blocks to HOST"
1419+
)
12971420

12981421
except Exception as e:
12991422
llm_logger.error(f"[StoragePrefetch] request_id={request.request_id} error: {e}")
1423+
self.cache_manager.abort_prefetch_blocks(host_block_ids)
13001424

13011425
def _has_features_info(self, task):
13021426
inputs = task.multimodal_inputs

fastdeploy/inter_communicator/zmq_client.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,28 @@ def recv_pyobj(self, flags: int = 0):
143143
return envelope["data"]
144144
return envelope
145145

146+
def receive_pyobj_once(self, block=False):
147+
"""
148+
Receive a single Pickle-serializable message from the socket.
149+
150+
Args:
151+
block: If True, block until a message arrives. If False, return immediately.
152+
153+
Returns:
154+
Tuple of (error, data). error is None on success, data is None if no message.
155+
"""
156+
self._ensure_socket()
157+
if self.socket is None or self.socket.closed:
158+
return "zmq socket has closed", None
159+
try:
160+
flags = 0 if block else zmq.NOBLOCK
161+
return None, self.recv_pyobj(flags=flags)
162+
except zmq.Again:
163+
return None, None
164+
except Exception as e:
165+
llm_logger.warning(f"[ZmqClient] receive_pyobj_once error: {e}")
166+
return str(e), None
167+
146168
@abstractmethod
147169
def close(self):
148170
pass

0 commit comments

Comments
 (0)