@@ -938,6 +938,124 @@ __global__ void indexer_k_cache_kernel(
938938 kv_cache[dst_offset + i] = k_val_ptr[i];
939939 }
940940}
941+
942+ template <int BLOCK_Y_SIZE>
943+ __global__ void cp_gather_indexer_k_quant_cache_kernel (
944+ const char * __restrict__ kv_cache, // [num_blocks, block_size,
945+ // cache_stride]
946+ char * __restrict__ dst_k, // [num_tokens, head_dim]
947+ char * __restrict__ dst_scale, // [num_tokens, head_dim / quant_block_size *
948+ // 4]
949+ const int * __restrict__ block_table, // [batch_size, num_blocks]
950+ const int * __restrict__ cu_seq_lens, // [batch_size + 1]
951+ const int batch_size, // batch size
952+ const int64_t token_stride, // stride for each token in dst_k
953+ const int64_t head_dim, // dimension of each head
954+ const int64_t block_stride, // stride for each block in kv_cache
955+ const int64_t cache_token_stride, // stride for each token in kv_cache
956+ const int64_t cache_block_size, // num_tokens for each block in kv_cache
957+ const int num_blocks, // number of blocks
958+ const int num_tokens, // number of tokens
959+ const int quant_block_size // quantization block size
960+ ) {
961+ constexpr int VEC_SIZE = sizeof (float4 ) / sizeof (char );
962+ const int token_idx = blockIdx .x * blockDim .y + threadIdx .y ;
963+ const int head_idx = (blockIdx .y * blockDim .x + threadIdx .x ) * VEC_SIZE;
964+ // Find batch index within a block
965+ __shared__ int batch_idx[BLOCK_Y_SIZE];
966+ for (int iter = 0 ; iter < cuda_utils::ceil_div (batch_size, int (blockDim .x ));
967+ iter++) {
968+ int tid = iter * blockDim .x + threadIdx .x ;
969+ if (tid < batch_size) {
970+ const int seq_start = cu_seq_lens[tid];
971+ const int seq_end = cu_seq_lens[tid + 1 ];
972+ if (token_idx >= seq_start && token_idx < seq_end) {
973+ batch_idx[threadIdx .y ] = tid;
974+ }
975+ }
976+ }
977+
978+ #ifndef USE_ROCM
979+ __syncwarp ();
980+ #endif
981+
982+ if (head_idx >= head_dim || token_idx >= num_tokens) {
983+ return ;
984+ }
985+ const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx .y ]];
986+ const int block_idx = block_table[batch_idx[threadIdx .y ] * num_blocks +
987+ inbatch_seq_idx / cache_block_size];
988+ const int64_t src_block_offset = block_idx * block_stride;
989+ const int64_t cache_inblock_offset =
990+ (inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
991+ const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset;
992+ const int64_t dst_inblock_offset = token_idx * token_stride + head_idx;
993+
994+ reinterpret_cast <float4 *>(dst_k)[dst_inblock_offset / VEC_SIZE] =
995+ reinterpret_cast <const float4 *>(kv_cache)[src_inblock_offset / VEC_SIZE];
996+ ;
997+ if (threadIdx .x == 0 ) {
998+ const int64_t src_scale_offset =
999+ src_block_offset + cache_block_size * head_dim +
1000+ cache_inblock_offset * 4 / quant_block_size;
1001+ reinterpret_cast <float *>(dst_scale)[dst_inblock_offset / quant_block_size] =
1002+ reinterpret_cast <const float *>(kv_cache)[src_scale_offset / 4 ];
1003+ }
1004+ }
1005+
1006+ template <int BLOCK_Y_SIZE>
1007+ __global__ void cp_gather_indexer_k_cache_kernel (
1008+ const char * __restrict__ kv_cache, // [num_blocks, block_size,
1009+ // cache_stride]
1010+ char * __restrict__ dst_k, // [num_tokens, head_dim]
1011+ const int * __restrict__ block_table, // [batch_size, num_blocks]
1012+ const int * __restrict__ cu_seq_lens, // [batch_size + 1]
1013+ const int batch_size, // batch size
1014+ const int64_t token_stride, // stride for each token in dst_k
1015+ const int64_t head_dim, // dimension of each head
1016+ const int64_t block_stride, // stride for each block in kv_cache
1017+ const int64_t cache_token_stride, // stride for each token in kv_cache
1018+ const int64_t cache_block_size, // num_tokens for each block in kv_cache
1019+ const int num_blocks, // number of blocks
1020+ const int num_tokens, // number of tokens
1021+ ) {
1022+ constexpr int VEC_SIZE = sizeof (float4 ) / sizeof (char );
1023+ const int token_idx = blockIdx .x * blockDim .y + threadIdx .y ;
1024+ const int head_idx = (blockIdx .y * blockDim .x + threadIdx .x ) * VEC_SIZE;
1025+ // Find batch index within a block
1026+ __shared__ int batch_idx[BLOCK_Y_SIZE];
1027+ for (int iter = 0 ; iter < cuda_utils::ceil_div (batch_size, int (blockDim .x ));
1028+ iter++) {
1029+ int tid = iter * blockDim .x + threadIdx .x ;
1030+ if (tid < batch_size) {
1031+ const int seq_start = cu_seq_lens[tid];
1032+ const int seq_end = cu_seq_lens[tid + 1 ];
1033+ if (token_idx >= seq_start && token_idx < seq_end) {
1034+ batch_idx[threadIdx .y ] = tid;
1035+ }
1036+ }
1037+ }
1038+
1039+ #ifndef USE_ROCM
1040+ __syncwarp ();
1041+ #endif
1042+
1043+ if (head_idx >= head_dim || token_idx >= num_tokens) {
1044+ return ;
1045+ }
1046+ const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx .y ]];
1047+ const int block_idx = block_table[batch_idx[threadIdx .y ] * num_blocks +
1048+ inbatch_seq_idx / cache_block_size];
1049+ const int64_t src_block_offset = block_idx * block_stride;
1050+ const int64_t cache_inblock_offset =
1051+ (inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
1052+ const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset;
1053+ const int64_t dst_inblock_offset = token_idx * token_stride + head_idx;
1054+
1055+ reinterpret_cast <float4 *>(dst_k)[dst_inblock_offset / VEC_SIZE] =
1056+ reinterpret_cast <const float4 *>(kv_cache)[src_inblock_offset / VEC_SIZE];
1057+ }
1058+
9411059} // namespace vllm
9421060
9431061// Macro to dispatch the kernel based on the data type.
@@ -1083,4 +1201,114 @@ void indexer_k_cache(
10831201 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
10841202
10851203 DISPATCH_BY_KV_CACHE_DTYPE (k.dtype (), " auto" , CALL_INDEXER_K_CACHE);
1204+ }
1205+
1206+ // Macro to dispatch the kernel based on the data amount.
1207+ #define CALL_CP_GATHER_INDEXER_K_CACHE (BLOCK_Y_SIZE ) \
1208+ vllm::cp_gather_indexer_k_cache_kernel<BLOCK_Y_SIZE> \
1209+ <<<dim3 ((num_tokens + BLOCK_Y_SIZE - 1 ) / BLOCK_Y_SIZE, \
1210+ (head_dim + 8 * vec_size - 1 ) / (8 * vec_size)), \
1211+ dim3 (8 , BLOCK_Y_SIZE), 0, stream>>>( \
1212+ reinterpret_cast <char *>(kv_cache.data_ptr()), \
1213+ reinterpret_cast<char*>(dst_k.data_ptr()), \
1214+ block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
1215+ batch_size, dst_k.stride(0 ), dst_k.size(1 ), kv_cache.stride(0 ), \
1216+ kv_cache.stride(1 ), kv_cache.size(1 ), block_table.size(1 ), \
1217+ num_tokens);
1218+
1219+ void cp_gather_indexer_k_cache (
1220+ const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
1221+ torch::Tensor& dst_k, // [num_tokens, head_dim]
1222+ const torch::Tensor& block_table, // [batch_size, num_blocks]
1223+ const torch::Tensor& cu_seq_lens // [batch_size + 1]
1224+ ) {
1225+ int batch_size = block_table.size (0 );
1226+ int num_tokens = dst_k.size (0 );
1227+ int head_dim = dst_k.size (1 );
1228+ // int quant_block_size = head_dim * 4 / dst_scale.size(1);
1229+
1230+ TORCH_CHECK (kv_cache.device () == dst_k.device (),
1231+ " kv_cache and dst_k must be on the same device" );
1232+ // TORCH_CHECK(kv_cache.device() == dst_scale.device(),
1233+ // "kv_cache and dst_scale must be on the same device");
1234+ TORCH_CHECK (kv_cache.device () == block_table.device (),
1235+ " kv_cache and block_table must be on the same device" );
1236+ TORCH_CHECK (kv_cache.device () == cu_seq_lens.device (),
1237+ " kv_cache and cu_seq_lens must be on the same device" );
1238+ // TORCH_CHECK(head_dim % quant_block_size == 0,
1239+ // "head_dim must be divisible by quant_block_size");
1240+
1241+ constexpr int vec_size = 16 ;
1242+ const at::cuda::OptionalCUDAGuard device_guard (device_of (kv_cache));
1243+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
1244+
1245+ if (num_tokens < 32 ) {
1246+ CALL_CP_GATHER_INDEXER_K_QUANT_CACHE (1 );
1247+ } else if (num_tokens < 64 ) {
1248+ CALL_CP_GATHER_INDEXER_K_QUANT_CACHE (2 );
1249+ } else if (num_tokens < 128 ) {
1250+ CALL_CP_GATHER_INDEXER_K_QUANT_CACHE (4 );
1251+ } else if (num_tokens < 256 ) {
1252+ CALL_CP_GATHER_INDEXER_K_QUANT_CACHE (8 );
1253+ } else if (num_tokens < 512 ) {
1254+ CALL_CP_GATHER_INDEXER_K_QUANT_CACHE (16 );
1255+ } else {
1256+ CALL_CP_GATHER_INDEXER_K_QUANT_CACHE (32 );
1257+ }
1258+ }
1259+
1260+ // Macro to dispatch the kernel based on the data amount.
1261+ #define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE (BLOCK_Y_SIZE ) \
1262+ vllm::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \
1263+ <<<dim3 ((num_tokens + BLOCK_Y_SIZE - 1 ) / BLOCK_Y_SIZE, \
1264+ (head_dim + 8 * vec_size - 1 ) / (8 * vec_size)), \
1265+ dim3 (8 , BLOCK_Y_SIZE), 0, stream>>>( \
1266+ reinterpret_cast <char *>(kv_cache.data_ptr()), \
1267+ reinterpret_cast<char*>(dst_k.data_ptr()), \
1268+ reinterpret_cast<char*>(dst_scale.data_ptr()), \
1269+ block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
1270+ batch_size, dst_k.stride(0 ), dst_k.size(1 ), kv_cache.stride(0 ), \
1271+ kv_cache.stride(1 ), kv_cache.size(1 ), block_table.size(1 ), \
1272+ num_tokens, quant_block_size);
1273+
1274+ void cp_gather_indexer_k_quant_cache (
1275+ const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
1276+ torch::Tensor& dst_k, // [num_tokens, head_dim]
1277+ torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
1278+ const torch::Tensor& block_table, // [batch_size, num_blocks]
1279+ const torch::Tensor& cu_seq_lens // [batch_size + 1]
1280+ ) {
1281+ int batch_size = block_table.size (0 );
1282+ int num_tokens = dst_k.size (0 );
1283+ int head_dim = dst_k.size (1 );
1284+ int quant_block_size = head_dim * 4 / dst_scale.size (1 );
1285+
1286+ TORCH_CHECK (kv_cache.device () == dst_k.device (),
1287+ " kv_cache and dst_k must be on the same device" );
1288+ TORCH_CHECK (kv_cache.device () == dst_scale.device (),
1289+ " kv_cache and dst_scale must be on the same device" );
1290+ TORCH_CHECK (kv_cache.device () == block_table.device (),
1291+ " kv_cache and block_table must be on the same device" );
1292+ TORCH_CHECK (kv_cache.device () == cu_seq_lens.device (),
1293+ " kv_cache and cu_seq_lens must be on the same device" );
1294+ TORCH_CHECK (head_dim % quant_block_size == 0 ,
1295+ " head_dim must be divisible by quant_block_size" );
1296+
1297+ constexpr int vec_size = 16 ;
1298+ const at::cuda::OptionalCUDAGuard device_guard (device_of (kv_cache));
1299+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
1300+
1301+ if (num_tokens < 32 ) {
1302+ CALL_CP_GATHER_INDEXER_K_QUANT_CACHE (1 );
1303+ } else if (num_tokens < 64 ) {
1304+ CALL_CP_GATHER_INDEXER_K_QUANT_CACHE (2 );
1305+ } else if (num_tokens < 128 ) {
1306+ CALL_CP_GATHER_INDEXER_K_QUANT_CACHE (4 );
1307+ } else if (num_tokens < 256 ) {
1308+ CALL_CP_GATHER_INDEXER_K_QUANT_CACHE (8 );
1309+ } else if (num_tokens < 512 ) {
1310+ CALL_CP_GATHER_INDEXER_K_QUANT_CACHE (16 );
1311+ } else {
1312+ CALL_CP_GATHER_INDEXER_K_QUANT_CACHE (32 );
1313+ }
10861314}
0 commit comments