3434 logger .debug (f"flash_attention_v3_varlen not available: { e } " )
3535 flash_attention_v3_varlen = None
3636
37- # Enable verbose debug logging for MLA prefix-cache / chunk-prefill paths when MLA_CHUNK_DEBUG=1.
38- # All logger.debug messages in this file are silent by default.
39- if os .environ .get ("MLA_CHUNK_DEBUG" , "0" ) == "1" :
40- # paddleformers logger exposes set_level (not the standard logging.Logger.setLevel)
41- logger .set_level ("DEBUG" )
42-
4337from fastdeploy .model_executor .layers .attention .ops import (
4438 get_block_shape_and_split_kv_block ,
4539 init_kv_signal_per_query ,
7872# ============================================================================
7973
8074
81- @enable_compat_on_triton_kernel
82- @triton .jit ()
83- def read_latent_from_cache_kernel (
84- latent_cache ,
85- block_tables ,
86- cache_kv_lens ,
87- cu_seqlens_cached_kv ,
88- output_kv_c ,
89- output_k_pe ,
90- block_size : tl .constexpr ,
91- kv_lora_rank : tl .constexpr ,
92- qk_rope_head_dim : tl .constexpr ,
93- LATENT_DIM : tl .constexpr ,
94- ):
95- """
96- Kernel to read latent vectors (kv_c and k_pe) from paged latent cache.
97- Each program instance handles one cached token.
98-
99- Args:
100- latent_cache: [num_blocks, 1, block_size, kv_lora_rank + qk_rope_head_dim]
101- block_tables: [batch_size, max_blocks_per_seq]
102- cache_kv_lens: [batch_size] - cached KV length for each request
103- cu_seqlens_cached_kv: [batch_size + 1] - cumulative sequence lengths for cached KV
104- output_kv_c: [total_cached_tokens, kv_lora_rank]
105- output_k_pe: [total_cached_tokens, qk_rope_head_dim]
106- """
107- # Global token index in the output
108- token_idx = tl .program_id (axis = 0 )
109-
110- # Find which batch this token belongs to using binary search on cu_seqlens
111- # For simplicity, we use a linear scan (could be optimized with binary search)
112- batch_id = 0
113- for i in range (cu_seqlens_cached_kv .shape [0 ] - 1 ):
114- if token_idx >= tl .load (cu_seqlens_cached_kv + i ) and token_idx < tl .load (cu_seqlens_cached_kv + i + 1 ):
115- batch_id = i
116- break
117-
118- # Local token index within the batch
119- cu_start = tl .load (cu_seqlens_cached_kv + batch_id )
120- local_token_idx = token_idx - cu_start
121-
122- # Get the physical block and offset
123- block_idx = local_token_idx // block_size
124- block_offset = local_token_idx % block_size
125-
126- # Get physical block id from block_tables
127- physical_block_id = tl .load (block_tables + batch_id * block_tables .shape [1 ] + block_idx )
128-
129- # Load latent vector from cache
130- # latent_cache shape: [num_blocks, 1, block_size, kv_lora_rank + qk_rope_head_dim]
131- latent_base = latent_cache + physical_block_id * LATENT_DIM * block_size + block_offset * LATENT_DIM
132-
133- # Read kv_c (first kv_lora_rank dimensions)
134- kv_c_offsets = tl .arange (0 , kv_lora_rank )
135- kv_c_value = tl .load (latent_base + kv_c_offsets , mask = kv_c_offsets < kv_lora_rank )
136-
137- # Read k_pe (last qk_rope_head_dim dimensions)
138- k_pe_offsets = tl .arange (kv_lora_rank , kv_lora_rank + qk_rope_head_dim )
139- k_pe_value = tl .load (latent_base + k_pe_offsets , mask = k_pe_offsets < LATENT_DIM )
140-
141- # Store outputs
142- output_kv_c_base = output_kv_c + token_idx * kv_lora_rank
143- tl .store (output_kv_c_base + kv_c_offsets , kv_c_value , mask = kv_c_offsets < kv_lora_rank )
144-
145- output_k_pe_base = output_k_pe + token_idx * qk_rope_head_dim
146- k_pe_out_offsets = tl .arange (0 , qk_rope_head_dim )
147- tl .store (output_k_pe_base + k_pe_out_offsets , k_pe_value )
148-
149-
15075def read_latent_from_cache_naive (
15176 latent_cache : paddle .Tensor ,
15277 block_tables : paddle .Tensor ,
@@ -188,10 +113,6 @@ def read_latent_from_cache_naive(
188113 bsz = cu_seqlens_cached_kv .shape [0 ] - 1
189114 output_idx = 0
190115
191- logger .debug (f"[read_latent_from_cache] total_cached_tokens={ total_cached_tokens } , bsz={ bsz } " )
192- logger .debug (f"[read_latent_from_cache] cu_seqlens_cached_kv={ cu_seqlens_cached_kv .tolist ()} " )
193- logger .debug (f"[read_latent_from_cache] block_tables shape={ block_tables .shape } " )
194-
195116 for batch_id in range (bsz ):
196117 # Get the number of cached tokens for this batch from cu_seqlens_cached_kv
197118 cu_start = (
@@ -209,9 +130,6 @@ def read_latent_from_cache_naive(
209130 if cache_len <= 0 :
210131 continue
211132
212- # Debug: Print cache reading info
213- logger .debug (f"[read_latent_from_cache] batch_id={ batch_id } , cache_len={ cache_len } " )
214-
215133 # Read tokens from multiple blocks if cache_len > block_size
216134 local_idx = 0
217135 while local_idx < cache_len :
@@ -221,11 +139,6 @@ def read_latent_from_cache_naive(
221139
222140 physical_block_id = block_tables [batch_id , block_idx ].item ()
223141
224- # Debug: Print block access info
225- logger .debug (
226- f"[read_latent_from_cache] block_idx={ block_idx } , block_offset={ block_offset } , physical_block_id={ physical_block_id } "
227- )
228-
229142 # Load latent vectors from this block
230143 for offset in range (tokens_to_read ):
231144 latent_vec = latent_cache [physical_block_id , 0 , block_offset + offset , :]
@@ -237,7 +150,6 @@ def read_latent_from_cache_naive(
237150
238151 local_idx += tokens_to_read
239152
240- logger .debug (f"[read_latent_from_cache] Total cached tokens read: { output_idx } " )
241153 assert (
242154 output_idx == total_cached_tokens
243155 ), f"read_latent_from_cache_naive: wrote { output_idx } tokens, expected { total_cached_tokens } "
@@ -293,10 +205,6 @@ def interleave_cached_and_new_latent_naive(
293205 new_idx = 0
294206 out_position = 0 # Track output position for each batch
295207
296- logger .debug (
297- f"[interleave_cached_and_new_latent] bsz={ bsz } , total_cached={ total_cached } , total_new={ total_new } , total_tokens={ total_tokens } "
298- )
299-
300208 for batch_id in range (bsz ):
301209 # Number of cached tokens for this batch
302210 cu_cached_start = (
@@ -322,10 +230,6 @@ def interleave_cached_and_new_latent_naive(
322230 )
323231 num_new = cu_new_end - cu_new_start
324232
325- logger .debug (
326- f"[interleave] batch_id={ batch_id } , num_cached={ num_cached } , num_new={ num_new } , cached_idx={ cached_idx } , out_position={ out_position } "
327- )
328-
329233 # Output position for this batch (sequential, no gaps)
330234 out_start = out_position
331235
@@ -348,7 +252,6 @@ def interleave_cached_and_new_latent_naive(
348252 # Update output position for next batch
349253 out_position += num_cached + num_new
350254
351- logger .debug (f"[interleave] Final: cached_idx={ cached_idx } , new_idx={ new_idx } , out_position={ out_position } " )
352255 assert (
353256 cached_idx == total_cached
354257 ), f"interleave_cached_and_new_latent_naive: cached_idx={ cached_idx } != total_cached={ total_cached } "
@@ -852,12 +755,12 @@ def __init__(
852755 is_paddle_supported = any (num >= 90 for num in paddle .version .cuda_archs ())
853756 if is_current_sm_supported and is_paddle_supported :
854757 self .flash_attn_func = flash_attention_v3_varlen
855- print ("The current platform supports Flash Attention V3." )
758+ logger . info ("The current platform supports Flash Attention V3." )
856759 self .flash_attn_kwargs = {"softmax_scale" : self .attn_softmax_scale }
857760 else :
858761 self .flash_attn_func = flash_attn_unpadded
859762 self .flash_attn_kwargs = {"scale" : self .attn_softmax_scale , "training" : False }
860- print (
763+ logger . info (
861764 "The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
862765 )
863766
@@ -918,11 +821,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
918821 # Prefix cache exists when seq_lens_decoder > 0
919822 # seq_lens_decoder stores the cached KV length for chunked prefill/prefix cache
920823 for i in range (bsz ):
921- enc_len = (
922- forward_meta .seq_lens_encoder [i ].item ()
923- if hasattr (forward_meta .seq_lens_encoder [i ], "item" )
924- else forward_meta .seq_lens_encoder [i ]
925- )
824+ # enc_len = (
825+ # forward_meta.seq_lens_encoder[i].item()
826+ # if hasattr(forward_meta.seq_lens_encoder[i], "item")
827+ # else forward_meta.seq_lens_encoder[i]
828+ # )
926829 dec_len = (
927830 forward_meta .seq_lens_decoder [i ].item ()
928831 if hasattr (forward_meta .seq_lens_decoder [i ], "item" )
@@ -966,11 +869,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
966869 # cu_seqlens_k_with_cache must reflect this sum per batch.
967870 # cu_seqlens_cached_kv tracks only the cached portion for read_latent_from_cache().
968871 for i in range (bsz ):
969- enc_len = (
970- forward_meta .seq_lens_encoder [i ].item ()
971- if hasattr (forward_meta .seq_lens_encoder [i ], "item" )
972- else forward_meta .seq_lens_encoder [i ]
973- )
872+ # enc_len = (
873+ # forward_meta.seq_lens_encoder[i].item()
874+ # if hasattr(forward_meta.seq_lens_encoder[i], "item")
875+ # else forward_meta.seq_lens_encoder[i]
876+ # )
974877 dec_len = (
975878 forward_meta .seq_lens_decoder [i ].item ()
976879 if hasattr (forward_meta .seq_lens_decoder [i ], "item" )
@@ -981,9 +884,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
981884 if hasattr (forward_meta .seq_lens_this_time [i ], "item" )
982885 else forward_meta .seq_lens_this_time [i ]
983886 )
984- logger .debug (
985- f"[init_attn_meta] batch { i } : enc_len={ enc_len } , dec_len={ dec_len } , seq_this={ seq_this_time } , cumsum_cached={ cumsum_cached } , cumsum_total={ cumsum_total } "
986- )
987887 if dec_len > 0 :
988888 cumsum_cached += dec_len
989889 cumsum_total += dec_len
@@ -992,8 +892,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
992892 cumsum_total += seq_this_time
993893 cu_seqlens_cached_kv [i + 1 ] = cumsum_cached
994894 cu_seqlens_k_with_cache [i + 1 ] = cumsum_total
995- logger .debug (f"[init_attn_meta] Final cu_seqlens_cached_kv: { cu_seqlens_cached_kv .tolist ()} " )
996- logger .debug (f"[init_attn_meta] Final cu_seqlens_k_with_cache: { cu_seqlens_k_with_cache .tolist ()} " )
997895 # Consistency checks: starts at 0, monotonic non-decreasing, final equals cumulative.
998896 assert cu_seqlens_cached_kv [0 ].item () == 0 , "cu_seqlens_cached_kv must start at 0"
999897 assert cu_seqlens_k_with_cache [0 ].item () == 0 , "cu_seqlens_k_with_cache must start at 0"
@@ -1245,37 +1143,7 @@ def forward_mixed(
12451143
12461144 # Prefill branch: k is not None
12471145 if k is not None :
1248- # Debug: Verify tensor shapes and sequence lengths
12491146 bsz = forward_meta .cu_seqlens_q .shape [0 ] - 1
1250- total_q_tokens = q .shape [0 ]
1251- total_k_tokens = k .shape [0 ]
1252-
1253- # Calculate expected cu_seqlens_k_with_cache
1254- if metadata .has_prefix_cache and metadata .cu_seqlens_k_with_cache is not None :
1255- expected_k_len = (
1256- metadata .cu_seqlens_k_with_cache [bsz ].item ()
1257- if hasattr (metadata .cu_seqlens_k_with_cache [bsz ], "item" )
1258- else metadata .cu_seqlens_k_with_cache [bsz ]
1259- )
1260- else :
1261- expected_k_len = (
1262- forward_meta .cu_seqlens_k [bsz ].item ()
1263- if hasattr (forward_meta .cu_seqlens_k [bsz ], "item" )
1264- else forward_meta .cu_seqlens_k [bsz ]
1265- )
1266-
1267- # Debug output
1268- logger .debug (
1269- f"[forward_mixed] bsz={ bsz } , total_q={ total_q_tokens } , total_k={ total_k_tokens } , expected_k={ expected_k_len } "
1270- )
1271- logger .debug (f"[forward_mixed] has_prefix_cache={ metadata .has_prefix_cache } " )
1272- logger .debug (
1273- f"[forward_mixed] cu_seqlens_q={ forward_meta .cu_seqlens_q .tolist () if hasattr (forward_meta .cu_seqlens_q , 'tolist' ) else forward_meta .cu_seqlens_q } "
1274- )
1275- if metadata .has_prefix_cache and metadata .cu_seqlens_k_with_cache is not None :
1276- logger .debug (
1277- f"[forward_mixed] cu_seqlens_k_with_cache={ metadata .cu_seqlens_k_with_cache .tolist () if hasattr (metadata .cu_seqlens_k_with_cache , 'tolist' ) else metadata .cu_seqlens_k_with_cache } "
1278- )
12791147
12801148 # Write cache only for new tokens of prefill/chunked-prefill batches.
12811149 # Decode batches (seq_lens_encoder == 0) are intentionally skipped here — they
0 commit comments