diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index bafb3a3999..519e78caf5 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -65,12 +65,15 @@ def post_init_from_vllm_config( self.cache_config.tokens_per_block = vllm_config.cache_config.block_size self.model_config.num_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) - self.model_config.num_kv_heads = vllm_config.model_config.get_num_kv_heads(vllm_config.parallel_config) self.model_config.head_size = vllm_config.model_config.get_head_size() self.model_config.dtype = vllm_config.model_config.dtype self.model_config.use_mla = vllm_config.model_config.is_deepseek_mla self.model_config.tp_size = vllm_config.parallel_config.tensor_parallel_size self.model_config.dp_size = vllm_config.parallel_config.data_parallel_size + if self.model_config.use_mla: + self.model_config.num_kv_heads = 1 + else: + self.model_config.num_kv_heads = vllm_config.model_config.get_total_num_kv_heads() self.__post_init__()