diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 292352649163..f8ae4b65d438 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -1287,7 +1287,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] - const int max_num_partitions) { + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const auto num_heads = gridDim.x; const auto head_idx = blockIdx.x; const auto seq_idx = blockIdx.y; @@ -1465,8 +1465,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + const float out_scale = + (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; acc *= inv_global_exp_sum; - + acc *= out_scale; const int64_t query_start_off = static_cast( query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE + @@ -1548,7 +1550,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] - const int max_num_partitions) { + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { UNREACHABLE_CODE } // clang-format on @@ -1582,7 +1584,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( PARTITION_SIZE, NPAR_LOOPS> \ <<>>( \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ - context_lens_ptr, query_start_loc_ptr, max_num_partitions); + context_lens_ptr, query_start_loc_ptr, max_num_partitions, \ + fp8_out_scale_ptr); template & query_start_loc, int max_context_len, const std::optional& alibi_slopes, torch::Tensor& k_scale, - torch::Tensor& v_scale) { + torch::Tensor& v_scale, const c10::optional& fp8_out_scale) { int num_seqs = block_tables.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -1626,6 +1629,11 @@ void paged_attention_custom_launcher( int* context_lens_ptr = context_lens.data_ptr(); const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + // NOTE: fp8_out_scale is optional. + const auto fp8_out_scale_ptr = + fp8_out_scale + ? static_cast(fp8_out_scale.value().data_ptr()) + : nullptr; OUTT* out_ptr = reinterpret_cast(out.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); @@ -1736,33 +1744,54 @@ void paged_attention_custom_launcher( } } -#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \ - ALIBI_ENABLED) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ - max_context_len, alibi_slopes, k_scale, v_scale); - -#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ - PSIZE) \ - if (alibi_slopes) { \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true); \ - } else { \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false); \ +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ + PSIZE, ALIBI_ENABLED) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ + max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); + +#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + OUTT, PSIZE) \ + if (alibi_slopes) { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ + true); \ + } else { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ + false); \ } -#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ - switch (block_size) { \ - case 16: \ - CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 16, HEAD_SIZE, 256); \ - break; \ - case 32: \ - CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 32, HEAD_SIZE, 256); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#if defined(__HIPCC__) && defined(__gfx90a__) + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + if (fp8_out_scale) { \ + TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ + } else { \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ + 256); \ + } +#else + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + if (fp8_out_scale) { \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + uint8_t, 256); \ + } else { \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ + 256); \ + } +#endif + +#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ + switch (block_size) { \ + case 16: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } #define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ @@ -1795,7 +1824,8 @@ void paged_attention( int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale) { + torch::Tensor& v_scale, + const c10::optional& fp8_out_scale) { // clang-format on const int head_size = query.size(2); if (kv_cache_dtype == "auto") { diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index b90cfdc617af..2252ea717220 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -11,14 +11,12 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount); -void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, - torch::Tensor& max_logits, torch::Tensor& tmp_out, - torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, - double scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, - const std::optional& query_start_loc, - int64_t block_size, int64_t max_context_len, - const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale); +void paged_attention( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + const std::optional& query_start_loc, int64_t block_size, + int64_t max_context_len, const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const c10::optional& fp8_out_scale); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 4ac6fd1e9940..34575477bcc9 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype," - " Tensor k_scale, Tensor v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale," + " Tensor? fp8_out_scale) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 44377ccb2959..33989fc9438c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -117,13 +117,14 @@ def paged_attention_rocm( kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, + fp8_out_scale: Optional[torch.Tensor] = None, ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, query_start_loc, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, - v_scale) + v_scale, fp8_out_scale) def mla_decode_kvcache_cpu(