diff --git a/flexkv/common/config.py b/flexkv/common/config.py index c79c1e5773..bbd941cd43 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -12,9 +12,9 @@ @dataclass class ModelConfig: - num_layers: int = 0 - num_kv_heads: int = 0 - head_size: int = 0 + num_layers: int = 1 + num_kv_heads: int = 1 + head_size: int = 1 use_mla: bool = False dtype: torch.dtype = torch.bfloat16 diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index 12d4d8ef14..9f40436b60 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -23,6 +23,8 @@ class FlexKVConfig: #base config server_recv_port: str = "" + gpu_register_port: str = "" + # cache config cache_config: CacheConfig = field(default_factory=CacheConfig) @@ -35,6 +37,8 @@ class FlexKVConfig: def __post_init__(self): if self.server_recv_port == "": self.server_recv_port = GLOBAL_CONFIG_FROM_ENV.server_recv_port + if self.gpu_register_port == "": + self.gpu_register_port = self.server_recv_port + "_gpu_register" update_default_config_from_user_config(self.model_config, self.cache_config, self.user_config) @classmethod @@ -68,3 +72,45 @@ def post_init_from_vllm_config( self.model_config.dp_size = vllm_config.parallel_config.data_parallel_size self.__post_init__() + + def post_init_from_sglang_config( + self, + sglang_config, + tp_size: int, + page_size: int, + ): + """ + Initialize FlexKVConfig fields from sglang config. + Args: + sglang_config: sglang.srt.configs.model_config.ModelConfig-like object + tp_size: tensor parallel size used by sglang + page_size: KV block size (tokens per block) used by sglang + """ + # cache config + self.cache_config.tokens_per_block = int(page_size) + + self.model_config.num_layers = int(getattr(sglang_config, "num_hidden_layers", 0)) + + if hasattr(sglang_config, "get_num_kv_heads"): + try: + self.model_config.num_kv_heads = int(sglang_config.get_num_kv_heads(tp_size)) + except Exception: + self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) + else: + self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0)) + self.model_config.head_size = int(getattr(sglang_config, "head_dim", 0)) + + self.model_config.dtype = getattr(sglang_config, "dtype", torch.bfloat16) + + attn_arch = getattr(sglang_config, "attention_arch", None) + use_mla = False + if hasattr(attn_arch, "name"): + use_mla = (attn_arch.name.upper() == "MLA") + elif isinstance(attn_arch, str): + use_mla = (attn_arch.upper() == "MLA") + self.model_config.use_mla = use_mla + + self.model_config.tp_size = int(tp_size) + self.model_config.dp_size = int(getattr(sglang_config, "dp_size", 1)) + + self.__post_init__() diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index cbf795648b..06aa79e618 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -31,7 +31,8 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, dp_client_id: int = 0, - server_recv_port: str = ""): + server_recv_port: str = "", + gpu_register_port: str = ""): flexkv_logger.info(f"{model_config = }") flexkv_logger.info(f"{cache_config = }") flexkv_logger.info(f"{GLOBAL_CONFIG_FROM_ENV = }") @@ -42,7 +43,10 @@ def __init__(self, self.server_recv_port = server_recv_port else: self.server_recv_port = GLOBAL_CONFIG_FROM_ENV.server_recv_port - self.gpu_register_port = self.server_recv_port + "_gpu_register" + if gpu_register_port != "": + self.gpu_register_port = gpu_register_port + else: + self.gpu_register_port = self.server_recv_port + "_gpu_register" self.server_client_mode = model_config.dp_size > 1 or GLOBAL_CONFIG_FROM_ENV.server_client_mode self.dp_client_id = dp_client_id