|
2 | 2 | #include <ATen/cuda/CUDAContext.h> |
3 | 3 | #include <c10/cuda/CUDAGuard.h> |
4 | 4 |
|
| 5 | +#include "cuda_utils.h" |
5 | 6 | #include "cuda_compat.h" |
6 | 7 | #include "dispatch_utils.h" |
7 | 8 |
|
@@ -570,3 +571,161 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, |
570 | 571 | TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); |
571 | 572 | } |
572 | 573 | } |
| 574 | + |
| 575 | +namespace vllm { |
| 576 | + |
| 577 | +// grid is launched with dimensions (batch, num_splits) |
| 578 | +template <typename scalar_t> |
| 579 | +__global__ void gather_cache( |
| 580 | + const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, |
| 581 | + // ENTRIES...] |
| 582 | + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] |
| 583 | + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] |
| 584 | + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] |
| 585 | + const int32_t block_size, const int32_t entry_size, |
| 586 | + const int64_t block_table_stride, const int64_t cache_block_stride, |
| 587 | + const int64_t cache_entry_stride, const int64_t dst_entry_stride, |
| 588 | + const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per |
| 589 | + // batch |
| 590 | + |
| 591 | + const int64_t bid = blockIdx.x; // Batch ID |
| 592 | + const int32_t num_splits = gridDim.y; |
| 593 | + const int32_t split = blockIdx.y; |
| 594 | + const int32_t seq_start = cu_seq_lens[bid]; |
| 595 | + const int32_t seq_end = cu_seq_lens[bid + 1]; |
| 596 | + const int32_t seq_len = seq_end - seq_start; |
| 597 | + const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size); |
| 598 | + const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits); |
| 599 | + |
| 600 | + const int32_t split_start = split * split_blocks; |
| 601 | + const int32_t split_end = min((split + 1) * split_blocks, tot_blocks); |
| 602 | + |
| 603 | + const bool is_active_split = (split_start < tot_blocks); |
| 604 | + const bool is_last_split = (split_end == tot_blocks); |
| 605 | + |
| 606 | + if (!is_active_split) return; |
| 607 | + |
| 608 | + int32_t full_blocks_end = split_end; |
| 609 | + int32_t partial_block_size = 0; |
| 610 | + |
| 611 | + // Adjust the pointer for the block_table for this batch. |
| 612 | + // If seq_starts is provided, compute an offset based on (seq_starts[bid] / |
| 613 | + // page_size) |
| 614 | + const int32_t batch_offset = bid * block_table_stride; |
| 615 | + int32_t offset = 0; |
| 616 | + if (seq_starts != nullptr) { |
| 617 | + offset = seq_starts[bid] / block_size; |
| 618 | + } |
| 619 | + const int32_t* batch_block_table = block_table + batch_offset + offset; |
| 620 | + |
| 621 | + // Adjust dst pointer based on the cumulative sequence lengths. |
| 622 | + dst += seq_start * dst_entry_stride; |
| 623 | + |
| 624 | + if (is_last_split) { |
| 625 | + partial_block_size = seq_len % block_size; |
| 626 | + if (partial_block_size) full_blocks_end -= 1; |
| 627 | + } |
| 628 | + |
| 629 | + auto copy_entry = [&](const scalar_t* __restrict__ _src, |
| 630 | + scalar_t* __restrict__ _dst) { |
| 631 | + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) |
| 632 | + _dst[i] = _src[i]; |
| 633 | + }; |
| 634 | + |
| 635 | + for (int pid = split_start; pid < full_blocks_end; ++pid) { |
| 636 | + auto block_id = batch_block_table[pid]; |
| 637 | + auto block_start_ptr = src_cache + block_id * cache_block_stride; |
| 638 | + auto block_dst_ptr = dst + pid * block_size * dst_entry_stride; |
| 639 | + for (int eid = 0; eid < block_size; ++eid) { |
| 640 | + copy_entry(block_start_ptr + eid * cache_entry_stride, |
| 641 | + block_dst_ptr + eid * dst_entry_stride); |
| 642 | + } |
| 643 | + } |
| 644 | + |
| 645 | + if (partial_block_size) { |
| 646 | + auto block_id = batch_block_table[full_blocks_end]; |
| 647 | + auto block_start_ptr = src_cache + block_id * cache_block_stride; |
| 648 | + auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride; |
| 649 | + for (int eid = 0; eid < partial_block_size; ++eid) { |
| 650 | + copy_entry(block_start_ptr + eid * cache_entry_stride, |
| 651 | + block_dst_ptr + eid * dst_entry_stride); |
| 652 | + } |
| 653 | + } |
| 654 | +} |
| 655 | + |
| 656 | +} // namespace vllm |
| 657 | + |
| 658 | +// Macro to dispatch the kernel based on the data type. |
| 659 | +#define CALL_GATHER_CACHE(CPY_DTYPE) \ |
| 660 | + vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \ |
| 661 | + reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \ |
| 662 | + reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \ |
| 663 | + block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \ |
| 664 | + block_size, entry_size, block_table_stride, cache_block_stride, \ |
| 665 | + cache_entry_stride, dst_entry_stride, seq_starts_ptr); |
| 666 | + |
| 667 | +// Gather sequences from the cache into the destination tensor. |
| 668 | +// - cu_seq_lens contains the cumulative sequence lengths for each batch |
| 669 | +// - block_table contains the cache block indices for each sequence |
| 670 | +// - Optionally, seq_starts (if provided) offsets the starting block index by |
| 671 | +// (seq_starts[bid] / page_size) |
| 672 | +void gather_cache( |
| 673 | + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] |
| 674 | + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] |
| 675 | + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] |
| 676 | + torch::Tensor const& cu_seq_lens, // [BATCH+1] |
| 677 | + int64_t batch_size, |
| 678 | + std::optional<torch::Tensor> seq_starts = std::nullopt) { |
| 679 | + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); |
| 680 | + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 681 | + |
| 682 | + int32_t block_size = src_cache.size(1); |
| 683 | + int32_t entry_size = src_cache.flatten(2, -1).size(2); |
| 684 | + |
| 685 | + TORCH_CHECK(block_table.dtype() == torch::kInt32, |
| 686 | + "block_table must be int32"); |
| 687 | + TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32, |
| 688 | + "cu_seq_lens must be int32"); |
| 689 | + if (seq_starts.has_value()) { |
| 690 | + TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, |
| 691 | + "seq_starts must be int32"); |
| 692 | + } |
| 693 | + |
| 694 | + TORCH_CHECK(src_cache.device() == dst.device(), |
| 695 | + "src_cache and dst must be on the same device"); |
| 696 | + TORCH_CHECK(src_cache.device() == block_table.device(), |
| 697 | + "src_cache and block_table must be on the same device"); |
| 698 | + TORCH_CHECK(src_cache.device() == cu_seq_lens.device(), |
| 699 | + "src_cache and cu_seq_lens must be on the same device"); |
| 700 | + if (seq_starts.has_value()) { |
| 701 | + TORCH_CHECK(src_cache.device() == seq_starts.value().device(), |
| 702 | + "src_cache and seq_starts must be on the same device"); |
| 703 | + } |
| 704 | + |
| 705 | + int64_t block_table_stride = block_table.stride(0); |
| 706 | + int64_t cache_block_stride = src_cache.stride(0); |
| 707 | + int64_t cache_entry_stride = src_cache.stride(1); |
| 708 | + int64_t dst_entry_stride = dst.stride(0); |
| 709 | + |
| 710 | + // Decide on the number of splits based on the batch size. |
| 711 | + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; |
| 712 | + dim3 grid(batch_size, num_splits); |
| 713 | + dim3 block(1024); |
| 714 | + |
| 715 | + TORCH_CHECK(src_cache.dtype() == dst.dtype(), |
| 716 | + "src_cache and dst must have the same dtype"); |
| 717 | + |
| 718 | + const int dtype_bits = src_cache.element_size() * 8; |
| 719 | + const int32_t* seq_starts_ptr = |
| 720 | + seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr; |
| 721 | + |
| 722 | + if (dtype_bits == 32) { |
| 723 | + CALL_GATHER_CACHE(uint32_t); |
| 724 | + } else if (dtype_bits == 16) { |
| 725 | + CALL_GATHER_CACHE(uint16_t); |
| 726 | + } else if (dtype_bits == 8) { |
| 727 | + CALL_GATHER_CACHE(uint8_t); |
| 728 | + } else { |
| 729 | + TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); |
| 730 | + } |
| 731 | +} |
0 commit comments