Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions ucm/sparse/kvcomp/ham_dist/paged_ham_dist_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@
{ \
__VA_ARGS__ \
} \
} else if ((val) == 2) { \
constexpr int NumKVHead = 2; \
{ \
__VA_ARGS__ \
} \
} else if ((val) == 4) { \
constexpr int NumKVHead = 4; \
{ \
__VA_ARGS__ \
} \
} else if ((val) == 8) { \
constexpr int NumKVHead = 8; \
{ \
__VA_ARGS__ \
} \
} else { \
LOG(FATAL) << "NumKVHead is not support"; \
} \
Expand Down Expand Up @@ -295,7 +310,7 @@ torch::Tensor HammingScoreContiCUDA(torch::Tensor& key_codes,
bool is_block_mode = block_table_opt.has_value();

int32_t bsz = query_code.size(0);
int32_t num_kv_head = is_block_mode ? key_codes.size(1) : key_codes.size(2);
int32_t num_kv_head = key_codes.size(2);
int32_t num_chunk = key_codes.size(3);

int32_t num_head = query_code.size(2);
Expand All @@ -309,7 +324,7 @@ torch::Tensor HammingScoreContiCUDA(torch::Tensor& key_codes,

if(is_block_mode) {
int32_t num_blocks = key_codes.size(0);
int32_t block_size = key_codes.size(2);
int32_t block_size = key_codes.size(1);
const auto& block_table = block_table_opt.value(); // *block_table_opt;
int32_t max_num_block_per_seq = block_table.size(1);
TORCH_CHECK(bsz == block_table.size(0), "batch size mismatch between query_code and block_table");
Expand Down
32 changes: 19 additions & 13 deletions ucm/sparse/kvcomp/hamming_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ def cuda_hamming_topk(
topk_token,
sink_token,
recent_token,
is_mla,
):
q_hash = q_hash.view(torch.int32)
k_hash = k_hash.view(torch.int32)
assert k_hash.shape[1] == 1
# assert k_hash.shape[1] == 1
# assert k_hash.shape[-1] == 18 and q_hash.shape[-1] == 18
block_size = k_hash.shape[2]
block_size = k_hash.shape[1]
assert topk_token % block_size == 0
assert recent_token > 0 and topk_token > (sink_token + recent_token)
max_seqlen = block_size * block_table.shape[1]

output = hamming.hamming_score(
k_hash,
q_hash,
Expand All @@ -40,17 +40,23 @@ def cuda_hamming_topk(
recent_token,
)

block_output = torch.min(
output.view(output.shape[0], output.shape[-1] // block_size, block_size), dim=-1
)[0]
k_blocks = topk_token // block_size
B, Hk, S = output.shape
num_blocks = S // block_size

ind = torch.topk(block_output, k=(topk_token // block_size), dim=-1, largest=False)[
1
]
ind = torch.sort(ind, dim=-1, descending=False)[0]
# block_output: [B, Hk, num_blocks]
block_output = output.view(B, Hk, num_blocks, block_size).amin(dim=-1)

new_block_table = torch.gather(block_table, dim=-1, index=ind)
return new_block_table
if is_mla:
block_score = block_output[:, 0, :]
ind = torch.topk(block_score, k=k_blocks, dim=-1, largest=False).indices
ind = ind.sort(dim=-1).values
return torch.gather(block_table, dim=-1, index=ind)

block_score = block_output.amin(dim=1) # [B, num_blocks]
ind = torch.topk(block_score, k=k_blocks, dim=-1, largest=False).indices
ind = ind.sort(dim=-1).values
return torch.gather(block_table, dim=-1, index=ind)


def fake_hamming_topk(
Expand All @@ -66,7 +72,7 @@ def fake_hamming_topk(
k_hash = k_hash.view(torch.int32)
assert k_hash.shape[1] == 1
assert k_hash.shape[-1] == 18 and q_hash.shape[-1] == 18
block_size = k_hash.shape[2]
block_size = k_hash.shape[1]
assert topk_token % block_size == 0
assert recent_token > 0 and topk_token > (sink_token + recent_token)
max_seqlen = block_size * block_table.shape[1]
Expand Down
8 changes: 6 additions & 2 deletions ucm/sparse/kvcomp/kvcomp_hbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def kvcomp_config_path_for_model(vllm_config) -> str:
rel = "ucm/sparse/kvcomp/configs/kvcomp_deepseek_r1_awq_config.json"
elif "qwen3" in model and "32b" in model:
rel = "ucm/sparse/kvcomp/configs/kvcomp_qwen3_32B_config.json"
elif "deepseek" in model and "v2" in model:
rel = "ucm/sparse/kvcomp/configs/kvcomp_deepseek_v2_lite_config.json"
else:
raise ValueError(f"[KvCompOnDevice] Unsupported model for kvcomp: {model}")

Expand Down Expand Up @@ -270,12 +272,13 @@ def attention_begin(
topk_token = self.hash_topk_tokens
block_table = cuda_hamming_topk(
q_hash.unsqueeze(1),
k_hash.unsqueeze(1),
k_hash.unsqueeze(2),
attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens,
topk_token=topk_token,
sink_token=64,
recent_token=512,
is_mla=self.is_mla,
)
attn_metadata.decode.topk_block_table = block_table

Expand Down Expand Up @@ -324,12 +327,13 @@ def attention_begin(
)
block_table_decode = cuda_hamming_topk(
q_hash.unsqueeze(1),
k_hash.unsqueeze(1),
k_hash,
block_table_decode,
seq_len_decode,
topk_token=topk_token,
sink_token=64,
recent_token=512,
is_mla=self.is_mla,
)
# update topk_block_table
topk = block_table_decode.shape[1]
Expand Down
2 changes: 1 addition & 1 deletion ucm/sparse/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def ensure_ucm_sparse_initialized(

# Check if UCM sparse is enabled
ucm_config = Config(vllm_config.kv_transfer_config)
ucm_sparse_config = ucm_config.get_config().get("ucm_sparse_method")
ucm_sparse_config = ucm_config.get_config().get("ucm_sparse_config")
if not ucm_sparse_config:
return

Expand Down