From 9b6848201714c251f93418dff9b0b7973ae298f7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 19 Feb 2025 23:59:14 -0800 Subject: [PATCH] [V1][Minor] Print KV cache size in token counts --- vllm/v1/core/kv_cache_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6dec87d4dd20..e3eb6b24c195 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -519,11 +519,13 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) num_blocks = num_gpu_blocks_override - logger.info("# GPU blocks: %d", num_blocks) - max_concurrency = (num_blocks * vllm_config.cache_config.block_size / - vllm_config.model_config.max_model_len) + num_tokens = num_blocks * vllm_config.cache_config.block_size + num_tokens_str = f"{num_tokens:,}" + logger.info("GPU KV cache size: %s tokens", num_tokens_str) + max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + max_concurrency = num_tokens / vllm_config.model_config.max_model_len logger.info("Maximum concurrency for %s tokens per request: %.2fx", - vllm_config.model_config.max_model_len, max_concurrency) + max_model_len_str, max_concurrency) per_layer_size = page_size * num_blocks