@@ -1316,6 +1316,75 @@ def permute_device_kv(self, block_ids: list[int]):
13161316 )
13171317 cache .index_copy_ (0 , indices , permuted_blocks )
13181318
1319+ def blocksize_post_process (self , block_ids : list [int ]):
1320+ def _process_local_gt_remote (blocks_to_update ):
1321+ n_kv_heads , block_size , head_size = blocks_to_update .shape [1 :]
1322+ remote_block_size = int (block_size * self .block_size_ratio )
1323+ n_blocks = int (1 / self .block_size_ratio )
1324+ # actual permute is to convert
1325+ # for local blocksize > remote blocksize
1326+ # ex: local blocksize = 16 tokens, remote blocksize = 4 tokens
1327+ # local block0 = remote [block0, 1, 2, 3]
1328+ # remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
1329+ # local is |h0-b0..................|h1-b0..................|...
1330+ # permute is to:
1331+ # 1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
1332+ # 2. permute => (H, nblocks, remoteN, D)
1333+ # 3. flatten => (H, nblocks, remoteN)
1334+ permuted_blocks = (
1335+ blocks_to_update .reshape (
1336+ - 1 , n_blocks , n_kv_heads , remote_block_size , head_size
1337+ )
1338+ .permute (0 , 2 , 1 , 3 , 4 )
1339+ .flatten (2 , 3 )
1340+ )
1341+ return permuted_blocks
1342+
1343+ def _process_local_lt_remote (blocks_to_update ):
1344+ n_kv_heads , block_size , head_size = blocks_to_update .shape [1 :]
1345+ remote_block_size = int (block_size * self .block_size_ratio )
1346+ n_blocks = int (1 / self .block_size_ratio )
1347+ # actual permute is to convert
1348+ # for local blocksize > remote blocksize
1349+ # ex: local blocksize = 16 tokens, remote blocksize = 4 tokens
1350+ # local block0 = remote [block0, 1, 2, 3]
1351+ # remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
1352+ # local is |h0-b0..................|h1-b0..................|...
1353+ # permute is to:
1354+ # 1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
1355+ # 2. permute => (H, nblocks, remoteN, D)
1356+ # 3. flatten => (H, nblocks, remoteN)
1357+ permuted_blocks = (
1358+ blocks_to_update .reshape (
1359+ - 1 , n_blocks , n_kv_heads , remote_block_size , head_size
1360+ )
1361+ .permute (0 , 2 , 1 , 3 , 4 )
1362+ .flatten (2 , 3 )
1363+ )
1364+ return permuted_blocks
1365+
1366+ split_k_and_v = not (self .use_mla or self ._use_pallas or self ._use_flashinfer )
1367+ sample_cache = list (self .device_kv_caches .values ())[0 ][0 ]
1368+ indices = torch .tensor (block_ids , device = sample_cache .device )
1369+ fn = (
1370+ _process_local_gt_remote
1371+ if self .block_size_ratio < 1
1372+ else _process_local_lt_remote
1373+ )
1374+
1375+ for _ , cache_or_caches in self .device_kv_caches .items ():
1376+ cache_list = cache_or_caches if split_k_and_v else [cache_or_caches ]
1377+ for cache in cache_list :
1378+ blocks_to_update = cache .index_select (0 , indices )
1379+ # because kv_cache is always using original layout NHD as virtual shape
1380+ # while stride can be either HND / NHD at initialization.
1381+ # we need to firstly get physical view of the tensor
1382+ cache .index_copy_ (
1383+ 0 ,
1384+ indices ,
1385+ fn (blocks_to_update .permute (0 , 2 , 1 , 3 )).permute (0 , 2 , 1 , 3 ),
1386+ )
1387+
13191388 def get_finished (self ) -> tuple [set [str ], set [str ]]:
13201389 """
13211390 Get requests that are done sending or recving on this specific worker.
@@ -1339,11 +1408,18 @@ def get_finished(self) -> tuple[set[str], set[str]]:
13391408 )
13401409
13411410 # clean up metadata for completed requests
1411+ block_ids_for_blocksize_post_process = []
13421412 for req_id in done_recving :
1413+ meta = self ._recving_metadata .pop (req_id )
13431414 if self .use_host_buffer :
1344- meta = self ._recving_metadata .pop (req_id )
13451415 self .sync_recved_kv_to_device (req_id , meta )
13461416
1417+ # post processing for heteroblocksize
1418+ if self .block_size_ratio < 1 and self .kv_cache_layout == "HND" :
1419+ block_ids_for_blocksize_post_process += meta .local_block_ids
1420+ if len (block_ids_for_blocksize_post_process ) > 0 :
1421+ self .blocksize_post_process (block_ids_for_blocksize_post_process )
1422+
13471423 # Handle timeout to avoid stranding blocks on remote.
13481424 now = time .perf_counter ()
13491425 while self ._reqs_to_send :
0 commit comments