Skip to content

Commit a0d5179

Browse files
zyongyebogdanminko
authored andcommitted
[Deepseek-V3.2][Kernel] Integrate cuda indexer k cache gather (vllm-project#26456)
Signed-off-by: Yongye Zhu <[email protected]> Signed-off-by: bogdan01m <[email protected]>
1 parent efdef57 commit a0d5179

1 file changed

Lines changed: 6 additions & 68 deletions

File tree

vllm/model_executor/models/deepseek_v2.py

Lines changed: 6 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
from vllm.model_executor.models.utils import sequence_parallel_chunk
7676
from vllm.platforms import current_platform
7777
from vllm.sequence import IntermediateTensors
78-
from vllm.utils import cdiv, direct_register_custom_op
78+
from vllm.utils import direct_register_custom_op
7979
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
8080
from vllm.v1.attention.backends.mla.indexer import (
8181
DeepseekV32IndexerBackend,
@@ -483,69 +483,6 @@ def get_attn_backend(self) -> AttentionBackend:
483483
return DeepseekV32IndexerBackend
484484

485485

486-
@torch.inference_mode()
487-
def cp_gather_indexer_k_quant_cache(
488-
kv_cache, # [num_blocks, block_size, head_dim + 1]
489-
dst_value, # [cu_seq_lens[-1], head_dim]
490-
dst_scale, # [cu_seq_lens[-1], 4]
491-
block_table, # [batch_size, num_blocks]
492-
cu_seq_lens, # [batch_size + 1, ]
493-
batch_size,
494-
):
495-
num_blocks, block_size, _ = kv_cache.shape
496-
head_dim = dst_value.shape[-1]
497-
kv_cache = kv_cache.view(num_blocks, -1)
498-
499-
expected_value = []
500-
expected_scale = []
501-
for b in range(batch_size):
502-
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
503-
if s == 0:
504-
continue
505-
tot = cdiv(s, block_size)
506-
blocks = block_table[b, :tot]
507-
508-
value = []
509-
scale = []
510-
full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32)
511-
non_remaining_value = kv_cache[
512-
blocks[full_block], : block_size * head_dim
513-
].view(-1, head_dim)
514-
non_remaining_scale = kv_cache[
515-
blocks[full_block], block_size * head_dim :
516-
].view(-1, 4)
517-
518-
remaining = s - (tot - 1) * block_size
519-
520-
value = torch.cat(
521-
[
522-
non_remaining_value,
523-
kv_cache[blocks[-1], : remaining * head_dim].view(-1, head_dim),
524-
],
525-
dim=0,
526-
)
527-
scale = torch.cat(
528-
[
529-
non_remaining_scale,
530-
kv_cache[
531-
blocks[-1],
532-
block_size * head_dim : block_size * head_dim + remaining * 4,
533-
].view(-1, 4),
534-
],
535-
dim=0,
536-
)
537-
538-
expected_value.append(value)
539-
expected_scale.append(scale)
540-
541-
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
542-
gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
543-
gather_value = gather_value.view(torch.float8_e4m3fn)
544-
gather_scale = gather_scale.view(torch.float32)
545-
dst_value.copy_(gather_value)
546-
dst_scale.copy_(gather_scale)
547-
548-
549486
def sparse_attn_indexer(
550487
hidden_states: torch.Tensor,
551488
k_cache_prefix: str,
@@ -605,19 +542,20 @@ def sparse_attn_indexer(
605542
dtype=torch.float8_e4m3fn,
606543
)
607544
k_scale = torch.empty(
608-
[chunk.total_seq_lens, 1], device=k.device, dtype=torch.float32
545+
[chunk.total_seq_lens, 4],
546+
device=k.device,
547+
dtype=torch.uint8,
609548
)
610-
cp_gather_indexer_k_quant_cache(
549+
ops.cp_gather_indexer_k_quant_cache(
611550
kv_cache,
612551
k_fp8,
613552
k_scale,
614553
chunk.block_table,
615554
chunk.cu_seq_lens,
616-
chunk.num_reqs,
617555
)
618556
logits = fp8_mqa_logits(
619557
q_fp8[chunk.token_start : chunk.token_end],
620-
(k_fp8, k_scale),
558+
(k_fp8, k_scale.view(torch.float32)),
621559
weights[chunk.token_start : chunk.token_end],
622560
chunk.cu_seqlen_ks,
623561
chunk.cu_seqlen_ke,

0 commit comments

Comments
 (0)