diff --git a/.gitignore b/.gitignore index 80ef837fd..b606d2e41 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ # aiter version aiter/_version.py +# aiter install mode +aiter/install_mode + # Prerequisites *.d diff --git a/aiter/mla.py b/aiter/mla.py index 1beafe499..1930fe978 100644 --- a/aiter/mla.py +++ b/aiter/mla.py @@ -3,14 +3,15 @@ # user interface +import functools + import torch -import aiter -from aiter import dtypes import triton import triton.language as tl -import functools + +import aiter +from aiter import dtypes from aiter.jit.utils.chip_info import get_cu_num -from aiter.ops.triton.utils.types import get_fp8_e4m3_dtype @triton.jit @@ -21,11 +22,11 @@ def _fwd_kernel_stage2_asm( qo_indptr, kv_indptr, num_kv_splits_indptr, - stride_mid_ob, - stride_mid_oh, - stride_mid_os, - stride_obs, - stride_oh, + stride_mid_ob: tl.int64, + stride_mid_oh: tl.int64, + stride_mid_os: tl.int64, + stride_obs: tl.int64, + stride_oh: tl.int64, MAYBE_FINAL_OUT: tl.constexpr, BATCH_NUM: tl.constexpr, BLOCK_DV: tl.constexpr, @@ -96,7 +97,7 @@ def _fwd_kernel_stage2_asm( ) -@functools.lru_cache() +@functools.lru_cache(maxsize=1) def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype): if num_kv_splits is None: cu_num = get_cu_num() @@ -128,7 +129,7 @@ def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype): 512: 32, } - if dtype == get_fp8_e4m3_dtype(): + if dtype == dtypes.fp8: min_block_n = get_block_n_fp8[int(nhead * max_seqlen_q)] num_kv_splits = min( num_kv_splits, int(total_kv / bs + min_block_n - 1) // min_block_n @@ -138,7 +139,12 @@ def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype): mgc = get_mgc[nhead] if max_seqlen_q == 1 and nhead == 16: mgc = 64 - return num_kv_splits, mgc + + num_kv_splits_indptr = torch.arange( + 0, (bs + 1) * num_kv_splits, num_kv_splits, dtype=torch.int, device="cuda" + ) + + return num_kv_splits, mgc, num_kv_splits_indptr def mla_decode_fwd( @@ -176,30 +182,34 @@ def mla_decode_fwd( persistent_mode = work_meta_data is not None - if num_kv_splits_indptr is None and not persistent_mode: - num_kv_splits, mgc = get_meta_param( - None, bs, total_kv, nhead, max_seqlen_q, q.dtype - ) - num_kv_splits_indptr = torch.arange( - 0, (bs + 1) * num_kv_splits, num_kv_splits, dtype=torch.int, device=device - ) - - if num_kv_splits is None: - num_kv_splits = get_cu_num() - io_transformed = False if not persistent_mode: + num_kv_splits, mgc, num_kv_splits_indptr = get_meta_param( + num_kv_splits, bs, total_kv, nhead, max_seqlen_q, q.dtype + ) + MAYBE_FINAL_OUT = True if nhead == 16 and max_seqlen_q == 1: MAYBE_FINAL_OUT = False - logits = torch.empty( - (total_s, num_kv_splits, nhead, v_head_dim), - dtype=dtypes.fp32, - device=device, + logits = ( + o.view((total_s, num_kv_splits, nhead, v_head_dim)) + if ( + num_kv_splits == 1 + and ( + q.dtype == dtypes.fp8 + or (q.dtype == dtypes.bf16 and max_seqlen_q == 4) + ) + ) + else torch.empty( + (total_s, num_kv_splits, nhead, v_head_dim), + dtype=dtypes.fp32, + device=device, + ) ) + attn_lse = torch.empty( (total_s, num_kv_splits, nhead, 1), dtype=dtypes.fp32, device=device ) @@ -225,7 +235,9 @@ def mla_decode_fwd( kv_scale, ) - if num_kv_splits == 1 and q.dtype != torch.bfloat16: + if num_kv_splits == 1 and ( + q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4) + ): return logits.view(total_s, nhead, v_head_dim), attn_lse Lv = v_head_dim @@ -255,7 +267,9 @@ def mla_decode_fwd( **extra_kargs, ) else: - if nhead == 16 or (nhead == 128 and kv_buffer.dtype == get_fp8_e4m3_dtype()): + if num_kv_splits is None: + num_kv_splits = get_cu_num() + if nhead == 16 or (nhead == 128 and kv_buffer.dtype == dtypes.fp8): # Natively support cases pass elif nhead in range(32, 512 + 1, 16) and persistent_mode and max_seqlen_q == 1: diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 03ea08464..291d5de86 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -404,6 +404,7 @@ def get_mla_metadata_v1( uni_seqlen_qo: int = -1, fast_mode: bool = True, topk: int = -1, + max_split_per_batch: int = -1, ) -> None: """ Inputs: diff --git a/csrc/include/mla.h b/csrc/include/mla.h index 249ab8b25..a234e1727 100644 --- a/csrc/include/mla.h +++ b/csrc/include/mla.h @@ -50,7 +50,8 @@ void get_mla_metadata_v1(const torch::Tensor& seqlens_qo_indptr, // [batch size const int32_t max_seqlen_qo, const int32_t uni_seqlen_qo, const bool fast_mode, - const int32_t topk); + const int32_t topk, + const int32_t max_split_per_batch); std::vector get_mla_metadata_v1_no_redundant(const torch::Tensor& seqlens_qo_indptr, // [batch size + 1] diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 6aca38674..b92d1268b 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1350,26 +1350,27 @@ namespace py = pybind11; py::arg("stride0"), \ py::arg("stride1")); -#define MLA_METADATA_PYBIND \ - m.def("get_mla_metadata_v1", \ - &get_mla_metadata_v1, \ - "get_mla_metadata_v1", \ - py::arg("seqlens_qo_indptr"), \ - py::arg("seqlens_kv_indptr"), \ - py::arg("num_heads_per_head_k"), \ - py::arg("num_heads_k"), \ - py::arg("is_causal"), \ - py::arg("work_metadata_ptrs"), \ - py::arg("work_info_set"), \ - py::arg("work_indptr"), \ - py::arg("reduce_indptr"), \ - py::arg("reduce_final_map"), \ - py::arg("reduce_partial_map"), \ - py::arg("kv_granularity") = 16, \ - py::arg("max_seqlen_qo") = -1, \ - py::arg("uni_seqlen_qo") = -1, \ - py::arg("fast_mode") = true, \ - py::arg("topk") = -1); \ +#define MLA_METADATA_PYBIND \ + m.def("get_mla_metadata_v1", \ + &get_mla_metadata_v1, \ + "get_mla_metadata_v1", \ + py::arg("seqlens_qo_indptr"), \ + py::arg("seqlens_kv_indptr"), \ + py::arg("num_heads_per_head_k"), \ + py::arg("num_heads_k"), \ + py::arg("is_causal"), \ + py::arg("work_metadata_ptrs"), \ + py::arg("work_info_set"), \ + py::arg("work_indptr"), \ + py::arg("reduce_indptr"), \ + py::arg("reduce_final_map"), \ + py::arg("reduce_partial_map"), \ + py::arg("kv_granularity") = 16, \ + py::arg("max_seqlen_qo") = -1, \ + py::arg("uni_seqlen_qo") = -1, \ + py::arg("fast_mode") = true, \ + py::arg("topk") = -1, \ + py::arg("max_split_per_batch") = -1); \ m.def("get_mla_metadata_v1_no_redundant", &get_mla_metadata_v1_no_redundant); #define MLA_REDUCE_PYBIND \ diff --git a/csrc/kernels/mla/metadata.cu b/csrc/kernels/mla/metadata.cu index 2e56d4604..8d3078e7b 100644 --- a/csrc/kernels/mla/metadata.cu +++ b/csrc/kernels/mla/metadata.cu @@ -50,7 +50,8 @@ void get_mla_metadata_v1( const int32_t max_seqlen_qo, const int32_t uni_seqlen_qo, const bool fast_mode, - const int32_t topk) + const int32_t topk, + const int32_t max_split_per_batch) { const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(seqlens_kv_indptr)); @@ -77,6 +78,7 @@ void get_mla_metadata_v1( max_seqlen_qo, uni_seqlen_qo, topk, + max_split_per_batch, work_metadata_ptrs, work_info_set, work_indptr, diff --git a/csrc/kernels/mla/metadata/v1_2_device.cuh b/csrc/kernels/mla/metadata/v1_2_device.cuh index 9340d0541..80adf485e 100644 --- a/csrc/kernels/mla/metadata/v1_2_device.cuh +++ b/csrc/kernels/mla/metadata/v1_2_device.cuh @@ -84,7 +84,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ // expected payload handled by each cu part. const int32_t payload = - ck_tile::integer_divide_ceil(sum_blocks, params.num_cu) + Traits::kFixedOverheadNumBlocks; + ck_tile::integer_divide_ceil(sum_blocks, params.num_splits) + Traits::kFixedOverheadNumBlocks; int32_t curr_batch = 0; // batch ID of the batch which is under review int32_t curr_kv_block = 0; // #blocks handled by previous cu part(s) @@ -358,6 +358,7 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba const int32_t max_seqlen_qo, const int32_t ori_uni_seqlen_qo, const int32_t topk, + const int32_t max_split_per_batch, torch::Tensor& work_metadata_ptrs, torch::Tensor& work_info_set, torch::Tensor& work_indptr, @@ -404,6 +405,8 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba ": only supports #heads in [16, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where " "N is in [2, 8).") + int32_t num_splits = max_split_per_batch < 0 ? num_clusters : min(num_clusters, max_split_per_batch * num_batches); + MlaMetadataV1KernelParameter params = {}; params.p_work_metadata_ptrs = work_metadata_ptrs.data_ptr(); params.p_work_indptr = work_indptr.data_ptr(); @@ -416,6 +419,7 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba params.num_batches = num_batches; params.num_heads = num_heads_k * num_heads_per_head_k; params.num_cu = num_clusters; + params.num_splits = num_splits; params.reduce_indptr_size = reduce_indptr.size(0); params.kv_granularity = kv_granularity; params.kv_granularity_log2 = __builtin_ctz(kv_granularity); diff --git a/csrc/kernels/mla/metadata/v1_comm.cuh b/csrc/kernels/mla/metadata/v1_comm.cuh index e166d4a36..43b5c3975 100644 --- a/csrc/kernels/mla/metadata/v1_comm.cuh +++ b/csrc/kernels/mla/metadata/v1_comm.cuh @@ -65,6 +65,7 @@ struct MlaMetadataV1KernelParameter int32_t ori_seqlen_qo; int32_t topk; int32_t qk_batch_ratio; + int32_t num_splits; bool is_causal; }; diff --git a/csrc/py_itfs_cu/asm_mla.cu b/csrc/py_itfs_cu/asm_mla.cu index 4bbb3ec45..192137659 100644 --- a/csrc/py_itfs_cu/asm_mla.cu +++ b/csrc/py_itfs_cu/asm_mla.cu @@ -45,6 +45,8 @@ struct __attribute__((packed)) KernelArgs p2 _p21; void* ptr_KVSCALE; p2 _p22; + unsigned int out_16_nosplit; + p3 _p23; }; void mla_decode_stage1_asm_fwd( @@ -98,6 +100,7 @@ void mla_decode_stage1_asm_fwd( args.s_Q_Bs = stride_Q; args.s_Bs = stride_Page; args.s_log2_plen = log2_page; + args.out_16_nosplit = kv_split; if (persistent) { @@ -126,8 +129,8 @@ void mla_decode_stage1_asm_fwd( { args.ptr_STP = num_kv_splits_indptr.value().data_ptr(); } - args.ptr_RP = output.data_ptr(); - + args.ptr_RP = output.data_ptr(); //final output + // std::cout << "mla args" << std::endl; // std::cout << "ptr_R: " << args.ptr_R << std::endl; @@ -146,6 +149,7 @@ void mla_decode_stage1_asm_fwd( // std::cout << "ptr_RP: " << args.ptr_RP << std::endl; // std::cout << "ptr_QTP: " << args.ptr_QTP << std::endl; // std::cout << "ptr_STP: " << args.ptr_STP << std::endl; + // std::cout << "out_16_nosplit: " << args.out_16_nosplit << std::endl; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Q)); const hipStream_t stream = at::hip::getCurrentHIPStream(); @@ -234,11 +238,6 @@ void mla_decode_stage1_asm_fwd( args.ptr_QSCALE = q_scale.value().data_ptr(); args.ptr_KVSCALE = kv_scale.value().data_ptr(); - if(!persistent && kv_split == 1) - { - args.ptr_R = output.data_ptr(); - } - if(gqa_ratio == 16) { if(persistent) diff --git a/hsa/gfx950/mla/mla_a8w8_qh16_qseqlen4_gqaratio16.co b/hsa/gfx950/mla/mla_a8w8_qh16_qseqlen4_gqaratio16.co old mode 100755 new mode 100644 index 49aeaece3..c62cb3882 Binary files a/hsa/gfx950/mla/mla_a8w8_qh16_qseqlen4_gqaratio16.co and b/hsa/gfx950/mla/mla_a8w8_qh16_qseqlen4_gqaratio16.co differ diff --git a/op_tests/test_mla.py b/op_tests/test_mla.py index d5a639266..efe8b47f7 100644 --- a/op_tests/test_mla.py +++ b/op_tests/test_mla.py @@ -1,14 +1,15 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import argparse +import itertools +import random + import torch + import aiter -from aiter.test_common import checkAllclose, benchmark, run_perftest from aiter import dtypes -import random -import itertools -import argparse -from aiter.ops.triton.utils.types import get_fp8_e4m3_dtype +from aiter.test_common import benchmark, checkAllclose, run_perftest torch.set_default_device("cuda") torch.set_printoptions(sci_mode=False) @@ -124,6 +125,7 @@ def test_mla( page_size, varlen, decode_qlen, + split_per_batch=None, ): ret = {} @@ -377,6 +379,7 @@ def test_absorb_decode_bf16(): kv_last_page_lens, max_seqlen_qo, sm_scale, + num_kv_splits=split_per_batch, ) # print(f"{out_ref.view(total_q, -1)=}") @@ -423,6 +426,7 @@ def test_absorb_decode_fp8(): sm_scale, q_scale=q_scale, kv_scale=kv_scale, + num_kv_splits=split_per_batch, ) # print(f"{out_ref.view(total_q, -1)=}") @@ -566,6 +570,15 @@ def test_absorb_decode_fp8(): help="""Number of nhead and decode_qlen. e.g.: -n 16,1""", ) +parser.add_argument( + "-splits", + "--split_per_batch", + type=int, + nargs="*", + default=[None], + help="""kv seqlens split num for per batch. + e.g.: -ms 32""", +) parser.add_argument( "--varlen", action="store_true", @@ -583,8 +596,8 @@ def test_absorb_decode_fp8(): for nhead, decode_qlen in list_nhead: df = [] - for dtype, kvtype, ctx_len, batch_size in itertools.product( - list_dtype, l_kv_dtype, args.ctxLen, args.batchSize + for dtype, kvtype, ctx_len, batch_size, split_per_batch in itertools.product( + list_dtype, l_kv_dtype, args.ctxLen, args.batchSize, args.split_per_batch ): ret = test_mla( ctx_len, @@ -599,6 +612,7 @@ def test_absorb_decode_fp8(): args.block_size, varlen=args.varlen, decode_qlen=decode_qlen, + split_per_batch=split_per_batch, ) df.append(ret) df = pd.DataFrame(df) diff --git a/op_tests/test_mla_persistent.py b/op_tests/test_mla_persistent.py index 548ef811d..aab94ec57 100644 --- a/op_tests/test_mla_persistent.py +++ b/op_tests/test_mla_persistent.py @@ -149,7 +149,8 @@ def test_mla( kvtype, page_size, varlen, - mtp, + decode_qlen, + max_split_per_batch, ): ret = {} @@ -199,9 +200,9 @@ def test_mla( # ############################## absorb: decode # seq_lens_qo = torch.randint(1, 5, (batch_size,), dtype=torch.int) - # if nhead == 16 and mtp != 1: + # if nhead == 16 and decode_qlen != 1: # return - seq_lens_qo.fill_(mtp) + seq_lens_qo.fill_(decode_qlen) max_seqlen_qo = seq_lens_qo.max().item() qo_indptr[1 : batch_size + 1] = torch.cumsum(seq_lens_qo, dim=0) @@ -274,8 +275,9 @@ def test_mla( reduce_partial_map, kv_granularity=max(page_size, 16), max_seqlen_qo=int(max_seqlen_qo), - uni_seqlen_qo=mtp, + uni_seqlen_qo=decode_qlen, fast_mode=True, + max_split_per_batch=max_split_per_batch, ) def test_absorb_decode_bf16(): @@ -293,6 +295,7 @@ def test_absorb_decode_bf16(): kv_last_page_lens, max_seqlen_qo, sm_scale, + num_kv_splits=max_split_per_batch, work_meta_data=work_meta_data, work_indptr=work_indptr, work_info_set=work_info_set, @@ -353,6 +356,7 @@ def test_absorb_decode_fp8(): kv_last_page_lens, max_seqlen_qo, sm_scale, + num_kv_splits=max_split_per_batch, q_scale=q_scale, kv_scale=kv_scale, work_meta_data=work_meta_data, @@ -385,7 +389,7 @@ def test_absorb_decode_fp8(): err = None us_asm_decode = 1e12 if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and ( - nhead == 16 or (nhead in range(32, 128, 16) and mtp == 1) + nhead == 16 or (nhead in range(32, 128, 16) and decode_qlen == 1) ): err, us_asm_decode = test_absorb_decode_bf16() elif kvtype == dtypes.fp8 and nhead in [16, 128]: @@ -393,7 +397,7 @@ def test_absorb_decode_fp8(): ret["decode:err"] = err ret["decode:asm_576"] = us_asm_decode - flops = mtp * total_kv * nhead * (qk_head_dim + v_head_dim) * 2 + flops = decode_qlen * total_kv * nhead * (qk_head_dim + v_head_dim) * 2 bytes = ( total_kv * nhead_kv * qk_head_dim * (torch.finfo(kvtype).bits // 8) + total_q * nhead * qk_head_dim * (torch.finfo(dtype).bits // 8) @@ -509,6 +513,15 @@ def test_absorb_decode_fp8(): help="""Number of heads. e.g.: -n 16,1""", ) +parser.add_argument( + "-ms", + "--max_split_per_batch", + type=int, + nargs="*", + default=[16, 32], + help="""kv seqlens max split num for per batch. + e.g.: -ms 32""", +) parser.add_argument( "--varlen", action="store_true", @@ -524,10 +537,10 @@ def test_absorb_decode_fp8(): if args.nhead is not None: list_nhead = [args.nhead] -for nhead, mtp in list_nhead: +for nhead, decode_qlen in list_nhead: df = [] - for dtype, kvtype, ctx_len, batch_size in itertools.product( - list_dtype, l_kv_dtype, args.ctxLen, args.batchSize + for dtype, kvtype, ctx_len, batch_size, max_split_per_batch in itertools.product( + list_dtype, l_kv_dtype, args.ctxLen, args.batchSize, args.max_split_per_batch ): ret = test_mla( ctx_len, @@ -541,9 +554,10 @@ def test_absorb_decode_fp8(): kvtype, args.block_size, varlen=args.varlen, - mtp=mtp, + decode_qlen=decode_qlen, + max_split_per_batch=max_split_per_batch, ) df.append(ret) df = pd.DataFrame(df) - # df.to_csv(f"mla_nhead{nhead}mtp{mtp}.csv") + # df.to_csv(f"mla_nhead{nhead}decode_qlen{decode_qlen}.csv") aiter.logger.info(f"summary:\n{df}")