Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flexkv/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

@dataclass
class ModelConfig:
num_layers: int = 0
num_kv_heads: int = 0
head_size: int = 0
num_layers: int = 1
num_kv_heads: int = 1
head_size: int = 1
use_mla: bool = False
dtype: torch.dtype = torch.bfloat16

Expand Down
46 changes: 46 additions & 0 deletions flexkv/integration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class FlexKVConfig:
#base config
server_recv_port: str = ""

gpu_register_port: str = ""

# cache config
cache_config: CacheConfig = field(default_factory=CacheConfig)

Expand All @@ -35,6 +37,8 @@ class FlexKVConfig:
def __post_init__(self):
if self.server_recv_port == "":
self.server_recv_port = GLOBAL_CONFIG_FROM_ENV.server_recv_port
if self.gpu_register_port == "":
self.gpu_register_port = self.server_recv_port + "_gpu_register"
update_default_config_from_user_config(self.model_config, self.cache_config, self.user_config)

@classmethod
Expand Down Expand Up @@ -68,3 +72,45 @@ def post_init_from_vllm_config(
self.model_config.dp_size = vllm_config.parallel_config.data_parallel_size

self.__post_init__()

def post_init_from_sglang_config(
self,
sglang_config,
tp_size: int,
page_size: int,
):
"""
Initialize FlexKVConfig fields from sglang config.
Args:
sglang_config: sglang.srt.configs.model_config.ModelConfig-like object
tp_size: tensor parallel size used by sglang
page_size: KV block size (tokens per block) used by sglang
"""
# cache config
self.cache_config.tokens_per_block = int(page_size)

self.model_config.num_layers = int(getattr(sglang_config, "num_hidden_layers", 0))

if hasattr(sglang_config, "get_num_kv_heads"):
try:
self.model_config.num_kv_heads = int(sglang_config.get_num_kv_heads(tp_size))
except Exception:
self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0))
else:
self.model_config.num_kv_heads = int(getattr(sglang_config, "num_key_value_heads", 0))
self.model_config.head_size = int(getattr(sglang_config, "head_dim", 0))

self.model_config.dtype = getattr(sglang_config, "dtype", torch.bfloat16)

attn_arch = getattr(sglang_config, "attention_arch", None)
use_mla = False
if hasattr(attn_arch, "name"):
use_mla = (attn_arch.name.upper() == "MLA")
elif isinstance(attn_arch, str):
use_mla = (attn_arch.upper() == "MLA")
self.model_config.use_mla = use_mla

self.model_config.tp_size = int(tp_size)
self.model_config.dp_size = int(getattr(sglang_config, "dp_size", 1))

self.__post_init__()
8 changes: 6 additions & 2 deletions flexkv/kvmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __init__(self,
model_config: ModelConfig,
cache_config: CacheConfig,
dp_client_id: int = 0,
server_recv_port: str = ""):
server_recv_port: str = "",
gpu_register_port: str = ""):
flexkv_logger.info(f"{model_config = }")
flexkv_logger.info(f"{cache_config = }")
flexkv_logger.info(f"{GLOBAL_CONFIG_FROM_ENV = }")
Expand All @@ -42,7 +43,10 @@ def __init__(self,
self.server_recv_port = server_recv_port
else:
self.server_recv_port = GLOBAL_CONFIG_FROM_ENV.server_recv_port
self.gpu_register_port = self.server_recv_port + "_gpu_register"
if gpu_register_port != "":
self.gpu_register_port = gpu_register_port
else:
self.gpu_register_port = self.server_recv_port + "_gpu_register"

self.server_client_mode = model_config.dp_size > 1 or GLOBAL_CONFIG_FROM_ENV.server_client_mode
self.dp_client_id = dp_client_id
Expand Down