Skip to content

Commit aaa5924

Browse files
author
chang-wenbin
committed
optimized mla chunk prefill
1 parent c346a24 commit aaa5924

2 files changed

Lines changed: 12 additions & 150 deletions

File tree

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 12 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,6 @@
3434
logger.debug(f"flash_attention_v3_varlen not available: {e}")
3535
flash_attention_v3_varlen = None
3636

37-
# Enable verbose debug logging for MLA prefix-cache / chunk-prefill paths when MLA_CHUNK_DEBUG=1.
38-
# All logger.debug messages in this file are silent by default.
39-
if os.environ.get("MLA_CHUNK_DEBUG", "0") == "1":
40-
# paddleformers logger exposes set_level (not the standard logging.Logger.setLevel)
41-
logger.set_level("DEBUG")
42-
4337
from fastdeploy.model_executor.layers.attention.ops import (
4438
get_block_shape_and_split_kv_block,
4539
init_kv_signal_per_query,
@@ -78,75 +72,6 @@
7872
# ============================================================================
7973

8074

81-
@enable_compat_on_triton_kernel
82-
@triton.jit()
83-
def read_latent_from_cache_kernel(
84-
latent_cache,
85-
block_tables,
86-
cache_kv_lens,
87-
cu_seqlens_cached_kv,
88-
output_kv_c,
89-
output_k_pe,
90-
block_size: tl.constexpr,
91-
kv_lora_rank: tl.constexpr,
92-
qk_rope_head_dim: tl.constexpr,
93-
LATENT_DIM: tl.constexpr,
94-
):
95-
"""
96-
Kernel to read latent vectors (kv_c and k_pe) from paged latent cache.
97-
Each program instance handles one cached token.
98-
99-
Args:
100-
latent_cache: [num_blocks, 1, block_size, kv_lora_rank + qk_rope_head_dim]
101-
block_tables: [batch_size, max_blocks_per_seq]
102-
cache_kv_lens: [batch_size] - cached KV length for each request
103-
cu_seqlens_cached_kv: [batch_size + 1] - cumulative sequence lengths for cached KV
104-
output_kv_c: [total_cached_tokens, kv_lora_rank]
105-
output_k_pe: [total_cached_tokens, qk_rope_head_dim]
106-
"""
107-
# Global token index in the output
108-
token_idx = tl.program_id(axis=0)
109-
110-
# Find which batch this token belongs to using binary search on cu_seqlens
111-
# For simplicity, we use a linear scan (could be optimized with binary search)
112-
batch_id = 0
113-
for i in range(cu_seqlens_cached_kv.shape[0] - 1):
114-
if token_idx >= tl.load(cu_seqlens_cached_kv + i) and token_idx < tl.load(cu_seqlens_cached_kv + i + 1):
115-
batch_id = i
116-
break
117-
118-
# Local token index within the batch
119-
cu_start = tl.load(cu_seqlens_cached_kv + batch_id)
120-
local_token_idx = token_idx - cu_start
121-
122-
# Get the physical block and offset
123-
block_idx = local_token_idx // block_size
124-
block_offset = local_token_idx % block_size
125-
126-
# Get physical block id from block_tables
127-
physical_block_id = tl.load(block_tables + batch_id * block_tables.shape[1] + block_idx)
128-
129-
# Load latent vector from cache
130-
# latent_cache shape: [num_blocks, 1, block_size, kv_lora_rank + qk_rope_head_dim]
131-
latent_base = latent_cache + physical_block_id * LATENT_DIM * block_size + block_offset * LATENT_DIM
132-
133-
# Read kv_c (first kv_lora_rank dimensions)
134-
kv_c_offsets = tl.arange(0, kv_lora_rank)
135-
kv_c_value = tl.load(latent_base + kv_c_offsets, mask=kv_c_offsets < kv_lora_rank)
136-
137-
# Read k_pe (last qk_rope_head_dim dimensions)
138-
k_pe_offsets = tl.arange(kv_lora_rank, kv_lora_rank + qk_rope_head_dim)
139-
k_pe_value = tl.load(latent_base + k_pe_offsets, mask=k_pe_offsets < LATENT_DIM)
140-
141-
# Store outputs
142-
output_kv_c_base = output_kv_c + token_idx * kv_lora_rank
143-
tl.store(output_kv_c_base + kv_c_offsets, kv_c_value, mask=kv_c_offsets < kv_lora_rank)
144-
145-
output_k_pe_base = output_k_pe + token_idx * qk_rope_head_dim
146-
k_pe_out_offsets = tl.arange(0, qk_rope_head_dim)
147-
tl.store(output_k_pe_base + k_pe_out_offsets, k_pe_value)
148-
149-
15075
def read_latent_from_cache_naive(
15176
latent_cache: paddle.Tensor,
15277
block_tables: paddle.Tensor,
@@ -188,10 +113,6 @@ def read_latent_from_cache_naive(
188113
bsz = cu_seqlens_cached_kv.shape[0] - 1
189114
output_idx = 0
190115

191-
logger.debug(f"[read_latent_from_cache] total_cached_tokens={total_cached_tokens}, bsz={bsz}")
192-
logger.debug(f"[read_latent_from_cache] cu_seqlens_cached_kv={cu_seqlens_cached_kv.tolist()}")
193-
logger.debug(f"[read_latent_from_cache] block_tables shape={block_tables.shape}")
194-
195116
for batch_id in range(bsz):
196117
# Get the number of cached tokens for this batch from cu_seqlens_cached_kv
197118
cu_start = (
@@ -209,9 +130,6 @@ def read_latent_from_cache_naive(
209130
if cache_len <= 0:
210131
continue
211132

212-
# Debug: Print cache reading info
213-
logger.debug(f"[read_latent_from_cache] batch_id={batch_id}, cache_len={cache_len}")
214-
215133
# Read tokens from multiple blocks if cache_len > block_size
216134
local_idx = 0
217135
while local_idx < cache_len:
@@ -221,11 +139,6 @@ def read_latent_from_cache_naive(
221139

222140
physical_block_id = block_tables[batch_id, block_idx].item()
223141

224-
# Debug: Print block access info
225-
logger.debug(
226-
f"[read_latent_from_cache] block_idx={block_idx}, block_offset={block_offset}, physical_block_id={physical_block_id}"
227-
)
228-
229142
# Load latent vectors from this block
230143
for offset in range(tokens_to_read):
231144
latent_vec = latent_cache[physical_block_id, 0, block_offset + offset, :]
@@ -237,7 +150,6 @@ def read_latent_from_cache_naive(
237150

238151
local_idx += tokens_to_read
239152

240-
logger.debug(f"[read_latent_from_cache] Total cached tokens read: {output_idx}")
241153
assert (
242154
output_idx == total_cached_tokens
243155
), f"read_latent_from_cache_naive: wrote {output_idx} tokens, expected {total_cached_tokens}"
@@ -293,10 +205,6 @@ def interleave_cached_and_new_latent_naive(
293205
new_idx = 0
294206
out_position = 0 # Track output position for each batch
295207

296-
logger.debug(
297-
f"[interleave_cached_and_new_latent] bsz={bsz}, total_cached={total_cached}, total_new={total_new}, total_tokens={total_tokens}"
298-
)
299-
300208
for batch_id in range(bsz):
301209
# Number of cached tokens for this batch
302210
cu_cached_start = (
@@ -322,10 +230,6 @@ def interleave_cached_and_new_latent_naive(
322230
)
323231
num_new = cu_new_end - cu_new_start
324232

325-
logger.debug(
326-
f"[interleave] batch_id={batch_id}, num_cached={num_cached}, num_new={num_new}, cached_idx={cached_idx}, out_position={out_position}"
327-
)
328-
329233
# Output position for this batch (sequential, no gaps)
330234
out_start = out_position
331235

@@ -348,7 +252,6 @@ def interleave_cached_and_new_latent_naive(
348252
# Update output position for next batch
349253
out_position += num_cached + num_new
350254

351-
logger.debug(f"[interleave] Final: cached_idx={cached_idx}, new_idx={new_idx}, out_position={out_position}")
352255
assert (
353256
cached_idx == total_cached
354257
), f"interleave_cached_and_new_latent_naive: cached_idx={cached_idx} != total_cached={total_cached}"
@@ -852,12 +755,12 @@ def __init__(
852755
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
853756
if is_current_sm_supported and is_paddle_supported:
854757
self.flash_attn_func = flash_attention_v3_varlen
855-
print("The current platform supports Flash Attention V3.")
758+
logger.info("The current platform supports Flash Attention V3.")
856759
self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale}
857760
else:
858761
self.flash_attn_func = flash_attn_unpadded
859762
self.flash_attn_kwargs = {"scale": self.attn_softmax_scale, "training": False}
860-
print(
763+
logger.info(
861764
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
862765
)
863766

@@ -918,11 +821,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
918821
# Prefix cache exists when seq_lens_decoder > 0
919822
# seq_lens_decoder stores the cached KV length for chunked prefill/prefix cache
920823
for i in range(bsz):
921-
enc_len = (
922-
forward_meta.seq_lens_encoder[i].item()
923-
if hasattr(forward_meta.seq_lens_encoder[i], "item")
924-
else forward_meta.seq_lens_encoder[i]
925-
)
824+
# enc_len = (
825+
# forward_meta.seq_lens_encoder[i].item()
826+
# if hasattr(forward_meta.seq_lens_encoder[i], "item")
827+
# else forward_meta.seq_lens_encoder[i]
828+
# )
926829
dec_len = (
927830
forward_meta.seq_lens_decoder[i].item()
928831
if hasattr(forward_meta.seq_lens_decoder[i], "item")
@@ -966,11 +869,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
966869
# cu_seqlens_k_with_cache must reflect this sum per batch.
967870
# cu_seqlens_cached_kv tracks only the cached portion for read_latent_from_cache().
968871
for i in range(bsz):
969-
enc_len = (
970-
forward_meta.seq_lens_encoder[i].item()
971-
if hasattr(forward_meta.seq_lens_encoder[i], "item")
972-
else forward_meta.seq_lens_encoder[i]
973-
)
872+
# enc_len = (
873+
# forward_meta.seq_lens_encoder[i].item()
874+
# if hasattr(forward_meta.seq_lens_encoder[i], "item")
875+
# else forward_meta.seq_lens_encoder[i]
876+
# )
974877
dec_len = (
975878
forward_meta.seq_lens_decoder[i].item()
976879
if hasattr(forward_meta.seq_lens_decoder[i], "item")
@@ -981,9 +884,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
981884
if hasattr(forward_meta.seq_lens_this_time[i], "item")
982885
else forward_meta.seq_lens_this_time[i]
983886
)
984-
logger.debug(
985-
f"[init_attn_meta] batch {i}: enc_len={enc_len}, dec_len={dec_len}, seq_this={seq_this_time}, cumsum_cached={cumsum_cached}, cumsum_total={cumsum_total}"
986-
)
987887
if dec_len > 0:
988888
cumsum_cached += dec_len
989889
cumsum_total += dec_len
@@ -992,8 +892,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
992892
cumsum_total += seq_this_time
993893
cu_seqlens_cached_kv[i + 1] = cumsum_cached
994894
cu_seqlens_k_with_cache[i + 1] = cumsum_total
995-
logger.debug(f"[init_attn_meta] Final cu_seqlens_cached_kv: {cu_seqlens_cached_kv.tolist()}")
996-
logger.debug(f"[init_attn_meta] Final cu_seqlens_k_with_cache: {cu_seqlens_k_with_cache.tolist()}")
997895
# Consistency checks: starts at 0, monotonic non-decreasing, final equals cumulative.
998896
assert cu_seqlens_cached_kv[0].item() == 0, "cu_seqlens_cached_kv must start at 0"
999897
assert cu_seqlens_k_with_cache[0].item() == 0, "cu_seqlens_k_with_cache must start at 0"
@@ -1245,37 +1143,7 @@ def forward_mixed(
12451143

12461144
# Prefill branch: k is not None
12471145
if k is not None:
1248-
# Debug: Verify tensor shapes and sequence lengths
12491146
bsz = forward_meta.cu_seqlens_q.shape[0] - 1
1250-
total_q_tokens = q.shape[0]
1251-
total_k_tokens = k.shape[0]
1252-
1253-
# Calculate expected cu_seqlens_k_with_cache
1254-
if metadata.has_prefix_cache and metadata.cu_seqlens_k_with_cache is not None:
1255-
expected_k_len = (
1256-
metadata.cu_seqlens_k_with_cache[bsz].item()
1257-
if hasattr(metadata.cu_seqlens_k_with_cache[bsz], "item")
1258-
else metadata.cu_seqlens_k_with_cache[bsz]
1259-
)
1260-
else:
1261-
expected_k_len = (
1262-
forward_meta.cu_seqlens_k[bsz].item()
1263-
if hasattr(forward_meta.cu_seqlens_k[bsz], "item")
1264-
else forward_meta.cu_seqlens_k[bsz]
1265-
)
1266-
1267-
# Debug output
1268-
logger.debug(
1269-
f"[forward_mixed] bsz={bsz}, total_q={total_q_tokens}, total_k={total_k_tokens}, expected_k={expected_k_len}"
1270-
)
1271-
logger.debug(f"[forward_mixed] has_prefix_cache={metadata.has_prefix_cache}")
1272-
logger.debug(
1273-
f"[forward_mixed] cu_seqlens_q={forward_meta.cu_seqlens_q.tolist() if hasattr(forward_meta.cu_seqlens_q, 'tolist') else forward_meta.cu_seqlens_q}"
1274-
)
1275-
if metadata.has_prefix_cache and metadata.cu_seqlens_k_with_cache is not None:
1276-
logger.debug(
1277-
f"[forward_mixed] cu_seqlens_k_with_cache={metadata.cu_seqlens_k_with_cache.tolist() if hasattr(metadata.cu_seqlens_k_with_cache, 'tolist') else metadata.cu_seqlens_k_with_cache}"
1278-
)
12791147

12801148
# Write cache only for new tokens of prefill/chunked-prefill batches.
12811149
# Decode batches (seq_lens_encoder == 0) are intentionally skipped here — they

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -464,12 +464,6 @@ def forward(
464464
forward_meta=forward_meta,
465465
)
466466

467-
# Gated by MLA_CHUNK_DEBUG=1 via logger.debug (see mla_attention_backend.py).
468-
logger.debug(
469-
f"[deepseek_v3 forward] key.shape={key.shape}, value.shape={value.shape}, "
470-
f"full_k_pe.shape={full_k_pe.shape}"
471-
)
472-
473467
fmha_out.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim])
474468
fmha_out = fmha_out[:, :, : self.v_head_dim]
475469
fmha_out.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])

0 commit comments

Comments
 (0)