Skip to content

Commit d3ea501

Browse files
authored
[V1][Minor] Print KV cache size in token counts (#13596)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 34aad51 commit d3ea501

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

vllm/v1/core/kv_cache_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -519,11 +519,13 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
519519
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
520520
num_blocks = num_gpu_blocks_override
521521

522-
logger.info("# GPU blocks: %d", num_blocks)
523-
max_concurrency = (num_blocks * vllm_config.cache_config.block_size /
524-
vllm_config.model_config.max_model_len)
522+
num_tokens = num_blocks * vllm_config.cache_config.block_size
523+
num_tokens_str = f"{num_tokens:,}"
524+
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
525+
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
526+
max_concurrency = num_tokens / vllm_config.model_config.max_model_len
525527
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
526-
vllm_config.model_config.max_model_len, max_concurrency)
528+
max_model_len_str, max_concurrency)
527529

528530
per_layer_size = page_size * num_blocks
529531

0 commit comments

Comments
 (0)