@@ -203,17 +203,18 @@ __global__ void reshape_and_cache_kernel(
203203 }
204204}
205205
206- template <typename scalar_t >
206+ template <typename scalar_t , typename cache_t , Fp8KVCacheDataType kv_dt >
207207__global__ void reshape_and_cache_flash_kernel (
208208 const scalar_t * __restrict__ key, // [num_tokens, num_heads, head_size]
209209 const scalar_t * __restrict__ value, // [num_tokens, num_heads, head_size]
210- scalar_t * __restrict__ k_cache, // [num_blocks, block_size, num_heads,
210+ cache_t * __restrict__ key_cache, // [num_blocks, block_size, num_heads,
211211 // head_size]
212- scalar_t * __restrict__ v_cache, // [num_blocks, block_size, num_heads,
212+ cache_t * __restrict__ value_cache, // [num_blocks, block_size, num_heads,
213213 // head_size]
214214 const int64_t * __restrict__ slot_mapping, // [num_tokens]
215215 const int block_stride, const int key_stride, const int value_stride,
216- const int num_heads, const int head_size, const int block_size) {
216+ const int num_heads, const int head_size, const int block_size,
217+ const float k_scale, const float v_scale) {
217218 const int64_t token_idx = blockIdx .x ;
218219 const int64_t slot_idx = slot_mapping[token_idx];
219220 // NOTE: slot_idx can be -1 if the token is padded
@@ -228,11 +229,20 @@ __global__ void reshape_and_cache_flash_kernel(
228229 const int64_t src_value_idx = token_idx * value_stride + i;
229230 const int head_idx = i / head_size;
230231 const int head_offset = i % head_size;
231- const int64_t tgt_value_idx = block_idx * block_stride +
232- block_offset * num_heads * head_size +
233- head_idx * head_size + head_offset;
234- k_cache[tgt_value_idx] = key[src_key_idx];
235- v_cache[tgt_value_idx] = value[src_value_idx];
232+ const int64_t tgt_key_value_idx = block_idx * block_stride +
233+ block_offset * num_heads * head_size +
234+ head_idx * head_size + head_offset;
235+ scalar_t tgt_key = key[src_key_idx];
236+ scalar_t tgt_value = value[src_value_idx];
237+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
238+ key_cache[tgt_key_value_idx] = tgt_key;
239+ value_cache[tgt_key_value_idx] = tgt_value;
240+ } else {
241+ key_cache[tgt_key_value_idx] =
242+ fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_key, k_scale);
243+ value_cache[tgt_key_value_idx] =
244+ fp8::scaled_convert<cache_t , scalar_t , kv_dt>(tgt_value, v_scale);
245+ }
236246 }
237247}
238248} // namespace vllm
@@ -278,40 +288,45 @@ void reshape_and_cache(
278288 CALL_RESHAPE_AND_CACHE)
279289}
280290
291+ // KV_T is the stored data type of kv-cache.
292+ // CACHE_T is the data type of key and value tensors.
293+ // KV_DTYPE is the real data type of kv-cache.
294+ #define CALL_RESHAPE_AND_CACHE_FLASH (KV_T, CACHE_T, KV_DTYPE ) \
295+ vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
296+ <<<grid, block, 0 , stream>>> ( \
297+ reinterpret_cast <KV_T*>(key.data_ptr()), \
298+ reinterpret_cast <KV_T*>(value.data_ptr()), \
299+ reinterpret_cast <CACHE_T*>(key_cache.data_ptr()), \
300+ reinterpret_cast <CACHE_T*>(value_cache.data_ptr()), \
301+ slot_mapping.data_ptr<int64_t >(), block_stride, key_stride, \
302+ value_stride, num_heads, head_size, block_size, k_scale, v_scale);
303+
281304void reshape_and_cache_flash (
282- torch::Tensor& key, // [num_tokens, num_heads, head_size]
283- torch::Tensor& value, // [num_tokens, num_heads, head_size]
284- torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
285- torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
305+ torch::Tensor& key, // [num_tokens, num_heads, head_size]
306+ torch::Tensor& value, // [num_tokens, num_heads, head_size]
307+ torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
308+ torch::Tensor&
309+ value_cache, // [num_blocks, block_size, num_heads, head_size]
286310 torch::Tensor& slot_mapping, // [num_tokens]
287- const std::string& kv_cache_dtype) {
288- // FIXME: only support auto datatype, does not support fp8
289- if (kv_cache_dtype != " auto" ) {
290- TORCH_CHECK (false , " Unsupported data type of kv cache: " , kv_cache_dtype);
291- }
311+ const std::string& kv_cache_dtype, const double k_scale,
312+ const double v_scale) {
292313 int num_tokens = key.size (0 );
293314 int num_heads = key.size (1 );
294315 int head_size = key.size (2 );
295- int block_size = k_cache .size (1 );
316+ int block_size = key_cache .size (1 );
296317
297318 int key_stride = key.stride (0 );
298319 int value_stride = value.stride (0 );
299- int block_stride = k_cache .stride (0 );
300- TORCH_CHECK (k_cache .stride (0 ) == v_cache .stride (0 ));
320+ int block_stride = key_cache .stride (0 );
321+ TORCH_CHECK (key_cache .stride (0 ) == value_cache .stride (0 ));
301322
302323 dim3 grid (num_tokens);
303324 dim3 block (std::min (num_heads * head_size, 512 ));
304325 const at::cuda::OptionalCUDAGuard device_guard (device_of (key));
305326 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
306- VLLM_DISPATCH_FLOATING_TYPES (
307- key.scalar_type (), " reshape_and_cache_flash" , [&] {
308- vllm::reshape_and_cache_flash_kernel<scalar_t >
309- <<<grid, block, 0 , stream>>> (
310- key.data_ptr <scalar_t >(), value.data_ptr <scalar_t >(),
311- k_cache.data_ptr <scalar_t >(), v_cache.data_ptr <scalar_t >(),
312- slot_mapping.data_ptr <int64_t >(), block_stride, key_stride,
313- value_stride, num_heads, head_size, block_size);
314- });
327+
328+ DISPATCH_BY_KV_CACHE_DTYPE (key.dtype (), kv_cache_dtype,
329+ CALL_RESHAPE_AND_CACHE_FLASH);
315330}
316331
317332namespace vllm {
0 commit comments