Skip to content

Commit 939dfa4

Browse files
authored
[BugFix][Cherry-Pick] Cp fix eb5 prefix cache(#5879) (#5881)
* fix eb5 prefix bug * update code * update code * update code * update code
1 parent ed3db9d commit 939dfa4

4 files changed

Lines changed: 124 additions & 389 deletions

File tree

fastdeploy/cache_manager/prefix_cache_manager.py

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,66 +1272,6 @@ def hash_block_features(self, input_ids, extra_keys: list = []):
12721272
"""
12731273
return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest()
12741274

1275-
def _revert_match_blocks(
1276-
self,
1277-
request,
1278-
matched_token_num: int,
1279-
block_size: int,
1280-
chunk_idx: int,
1281-
match_node_ids: list,
1282-
matche_nodes: list,
1283-
match_gpu_block_ids: list,
1284-
match_cpu_block_ids: list,
1285-
gpu_match_token_num: int,
1286-
cpu_match_token_num: int,
1287-
swap_node_ids: list,
1288-
):
1289-
# position = request.multimodal_inputs["mm_positions"][chunk_idx]
1290-
# revert_tokens = matched_token_num - position.offset
1291-
# TODO(chengyanfu): fix when is_chunked_mm_input=True, revert all matched tokens
1292-
revert_tokens = matched_token_num
1293-
match_block_ids = [node.block_id for node in matche_nodes]
1294-
logger.warning(
1295-
f"match_block: req_id {request.request_id} revert tokens: {revert_tokens} from matched nodes: {match_block_ids}"
1296-
)
1297-
while revert_tokens >= block_size:
1298-
if len(matche_nodes) == 0:
1299-
logger.error(f"req_id {request.request_id} revert nodes error, tokens: {revert_tokens}")
1300-
break
1301-
revert_tokens -= block_size
1302-
revert_block = matche_nodes.pop()
1303-
revert_block_id = revert_block.block_id
1304-
if revert_block_id in match_gpu_block_ids:
1305-
match_gpu_block_ids.remove(revert_block_id)
1306-
match_node_ids.remove(revert_block.node_id)
1307-
gpu_match_token_num -= block_size
1308-
elif revert_block_id in match_cpu_block_ids:
1309-
match_cpu_block_ids.remove(revert_block_id)
1310-
match_node_ids.remove(revert_block.node_id)
1311-
cpu_match_token_num -= block_size
1312-
else:
1313-
logger.error(
1314-
f"req_id {request.request_id} revert nodes error, nodes: {revert_block_id}, "
1315-
f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}"
1316-
)
1317-
break
1318-
if revert_block_id in swap_node_ids:
1319-
swap_node_ids.remove(revert_block_id)
1320-
1321-
if revert_tokens > 0:
1322-
last_block_id = matche_nodes[-1].block_id
1323-
if last_block_id in match_gpu_block_ids:
1324-
gpu_match_token_num -= revert_tokens
1325-
elif last_block_id in match_cpu_block_ids:
1326-
cpu_match_token_num -= revert_tokens
1327-
else:
1328-
logger.error(
1329-
f"req_id {request.request_id} revert nodes error, revert_tokens: {revert_tokens}, nodes: {last_block_id}, "
1330-
f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}"
1331-
)
1332-
current_node = self.radix_tree_root if len(matche_nodes) == 0 else matche_nodes[-1]
1333-
return gpu_match_token_num, cpu_match_token_num, current_node
1334-
13351275
def mm_match_block(self, request, block_size):
13361276
"""
13371277
Match and retrieve cached blocks for multimodal requests using a radix tree structure.
@@ -1420,28 +1360,6 @@ def mm_match_block(self, request, block_size):
14201360
if has_modified_cpu_lru_leaf_heap:
14211361
heapq.heapify(self.cpu_lru_leaf_heap)
14221362

1423-
if self.cache_config.disable_chunked_mm_input:
1424-
matched_token_num = gpu_match_token_num + cpu_match_token_num
1425-
is_chunked, chunk_idx = self.is_chunked_mm_input(request.multimodal_inputs, matched_token_num)
1426-
if is_chunked:
1427-
(
1428-
gpu_match_token_num,
1429-
cpu_match_token_num,
1430-
current_match_node,
1431-
) = self._revert_match_blocks(
1432-
request=request,
1433-
matched_token_num=matched_token_num,
1434-
block_size=block_size,
1435-
chunk_idx=chunk_idx,
1436-
match_node_ids=match_node_ids,
1437-
matche_nodes=matche_nodes,
1438-
match_gpu_block_ids=match_gpu_block_ids,
1439-
match_cpu_block_ids=match_cpu_block_ids,
1440-
gpu_match_token_num=gpu_match_token_num,
1441-
cpu_match_token_num=cpu_match_token_num,
1442-
swap_node_ids=swap_node_ids,
1443-
)
1444-
14451363
logger.info(f"match_block: req_id {request.request_id} matched nodes: {match_node_ids}")
14461364
return (
14471365
match_gpu_block_ids,

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,21 @@ def _is_mm_request(self, request):
353353

354354
return False
355355

356+
def revert_chunked_mm_input(self, mm_inputs, matched_token_num):
357+
"""
358+
revert mm_inputs that is chunked
359+
"""
360+
if mm_inputs is None or "mm_positions" not in mm_inputs or len(mm_inputs["mm_positions"]) == 0:
361+
return matched_token_num
362+
363+
for idx in range(len(mm_inputs["mm_positions"])):
364+
position = mm_inputs["mm_positions"][idx]
365+
if position.offset < matched_token_num < position.offset + position.length:
366+
return position.offset
367+
elif matched_token_num < position.offset:
368+
break
369+
return matched_token_num
370+
356371
def _get_num_new_tokens(self, request, token_budget):
357372
# TODO: set condition to new _get_num_new_tokens
358373
num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens
@@ -904,11 +919,20 @@ def get_prefix_cached_blocks(self, request: Request):
904919
main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num)
905920
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
906921

907-
if matched_token_num == request.need_prefill_tokens:
908-
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
909-
request.skip_allocate = True
922+
if self.config.cache_config.disable_chunked_mm_input:
923+
if matched_token_num == request.need_prefill_tokens:
924+
matched_token_num = matched_token_num - self.config.cache_config.block_size
925+
request.skip_allocate = True
926+
request.num_computed_tokens = self.revert_chunked_mm_input(
927+
request.multimodal_inputs, matched_token_num
928+
)
910929
else:
911-
request.num_computed_tokens = matched_token_num
930+
if matched_token_num == request.need_prefill_tokens:
931+
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
932+
request.skip_allocate = True
933+
else:
934+
request.num_computed_tokens = matched_token_num
935+
llm_logger.info(f"request {request.request_id} num_computed_tokens: {request.num_computed_tokens}")
912936
request.cache_prepare_time = time.time() - cache_prepare_time
913937
return True
914938
except Exception as e:

0 commit comments

Comments
 (0)