Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# aiter version
aiter/_version.py

# aiter install mode
aiter/install_mode

# Prerequisites
*.d

Expand Down
72 changes: 43 additions & 29 deletions aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

# user interface

import functools

import torch
import aiter
from aiter import dtypes
import triton
import triton.language as tl
import functools

import aiter
from aiter import dtypes
from aiter.jit.utils.chip_info import get_cu_num
from aiter.ops.triton.utils.types import get_fp8_e4m3_dtype


@triton.jit
Expand All @@ -21,11 +22,11 @@ def _fwd_kernel_stage2_asm(
qo_indptr,
kv_indptr,
num_kv_splits_indptr,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_obs,
stride_oh,
stride_mid_ob: tl.int64,
stride_mid_oh: tl.int64,
stride_mid_os: tl.int64,
stride_obs: tl.int64,
stride_oh: tl.int64,
MAYBE_FINAL_OUT: tl.constexpr,
BATCH_NUM: tl.constexpr,
BLOCK_DV: tl.constexpr,
Expand Down Expand Up @@ -96,7 +97,7 @@ def _fwd_kernel_stage2_asm(
)


@functools.lru_cache()
@functools.lru_cache(maxsize=1)
def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype):
if num_kv_splits is None:
cu_num = get_cu_num()
Expand Down Expand Up @@ -128,7 +129,7 @@ def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype):
512: 32,
}

if dtype == get_fp8_e4m3_dtype():
if dtype == dtypes.fp8:
min_block_n = get_block_n_fp8[int(nhead * max_seqlen_q)]
num_kv_splits = min(
num_kv_splits, int(total_kv / bs + min_block_n - 1) // min_block_n
Expand All @@ -138,7 +139,12 @@ def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype):
mgc = get_mgc[nhead]
if max_seqlen_q == 1 and nhead == 16:
mgc = 64
return num_kv_splits, mgc

num_kv_splits_indptr = torch.arange(
0, (bs + 1) * num_kv_splits, num_kv_splits, dtype=torch.int, device="cuda"
)

return num_kv_splits, mgc, num_kv_splits_indptr


def mla_decode_fwd(
Expand Down Expand Up @@ -176,30 +182,34 @@ def mla_decode_fwd(

persistent_mode = work_meta_data is not None

if num_kv_splits_indptr is None and not persistent_mode:
num_kv_splits, mgc = get_meta_param(
None, bs, total_kv, nhead, max_seqlen_q, q.dtype
)
num_kv_splits_indptr = torch.arange(
0, (bs + 1) * num_kv_splits, num_kv_splits, dtype=torch.int, device=device
)

if num_kv_splits is None:
num_kv_splits = get_cu_num()

io_transformed = False

if not persistent_mode:
num_kv_splits, mgc, num_kv_splits_indptr = get_meta_param(
num_kv_splits, bs, total_kv, nhead, max_seqlen_q, q.dtype
)

MAYBE_FINAL_OUT = True

if nhead == 16 and max_seqlen_q == 1:
MAYBE_FINAL_OUT = False

logits = torch.empty(
(total_s, num_kv_splits, nhead, v_head_dim),
dtype=dtypes.fp32,
device=device,
logits = (
o.view((total_s, num_kv_splits, nhead, v_head_dim))
if (
num_kv_splits == 1
and (
q.dtype == dtypes.fp8
or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
)
)
else torch.empty(
(total_s, num_kv_splits, nhead, v_head_dim),
dtype=dtypes.fp32,
device=device,
)
)

attn_lse = torch.empty(
(total_s, num_kv_splits, nhead, 1), dtype=dtypes.fp32, device=device
)
Expand All @@ -225,7 +235,9 @@ def mla_decode_fwd(
kv_scale,
)

if num_kv_splits == 1 and q.dtype != torch.bfloat16:
if num_kv_splits == 1 and (
q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
):
return logits.view(total_s, nhead, v_head_dim), attn_lse

Lv = v_head_dim
Expand Down Expand Up @@ -255,7 +267,9 @@ def mla_decode_fwd(
**extra_kargs,
)
else:
if nhead == 16 or (nhead == 128 and kv_buffer.dtype == get_fp8_e4m3_dtype()):
if num_kv_splits is None:
num_kv_splits = get_cu_num()
if nhead == 16 or (nhead == 128 and kv_buffer.dtype == dtypes.fp8):
# Natively support cases
pass
elif nhead in range(32, 512 + 1, 16) and persistent_mode and max_seqlen_q == 1:
Expand Down
1 change: 1 addition & 0 deletions aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def get_mla_metadata_v1(
uni_seqlen_qo: int = -1,
fast_mode: bool = True,
topk: int = -1,
max_split_per_batch: int = -1,
) -> None:
"""
Inputs:
Expand Down
3 changes: 2 additions & 1 deletion csrc/include/mla.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ void get_mla_metadata_v1(const torch::Tensor& seqlens_qo_indptr, // [batch size
const int32_t max_seqlen_qo,
const int32_t uni_seqlen_qo,
const bool fast_mode,
const int32_t topk);
const int32_t topk,
const int32_t max_split_per_batch);

std::vector<torch::Tensor>
get_mla_metadata_v1_no_redundant(const torch::Tensor& seqlens_qo_indptr, // [batch size + 1]
Expand Down
41 changes: 21 additions & 20 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1350,26 +1350,27 @@ namespace py = pybind11;
py::arg("stride0"), \
py::arg("stride1"));

#define MLA_METADATA_PYBIND \
m.def("get_mla_metadata_v1", \
&get_mla_metadata_v1, \
"get_mla_metadata_v1", \
py::arg("seqlens_qo_indptr"), \
py::arg("seqlens_kv_indptr"), \
py::arg("num_heads_per_head_k"), \
py::arg("num_heads_k"), \
py::arg("is_causal"), \
py::arg("work_metadata_ptrs"), \
py::arg("work_info_set"), \
py::arg("work_indptr"), \
py::arg("reduce_indptr"), \
py::arg("reduce_final_map"), \
py::arg("reduce_partial_map"), \
py::arg("kv_granularity") = 16, \
py::arg("max_seqlen_qo") = -1, \
py::arg("uni_seqlen_qo") = -1, \
py::arg("fast_mode") = true, \
py::arg("topk") = -1); \
#define MLA_METADATA_PYBIND \
m.def("get_mla_metadata_v1", \
&get_mla_metadata_v1, \
"get_mla_metadata_v1", \
py::arg("seqlens_qo_indptr"), \
py::arg("seqlens_kv_indptr"), \
py::arg("num_heads_per_head_k"), \
py::arg("num_heads_k"), \
py::arg("is_causal"), \
py::arg("work_metadata_ptrs"), \
py::arg("work_info_set"), \
py::arg("work_indptr"), \
py::arg("reduce_indptr"), \
py::arg("reduce_final_map"), \
py::arg("reduce_partial_map"), \
py::arg("kv_granularity") = 16, \
py::arg("max_seqlen_qo") = -1, \
py::arg("uni_seqlen_qo") = -1, \
py::arg("fast_mode") = true, \
py::arg("topk") = -1, \
py::arg("max_split_per_batch") = -1); \
m.def("get_mla_metadata_v1_no_redundant", &get_mla_metadata_v1_no_redundant);

#define MLA_REDUCE_PYBIND \
Expand Down
4 changes: 3 additions & 1 deletion csrc/kernels/mla/metadata.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ void get_mla_metadata_v1(
const int32_t max_seqlen_qo,
const int32_t uni_seqlen_qo,
const bool fast_mode,
const int32_t topk)
const int32_t topk,
const int32_t max_split_per_batch)
{
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(seqlens_kv_indptr));

Expand All @@ -77,6 +78,7 @@ void get_mla_metadata_v1(
max_seqlen_qo,
uni_seqlen_qo,
topk,
max_split_per_batch,
work_metadata_ptrs,
work_info_set,
work_indptr,
Expand Down
6 changes: 5 additions & 1 deletion csrc/kernels/mla/metadata/v1_2_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__

// expected payload handled by each cu part.
const int32_t payload =
ck_tile::integer_divide_ceil(sum_blocks, params.num_cu) + Traits::kFixedOverheadNumBlocks;
ck_tile::integer_divide_ceil(sum_blocks, params.num_splits) + Traits::kFixedOverheadNumBlocks;

int32_t curr_batch = 0; // batch ID of the batch which is under review
int32_t curr_kv_block = 0; // #blocks handled by previous cu part(s)
Expand Down Expand Up @@ -358,6 +358,7 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
const int32_t max_seqlen_qo,
const int32_t ori_uni_seqlen_qo,
const int32_t topk,
const int32_t max_split_per_batch,
torch::Tensor& work_metadata_ptrs,
torch::Tensor& work_info_set,
torch::Tensor& work_indptr,
Expand Down Expand Up @@ -404,6 +405,8 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
": only supports #heads in [16, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where "
"N is in [2, 8).")

int32_t num_splits = max_split_per_batch < 0 ? num_clusters : min(num_clusters, max_split_per_batch * num_batches);

MlaMetadataV1KernelParameter params = {};
params.p_work_metadata_ptrs = work_metadata_ptrs.data_ptr<uint64_t>();
params.p_work_indptr = work_indptr.data_ptr<int32_t>();
Expand All @@ -416,6 +419,7 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
params.num_batches = num_batches;
params.num_heads = num_heads_k * num_heads_per_head_k;
params.num_cu = num_clusters;
params.num_splits = num_splits;
params.reduce_indptr_size = reduce_indptr.size(0);
params.kv_granularity = kv_granularity;
params.kv_granularity_log2 = __builtin_ctz(kv_granularity);
Expand Down
1 change: 1 addition & 0 deletions csrc/kernels/mla/metadata/v1_comm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ struct MlaMetadataV1KernelParameter
int32_t ori_seqlen_qo;
int32_t topk;
int32_t qk_batch_ratio;
int32_t num_splits;
bool is_causal;
};

Expand Down
13 changes: 6 additions & 7 deletions csrc/py_itfs_cu/asm_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ struct __attribute__((packed)) KernelArgs
p2 _p21;
void* ptr_KVSCALE;
p2 _p22;
unsigned int out_16_nosplit;
p3 _p23;
};

void mla_decode_stage1_asm_fwd(
Expand Down Expand Up @@ -98,6 +100,7 @@ void mla_decode_stage1_asm_fwd(
args.s_Q_Bs = stride_Q;
args.s_Bs = stride_Page;
args.s_log2_plen = log2_page;
args.out_16_nosplit = kv_split;

if (persistent)
{
Expand Down Expand Up @@ -126,8 +129,8 @@ void mla_decode_stage1_asm_fwd(
{
args.ptr_STP = num_kv_splits_indptr.value().data_ptr();
}
args.ptr_RP = output.data_ptr();

args.ptr_RP = output.data_ptr(); //final output

// std::cout << "mla args" << std::endl;
// std::cout << "ptr_R: " << args.ptr_R << std::endl;
Expand All @@ -146,6 +149,7 @@ void mla_decode_stage1_asm_fwd(
// std::cout << "ptr_RP: " << args.ptr_RP << std::endl;
// std::cout << "ptr_QTP: " << args.ptr_QTP << std::endl;
// std::cout << "ptr_STP: " << args.ptr_STP << std::endl;
// std::cout << "out_16_nosplit: " << args.out_16_nosplit << std::endl;

const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Q));
const hipStream_t stream = at::hip::getCurrentHIPStream();
Expand Down Expand Up @@ -234,11 +238,6 @@ void mla_decode_stage1_asm_fwd(
args.ptr_QSCALE = q_scale.value().data_ptr();
args.ptr_KVSCALE = kv_scale.value().data_ptr();

if(!persistent && kv_split == 1)
{
args.ptr_R = output.data_ptr();
}

if(gqa_ratio == 16)
{
if(persistent)
Expand Down
Binary file modified hsa/gfx950/mla/mla_a8w8_qh16_qseqlen4_gqaratio16.co
100755 → 100644
Binary file not shown.
Loading