Skip to content

Commit 288c82f

Browse files
Zzz9990valarLipfangche123
authored
max mla splits perbatch (#1390)
* fix issues * add limit for split num per batch * fix non-ps num kv split * fix issue for big batch size * fix logits alloc * fix black code stype * fix ut * update git ignore& remove aiter/install_mode * update qh16 qseqlen4 kernel * update --------- Co-authored-by: valarLip <[email protected]> Co-authored-by: Fang.Che <[email protected]> Co-authored-by: Lingpeng Jin <[email protected]>
1 parent 934d087 commit 288c82f

File tree

12 files changed

+131
-77
lines changed

12 files changed

+131
-77
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# aiter version
22
aiter/_version.py
33

4+
# aiter install mode
5+
aiter/install_mode
6+
47
# Prerequisites
58
*.d
69

aiter/mla.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33

44
# user interface
55

6+
import functools
7+
68
import torch
7-
import aiter
8-
from aiter import dtypes
99
import triton
1010
import triton.language as tl
11-
import functools
11+
12+
import aiter
13+
from aiter import dtypes
1214
from aiter.jit.utils.chip_info import get_cu_num
13-
from aiter.ops.triton.utils.types import get_fp8_e4m3_dtype
1415

1516

1617
@triton.jit
@@ -21,11 +22,11 @@ def _fwd_kernel_stage2_asm(
2122
qo_indptr,
2223
kv_indptr,
2324
num_kv_splits_indptr,
24-
stride_mid_ob,
25-
stride_mid_oh,
26-
stride_mid_os,
27-
stride_obs,
28-
stride_oh,
25+
stride_mid_ob: tl.int64,
26+
stride_mid_oh: tl.int64,
27+
stride_mid_os: tl.int64,
28+
stride_obs: tl.int64,
29+
stride_oh: tl.int64,
2930
MAYBE_FINAL_OUT: tl.constexpr,
3031
BATCH_NUM: tl.constexpr,
3132
BLOCK_DV: tl.constexpr,
@@ -96,7 +97,7 @@ def _fwd_kernel_stage2_asm(
9697
)
9798

9899

99-
@functools.lru_cache()
100+
@functools.lru_cache(maxsize=1)
100101
def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype):
101102
if num_kv_splits is None:
102103
cu_num = get_cu_num()
@@ -128,7 +129,7 @@ def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype):
128129
512: 32,
129130
}
130131

131-
if dtype == get_fp8_e4m3_dtype():
132+
if dtype == dtypes.fp8:
132133
min_block_n = get_block_n_fp8[int(nhead * max_seqlen_q)]
133134
num_kv_splits = min(
134135
num_kv_splits, int(total_kv / bs + min_block_n - 1) // min_block_n
@@ -138,7 +139,12 @@ def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype):
138139
mgc = get_mgc[nhead]
139140
if max_seqlen_q == 1 and nhead == 16:
140141
mgc = 64
141-
return num_kv_splits, mgc
142+
143+
num_kv_splits_indptr = torch.arange(
144+
0, (bs + 1) * num_kv_splits, num_kv_splits, dtype=torch.int, device="cuda"
145+
)
146+
147+
return num_kv_splits, mgc, num_kv_splits_indptr
142148

143149

144150
def mla_decode_fwd(
@@ -176,30 +182,34 @@ def mla_decode_fwd(
176182

177183
persistent_mode = work_meta_data is not None
178184

179-
if num_kv_splits_indptr is None and not persistent_mode:
180-
num_kv_splits, mgc = get_meta_param(
181-
None, bs, total_kv, nhead, max_seqlen_q, q.dtype
182-
)
183-
num_kv_splits_indptr = torch.arange(
184-
0, (bs + 1) * num_kv_splits, num_kv_splits, dtype=torch.int, device=device
185-
)
186-
187-
if num_kv_splits is None:
188-
num_kv_splits = get_cu_num()
189-
190185
io_transformed = False
191186

192187
if not persistent_mode:
188+
num_kv_splits, mgc, num_kv_splits_indptr = get_meta_param(
189+
num_kv_splits, bs, total_kv, nhead, max_seqlen_q, q.dtype
190+
)
191+
193192
MAYBE_FINAL_OUT = True
194193

195194
if nhead == 16 and max_seqlen_q == 1:
196195
MAYBE_FINAL_OUT = False
197196

198-
logits = torch.empty(
199-
(total_s, num_kv_splits, nhead, v_head_dim),
200-
dtype=dtypes.fp32,
201-
device=device,
197+
logits = (
198+
o.view((total_s, num_kv_splits, nhead, v_head_dim))
199+
if (
200+
num_kv_splits == 1
201+
and (
202+
q.dtype == dtypes.fp8
203+
or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
204+
)
205+
)
206+
else torch.empty(
207+
(total_s, num_kv_splits, nhead, v_head_dim),
208+
dtype=dtypes.fp32,
209+
device=device,
210+
)
202211
)
212+
203213
attn_lse = torch.empty(
204214
(total_s, num_kv_splits, nhead, 1), dtype=dtypes.fp32, device=device
205215
)
@@ -225,7 +235,9 @@ def mla_decode_fwd(
225235
kv_scale,
226236
)
227237

228-
if num_kv_splits == 1 and q.dtype != torch.bfloat16:
238+
if num_kv_splits == 1 and (
239+
q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
240+
):
229241
return logits.view(total_s, nhead, v_head_dim), attn_lse
230242

231243
Lv = v_head_dim
@@ -255,7 +267,9 @@ def mla_decode_fwd(
255267
**extra_kargs,
256268
)
257269
else:
258-
if nhead == 16 or (nhead == 128 and kv_buffer.dtype == get_fp8_e4m3_dtype()):
270+
if num_kv_splits is None:
271+
num_kv_splits = get_cu_num()
272+
if nhead == 16 or (nhead == 128 and kv_buffer.dtype == dtypes.fp8):
259273
# Natively support cases
260274
pass
261275
elif nhead in range(32, 512 + 1, 16) and persistent_mode and max_seqlen_q == 1:

aiter/ops/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ def get_mla_metadata_v1(
406406
uni_seqlen_qo: int = -1,
407407
fast_mode: bool = True,
408408
topk: int = -1,
409+
max_split_per_batch: int = -1,
409410
) -> None:
410411
"""
411412
Inputs:

csrc/include/mla.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ void get_mla_metadata_v1(const torch::Tensor& seqlens_qo_indptr, // [batch size
5050
const int32_t max_seqlen_qo,
5151
const int32_t uni_seqlen_qo,
5252
const bool fast_mode,
53-
const int32_t topk);
53+
const int32_t topk,
54+
const int32_t max_split_per_batch);
5455

5556
std::vector<torch::Tensor>
5657
get_mla_metadata_v1_no_redundant(const torch::Tensor& seqlens_qo_indptr, // [batch size + 1]

csrc/include/rocm_ops.hpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,26 +1350,27 @@ namespace py = pybind11;
13501350
py::arg("stride0"), \
13511351
py::arg("stride1"));
13521352

1353-
#define MLA_METADATA_PYBIND \
1354-
m.def("get_mla_metadata_v1", \
1355-
&get_mla_metadata_v1, \
1356-
"get_mla_metadata_v1", \
1357-
py::arg("seqlens_qo_indptr"), \
1358-
py::arg("seqlens_kv_indptr"), \
1359-
py::arg("num_heads_per_head_k"), \
1360-
py::arg("num_heads_k"), \
1361-
py::arg("is_causal"), \
1362-
py::arg("work_metadata_ptrs"), \
1363-
py::arg("work_info_set"), \
1364-
py::arg("work_indptr"), \
1365-
py::arg("reduce_indptr"), \
1366-
py::arg("reduce_final_map"), \
1367-
py::arg("reduce_partial_map"), \
1368-
py::arg("kv_granularity") = 16, \
1369-
py::arg("max_seqlen_qo") = -1, \
1370-
py::arg("uni_seqlen_qo") = -1, \
1371-
py::arg("fast_mode") = true, \
1372-
py::arg("topk") = -1); \
1353+
#define MLA_METADATA_PYBIND \
1354+
m.def("get_mla_metadata_v1", \
1355+
&get_mla_metadata_v1, \
1356+
"get_mla_metadata_v1", \
1357+
py::arg("seqlens_qo_indptr"), \
1358+
py::arg("seqlens_kv_indptr"), \
1359+
py::arg("num_heads_per_head_k"), \
1360+
py::arg("num_heads_k"), \
1361+
py::arg("is_causal"), \
1362+
py::arg("work_metadata_ptrs"), \
1363+
py::arg("work_info_set"), \
1364+
py::arg("work_indptr"), \
1365+
py::arg("reduce_indptr"), \
1366+
py::arg("reduce_final_map"), \
1367+
py::arg("reduce_partial_map"), \
1368+
py::arg("kv_granularity") = 16, \
1369+
py::arg("max_seqlen_qo") = -1, \
1370+
py::arg("uni_seqlen_qo") = -1, \
1371+
py::arg("fast_mode") = true, \
1372+
py::arg("topk") = -1, \
1373+
py::arg("max_split_per_batch") = -1); \
13731374
m.def("get_mla_metadata_v1_no_redundant", &get_mla_metadata_v1_no_redundant);
13741375

13751376
#define MLA_REDUCE_PYBIND \

csrc/kernels/mla/metadata.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ void get_mla_metadata_v1(
5050
const int32_t max_seqlen_qo,
5151
const int32_t uni_seqlen_qo,
5252
const bool fast_mode,
53-
const int32_t topk)
53+
const int32_t topk,
54+
const int32_t max_split_per_batch)
5455
{
5556
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(seqlens_kv_indptr));
5657

@@ -77,6 +78,7 @@ void get_mla_metadata_v1(
7778
max_seqlen_qo,
7879
uni_seqlen_qo,
7980
topk,
81+
max_split_per_batch,
8082
work_metadata_ptrs,
8183
work_info_set,
8284
work_indptr,

csrc/kernels/mla/metadata/v1_2_device.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
8484

8585
// expected payload handled by each cu part.
8686
const int32_t payload =
87-
ck_tile::integer_divide_ceil(sum_blocks, params.num_cu) + Traits::kFixedOverheadNumBlocks;
87+
ck_tile::integer_divide_ceil(sum_blocks, params.num_splits) + Traits::kFixedOverheadNumBlocks;
8888

8989
int32_t curr_batch = 0; // batch ID of the batch which is under review
9090
int32_t curr_kv_block = 0; // #blocks handled by previous cu part(s)
@@ -358,6 +358,7 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
358358
const int32_t max_seqlen_qo,
359359
const int32_t ori_uni_seqlen_qo,
360360
const int32_t topk,
361+
const int32_t max_split_per_batch,
361362
torch::Tensor& work_metadata_ptrs,
362363
torch::Tensor& work_info_set,
363364
torch::Tensor& work_indptr,
@@ -404,6 +405,8 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
404405
": only supports #heads in [16, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where "
405406
"N is in [2, 8).")
406407

408+
int32_t num_splits = max_split_per_batch < 0 ? num_clusters : min(num_clusters, max_split_per_batch * num_batches);
409+
407410
MlaMetadataV1KernelParameter params = {};
408411
params.p_work_metadata_ptrs = work_metadata_ptrs.data_ptr<uint64_t>();
409412
params.p_work_indptr = work_indptr.data_ptr<int32_t>();
@@ -416,6 +419,7 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
416419
params.num_batches = num_batches;
417420
params.num_heads = num_heads_k * num_heads_per_head_k;
418421
params.num_cu = num_clusters;
422+
params.num_splits = num_splits;
419423
params.reduce_indptr_size = reduce_indptr.size(0);
420424
params.kv_granularity = kv_granularity;
421425
params.kv_granularity_log2 = __builtin_ctz(kv_granularity);

csrc/kernels/mla/metadata/v1_comm.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ struct MlaMetadataV1KernelParameter
6565
int32_t ori_seqlen_qo;
6666
int32_t topk;
6767
int32_t qk_batch_ratio;
68+
int32_t num_splits;
6869
bool is_causal;
6970
};
7071

csrc/py_itfs_cu/asm_mla.cu

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ struct __attribute__((packed)) KernelArgs
4545
p2 _p21;
4646
void* ptr_KVSCALE;
4747
p2 _p22;
48+
unsigned int out_16_nosplit;
49+
p3 _p23;
4850
};
4951

5052
void mla_decode_stage1_asm_fwd(
@@ -98,6 +100,7 @@ void mla_decode_stage1_asm_fwd(
98100
args.s_Q_Bs = stride_Q;
99101
args.s_Bs = stride_Page;
100102
args.s_log2_plen = log2_page;
103+
args.out_16_nosplit = kv_split;
101104

102105
if (persistent)
103106
{
@@ -126,8 +129,8 @@ void mla_decode_stage1_asm_fwd(
126129
{
127130
args.ptr_STP = num_kv_splits_indptr.value().data_ptr();
128131
}
129-
args.ptr_RP = output.data_ptr();
130-
132+
args.ptr_RP = output.data_ptr(); //final output
133+
131134

132135
// std::cout << "mla args" << std::endl;
133136
// std::cout << "ptr_R: " << args.ptr_R << std::endl;
@@ -146,6 +149,7 @@ void mla_decode_stage1_asm_fwd(
146149
// std::cout << "ptr_RP: " << args.ptr_RP << std::endl;
147150
// std::cout << "ptr_QTP: " << args.ptr_QTP << std::endl;
148151
// std::cout << "ptr_STP: " << args.ptr_STP << std::endl;
152+
// std::cout << "out_16_nosplit: " << args.out_16_nosplit << std::endl;
149153

150154
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Q));
151155
const hipStream_t stream = at::hip::getCurrentHIPStream();
@@ -234,11 +238,6 @@ void mla_decode_stage1_asm_fwd(
234238
args.ptr_QSCALE = q_scale.value().data_ptr();
235239
args.ptr_KVSCALE = kv_scale.value().data_ptr();
236240

237-
if(!persistent && kv_split == 1)
238-
{
239-
args.ptr_R = output.data_ptr();
240-
}
241-
242241
if(gqa_ratio == 16)
243242
{
244243
if(persistent)

hsa/gfx950/mla/mla_a8w8_qh16_qseqlen4_gqaratio16.co

100755100644
4.11 KB
Binary file not shown.

0 commit comments

Comments
 (0)