|
75 | 75 | from vllm.model_executor.models.utils import sequence_parallel_chunk |
76 | 76 | from vllm.platforms import current_platform |
77 | 77 | from vllm.sequence import IntermediateTensors |
78 | | -from vllm.utils import cdiv, direct_register_custom_op |
| 78 | +from vllm.utils import direct_register_custom_op |
79 | 79 | from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits |
80 | 80 | from vllm.v1.attention.backends.mla.indexer import ( |
81 | 81 | DeepseekV32IndexerBackend, |
@@ -483,69 +483,6 @@ def get_attn_backend(self) -> AttentionBackend: |
483 | 483 | return DeepseekV32IndexerBackend |
484 | 484 |
|
485 | 485 |
|
486 | | -@torch.inference_mode() |
487 | | -def cp_gather_indexer_k_quant_cache( |
488 | | - kv_cache, # [num_blocks, block_size, head_dim + 1] |
489 | | - dst_value, # [cu_seq_lens[-1], head_dim] |
490 | | - dst_scale, # [cu_seq_lens[-1], 4] |
491 | | - block_table, # [batch_size, num_blocks] |
492 | | - cu_seq_lens, # [batch_size + 1, ] |
493 | | - batch_size, |
494 | | -): |
495 | | - num_blocks, block_size, _ = kv_cache.shape |
496 | | - head_dim = dst_value.shape[-1] |
497 | | - kv_cache = kv_cache.view(num_blocks, -1) |
498 | | - |
499 | | - expected_value = [] |
500 | | - expected_scale = [] |
501 | | - for b in range(batch_size): |
502 | | - s = cu_seq_lens[b + 1] - cu_seq_lens[b] |
503 | | - if s == 0: |
504 | | - continue |
505 | | - tot = cdiv(s, block_size) |
506 | | - blocks = block_table[b, :tot] |
507 | | - |
508 | | - value = [] |
509 | | - scale = [] |
510 | | - full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32) |
511 | | - non_remaining_value = kv_cache[ |
512 | | - blocks[full_block], : block_size * head_dim |
513 | | - ].view(-1, head_dim) |
514 | | - non_remaining_scale = kv_cache[ |
515 | | - blocks[full_block], block_size * head_dim : |
516 | | - ].view(-1, 4) |
517 | | - |
518 | | - remaining = s - (tot - 1) * block_size |
519 | | - |
520 | | - value = torch.cat( |
521 | | - [ |
522 | | - non_remaining_value, |
523 | | - kv_cache[blocks[-1], : remaining * head_dim].view(-1, head_dim), |
524 | | - ], |
525 | | - dim=0, |
526 | | - ) |
527 | | - scale = torch.cat( |
528 | | - [ |
529 | | - non_remaining_scale, |
530 | | - kv_cache[ |
531 | | - blocks[-1], |
532 | | - block_size * head_dim : block_size * head_dim + remaining * 4, |
533 | | - ].view(-1, 4), |
534 | | - ], |
535 | | - dim=0, |
536 | | - ) |
537 | | - |
538 | | - expected_value.append(value) |
539 | | - expected_scale.append(scale) |
540 | | - |
541 | | - gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim) |
542 | | - gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4) |
543 | | - gather_value = gather_value.view(torch.float8_e4m3fn) |
544 | | - gather_scale = gather_scale.view(torch.float32) |
545 | | - dst_value.copy_(gather_value) |
546 | | - dst_scale.copy_(gather_scale) |
547 | | - |
548 | | - |
549 | 486 | def sparse_attn_indexer( |
550 | 487 | hidden_states: torch.Tensor, |
551 | 488 | k_cache_prefix: str, |
@@ -605,19 +542,20 @@ def sparse_attn_indexer( |
605 | 542 | dtype=torch.float8_e4m3fn, |
606 | 543 | ) |
607 | 544 | k_scale = torch.empty( |
608 | | - [chunk.total_seq_lens, 1], device=k.device, dtype=torch.float32 |
| 545 | + [chunk.total_seq_lens, 4], |
| 546 | + device=k.device, |
| 547 | + dtype=torch.uint8, |
609 | 548 | ) |
610 | | - cp_gather_indexer_k_quant_cache( |
| 549 | + ops.cp_gather_indexer_k_quant_cache( |
611 | 550 | kv_cache, |
612 | 551 | k_fp8, |
613 | 552 | k_scale, |
614 | 553 | chunk.block_table, |
615 | 554 | chunk.cu_seq_lens, |
616 | | - chunk.num_reqs, |
617 | 555 | ) |
618 | 556 | logits = fp8_mqa_logits( |
619 | 557 | q_fp8[chunk.token_start : chunk.token_end], |
620 | | - (k_fp8, k_scale), |
| 558 | + (k_fp8, k_scale.view(torch.float32)), |
621 | 559 | weights[chunk.token_start : chunk.token_end], |
622 | 560 | chunk.cu_seqlen_ks, |
623 | 561 | chunk.cu_seqlen_ke, |
|
0 commit comments