@@ -75,7 +75,7 @@ class TllmGenFmhaRunnerCache {
7575void trtllm_paged_attention_launcher (
7676 void * out, void * out_scale_factor, void * query, void * key_cache, void * value_cache,
7777 void * workspace_buffer, int * block_tables, int * seq_lens, int * cum_seq_lens_q,
78- int * cum_seq_lens_kv, float * attention_sinks,
78+ int * cum_seq_lens_kv, float * attention_sinks, float * lse,
7979 void * k_cache_scales, void * v_cache_scales,
8080 Data_type q_data_type, Data_type kv_data_type,
8181 Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len,
@@ -84,8 +84,9 @@ void trtllm_paged_attention_launcher(
8484 int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq,
8585 double bmm1_scale, double bmm2_scale, const float * bmm1_scale_log2_ptr,
8686 const float * bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index,
87- int64_t window_left, int64_t sum_seq_q, int64_t sm_count, bool enable_pdl,
88- int64_t workspace_size, cudaStream_t stream) {
87+ int64_t window_left, int64_t sum_seq_q,
88+ int64_t lse_stride_tokens, int64_t lse_stride_heads,
89+ int64_t sm_count, bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
8990 if (num_qo_heads % num_kv_heads != 0 ) {
9091 std::ostringstream err_msg;
9192 err_msg << " num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
@@ -173,6 +174,12 @@ void trtllm_paged_attention_launcher(
173174 runner_params.multiCtasKvScratchPtr =
174175 float_allocator.aligned_alloc <void >(0 , 16 , " trtllm_gen_scratch_workspace" );
175176 }
177+ runner_params.softmaxStatsPtr = float_allocator.aligned_alloc <float2 >(
178+ sizeof (float2 ) * num_qo_heads * runner_params.mSumOfSeqLensQ , 16 ,
179+ " trtllm_gen_softmax_workspace" );
180+ runner_params.lsePtr = lse;
181+ runner_params.lseStrideTokens = lse_stride_tokens;
182+ runner_params.lseStrideHeads = lse_stride_heads;
176183
177184 auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo (runner_params);
178185 if (!foundKernels) {
@@ -214,7 +221,7 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
214221 int64_t o_sf_vec_size, int64_t o_sf_start_index,
215222 int64_t window_left, int64_t sm_count, bool enable_pdl,
216223 int64_t workspace_size, Optional<TensorView> attention_sinks,
217- Optional<TensorView> k_cache_scales, Optional<TensorView> v_cache_scales) {
224+ Optional<TensorView> k_cache_scales, Optional<TensorView> v_cache_scales, Optional<TensorView> lse ) {
218225 auto q_data_type = dl_dtype_to_tllm_data_type (query.dtype ());
219226 auto kv_data_type = dl_dtype_to_tllm_data_type (key_cache.dtype ());
220227 TVM_FFI_ICHECK_EQ (key_cache.ndim (), value_cache.ndim ());
@@ -287,6 +294,16 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
287294 ? static_cast <float *>(maybe_bmm2_scale_tensor.value ().data_ptr ())
288295 : nullptr ;
289296
297+ float * lse_ptr = nullptr ;
298+ int lse_stride_tokens = 0 ;
299+ int lse_stride_heads = 0 ;
300+ if (lse.has_value ()) {
301+ TVM_FFI_ICHECK_EQ (lse.value ().dtype (), dl_float32) << " lse must be a float tensor" ;
302+ lse_ptr = static_cast <float *>(lse.value ().data_ptr ());
303+ lse_stride_tokens = lse.value ().stride (0 );
304+ lse_stride_heads = lse.value ().stride (2 );
305+ }
306+
290307 void * k_cache_scales_ptr = nullptr ;
291308 void * v_cache_scales_ptr = nullptr ;
292309 if (k_cache_scales.has_value ()) {
@@ -301,15 +318,15 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
301318 workspace_buffer.data_ptr (), static_cast <int *>(block_tables.data_ptr ()),
302319 static_cast <int *>(seq_lens.data_ptr ()),
303320 /* cum_seq_lens_q=*/ nullptr ,
304- /* cum_seq_lens_kv=*/ nullptr , attention_sinks_ptr,
321+ /* cum_seq_lens_kv=*/ nullptr , attention_sinks_ptr, lse_ptr,
305322 k_cache_scales_ptr, v_cache_scales_ptr,
306323 q_data_type, kv_data_type, o_data_type,
307324 TllmPagedAttentionMode::ForGen, batch_size, /* max_q_len=*/ q_len_per_request, max_kv_len,
308325 num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
309326 kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq,
310327 bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,
311- o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, enable_pdl, workspace_size,
312- stream);
328+ o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, lse_stride_tokens, lse_stride_heads,
329+ sm_count, enable_pdl, workspace_size, stream);
313330}
314331
315332void trtllm_paged_attention_context (
@@ -320,7 +337,7 @@ void trtllm_paged_attention_context(
320337 double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size,
321338 int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count,
322339 bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
323- Optional<TensorView> k_cache_scales, Optional<TensorView> v_cache_scales) {
340+ Optional<TensorView> k_cache_scales, Optional<TensorView> v_cache_scales, Optional<TensorView> lse ) {
324341
325342 auto q_data_type = dl_dtype_to_tllm_data_type (query.dtype ());
326343 auto kv_data_type = dl_dtype_to_tllm_data_type (key_cache.dtype ());
@@ -385,6 +402,16 @@ void trtllm_paged_attention_context(
385402 ? static_cast <float *>(maybe_bmm2_scale_tensor.value ().data_ptr ())
386403 : nullptr ;
387404
405+ float * lse_ptr = nullptr ;
406+ int lse_stride_tokens = 0 ;
407+ int lse_stride_heads = 0 ;
408+ if (lse.has_value ()) {
409+ TVM_FFI_ICHECK_EQ (lse.value ().dtype (), dl_float32) << " lse must be a float tensor" ;
410+ lse_ptr = static_cast <float *>(lse.value ().data_ptr ());
411+ lse_stride_tokens = lse.value ().stride (0 );
412+ lse_stride_heads = lse.value ().stride (1 );
413+ }
414+
388415 void * k_cache_scales_ptr = nullptr ;
389416 void * v_cache_scales_ptr = nullptr ;
390417 if (k_cache_scales.has_value ()) {
@@ -399,14 +426,14 @@ void trtllm_paged_attention_context(
399426 workspace_buffer.data_ptr (), static_cast <int *>(block_tables.data_ptr ()),
400427 static_cast <int *>(seq_lens.data_ptr ()),
401428 /* cum_seq_lens_q=*/ static_cast <int *>(cum_seq_lens_q.data_ptr ()),
402- /* cum_seq_lens_kv=*/ static_cast <int *>(cum_seq_lens_kv.data_ptr ()), attention_sinks_ptr,
429+ /* cum_seq_lens_kv=*/ static_cast <int *>(cum_seq_lens_kv.data_ptr ()), attention_sinks_ptr, lse_ptr,
403430 k_cache_scales_ptr, v_cache_scales_ptr,
404431 q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
405432 max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
406433 head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
407434 max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
408- bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count,
409- enable_pdl, workspace_size, stream);
435+ bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
436+ lse_stride_tokens, lse_stride_heads, sm_count, enable_pdl, workspace_size, stream);
410437}
411438
412439void trtllm_ragged_attention_launcher (
@@ -419,6 +446,7 @@ void trtllm_ragged_attention_launcher(
419446 int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal,
420447 int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch,
421448 int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch,
449+ int64_t lse_stride_tokens, int64_t lse_stride_heads,
422450 int64_t workspace_size, cudaStream_t stream) {
423451 if (num_qo_heads % num_kv_heads != 0 ) {
424452 std::ostringstream err_msg;
@@ -475,6 +503,8 @@ void trtllm_ragged_attention_launcher(
475503 runner_params.mMaskType =
476504 is_causal ? TrtllmGenAttentionMaskType::Causal : TrtllmGenAttentionMaskType::Dense;
477505 runner_params.lsePtr = lse;
506+ runner_params.lseStrideTokens = lse_stride_tokens;
507+ runner_params.lseStrideHeads = lse_stride_heads;
478508
479509 AlignedAllocator float_allocator (workspace_buffer, workspace_size);
480510 size_t max_batch_size = 8192 ;
@@ -516,9 +546,13 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
516546 attention_sinks_ptr = static_cast <float *>(attention_sinks.value ().data_ptr ());
517547 }
518548 float * lse_ptr = nullptr ;
549+ int lse_stride_tokens = 0 ;
550+ int lse_stride_heads = 0 ;
519551 if (lse.has_value ()) {
520552 TVM_FFI_ICHECK_EQ (lse.value ().dtype (), dl_float32) << " lse must be a float tensor" ;
521553 lse_ptr = static_cast <float *>(lse.value ().data_ptr ());
554+ lse_stride_tokens = lse.value ().stride (0 );
555+ lse_stride_heads = lse.value ().stride (1 );
522556 }
523557 TVM_FFI_ICHECK_EQ (out.ndim (), 3 ) << " out must be a 3D tensor" ;
524558 TVM_FFI_ICHECK_EQ (query.ndim (), 3 ) << " query must be a 3D tensor" ;
@@ -569,7 +603,7 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
569603 num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale_value,
570604 bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left,
571605 sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch,
572- v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream);
606+ v_stride_keys_values, v_stride_heads, v_stride_batch, lse_stride_tokens, lse_stride_heads, workspace_size, stream);
573607}
574608
575609namespace trtllm_cubin_loader {
0 commit comments