diff --git a/aiter/ops/cache.py b/aiter/ops/cache.py index cff1b7a4da..dc4d9ddd7d 100644 --- a/aiter/ops/cache.py +++ b/aiter/ops/cache.py @@ -39,8 +39,8 @@ def reshape_and_cache_flash( value_cache: Tensor, slot_mapping: Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: Tensor, + v_scale: Tensor, ): ... @compile_ops("module_cache") diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 71cb50c22e..44d04b20d9 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. from torch import Tensor, Generator from typing import Optional, Tuple @@ -7,6 +7,7 @@ from ..utility import dtypes import torch + @compile_ops("module_mha_fwd", fc_name="mha_fwd") def mha_fwd( q: Tensor, @@ -48,7 +49,7 @@ def mha_varlen_fwd( bias: Optional[Tensor] = None, alibi_slopes: Optional[Tensor] = None, gen: Optional[Generator] = None, -): ... +) -> list[Tensor]: ... @compile_ops("module_mha_bwd", fc_name="mha_bwd") @@ -419,7 +420,9 @@ def pssk(): # bwd_hd64_bf16_causal_a32_rtz_pssk # bwd_hd64_fp16_a32_pssk # bwd_hd64_fp16_causal_a32_pssk - ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed + ret = ( + is_v3_atomic_fp32 == True + ) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed ret &= hdim_q == 64 ret &= nmask or ( mask and seqlen_q == seqlen_k @@ -474,7 +477,9 @@ def psskddv(): # bwd_hd192_bf16_causal_a32_rtz_psskddv ret = is_v3_atomic_fp32 == True ret &= hdim_q > 64 and hdim_q <= 192 - ret &= nmask or (mask and seqlen_q == seqlen_k) # TODO: or (seqlen_q != seqlen_k and mask_type == top_left) + ret &= nmask or ( + mask and seqlen_q == seqlen_k + ) # TODO: or (seqlen_q != seqlen_k and mask_type == top_left) return ret @@ -759,6 +764,7 @@ def _flash_attn_varlen_forward( return_lse: bool = False, return_softmax: bool = False, block_table: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, zero_tensors: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # causal=true is the same as causal=false in this case @@ -878,7 +884,7 @@ def _flash_attn_varlen_forward( window_size_right, return_lse, return_softmax, - None, + out, block_table, bias, alibi_slopes, @@ -963,7 +969,9 @@ def _flash_attn_varlen_backward( ] (_, nhead_q, hdim_q) = q.shape - (_, nhead_k, hdim_v) = v.shape + + nhead_k = v.shape[-2] + hdim_v = v.shape[-1] # mask window_size_left = -1 if window_size_left >= max_seqlen_k else window_size_left @@ -994,12 +1002,14 @@ def pssk(): # bwd_hd128_bf16_causal_a32_rtz_pssk_group # bwd_hd128_fp16_a32_pssk_group # bwd_hd128_fp16_causal_a32_pssk_group - ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed + ret = ( + is_v3_atomic_fp32 == True + ) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed ret &= hdim_q == 64 or hdim_q == 128 - ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left) + ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left) return ret - + def psskddv(): # bwd_hd128_bf16_a32_rtne_psskddv_group # bwd_hd128_bf16_a32_rtna_psskddv_group @@ -1009,9 +1019,11 @@ def psskddv(): # bwd_hd128_bf16_causal_a32_rtz_psskddv_group # bwd_hd128_fp16_a32_psskddv_group # bwd_hd128_fp16_causal_a32_psskddv_group - ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed + ret = ( + is_v3_atomic_fp32 == True + ) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed ret &= hdim_q > 64 and hdim_q < 128 - ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left) + ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left) return ret @@ -1027,7 +1039,7 @@ def can_impl_fmha_v3_bwd(): ret &= hdim_q >= 64 and hdim_q <= 128 and hdim_q % 8 == 0 ret &= mask or nmask ret &= pssk() or psskddv() - ret &= 'gfx942' in torch.cuda.get_device_properties("cuda").gcnArchName + ret &= "gfx942" in torch.cuda.get_device_properties("cuda").gcnArchName return ret @@ -1122,6 +1134,7 @@ def forward( return_lse, return_softmax, block_table, + out, is_grad_enabled, is_v3_atomic_fp32: Optional[bool] = True, how_v3_bf16_cvt: Optional[int] = 1, @@ -1129,8 +1142,8 @@ def forward( is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v]) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - head_size_q_og = q.size(2) - head_size_v_og = v.size(2) + head_size_q_og = q.size(-1) + head_size_v_og = v.size(-1) if head_size_q_og % 8 != 0: q = torch.nn.functional.pad(q, [0, 8 - head_size_q_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_q_og % 8]) @@ -1154,6 +1167,7 @@ def forward( return_lse=return_lse, return_softmax=return_softmax and dropout_p > 0, block_table=block_table, + out=out, ) if is_grad: ctx.save_for_backward( @@ -1243,6 +1257,7 @@ def backward(ctx, dout, *args): None, None, None, + None, ) @@ -1264,6 +1279,7 @@ def flash_attn_varlen_func( return_lse=False, return_attn_probs=False, block_table=None, + out=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -1338,5 +1354,6 @@ def flash_attn_varlen_func( return_lse, return_attn_probs, block_table, + out, torch.is_grad_enabled(), ) diff --git a/csrc/include/cache.h b/csrc/include/cache.h index d21dc3530d..494b0a7143 100644 --- a/csrc/include/cache.h +++ b/csrc/include/cache.h @@ -27,7 +27,7 @@ void reshape_and_cache_flash(torch::Tensor &key, torch::Tensor &value, torch::Tensor &value_cache, torch::Tensor &slot_mapping, const std::string &kv_cache_dtype, - const double k_scale, const double v_scale); + torch::Tensor& k_scale, torch::Tensor& v_scale); void reshape_and_cache_with_pertoken_quant(torch::Tensor &key, torch::Tensor &value, torch::Tensor &key_cache, torch::Tensor &value_cache, diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 2c9bf263f3..9849ec91e8 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #define ACTIVATION_PYBIND \ m.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); \ @@ -573,4 +573,4 @@ .value("No", ActivationType::No) \ .value("Silu", ActivationType::Silu) \ .value("Gelu", ActivationType::Gelu) \ - .export_values(); \ No newline at end of file + .export_values(); diff --git a/csrc/include/torch/mha_varlen_fwd.h b/csrc/include/torch/mha_varlen_fwd.h index abc522b348..e920061db8 100644 --- a/csrc/include/torch/mha_varlen_fwd.h +++ b/csrc/include/torch/mha_varlen_fwd.h @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include namespace aiter { @@ -10,7 +10,7 @@ mha_varlen_fwd(at::Tensor& q, // [total_q, hq, d] const at::Tensor& k, // [total_k, hk, d] const at::Tensor& v, // [total_k, hk, d] const at::Tensor& cu_seqlens_q, // [b+1] - const at::Tensor& cu_seqlens_k, // [b+1] + std::optional &cu_seqlens_k, // [b+1] int max_seqlen_q, int max_seqlen_k, float p_dropout, diff --git a/csrc/kernels/cache_kernels.cu b/csrc/kernels/cache_kernels.cu index 45bba25cf8..3a54727b89 100644 --- a/csrc/kernels/cache_kernels.cu +++ b/csrc/kernels/cache_kernels.cu @@ -274,7 +274,7 @@ namespace vllm const int64_t *__restrict__ slot_mapping, // [num_tokens] const int block_stride, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, - const float k_scale, const float v_scale) + const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; @@ -305,9 +305,9 @@ namespace vllm else { key_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_key, k_scale); + fp8::scaled_convert(tgt_key, *k_scale); value_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_value, v_scale); + fp8::scaled_convert(tgt_value, *v_scale); } } } @@ -873,7 +873,7 @@ void reshape_and_cache( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), block_stride, key_stride, \ - value_stride, num_heads, head_size, block_size, k_scale, v_scale); + value_stride, num_heads, head_size, block_size, k_scale.data_ptr(), v_scale.data_ptr()); void reshape_and_cache_flash( torch::Tensor &key, // [num_tokens, num_heads, head_size] @@ -882,8 +882,9 @@ void reshape_and_cache_flash( torch::Tensor & value_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor &slot_mapping, // [num_tokens] - const std::string &kv_cache_dtype, const double k_scale, - const double v_scale) + const std::string &kv_cache_dtype, + torch::Tensor& k_scale, + torch::Tensor& v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); diff --git a/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu index 08927fd2f2..8bb2911412 100644 --- a/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -24,8 +24,9 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, const at::Tensor q, const at::Tensor k, const at::Tensor v, - const at::Tensor seqlens_q, - const at::Tensor seqlens_k, + const at::Tensor cu_seqlens_q, + std::optional &cu_seqlens_k, + std::optional &seqlens_k, std::optional &bias_, std::optional &alibi_slopes_, at::Tensor out, @@ -98,9 +99,9 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), - seqlens_q.data_ptr(), // seqstart_q - seqlens_k.data_ptr(), // seqstart_k - nullptr, // seqlen_kpads + cu_seqlens_q.data_ptr(), // seqstart_q + cu_seqlens_k.has_value() ? cu_seqlens_k.value().data_ptr() : nullptr, // seqstart_k + seqlens_k.has_value() ? seqlens_k.value().data_ptr() : nullptr, // seqlen_kpads total_q, total_k, b, @@ -155,8 +156,9 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, const at::Tensor q, const at::Tensor k, const at::Tensor v, - const at::Tensor seqlens_q, - const at::Tensor seqlens_k, + const at::Tensor cu_seqlens_q, + std::optional &cu_seqlens_k, + std::optional &seqlens_k, std::optional &block_table_, std::optional &bias_, std::optional &alibi_slopes_, @@ -204,9 +206,19 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, args.is_gappy = false; args.cache_batch_idx = nullptr; - args.seqstart_q_ptr = seqlens_q.data_ptr(); - args.seqstart_k_ptr = seqlens_k.data_ptr(); - args.seqlen_k_ptr = nullptr; + args.seqstart_q_ptr = cu_seqlens_q.data_ptr(); + if (cu_seqlens_k.has_value()) { + args.seqstart_k_ptr = cu_seqlens_k.value().data_ptr(); + } + else { + args.seqstart_k_ptr = nullptr; + } + if (seqlens_k.has_value()) { + args.seqlen_k_ptr = seqlens_k.value().data_ptr(); + } + else { + args.seqlen_k_ptr = nullptr; + } args.batch = b; args.max_seqlen_q = max_seqlen_q; @@ -281,7 +293,7 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] const at::Tensor &k, // [total_k, hk, d] const at::Tensor &v, // [total_k, hk, d] const at::Tensor &cu_seqlens_q, // [b+1] - const at::Tensor &cu_seqlens_k, // [b+1] + std::optional &cu_seqlens_k, // [b+1] int max_seqlen_q, int max_seqlen_k, float p_dropout, @@ -305,13 +317,17 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + if (cu_seqlens_k.has_value()) { + TORCH_CHECK(cu_seqlens_k.value().dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + } std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(cu_seqlens_q); - CHECK_DEVICE(cu_seqlens_k); + if (cu_seqlens_k.has_value()) { + CHECK_DEVICE(cu_seqlens_k.value()); + } at::Tensor block_table; const bool paged_KV = block_table_.has_value(); @@ -326,15 +342,17 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); CHECK_CONTIGUOUS(cu_seqlens_q); - CHECK_CONTIGUOUS(cu_seqlens_k); + if (cu_seqlens_k.has_value()) { + CHECK_CONTIGUOUS(cu_seqlens_k.value()); + } const auto sizes = q.sizes(); const int batch_size = cu_seqlens_q.numel() - 1; int num_heads = sizes[1]; - const int head_size_q = sizes[2]; - const int head_size_v = v.size(2); - const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + const int head_size_q = q.size(-1); + const int head_size_v = v.size(-1); + const int num_heads_k = k.size(-2); const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); @@ -353,7 +371,6 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] const int total_q = q.size(0); - TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(head_size_q <= 256, "CK only supports head dimension at most 256"); TORCH_CHECK(head_size_v <= 256, "CK only supports head dimension at most 256"); @@ -393,7 +410,9 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] } CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (cu_seqlens_k.has_value()) { + CHECK_SHAPE(cu_seqlens_k.value(), batch_size + 1); + } auto opts = q.options(); at::Tensor out; @@ -461,7 +480,7 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] hipLaunchKernelGGL( aiter::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr); } - + std::optional seqlens_k = std::nullopt; if (max_seqlen_k > 0) { auto stream = at::cuda::getCurrentHIPStream().stream(); ck_tile::stream_config stream_config{stream}; @@ -486,6 +505,7 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] v, cu_seqlens_q, cu_seqlens_k, + seqlens_k, block_table_, bias_, alibi_slopes_, @@ -505,8 +525,8 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] } else { + TORCH_CHECK(cu_seqlens_k.has_value(), "cu_seqlens_k must be provided if paged_KV is false"); auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); - auto args = get_ck_fmha_varlen_fwd_args( has_lse, @@ -523,6 +543,7 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] v, cu_seqlens_q, cu_seqlens_k, + seqlens_k, bias_, alibi_slopes_, out, diff --git a/op_tests/test_mha_varlen.py b/op_tests/test_mha_varlen.py index 20c5387884..02c909b886 100644 --- a/op_tests/test_mha_varlen.py +++ b/op_tests/test_mha_varlen.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. from einops import rearrange, repeat import torch @@ -282,9 +282,7 @@ def test_flash_attn_varlen_func( requires_grad=True, ) elif bias_type == "alibi": - alibi_slopes = torch.rand( - batch_size, nheads, device="cuda", dtype=dtypes.fp32 - ) + alibi_slopes = torch.rand(batch_size, nheads, device="cuda", dtype=dtypes.fp32) dout = torch.randn( batch_size,