Skip to content

Commit 855ecfb

Browse files
LucasWilkinsonpathornsimon-motlrmchlsmth
authored andcommitted
[Attention] MLA with chunked prefill (vllm-project#12639)
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> Co-authored-by: Patrick Horn <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]>
1 parent feaf88e commit 855ecfb

File tree

18 files changed

+1910
-1275
lines changed

18 files changed

+1910
-1275
lines changed

csrc/cache.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,10 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
3939
// Just for unittest
4040
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
4141
const double scale, const std::string& kv_cache_dtype);
42+
43+
void gather_cache(
44+
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
45+
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
46+
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
47+
torch::Tensor const& cu_seq_lens, // [BATCH+1]
48+
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);

csrc/cache_kernels.cu

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <ATen/cuda/CUDAContext.h>
33
#include <c10/cuda/CUDAGuard.h>
44

5+
#include "cuda_utils.h"
56
#include "cuda_compat.h"
67
#include "dispatch_utils.h"
78

@@ -570,3 +571,161 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
570571
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
571572
}
572573
}
574+
575+
namespace vllm {
576+
577+
// grid is launched with dimensions (batch, num_splits)
578+
template <typename scalar_t>
579+
__global__ void gather_cache(
580+
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
581+
// ENTRIES...]
582+
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
583+
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
584+
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
585+
const int32_t block_size, const int32_t entry_size,
586+
const int64_t block_table_stride, const int64_t cache_block_stride,
587+
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
588+
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
589+
// batch
590+
591+
const int64_t bid = blockIdx.x; // Batch ID
592+
const int32_t num_splits = gridDim.y;
593+
const int32_t split = blockIdx.y;
594+
const int32_t seq_start = cu_seq_lens[bid];
595+
const int32_t seq_end = cu_seq_lens[bid + 1];
596+
const int32_t seq_len = seq_end - seq_start;
597+
const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
598+
const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);
599+
600+
const int32_t split_start = split * split_blocks;
601+
const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);
602+
603+
const bool is_active_split = (split_start < tot_blocks);
604+
const bool is_last_split = (split_end == tot_blocks);
605+
606+
if (!is_active_split) return;
607+
608+
int32_t full_blocks_end = split_end;
609+
int32_t partial_block_size = 0;
610+
611+
// Adjust the pointer for the block_table for this batch.
612+
// If seq_starts is provided, compute an offset based on (seq_starts[bid] /
613+
// page_size)
614+
const int32_t batch_offset = bid * block_table_stride;
615+
int32_t offset = 0;
616+
if (seq_starts != nullptr) {
617+
offset = seq_starts[bid] / block_size;
618+
}
619+
const int32_t* batch_block_table = block_table + batch_offset + offset;
620+
621+
// Adjust dst pointer based on the cumulative sequence lengths.
622+
dst += seq_start * dst_entry_stride;
623+
624+
if (is_last_split) {
625+
partial_block_size = seq_len % block_size;
626+
if (partial_block_size) full_blocks_end -= 1;
627+
}
628+
629+
auto copy_entry = [&](const scalar_t* __restrict__ _src,
630+
scalar_t* __restrict__ _dst) {
631+
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
632+
_dst[i] = _src[i];
633+
};
634+
635+
for (int pid = split_start; pid < full_blocks_end; ++pid) {
636+
auto block_id = batch_block_table[pid];
637+
auto block_start_ptr = src_cache + block_id * cache_block_stride;
638+
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
639+
for (int eid = 0; eid < block_size; ++eid) {
640+
copy_entry(block_start_ptr + eid * cache_entry_stride,
641+
block_dst_ptr + eid * dst_entry_stride);
642+
}
643+
}
644+
645+
if (partial_block_size) {
646+
auto block_id = batch_block_table[full_blocks_end];
647+
auto block_start_ptr = src_cache + block_id * cache_block_stride;
648+
auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
649+
for (int eid = 0; eid < partial_block_size; ++eid) {
650+
copy_entry(block_start_ptr + eid * cache_entry_stride,
651+
block_dst_ptr + eid * dst_entry_stride);
652+
}
653+
}
654+
}
655+
656+
} // namespace vllm
657+
658+
// Macro to dispatch the kernel based on the data type.
659+
#define CALL_GATHER_CACHE(CPY_DTYPE) \
660+
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
661+
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
662+
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
663+
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
664+
block_size, entry_size, block_table_stride, cache_block_stride, \
665+
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
666+
667+
// Gather sequences from the cache into the destination tensor.
668+
// - cu_seq_lens contains the cumulative sequence lengths for each batch
669+
// - block_table contains the cache block indices for each sequence
670+
// - Optionally, seq_starts (if provided) offsets the starting block index by
671+
// (seq_starts[bid] / page_size)
672+
void gather_cache(
673+
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
674+
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
675+
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
676+
torch::Tensor const& cu_seq_lens, // [BATCH+1]
677+
int64_t batch_size,
678+
std::optional<torch::Tensor> seq_starts = std::nullopt) {
679+
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
680+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
681+
682+
int32_t block_size = src_cache.size(1);
683+
int32_t entry_size = src_cache.flatten(2, -1).size(2);
684+
685+
TORCH_CHECK(block_table.dtype() == torch::kInt32,
686+
"block_table must be int32");
687+
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
688+
"cu_seq_lens must be int32");
689+
if (seq_starts.has_value()) {
690+
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
691+
"seq_starts must be int32");
692+
}
693+
694+
TORCH_CHECK(src_cache.device() == dst.device(),
695+
"src_cache and dst must be on the same device");
696+
TORCH_CHECK(src_cache.device() == block_table.device(),
697+
"src_cache and block_table must be on the same device");
698+
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
699+
"src_cache and cu_seq_lens must be on the same device");
700+
if (seq_starts.has_value()) {
701+
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
702+
"src_cache and seq_starts must be on the same device");
703+
}
704+
705+
int64_t block_table_stride = block_table.stride(0);
706+
int64_t cache_block_stride = src_cache.stride(0);
707+
int64_t cache_entry_stride = src_cache.stride(1);
708+
int64_t dst_entry_stride = dst.stride(0);
709+
710+
// Decide on the number of splits based on the batch size.
711+
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
712+
dim3 grid(batch_size, num_splits);
713+
dim3 block(1024);
714+
715+
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
716+
"src_cache and dst must have the same dtype");
717+
718+
const int dtype_bits = src_cache.element_size() * 8;
719+
const int32_t* seq_starts_ptr =
720+
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
721+
722+
if (dtype_bits == 32) {
723+
CALL_GATHER_CACHE(uint32_t);
724+
} else if (dtype_bits == 16) {
725+
CALL_GATHER_CACHE(uint16_t);
726+
} else if (dtype_bits == 8) {
727+
CALL_GATHER_CACHE(uint8_t);
728+
} else {
729+
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
730+
}
731+
}

csrc/core/math.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
77
if (num <= 1) return num;
88
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
99
}
10-
11-
template <typename T>
12-
inline constexpr std::enable_if_t<std::is_integral_v<T>, T> ceil_div(T a, T b) {
13-
return (a + b - 1) / b;
14-
}

csrc/cuda_utils.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
#include <stdio.h>
44

5-
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
6-
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
7-
#define DEVICE_INLINE __forceinline__ __device__
8-
#define HOST_INLINE __forceinline__ __host__
5+
#if defined(__HIPCC__)
6+
#define HOST_DEVICE_INLINE __host__ __device__
7+
#define DEVICE_INLINE __device__
8+
#define HOST_INLINE __host__
9+
#elif defined(__CUDACC__) || defined(_NVHPC_CUDA)
10+
#define HOST_DEVICE_INLINE __host__ __device__ __forceinline__
11+
#define DEVICE_INLINE __device__ __forceinline__
12+
#define HOST_INLINE __host__ __forceinline__
913
#else
1014
#define HOST_DEVICE_INLINE inline
1115
#define DEVICE_INLINE inline
@@ -25,3 +29,13 @@
2529
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
2630

2731
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
32+
33+
namespace cuda_utils {
34+
35+
template <typename T>
36+
HOST_DEVICE_INLINE constexpr std::enable_if_t<std::is_integral_v<T>, T>
37+
ceil_div(T a, T b) {
38+
return (a + b - 1) / b;
39+
}
40+
41+
}; // namespace cuda_utils

csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <cudaTypedefs.h>
22
#include "c3x/scaled_mm_kernels.hpp"
33

4-
#include "core/math.hpp"
4+
#include "cuda_utils.h"
55

66
/*
77
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
@@ -33,7 +33,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
3333
auto make_group_shape = [](torch::Tensor const& x,
3434
torch::Tensor const& s) -> GroupShape {
3535
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
36-
return {ceil_div(x.size(0), s.size(0)), ceil_div(x.size(1), s.size(1))};
36+
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
37+
cuda_utils::ceil_div(x.size(1), s.size(1))};
3738
};
3839

3940
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);

csrc/torch_bindings.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
493493
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
494494
"str kv_cache_dtype) -> ()");
495495
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
496+
497+
// Gather cache blocks from src_cache to dst.
498+
cache_ops.def(
499+
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
500+
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
501+
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
496502
}
497503

498504
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {

tests/kernels/test_cache.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -682,8 +682,6 @@ def test_swap_blocks_mla(
682682
torch.ops._C_cache_ops.swap_blocks,
683683
(src_cache, dst_cache, block_mapping_tensor),
684684
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
685-
cond=(kv_lora_rank == KV_LORA_RANKS[0]
686-
and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]),
687685
)
688686

689687
ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
@@ -694,3 +692,76 @@ def test_swap_blocks_mla(
694692
dst_cache[dst].cpu(),
695693
msg=f"Block {src} from src should have been swapped to block "
696694
f"{dst} in dst_cache.")
695+
696+
697+
@pytest.mark.parametrize("kv_lora_rank", [512])
698+
@pytest.mark.parametrize("qk_rope_head_dim", [64])
699+
@pytest.mark.parametrize("block_size", [16])
700+
@pytest.mark.parametrize("num_blocks", [1024])
701+
@pytest.mark.parametrize("max_seq_len", [512])
702+
@pytest.mark.parametrize("batch_size", [8])
703+
@pytest.mark.parametrize("dtype", [torch.float32])
704+
@pytest.mark.parametrize("kv_cache_dtype",
705+
["auto"]) # You can also test "fp8" if needed.
706+
@pytest.mark.parametrize("align_cache", [True, False])
707+
@pytest.mark.parametrize("device", CUDA_DEVICES)
708+
@torch.inference_mode()
709+
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
710+
num_blocks, max_seq_len, batch_size, dtype,
711+
kv_cache_dtype, align_cache, device):
712+
entry_size = kv_lora_rank + qk_rope_head_dim
713+
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
714+
kv_cache_dtype, device, align_cache)
715+
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
716+
717+
seq_len_tensor = torch.randint(0,
718+
max_seq_len + 1, (batch_size, ),
719+
device=device)
720+
721+
total_tokens = seq_len_tensor.sum()
722+
cu_seq_lens = torch.empty((batch_size + 1),
723+
dtype=torch.int32,
724+
device=device)
725+
cu_seq_lens[0] = 0
726+
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
727+
print("seq_len_tensor", seq_len_tensor)
728+
729+
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
730+
block_table = torch.empty((batch_size, num_blocks),
731+
dtype=torch.int32,
732+
device=device)
733+
734+
for b in range(batch_size):
735+
perm = torch.randperm(num_blocks, device=device)
736+
block_table[b, :] = perm
737+
738+
dst = torch.zeros((total_tokens, entry_size),
739+
dtype=src_cache.dtype,
740+
device=device)
741+
742+
expected_batches = []
743+
for b in range(batch_size):
744+
s = seq_len_tensor[b]
745+
if s == 0:
746+
continue
747+
tot = tot_blocks_tensor[b]
748+
blocks = block_table[b, :tot].tolist()
749+
750+
gathered_rows = []
751+
for i in range(tot - 1):
752+
gathered_rows.append(src_cache[blocks[i]])
753+
remaining = s - (tot - 1) * block_size
754+
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
755+
756+
batch_expected = torch.cat(gathered_rows, dim=0)
757+
expected_batches.append(batch_expected)
758+
expected = torch.cat(expected_batches, dim=0)
759+
760+
opcheck(
761+
torch.ops._C_cache_ops.gather_cache,
762+
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
763+
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
764+
)
765+
766+
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
767+
torch.testing.assert_close(dst, expected)

vllm/_custom_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,16 @@ def convert_fp8(output: torch.Tensor,
10991099
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
11001100

11011101

1102+
def gather_cache(src_cache: torch.Tensor,
1103+
dst: torch.Tensor,
1104+
block_table: torch.Tensor,
1105+
cu_seq_lens: torch.Tensor,
1106+
batch_size: int,
1107+
seq_starts: Optional[torch.Tensor] = None) -> None:
1108+
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
1109+
cu_seq_lens, batch_size, seq_starts)
1110+
1111+
11021112
def get_device_attribute(attribute: int, device: int) -> int:
11031113
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
11041114

vllm/attention/__init__.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,12 @@
44
AttentionMetadata,
55
AttentionMetadataBuilder,
66
AttentionState, AttentionType)
7+
from vllm.attention.backends.utils import get_flash_attn_version
78
from vllm.attention.layer import Attention
89
from vllm.attention.selector import get_attn_backend
910

1011
__all__ = [
11-
"Attention",
12-
"AttentionBackend",
13-
"AttentionMetadata",
14-
"AttentionType",
15-
"AttentionMetadataBuilder",
16-
"Attention",
17-
"AttentionState",
18-
"get_attn_backend",
12+
"Attention", "AttentionBackend", "AttentionMetadata", "AttentionType",
13+
"AttentionMetadataBuilder", "Attention", "AttentionState",
14+
"get_attn_backend", "get_flash_attn_version"
1915
]

0 commit comments

Comments
 (0)