Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions flexkv/integration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,13 @@ def post_init_from_trt_config(
self.model_config.head_size = hf_config.kv_lora_rank + hf_config.qk_rope_head_dim
self.model_config.num_kv_heads = 1
else:
self.model_config.head_size = hf_config.hidden_size // hf_config.num_key_value_heads // self.model_config.tp_size
self.model_config.num_kv_heads = hf_config.num_key_value_heads
if hasattr(hf_config, 'num_key_value_heads'):
assert hf_config.num_attention_heads != hf_config.num_key_value_heads, f"{hf_config.num_attention_heads=}, {hf_config.num_key_value_heads=}"
self.model_config.head_size = hf_config.head_dim
self.model_config.num_kv_heads = hf_config.num_key_value_heads
else:
self.model_config.head_size = hf_config.hidden_size // hf_config.num_attention_heads
self.model_config.num_kv_heads = hf_config.num_attention_heads

except Exception as e:
flexkv_logger.error(f"Failed to load config from {model_path}: {e}")
Expand Down
7 changes: 1 addition & 6 deletions flexkv/integration/tensorrt_llm/trtllm_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,6 @@ def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
flexkv_logger.debug(f"self.tp_client.device_id (from init): {self.tp_client.device_id}")

# Get physical GPU ID (in case CUDA_VISIBLE_DEVICES is set)
import os
cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None)
if cuda_visible_devices:
# Map logical ID to physical ID
Expand All @@ -505,11 +504,6 @@ def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
if self.flexkv_config.model_config.use_mla:
assert kv_dim == 1, (f"expect kv_dim eqals to 1 when using MLA but get kv_dim={kv_dim}")

assert num_kv_heads * head_size * block_size == kv_cache_tensor.shape[3], \
(f"expect kv cached tensor last dim equals to num_kv_heads*head_size*block_size, " \
f"but get last_dim = {kv_cache_tensor.shape[3]}, " \
f"num_kv_heads = {num_kv_heads}, head_size = {head_size}, block_size = {block_size}")

gpu_blocks = [kv_cache_tensor] # convert to list for flexkv register

gpu_layout = KVCacheLayout(
Expand All @@ -521,6 +515,7 @@ def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
head_size=head_size,
is_mla=self.flexkv_config.model_config.use_mla,
)
flexkv_logger.info(f"gpu_layout: {gpu_layout}")
# Use correct device_id from tensor's actual device
self.tp_client.register_to_server(gpu_blocks, gpu_layout, override_device_id=correct_device_id)
flexkv_logger.info(f"Finish register kv_caches on device {correct_device_id}")
Expand Down