33
44# user interface
55
6+ import functools
7+
68import torch
7- import aiter
8- from aiter import dtypes
99import triton
1010import triton .language as tl
11- import functools
11+
12+ import aiter
13+ from aiter import dtypes
1214from aiter .jit .utils .chip_info import get_cu_num
13- from aiter .ops .triton .utils .types import get_fp8_e4m3_dtype
1415
1516
1617@triton .jit
@@ -21,11 +22,11 @@ def _fwd_kernel_stage2_asm(
2122 qo_indptr ,
2223 kv_indptr ,
2324 num_kv_splits_indptr ,
24- stride_mid_ob ,
25- stride_mid_oh ,
26- stride_mid_os ,
27- stride_obs ,
28- stride_oh ,
25+ stride_mid_ob : tl . int64 ,
26+ stride_mid_oh : tl . int64 ,
27+ stride_mid_os : tl . int64 ,
28+ stride_obs : tl . int64 ,
29+ stride_oh : tl . int64 ,
2930 MAYBE_FINAL_OUT : tl .constexpr ,
3031 BATCH_NUM : tl .constexpr ,
3132 BLOCK_DV : tl .constexpr ,
@@ -96,7 +97,7 @@ def _fwd_kernel_stage2_asm(
9697 )
9798
9899
99- @functools .lru_cache ()
100+ @functools .lru_cache (maxsize = 1 )
100101def get_meta_param (num_kv_splits , bs , total_kv , nhead , max_seqlen_q , dtype ):
101102 if num_kv_splits is None :
102103 cu_num = get_cu_num ()
@@ -128,7 +129,7 @@ def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype):
128129 512 : 32 ,
129130 }
130131
131- if dtype == get_fp8_e4m3_dtype () :
132+ if dtype == dtypes . fp8 :
132133 min_block_n = get_block_n_fp8 [int (nhead * max_seqlen_q )]
133134 num_kv_splits = min (
134135 num_kv_splits , int (total_kv / bs + min_block_n - 1 ) // min_block_n
@@ -138,7 +139,12 @@ def get_meta_param(num_kv_splits, bs, total_kv, nhead, max_seqlen_q, dtype):
138139 mgc = get_mgc [nhead ]
139140 if max_seqlen_q == 1 and nhead == 16 :
140141 mgc = 64
141- return num_kv_splits , mgc
142+
143+ num_kv_splits_indptr = torch .arange (
144+ 0 , (bs + 1 ) * num_kv_splits , num_kv_splits , dtype = torch .int , device = "cuda"
145+ )
146+
147+ return num_kv_splits , mgc , num_kv_splits_indptr
142148
143149
144150def mla_decode_fwd (
@@ -176,30 +182,34 @@ def mla_decode_fwd(
176182
177183 persistent_mode = work_meta_data is not None
178184
179- if num_kv_splits_indptr is None and not persistent_mode :
180- num_kv_splits , mgc = get_meta_param (
181- None , bs , total_kv , nhead , max_seqlen_q , q .dtype
182- )
183- num_kv_splits_indptr = torch .arange (
184- 0 , (bs + 1 ) * num_kv_splits , num_kv_splits , dtype = torch .int , device = device
185- )
186-
187- if num_kv_splits is None :
188- num_kv_splits = get_cu_num ()
189-
190185 io_transformed = False
191186
192187 if not persistent_mode :
188+ num_kv_splits , mgc , num_kv_splits_indptr = get_meta_param (
189+ num_kv_splits , bs , total_kv , nhead , max_seqlen_q , q .dtype
190+ )
191+
193192 MAYBE_FINAL_OUT = True
194193
195194 if nhead == 16 and max_seqlen_q == 1 :
196195 MAYBE_FINAL_OUT = False
197196
198- logits = torch .empty (
199- (total_s , num_kv_splits , nhead , v_head_dim ),
200- dtype = dtypes .fp32 ,
201- device = device ,
197+ logits = (
198+ o .view ((total_s , num_kv_splits , nhead , v_head_dim ))
199+ if (
200+ num_kv_splits == 1
201+ and (
202+ q .dtype == dtypes .fp8
203+ or (q .dtype == dtypes .bf16 and max_seqlen_q == 4 )
204+ )
205+ )
206+ else torch .empty (
207+ (total_s , num_kv_splits , nhead , v_head_dim ),
208+ dtype = dtypes .fp32 ,
209+ device = device ,
210+ )
202211 )
212+
203213 attn_lse = torch .empty (
204214 (total_s , num_kv_splits , nhead , 1 ), dtype = dtypes .fp32 , device = device
205215 )
@@ -225,7 +235,9 @@ def mla_decode_fwd(
225235 kv_scale ,
226236 )
227237
228- if num_kv_splits == 1 and q .dtype != torch .bfloat16 :
238+ if num_kv_splits == 1 and (
239+ q .dtype == dtypes .fp8 or (q .dtype == dtypes .bf16 and max_seqlen_q == 4 )
240+ ):
229241 return logits .view (total_s , nhead , v_head_dim ), attn_lse
230242
231243 Lv = v_head_dim
@@ -255,7 +267,9 @@ def mla_decode_fwd(
255267 ** extra_kargs ,
256268 )
257269 else :
258- if nhead == 16 or (nhead == 128 and kv_buffer .dtype == get_fp8_e4m3_dtype ()):
270+ if num_kv_splits is None :
271+ num_kv_splits = get_cu_num ()
272+ if nhead == 16 or (nhead == 128 and kv_buffer .dtype == dtypes .fp8 ):
259273 # Natively support cases
260274 pass
261275 elif nhead in range (32 , 512 + 1 , 16 ) and persistent_mode and max_seqlen_q == 1 :
0 commit comments