Skip to content

Commit f1706a1

Browse files
committed
lse support (#7)
Add LSE suport in all attention kernels
1 parent 9b1478f commit f1706a1

6 files changed

Lines changed: 117 additions & 34 deletions

File tree

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class TllmGenFmhaRunnerCache {
7575
void trtllm_paged_attention_launcher(
7676
void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache,
7777
void* workspace_buffer, int* block_tables, int* seq_lens, int* cum_seq_lens_q,
78-
int* cum_seq_lens_kv, float* attention_sinks,
78+
int* cum_seq_lens_kv, float* attention_sinks, float* lse,
7979
void* k_cache_scales, void* v_cache_scales,
8080
Data_type q_data_type, Data_type kv_data_type,
8181
Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len,
@@ -84,8 +84,9 @@ void trtllm_paged_attention_launcher(
8484
int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq,
8585
double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr,
8686
const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index,
87-
int64_t window_left, int64_t sum_seq_q, int64_t sm_count, bool enable_pdl,
88-
int64_t workspace_size, cudaStream_t stream) {
87+
int64_t window_left, int64_t sum_seq_q,
88+
int64_t lse_stride_tokens, int64_t lse_stride_heads,
89+
int64_t sm_count, bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
8990
if (num_qo_heads % num_kv_heads != 0) {
9091
std::ostringstream err_msg;
9192
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
@@ -173,6 +174,12 @@ void trtllm_paged_attention_launcher(
173174
runner_params.multiCtasKvScratchPtr =
174175
float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");
175176
}
177+
runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
178+
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16,
179+
"trtllm_gen_softmax_workspace");
180+
runner_params.lsePtr = lse;
181+
runner_params.lseStrideTokens = lse_stride_tokens;
182+
runner_params.lseStrideHeads = lse_stride_heads;
176183

177184
auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params);
178185
if (!foundKernels) {
@@ -214,7 +221,7 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
214221
int64_t o_sf_vec_size, int64_t o_sf_start_index,
215222
int64_t window_left, int64_t sm_count, bool enable_pdl,
216223
int64_t workspace_size, Optional<TensorView> attention_sinks,
217-
Optional<TensorView> k_cache_scales, Optional<TensorView> v_cache_scales) {
224+
Optional<TensorView> k_cache_scales, Optional<TensorView> v_cache_scales, Optional<TensorView> lse) {
218225
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
219226
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
220227
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
@@ -287,6 +294,16 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
287294
? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
288295
: nullptr;
289296

297+
float* lse_ptr = nullptr;
298+
int lse_stride_tokens = 0;
299+
int lse_stride_heads = 0;
300+
if (lse.has_value()) {
301+
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
302+
lse_ptr = static_cast<float*>(lse.value().data_ptr());
303+
lse_stride_tokens = lse.value().stride(0);
304+
lse_stride_heads = lse.value().stride(2);
305+
}
306+
290307
void* k_cache_scales_ptr = nullptr;
291308
void* v_cache_scales_ptr = nullptr;
292309
if (k_cache_scales.has_value()) {
@@ -301,15 +318,15 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
301318
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
302319
static_cast<int*>(seq_lens.data_ptr()),
303320
/*cum_seq_lens_q=*/nullptr,
304-
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr,
321+
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, lse_ptr,
305322
k_cache_scales_ptr, v_cache_scales_ptr,
306323
q_data_type, kv_data_type, o_data_type,
307324
TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len,
308325
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
309326
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq,
310327
bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,
311-
o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, enable_pdl, workspace_size,
312-
stream);
328+
o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, lse_stride_tokens, lse_stride_heads,
329+
sm_count, enable_pdl, workspace_size, stream);
313330
}
314331

315332
void trtllm_paged_attention_context(
@@ -320,7 +337,7 @@ void trtllm_paged_attention_context(
320337
double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size,
321338
int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count,
322339
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
323-
Optional<TensorView> k_cache_scales, Optional<TensorView> v_cache_scales) {
340+
Optional<TensorView> k_cache_scales, Optional<TensorView> v_cache_scales, Optional<TensorView> lse) {
324341

325342
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
326343
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
@@ -385,6 +402,16 @@ void trtllm_paged_attention_context(
385402
? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
386403
: nullptr;
387404

405+
float* lse_ptr = nullptr;
406+
int lse_stride_tokens = 0;
407+
int lse_stride_heads = 0;
408+
if (lse.has_value()) {
409+
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
410+
lse_ptr = static_cast<float*>(lse.value().data_ptr());
411+
lse_stride_tokens = lse.value().stride(0);
412+
lse_stride_heads = lse.value().stride(1);
413+
}
414+
388415
void* k_cache_scales_ptr = nullptr;
389416
void* v_cache_scales_ptr = nullptr;
390417
if (k_cache_scales.has_value()) {
@@ -399,14 +426,14 @@ void trtllm_paged_attention_context(
399426
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
400427
static_cast<int*>(seq_lens.data_ptr()),
401428
/*cum_seq_lens_q=*/static_cast<int*>(cum_seq_lens_q.data_ptr()),
402-
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr,
429+
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, lse_ptr,
403430
k_cache_scales_ptr, v_cache_scales_ptr,
404431
q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
405432
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
406433
head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
407434
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
408-
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count,
409-
enable_pdl, workspace_size, stream);
435+
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
436+
lse_stride_tokens, lse_stride_heads, sm_count, enable_pdl, workspace_size, stream);
410437
}
411438

412439
void trtllm_ragged_attention_launcher(
@@ -419,6 +446,7 @@ void trtllm_ragged_attention_launcher(
419446
int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal,
420447
int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch,
421448
int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch,
449+
int64_t lse_stride_tokens, int64_t lse_stride_heads,
422450
int64_t workspace_size, cudaStream_t stream) {
423451
if (num_qo_heads % num_kv_heads != 0) {
424452
std::ostringstream err_msg;
@@ -475,6 +503,8 @@ void trtllm_ragged_attention_launcher(
475503
runner_params.mMaskType =
476504
is_causal ? TrtllmGenAttentionMaskType::Causal : TrtllmGenAttentionMaskType::Dense;
477505
runner_params.lsePtr = lse;
506+
runner_params.lseStrideTokens = lse_stride_tokens;
507+
runner_params.lseStrideHeads = lse_stride_heads;
478508

479509
AlignedAllocator float_allocator(workspace_buffer, workspace_size);
480510
size_t max_batch_size = 8192;
@@ -516,9 +546,13 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
516546
attention_sinks_ptr = static_cast<float*>(attention_sinks.value().data_ptr());
517547
}
518548
float* lse_ptr = nullptr;
549+
int lse_stride_tokens = 0;
550+
int lse_stride_heads = 0;
519551
if (lse.has_value()) {
520552
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
521553
lse_ptr = static_cast<float*>(lse.value().data_ptr());
554+
lse_stride_tokens = lse.value().stride(0);
555+
lse_stride_heads = lse.value().stride(1);
522556
}
523557
TVM_FFI_ICHECK_EQ(out.ndim(), 3) << "out must be a 3D tensor";
524558
TVM_FFI_ICHECK_EQ(query.ndim(), 3) << "query must be a 3D tensor";
@@ -569,7 +603,7 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
569603
num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale_value,
570604
bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left,
571605
sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch,
572-
v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream);
606+
v_stride_keys_values, v_stride_heads, v_stride_batch, lse_stride_tokens, lse_stride_heads, workspace_size, stream);
573607
}
574608

575609
namespace trtllm_cubin_loader {

flashinfer/decode.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,6 +1920,9 @@ def _paged_run(
19201920
enable_pdl,
19211921
workspace_size,
19221922
sinks,
1923+
None,
1924+
None,
1925+
None,
19231926
)
19241927
return out
19251928

@@ -2080,6 +2083,8 @@ def trtllm_batch_decode_with_kv_cache(
20802083
q_len_per_req: Optional[int] = 1,
20812084
o_scale: Optional[float] = 1.0,
20822085
kv_cache_scales: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
2086+
return_lse: bool = False,
2087+
lse: Optional[torch.Tensor] = None,
20832088
) -> Union[torch.Tensor, FP4Tensor]:
20842089
"""
20852090
Parameters
@@ -2306,6 +2311,14 @@ def trtllm_batch_decode_with_kv_cache(
23062311
if kv_cache_scales is not None:
23072312
k_cache_scale, v_cache_scale = kv_cache_scales
23082313

2314+
if return_lse and lse is None:
2315+
lse = torch.empty(
2316+
query.shape[0],
2317+
query.shape[1],
2318+
device=query.device,
2319+
dtype=torch.float32,
2320+
)
2321+
23092322
run_func(
23102323
out,
23112324
out_scale_factor,
@@ -2333,13 +2346,18 @@ def trtllm_batch_decode_with_kv_cache(
23332346
sinks,
23342347
k_cache_scale,
23352348
v_cache_scale,
2349+
lse
23362350
)
23372351

2338-
return (
2352+
out = (
23392353
out
23402354
if out_dtype != "nvfp4"
23412355
else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape)
23422356
)
2357+
if return_lse:
2358+
return out, lse
2359+
else:
2360+
return out
23432361
else:
23442362
raise KeyError(f"Backend {backend} not supported")
23452363

@@ -2550,6 +2568,8 @@ def trtllm_batch_decode_with_kv_cache_mla(
25502568
bmm1_scale: Union[float, torch.Tensor] = 1.0,
25512569
bmm2_scale: Union[float, torch.Tensor] = 1.0,
25522570
sinks: Optional[List[torch.Tensor]] = None,
2571+
return_lse: bool = False,
2572+
lse: Optional[torch.Tensor] = None,
25532573
enable_pdl: bool = None,
25542574
backend: str = "auto",
25552575
) -> torch.Tensor:
@@ -2661,10 +2681,10 @@ def trtllm_batch_decode_with_kv_cache_mla(
26612681
out_shape = query.shape[:-1] + (kv_lora_rank,)
26622682
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
26632683
else:
2664-
batch_size, _, num_q_heads, _ = query.shape
2684+
batch_size, q_len_per_request, num_q_heads, _ = query.shape
26652685
check_shape_dtype_device(
26662686
out,
2667-
[batch_size, num_q_heads, kv_lora_rank],
2687+
[batch_size, q_len_per_request, num_q_heads, kv_lora_rank],
26682688
torch.bfloat16,
26692689
query.device,
26702690
"out",
@@ -2676,6 +2696,15 @@ def trtllm_batch_decode_with_kv_cache_mla(
26762696
if isinstance(bmm2_scale, torch.Tensor):
26772697
assert bmm2_scale.dtype == torch.float32
26782698

2699+
if return_lse and lse is None:
2700+
lse = torch.empty(
2701+
query.shape[0],
2702+
query.shape[1],
2703+
query.shape[2],
2704+
device=query.device,
2705+
dtype=torch.float32,
2706+
)
2707+
26792708
run_func(
26802709
out,
26812710
None, # fp4 output not supported in wrapper api yet.
@@ -2696,9 +2725,15 @@ def trtllm_batch_decode_with_kv_cache_mla(
26962725
enable_pdl,
26972726
workspace_buffer.numel() * workspace_buffer.element_size(),
26982727
sinks,
2728+
None,
2729+
None,
2730+
lse,
26992731
)
27002732

2701-
return out
2733+
if return_lse:
2734+
return out, lse
2735+
else:
2736+
return out
27022737
else:
27032738
raise ValueError(f"Backend {backend} not supported")
27042739

flashinfer/prefill.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3349,6 +3349,8 @@ def trtllm_batch_context_with_kv_cache(
33493349
enable_pdl: Optional[bool] = None,
33503350
sinks: Optional[List[torch.Tensor]] = None,
33513351
kv_cache_scales: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
3352+
return_lse: bool = False,
3353+
lse: Optional[torch.Tensor] = None,
33523354
) -> Union[torch.Tensor, FP4Tensor]:
33533355
"""
33543356
Parameters
@@ -3505,26 +3507,25 @@ def trtllm_batch_context_with_kv_cache(
35053507
else:
35063508
raise ValueError(f"Invalid out_dtype: {out_dtype}")
35073509

3508-
<<<<<<< HEAD
35093510
if isinstance(bmm1_scale, torch.Tensor):
35103511
assert bmm1_scale.dtype == torch.float32
35113512
bmm1_scale = bmm1_scale * log2e
35123513
if isinstance(bmm2_scale, torch.Tensor):
35133514
assert bmm2_scale.dtype == torch.float32
3514-
=======
3515-
bmm1_scale = (
3516-
bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale
3517-
)
3518-
bmm2_scale = (
3519-
bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale
3520-
)
35213515

35223516
k_cache_scale = None
35233517
v_cache_scale = None
35243518
if kv_cache_scales is not None:
35253519
k_cache_scale, v_cache_scale = kv_cache_scales
35263520

3527-
>>>>>>> 263e0743 (fp4 kv cache support)
3521+
if return_lse and lse is None:
3522+
lse = torch.empty(
3523+
query.shape[0],
3524+
query.shape[1],
3525+
device=query.device,
3526+
dtype=torch.float32,
3527+
)
3528+
35283529
workspace_size = workspace_buffer.numel() * workspace_buffer.element_size()
35293530
run_func(
35303531
out,
@@ -3552,9 +3553,14 @@ def trtllm_batch_context_with_kv_cache(
35523553
sinks,
35533554
k_cache_scale,
35543555
v_cache_scale,
3556+
lse
35553557
)
3556-
return (
3558+
out = (
35573559
out
35583560
if out_dtype != "nvfp4"
35593561
else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape)
35603562
)
3563+
if return_lse:
3564+
return out, lse
3565+
else:
3566+
return out

include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,9 @@ class TllmGenFmhaKernel {
264264

265265
if (params.lsePtr != nullptr) {
266266
flashinfer::ComputeLSEFromMD(params.softmaxStatsPtr, params.lsePtr,
267-
params.mSumOfSeqLensQ * params.mNumHeadsQ, params.enable_pdl,
268-
params.stream);
267+
params.mSumOfSeqLensQ, params.mNumHeadsQ,
268+
params.lseStrideTokens, params.lseStrideHeads,
269+
params.enable_pdl, params.stream);
269270
}
270271
// Break the while op.
271272
break;

include/flashinfer/trtllm/fmha/fmhaRunnerParams.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ struct TllmGenFmhaRunnerParams {
225225
// The LSE buffer.
226226
float* lsePtr;
227227

228+
int lseStrideTokens;
229+
int lseStrideHeads;
230+
228231
// Attention sink
229232
float const* ptrAttentionSinks{nullptr};
230233
// The output buffer.

include/flashinfer/trtllm/fmha/lse.cuh

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,27 @@ limitations under the License.
2323

2424
namespace flashinfer {
2525

26-
__global__ void ComputeLSEFromMDKernel(float2* __restrict__ md, float* __restrict__ lse, int n) {
26+
__global__ void ComputeLSEFromMDKernel(float2* __restrict__ md, float* __restrict__ lse, int num_tokens, int num_heads, int lse_stride_tokens, int lse_stride_heads) {
2727
int elem_idx = blockIdx.x * blockDim.x + threadIdx.x;
28-
if (elem_idx >= n) return;
28+
if (elem_idx >= num_tokens * num_heads) return;
2929
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
3030
asm volatile("griddepcontrol.wait;");
3131
#endif
3232
float2 md_elem = md[elem_idx];
3333
float m = md_elem.x;
3434
float d = md_elem.y;
35-
lse[elem_idx] = math::log2e * m + math::ptx_log2(d);
35+
int token_idx = elem_idx / num_heads;
36+
int head_idx = elem_idx % num_heads;
37+
int elem_idx_lse = token_idx * lse_stride_tokens + head_idx * lse_stride_heads;
38+
lse[elem_idx_lse] = m + math::loge2 * math::ptx_log2(d);
3639
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
3740
asm volatile("griddepcontrol.launch_dependents;");
3841
#endif
3942
}
4043

41-
inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool launch_with_pdl,
42-
cudaStream_t stream) {
44+
inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int num_tokens, int num_heads, int lse_stride_tokens, int lse_stride_heads,
45+
bool launch_with_pdl, cudaStream_t stream) {
46+
int n = num_tokens * num_heads;
4347
int num_threads = std::min(1024, UpPowerOfTwo(n));
4448
int num_blocks = ceil_div(n, num_threads);
4549
cudaLaunchConfig_t config;
@@ -53,7 +57,7 @@ inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool launch_w
5357
config.numAttrs = 1;
5458
config.attrs = attrs;
5559

56-
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, ComputeLSEFromMDKernel, md, lse, n));
60+
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, ComputeLSEFromMDKernel, md, lse, num_tokens, num_heads, lse_stride_tokens, lse_stride_heads));
5761
return cudaSuccess;
5862
}
5963

0 commit comments

Comments
 (0)