Skip to content

Commit ab31312

Browse files
Xin Lileex404
authored andcommitted
[feat] support bf16 cp_gather_indexer_k_cache kernel
Signed-off-by: Xin Li <[email protected]> Signed-off-by: Xin Li <[email protected]>
1 parent fb17679 commit ab31312

File tree

4 files changed

+263
-45
lines changed

4 files changed

+263
-45
lines changed

csrc/cache.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,18 @@ void indexer_k_quant_and_cache(
7070
torch::Tensor& slot_mapping, // [num_tokens]
7171
int64_t quant_block_size, // quantization block size
7272
const std::string& scale_fmt);
73+
74+
// Extract function to gather quantized K cache
75+
void cp_gather_indexer_k_cache(
76+
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
77+
torch::Tensor& dst_k, // [num_tokens, head_dim]
78+
const torch::Tensor& block_table, // [batch_size, num_blocks]
79+
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
80+
81+
// Extract function to gather quantized K cache
82+
void cp_gather_indexer_k_quant_cache(
83+
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
84+
torch::Tensor& dst_k, // [num_tokens, head_dim]
85+
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
86+
const torch::Tensor& block_table, // [batch_size, num_blocks]
87+
const torch::Tensor& cu_seq_lens); // [batch_size + 1]

csrc/cache_kernels.cu

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,124 @@ __global__ void indexer_k_cache_kernel(
938938
kv_cache[dst_offset + i] = k_val_ptr[i];
939939
}
940940
}
941+
942+
template <int BLOCK_Y_SIZE>
943+
__global__ void cp_gather_indexer_k_quant_cache_kernel(
944+
const char* __restrict__ kv_cache, // [num_blocks, block_size,
945+
// cache_stride]
946+
char* __restrict__ dst_k, // [num_tokens, head_dim]
947+
char* __restrict__ dst_scale, // [num_tokens, head_dim / quant_block_size *
948+
// 4]
949+
const int* __restrict__ block_table, // [batch_size, num_blocks]
950+
const int* __restrict__ cu_seq_lens, // [batch_size + 1]
951+
const int batch_size, // batch size
952+
const int64_t token_stride, // stride for each token in dst_k
953+
const int64_t head_dim, // dimension of each head
954+
const int64_t block_stride, // stride for each block in kv_cache
955+
const int64_t cache_token_stride, // stride for each token in kv_cache
956+
const int64_t cache_block_size, // num_tokens for each block in kv_cache
957+
const int num_blocks, // number of blocks
958+
const int num_tokens, // number of tokens
959+
const int quant_block_size // quantization block size
960+
) {
961+
constexpr int VEC_SIZE = sizeof(float4) / sizeof(char);
962+
const int token_idx = blockIdx.x * blockDim.y + threadIdx.y;
963+
const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE;
964+
// Find batch index within a block
965+
__shared__ int batch_idx[BLOCK_Y_SIZE];
966+
for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x));
967+
iter++) {
968+
int tid = iter * blockDim.x + threadIdx.x;
969+
if (tid < batch_size) {
970+
const int seq_start = cu_seq_lens[tid];
971+
const int seq_end = cu_seq_lens[tid + 1];
972+
if (token_idx >= seq_start && token_idx < seq_end) {
973+
batch_idx[threadIdx.y] = tid;
974+
}
975+
}
976+
}
977+
978+
#ifndef USE_ROCM
979+
__syncwarp();
980+
#endif
981+
982+
if (head_idx >= head_dim || token_idx >= num_tokens) {
983+
return;
984+
}
985+
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]];
986+
const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks +
987+
inbatch_seq_idx / cache_block_size];
988+
const int64_t src_block_offset = block_idx * block_stride;
989+
const int64_t cache_inblock_offset =
990+
(inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
991+
const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset;
992+
const int64_t dst_inblock_offset = token_idx * token_stride + head_idx;
993+
994+
reinterpret_cast<float4*>(dst_k)[dst_inblock_offset / VEC_SIZE] =
995+
reinterpret_cast<const float4*>(kv_cache)[src_inblock_offset / VEC_SIZE];
996+
;
997+
if (threadIdx.x == 0) {
998+
const int64_t src_scale_offset =
999+
src_block_offset + cache_block_size * head_dim +
1000+
cache_inblock_offset * 4 / quant_block_size;
1001+
reinterpret_cast<float*>(dst_scale)[dst_inblock_offset / quant_block_size] =
1002+
reinterpret_cast<const float*>(kv_cache)[src_scale_offset / 4];
1003+
}
1004+
}
1005+
1006+
template <int BLOCK_Y_SIZE>
1007+
__global__ void cp_gather_indexer_k_cache_kernel(
1008+
const char* __restrict__ kv_cache, // [num_blocks, block_size,
1009+
// cache_stride]
1010+
char* __restrict__ dst_k, // [num_tokens, head_dim]
1011+
const int* __restrict__ block_table, // [batch_size, num_blocks]
1012+
const int* __restrict__ cu_seq_lens, // [batch_size + 1]
1013+
const int batch_size, // batch size
1014+
const int64_t token_stride, // stride for each token in dst_k
1015+
const int64_t head_dim, // dimension of each head
1016+
const int64_t block_stride, // stride for each block in kv_cache
1017+
const int64_t cache_token_stride, // stride for each token in kv_cache
1018+
const int64_t cache_block_size, // num_tokens for each block in kv_cache
1019+
const int num_blocks, // number of blocks
1020+
const int num_tokens, // number of tokens
1021+
) {
1022+
constexpr int VEC_SIZE = sizeof(float4) / sizeof(char);
1023+
const int token_idx = blockIdx.x * blockDim.y + threadIdx.y;
1024+
const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE;
1025+
// Find batch index within a block
1026+
__shared__ int batch_idx[BLOCK_Y_SIZE];
1027+
for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x));
1028+
iter++) {
1029+
int tid = iter * blockDim.x + threadIdx.x;
1030+
if (tid < batch_size) {
1031+
const int seq_start = cu_seq_lens[tid];
1032+
const int seq_end = cu_seq_lens[tid + 1];
1033+
if (token_idx >= seq_start && token_idx < seq_end) {
1034+
batch_idx[threadIdx.y] = tid;
1035+
}
1036+
}
1037+
}
1038+
1039+
#ifndef USE_ROCM
1040+
__syncwarp();
1041+
#endif
1042+
1043+
if (head_idx >= head_dim || token_idx >= num_tokens) {
1044+
return;
1045+
}
1046+
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]];
1047+
const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks +
1048+
inbatch_seq_idx / cache_block_size];
1049+
const int64_t src_block_offset = block_idx * block_stride;
1050+
const int64_t cache_inblock_offset =
1051+
(inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
1052+
const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset;
1053+
const int64_t dst_inblock_offset = token_idx * token_stride + head_idx;
1054+
1055+
reinterpret_cast<float4*>(dst_k)[dst_inblock_offset / VEC_SIZE] =
1056+
reinterpret_cast<const float4*>(kv_cache)[src_inblock_offset / VEC_SIZE];
1057+
}
1058+
9411059
} // namespace vllm
9421060

9431061
// Macro to dispatch the kernel based on the data type.
@@ -1083,4 +1201,114 @@ void indexer_k_cache(
10831201
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
10841202

10851203
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "auto", CALL_INDEXER_K_CACHE);
1204+
}
1205+
1206+
// Macro to dispatch the kernel based on the data amount.
1207+
#define CALL_CP_GATHER_INDEXER_K_CACHE(BLOCK_Y_SIZE) \
1208+
vllm::cp_gather_indexer_k_cache_kernel<BLOCK_Y_SIZE> \
1209+
<<<dim3((num_tokens + BLOCK_Y_SIZE - 1) / BLOCK_Y_SIZE, \
1210+
(head_dim + 8 * vec_size - 1) / (8 * vec_size)), \
1211+
dim3(8, BLOCK_Y_SIZE), 0, stream>>>( \
1212+
reinterpret_cast<char*>(kv_cache.data_ptr()), \
1213+
reinterpret_cast<char*>(dst_k.data_ptr()), \
1214+
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
1215+
batch_size, dst_k.stride(0), dst_k.size(1), kv_cache.stride(0), \
1216+
kv_cache.stride(1), kv_cache.size(1), block_table.size(1), \
1217+
num_tokens);
1218+
1219+
void cp_gather_indexer_k_cache(
1220+
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
1221+
torch::Tensor& dst_k, // [num_tokens, head_dim]
1222+
const torch::Tensor& block_table, // [batch_size, num_blocks]
1223+
const torch::Tensor& cu_seq_lens // [batch_size + 1]
1224+
) {
1225+
int batch_size = block_table.size(0);
1226+
int num_tokens = dst_k.size(0);
1227+
int head_dim = dst_k.size(1);
1228+
// int quant_block_size = head_dim * 4 / dst_scale.size(1);
1229+
1230+
TORCH_CHECK(kv_cache.device() == dst_k.device(),
1231+
"kv_cache and dst_k must be on the same device");
1232+
// TORCH_CHECK(kv_cache.device() == dst_scale.device(),
1233+
// "kv_cache and dst_scale must be on the same device");
1234+
TORCH_CHECK(kv_cache.device() == block_table.device(),
1235+
"kv_cache and block_table must be on the same device");
1236+
TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(),
1237+
"kv_cache and cu_seq_lens must be on the same device");
1238+
// TORCH_CHECK(head_dim % quant_block_size == 0,
1239+
// "head_dim must be divisible by quant_block_size");
1240+
1241+
constexpr int vec_size = 16;
1242+
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_cache));
1243+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1244+
1245+
if (num_tokens < 32) {
1246+
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1);
1247+
} else if (num_tokens < 64) {
1248+
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2);
1249+
} else if (num_tokens < 128) {
1250+
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4);
1251+
} else if (num_tokens < 256) {
1252+
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8);
1253+
} else if (num_tokens < 512) {
1254+
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16);
1255+
} else {
1256+
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
1257+
}
1258+
}
1259+
1260+
// Macro to dispatch the kernel based on the data amount.
1261+
#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \
1262+
vllm::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \
1263+
<<<dim3((num_tokens + BLOCK_Y_SIZE - 1) / BLOCK_Y_SIZE, \
1264+
(head_dim + 8 * vec_size - 1) / (8 * vec_size)), \
1265+
dim3(8, BLOCK_Y_SIZE), 0, stream>>>( \
1266+
reinterpret_cast<char*>(kv_cache.data_ptr()), \
1267+
reinterpret_cast<char*>(dst_k.data_ptr()), \
1268+
reinterpret_cast<char*>(dst_scale.data_ptr()), \
1269+
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
1270+
batch_size, dst_k.stride(0), dst_k.size(1), kv_cache.stride(0), \
1271+
kv_cache.stride(1), kv_cache.size(1), block_table.size(1), \
1272+
num_tokens, quant_block_size);
1273+
1274+
void cp_gather_indexer_k_quant_cache(
1275+
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
1276+
torch::Tensor& dst_k, // [num_tokens, head_dim]
1277+
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
1278+
const torch::Tensor& block_table, // [batch_size, num_blocks]
1279+
const torch::Tensor& cu_seq_lens // [batch_size + 1]
1280+
) {
1281+
int batch_size = block_table.size(0);
1282+
int num_tokens = dst_k.size(0);
1283+
int head_dim = dst_k.size(1);
1284+
int quant_block_size = head_dim * 4 / dst_scale.size(1);
1285+
1286+
TORCH_CHECK(kv_cache.device() == dst_k.device(),
1287+
"kv_cache and dst_k must be on the same device");
1288+
TORCH_CHECK(kv_cache.device() == dst_scale.device(),
1289+
"kv_cache and dst_scale must be on the same device");
1290+
TORCH_CHECK(kv_cache.device() == block_table.device(),
1291+
"kv_cache and block_table must be on the same device");
1292+
TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(),
1293+
"kv_cache and cu_seq_lens must be on the same device");
1294+
TORCH_CHECK(head_dim % quant_block_size == 0,
1295+
"head_dim must be divisible by quant_block_size");
1296+
1297+
constexpr int vec_size = 16;
1298+
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_cache));
1299+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1300+
1301+
if (num_tokens < 32) {
1302+
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1);
1303+
} else if (num_tokens < 64) {
1304+
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2);
1305+
} else if (num_tokens < 128) {
1306+
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4);
1307+
} else if (num_tokens < 256) {
1308+
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8);
1309+
} else if (num_tokens < 512) {
1310+
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16);
1311+
} else {
1312+
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
1313+
}
10861314
}

vllm_metax/_custom_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,20 @@ def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor,
5454
else:
5555
torch.ops._C_cache_ops.indexer_k_quant_and_cache(
5656
k, kv_cache, slot_mapping, quant_block_size, kv_cache_dtype)
57+
58+
def cp_gather_indexer_k_quant_cache(
59+
kv_cache: torch.Tensor,
60+
dst_k: torch.Tensor,
61+
dst_scale: torch.Tensor,
62+
block_table: torch.Tensor,
63+
cu_seq_lens: torch.Tensor,
64+
) -> None:
65+
66+
if dst_k.dtype in (torch.bfloat16, torch.float16) or dst_scale is None:
67+
torch.ops._C_cache_ops.cp_gather_indexer_k_cache(
68+
kv_cache, dst_k, block_table, cu_seq_lens
69+
)
70+
else:
71+
torch.ops._C_cache_ops.cp_gather_indexer_k_quant_cache(
72+
kv_cache, dst_k, dst_scale, block_table, cu_seq_lens
73+
)

vllm_metax/models/deepseek_v2.py

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -506,49 +506,6 @@ def get_attn_backend(self) -> AttentionBackend:
506506
return DeepseekV32IndexerBackend
507507

508508

509-
@torch.inference_mode()
510-
def cp_gather_indexer_k_quant_cache(
511-
kv_cache, # [num_blocks, block_size, head_dim + 1]
512-
dst_value, # [cu_seq_lens[-1], head_dim]
513-
block_table, # [batch_size, num_blocks]
514-
cu_seq_lens, # [batch_size + 1, ]
515-
batch_size,
516-
):
517-
num_blocks, block_size, _ = kv_cache.shape
518-
head_dim = dst_value.shape[-1]
519-
kv_cache = kv_cache.view(num_blocks, -1)
520-
521-
expected_value = []
522-
for b in range(batch_size):
523-
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
524-
if s == 0:
525-
continue
526-
tot = cdiv(s, block_size)
527-
blocks = block_table[b, :tot]
528-
529-
value = []
530-
full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32)
531-
non_remaining_value = kv_cache[
532-
blocks[full_block], : block_size * head_dim
533-
].view(-1, head_dim)
534-
535-
remaining = s - (tot - 1) * block_size
536-
537-
value = torch.cat(
538-
[
539-
non_remaining_value,
540-
kv_cache[blocks[-1], : remaining * head_dim].view(-1, head_dim),
541-
],
542-
dim=0,
543-
)
544-
545-
expected_value.append(value)
546-
547-
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
548-
gather_value = gather_value.view(torch.bfloat16)
549-
dst_value.copy_(gather_value)
550-
551-
552509
def sparse_attn_indexer(
553510
hidden_states: torch.Tensor,
554511
k_cache_prefix: str,
@@ -607,12 +564,13 @@ def sparse_attn_indexer(
607564
device=k_bf16.device,
608565
dtype=torch.bfloat16,
609566
)
610-
cp_gather_indexer_k_quant_cache(
567+
k_scale = None
568+
mx_ops.cp_gather_indexer_k_quant_cache(
611569
kv_cache,
612570
_k_bf16,
571+
k_scale,
613572
chunk.block_table,
614573
chunk.cu_seq_lens,
615-
chunk.num_reqs,
616574
)
617575
logits = bf16_mqa_logits(
618576
q_bf16[chunk.token_start : chunk.token_end],

0 commit comments

Comments
 (0)