Skip to content

Commit 414215a

Browse files
committed
naive post process for HND
Signed-off-by: Chendi Xue <[email protected]>
1 parent 87b1de8 commit 414215a

1 file changed

Lines changed: 77 additions & 1 deletion

File tree

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,75 @@ def permute_device_kv(self, block_ids: list[int]):
13161316
)
13171317
cache.index_copy_(0, indices, permuted_blocks)
13181318

1319+
def blocksize_post_process(self, block_ids: list[int]):
1320+
def _process_local_gt_remote(blocks_to_update):
1321+
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
1322+
remote_block_size = int(block_size * self.block_size_ratio)
1323+
n_blocks = int(1 / self.block_size_ratio)
1324+
# actual permute is to convert
1325+
# for local blocksize > remote blocksize
1326+
# ex: local blocksize = 16 tokens, remote blocksize = 4 tokens
1327+
# local block0 = remote [block0, 1, 2, 3]
1328+
# remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
1329+
# local is |h0-b0..................|h1-b0..................|...
1330+
# permute is to:
1331+
# 1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
1332+
# 2. permute => (H, nblocks, remoteN, D)
1333+
# 3. flatten => (H, nblocks, remoteN)
1334+
permuted_blocks = (
1335+
blocks_to_update.reshape(
1336+
-1, n_blocks, n_kv_heads, remote_block_size, head_size
1337+
)
1338+
.permute(0, 2, 1, 3, 4)
1339+
.flatten(2, 3)
1340+
)
1341+
return permuted_blocks
1342+
1343+
def _process_local_lt_remote(blocks_to_update):
1344+
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
1345+
remote_block_size = int(block_size * self.block_size_ratio)
1346+
n_blocks = int(1 / self.block_size_ratio)
1347+
# actual permute is to convert
1348+
# for local blocksize > remote blocksize
1349+
# ex: local blocksize = 16 tokens, remote blocksize = 4 tokens
1350+
# local block0 = remote [block0, 1, 2, 3]
1351+
# remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
1352+
# local is |h0-b0..................|h1-b0..................|...
1353+
# permute is to:
1354+
# 1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
1355+
# 2. permute => (H, nblocks, remoteN, D)
1356+
# 3. flatten => (H, nblocks, remoteN)
1357+
permuted_blocks = (
1358+
blocks_to_update.reshape(
1359+
-1, n_blocks, n_kv_heads, remote_block_size, head_size
1360+
)
1361+
.permute(0, 2, 1, 3, 4)
1362+
.flatten(2, 3)
1363+
)
1364+
return permuted_blocks
1365+
1366+
split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer)
1367+
sample_cache = list(self.device_kv_caches.values())[0][0]
1368+
indices = torch.tensor(block_ids, device=sample_cache.device)
1369+
fn = (
1370+
_process_local_gt_remote
1371+
if self.block_size_ratio < 1
1372+
else _process_local_lt_remote
1373+
)
1374+
1375+
for _, cache_or_caches in self.device_kv_caches.items():
1376+
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
1377+
for cache in cache_list:
1378+
blocks_to_update = cache.index_select(0, indices)
1379+
# because kv_cache is always using original layout NHD as virtual shape
1380+
# while stride can be either HND / NHD at initialization.
1381+
# we need to firstly get physical view of the tensor
1382+
cache.index_copy_(
1383+
0,
1384+
indices,
1385+
fn(blocks_to_update.permute(0, 2, 1, 3)).permute(0, 2, 1, 3),
1386+
)
1387+
13191388
def get_finished(self) -> tuple[set[str], set[str]]:
13201389
"""
13211390
Get requests that are done sending or recving on this specific worker.
@@ -1339,11 +1408,18 @@ def get_finished(self) -> tuple[set[str], set[str]]:
13391408
)
13401409

13411410
# clean up metadata for completed requests
1411+
block_ids_for_blocksize_post_process = []
13421412
for req_id in done_recving:
1413+
meta = self._recving_metadata.pop(req_id)
13431414
if self.use_host_buffer:
1344-
meta = self._recving_metadata.pop(req_id)
13451415
self.sync_recved_kv_to_device(req_id, meta)
13461416

1417+
# post processing for heteroblocksize
1418+
if self.block_size_ratio < 1 and self.kv_cache_layout == "HND":
1419+
block_ids_for_blocksize_post_process += meta.local_block_ids
1420+
if len(block_ids_for_blocksize_post_process) > 0:
1421+
self.blocksize_post_process(block_ids_for_blocksize_post_process)
1422+
13471423
# Handle timeout to avoid stranding blocks on remote.
13481424
now = time.perf_counter()
13491425
while self._reqs_to_send:

0 commit comments

Comments
 (0)