Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aiter/ops/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def reshape_and_cache_flash(
value_cache: Tensor,
slot_mapping: Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: Tensor,
v_scale: Tensor,
): ...

@compile_ops("module_cache")
Expand Down
45 changes: 31 additions & 14 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

from torch import Tensor, Generator
from typing import Optional, Tuple
from ..jit.core import compile_ops, CK_DIR, AITER_CSRC_DIR, AITER_ROOT_DIR
from ..utility import dtypes
import torch


@compile_ops("module_mha_fwd", fc_name="mha_fwd")
def mha_fwd(
q: Tensor,
Expand Down Expand Up @@ -48,7 +49,7 @@
bias: Optional[Tensor] = None,
alibi_slopes: Optional[Tensor] = None,
gen: Optional[Generator] = None,
): ...
) -> list[Tensor]: ...


@compile_ops("module_mha_bwd", fc_name="mha_bwd")
Expand Down Expand Up @@ -419,7 +420,9 @@
# bwd_hd64_bf16_causal_a32_rtz_pssk
# bwd_hd64_fp16_a32_pssk
# bwd_hd64_fp16_causal_a32_pssk
ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
ret = (
is_v3_atomic_fp32 == True

Check failure on line 424 in aiter/ops/mha.py

View workflow job for this annotation

GitHub Actions / reviewdog

aiter/ops/mha.py#L424

E712 Avoid equality comparisons to `True`; use `if is_v3_atomic_fp32:` for truth checks
Raw output
aiter/ops/mha.py:424:13: E712 Avoid equality comparisons to `True`; use `if is_v3_atomic_fp32:` for truth checks
) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
ret &= hdim_q == 64
ret &= nmask or (
mask and seqlen_q == seqlen_k
Expand Down Expand Up @@ -474,7 +477,9 @@
# bwd_hd192_bf16_causal_a32_rtz_psskddv
ret = is_v3_atomic_fp32 == True
ret &= hdim_q > 64 and hdim_q <= 192
ret &= nmask or (mask and seqlen_q == seqlen_k) # TODO: or (seqlen_q != seqlen_k and mask_type == top_left)
ret &= nmask or (
mask and seqlen_q == seqlen_k
) # TODO: or (seqlen_q != seqlen_k and mask_type == top_left)

return ret

Expand Down Expand Up @@ -759,6 +764,7 @@
return_lse: bool = False,
return_softmax: bool = False,
block_table: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
zero_tensors: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# causal=true is the same as causal=false in this case
Expand Down Expand Up @@ -878,7 +884,7 @@
window_size_right,
return_lse,
return_softmax,
None,
out,
block_table,
bias,
alibi_slopes,
Expand Down Expand Up @@ -963,7 +969,9 @@
]

(_, nhead_q, hdim_q) = q.shape
(_, nhead_k, hdim_v) = v.shape

nhead_k = v.shape[-2]
hdim_v = v.shape[-1]

# mask
window_size_left = -1 if window_size_left >= max_seqlen_k else window_size_left
Expand Down Expand Up @@ -994,12 +1002,14 @@
# bwd_hd128_bf16_causal_a32_rtz_pssk_group
# bwd_hd128_fp16_a32_pssk_group
# bwd_hd128_fp16_causal_a32_pssk_group
ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
ret = (
is_v3_atomic_fp32 == True

Check failure on line 1006 in aiter/ops/mha.py

View workflow job for this annotation

GitHub Actions / reviewdog

aiter/ops/mha.py#L1006

E712 Avoid equality comparisons to `True`; use `if is_v3_atomic_fp32:` for truth checks
Raw output
aiter/ops/mha.py:1006:13: E712 Avoid equality comparisons to `True`; use `if is_v3_atomic_fp32:` for truth checks
) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
ret &= hdim_q == 64 or hdim_q == 128
ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)

return ret

def psskddv():
# bwd_hd128_bf16_a32_rtne_psskddv_group
# bwd_hd128_bf16_a32_rtna_psskddv_group
Expand All @@ -1009,9 +1019,11 @@
# bwd_hd128_bf16_causal_a32_rtz_psskddv_group
# bwd_hd128_fp16_a32_psskddv_group
# bwd_hd128_fp16_causal_a32_psskddv_group
ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
ret = (
is_v3_atomic_fp32 == True

Check failure on line 1023 in aiter/ops/mha.py

View workflow job for this annotation

GitHub Actions / reviewdog

aiter/ops/mha.py#L1023

E712 Avoid equality comparisons to `True`; use `if is_v3_atomic_fp32:` for truth checks
Raw output
aiter/ops/mha.py:1023:13: E712 Avoid equality comparisons to `True`; use `if is_v3_atomic_fp32:` for truth checks
) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
ret &= hdim_q > 64 and hdim_q < 128
ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)

return ret

Expand All @@ -1027,7 +1039,7 @@
ret &= hdim_q >= 64 and hdim_q <= 128 and hdim_q % 8 == 0
ret &= mask or nmask
ret &= pssk() or psskddv()
ret &= 'gfx942' in torch.cuda.get_device_properties("cuda").gcnArchName
ret &= "gfx942" in torch.cuda.get_device_properties("cuda").gcnArchName

return ret

Expand Down Expand Up @@ -1122,15 +1134,16 @@
return_lse,
return_softmax,
block_table,
out,
is_grad_enabled,
is_v3_atomic_fp32: Optional[bool] = True,
how_v3_bf16_cvt: Optional[int] = 1,
):
is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v])
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
head_size_q_og = q.size(2)
head_size_v_og = v.size(2)
head_size_q_og = q.size(-1)
head_size_v_og = v.size(-1)
if head_size_q_og % 8 != 0:
q = torch.nn.functional.pad(q, [0, 8 - head_size_q_og % 8])
k = torch.nn.functional.pad(k, [0, 8 - head_size_q_og % 8])
Expand All @@ -1154,6 +1167,7 @@
return_lse=return_lse,
return_softmax=return_softmax and dropout_p > 0,
block_table=block_table,
out=out,
)
if is_grad:
ctx.save_for_backward(
Expand Down Expand Up @@ -1243,6 +1257,7 @@
None,
None,
None,
None,
)


Expand All @@ -1264,6 +1279,7 @@
return_lse=False,
return_attn_probs=False,
block_table=None,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -1338,5 +1354,6 @@
return_lse,
return_attn_probs,
block_table,
out,
torch.is_grad_enabled(),
)
2 changes: 1 addition & 1 deletion csrc/include/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void reshape_and_cache_flash(torch::Tensor &key, torch::Tensor &value,
torch::Tensor &value_cache,
torch::Tensor &slot_mapping,
const std::string &kv_cache_dtype,
const double k_scale, const double v_scale);
torch::Tensor& k_scale, torch::Tensor& v_scale);

void reshape_and_cache_with_pertoken_quant(torch::Tensor &key, torch::Tensor &value,
torch::Tensor &key_cache, torch::Tensor &value_cache,
Expand Down
4 changes: 2 additions & 2 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

#define ACTIVATION_PYBIND \
m.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); \
Expand Down Expand Up @@ -573,4 +573,4 @@
.value("No", ActivationType::No) \
.value("Silu", ActivationType::Silu) \
.value("Gelu", ActivationType::Gelu) \
.export_values();
.export_values();
4 changes: 2 additions & 2 deletions csrc/include/torch/mha_varlen_fwd.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <torch/extension.h>

namespace aiter {
Expand All @@ -10,7 +10,7 @@ mha_varlen_fwd(at::Tensor& q, // [total_q, hq, d]
const at::Tensor& k, // [total_k, hk, d]
const at::Tensor& v, // [total_k, hk, d]
const at::Tensor& cu_seqlens_q, // [b+1]
const at::Tensor& cu_seqlens_k, // [b+1]
std::optional<const at::Tensor> &cu_seqlens_k, // [b+1]
int max_seqlen_q,
int max_seqlen_k,
float p_dropout,
Expand Down
13 changes: 7 additions & 6 deletions csrc/kernels/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ namespace vllm
const int64_t *__restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size,
const float k_scale, const float v_scale)
const float* k_scale, const float* v_scale)
{
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
Expand Down Expand Up @@ -305,9 +305,9 @@ namespace vllm
else
{
key_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
}
}
}
Expand Down Expand Up @@ -873,7 +873,7 @@ void reshape_and_cache(
reinterpret_cast<CACHE_T *>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T *>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
value_stride, num_heads, head_size, block_size, k_scale.data_ptr<float>(), v_scale.data_ptr<float>());

void reshape_and_cache_flash(
torch::Tensor &key, // [num_tokens, num_heads, head_size]
Expand All @@ -882,8 +882,9 @@ void reshape_and_cache_flash(
torch::Tensor &
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor &slot_mapping, // [num_tokens]
const std::string &kv_cache_dtype, const double k_scale,
const double v_scale)
const std::string &kv_cache_dtype,
torch::Tensor& k_scale,
torch::Tensor& v_scale)
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
Expand Down
Loading
Loading