diff --git a/benchmarks/benchmark_cache_engine.py b/benchmarks/benchmark_cache_engine.py index a75a2008b7..52b2e6f09c 100644 --- a/benchmarks/benchmark_cache_engine.py +++ b/benchmarks/benchmark_cache_engine.py @@ -4,6 +4,7 @@ import cProfile import pstats import torch +import numpy as np from flexkv.cache.cache_engine import GlobalCacheEngine from flexkv.common.config import CacheConfig, ModelConfig @@ -31,23 +32,27 @@ def main(args): num_put_requests = 0 request_id = 0 for req in reqs: - fake_slot_mapping = torch.arange(req.token_mask[req.token_mask].sum(), dtype=torch.int64) + token_ids_np = req.token_ids.numpy().astype(np.int64) + token_mask_np = req.token_mask.numpy().astype(np.int64) + fake_slot_mapping = torch.arange(req.token_mask[req.token_mask].sum().item(), dtype=torch.int64).numpy().astype(np.int64) local_vars = { 'cache_engine': cache_engine, 'req': req, 'fake_slot_mapping': fake_slot_mapping, 'request_id': request_id, + 'token_ids_np': token_ids_np, + 'token_mask_np': token_mask_np } if req.request_type == "get": num_get_requests += 1 if not args.only_put: profiler.runctx('graph, return_mask, transfer_call_back, finished_ops_ids = ' - 'cache_engine.get(request_id, req.token_ids, req.token_mask, ' + 'cache_engine.get(request_id, token_ids_np, token_mask_np, ' 'fake_slot_mapping, -1, -1)', globals(), local_vars) else: graph, return_mask, transfer_call_back, finished_ops_ids = \ - cache_engine.get(request_id, req.token_ids, req.token_mask, + cache_engine.get(request_id, token_ids_np, token_mask_np, fake_slot_mapping, -1, -1) local_vars.update({ 'graph': graph, @@ -67,11 +72,11 @@ def main(args): num_put_requests += 1 if not args.only_get: profiler.runctx('graph, return_mask, transfer_call_back, finished_ops_ids = ' - 'cache_engine.put(request_id, req.token_ids, req.token_mask, fake_slot_mapping)', + 'cache_engine.put(request_id, token_ids_np, token_mask_np, fake_slot_mapping)', globals(), local_vars) else: graph, return_mask, transfer_call_back, finished_ops_ids = \ - cache_engine.put(request_id, req.token_ids, req.token_mask, fake_slot_mapping) + cache_engine.put(request_id, token_ids_np, token_mask_np, fake_slot_mapping) local_vars.update({ 'graph': graph, 'return_mask': return_mask,