Triton mla#7804
Conversation
There was a problem hiding this comment.
Pull request overview
该 PR 将 SGLang 中基于 Triton 的 MLA(Multi-head Latent Attention)decode attention 迁移到 FastDeploy,作为新的 attention backend:TRITON_MLA_ATTN,并打通对应的后端选择、KV cache 写入与 decode kernel 逻辑。
Changes:
- 新增
TritonMLAAttentionBackend:extend 复用 FlashAttention,decode 走 Triton split-KV 两阶段 kernel,并加入 CUDA Graph 相关的元数据/缓冲区预分配。 - 新增 Triton kernel:paged KV cache 写入(
mla_cache_kernel.py)与 decode attention(decode_attention.py),并在triton_ops/__init__.py导出。 - 平台/配置/运行时适配:新增
_Backend.TRITON_MLA_ATTN,CUDA 平台路由、use_mla_cache识别、GPU runner 的 MLA cache 判断与 DeepSeek-V3 空 batch 保护。
需要关注(非代码行评论):
- PR 标题目前为
"Triton mla",不符合仓库约定的[CLASS]Title格式;建议例如:[Feature] Add Triton MLA attention backend(或按实际分类调整)。 - 该 PR 引入新 backend 与新的环境变量取值(
FD_ATTENTION_BACKEND=TRITON_MLA_ATTN),建议同步检查/更新相关使用文档(如环境变量说明文档)以避免用户漏配或误配。
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| fastdeploy/worker/gpu_model_runner.py | 识别 TRITON_MLA_ATTN 作为 MLA cache 路径,并让 position_ids/slot_mapping 计算覆盖新 backend |
| fastdeploy/platforms/cuda.py | CUDA 平台增加 TRITON_MLA_ATTN 路由与日志/错误提示更新 |
| fastdeploy/platforms/base.py | 新增 _Backend.TRITON_MLA_ATTN 枚举值 |
| fastdeploy/config.py | CacheConfig.use_mla_cache 识别 TRITON_MLA_ATTN |
| fastdeploy/model_executor/layers/attention/init.py | 注册并导出 TritonMLAAttentionBackend |
| fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py | Triton MLA backend 核心实现(extend/ decode/ mixed + metadata/buffer 预分配) |
| fastdeploy/model_executor/layers/attention/triton_ops/init.py | 导出 triton_ops 下的核心函数 |
| fastdeploy/model_executor/layers/attention/triton_ops/decode_attention.py | 新增 split-KV decode attention Triton kernel(paged KV 寻址) |
| fastdeploy/model_executor/layers/attention/triton_ops/mla_cache_kernel.py | 新增 Triton KV cache 写入 kernel(写入 paged latent cache) |
| fastdeploy/model_executor/models/deepseek_v3.py | 空 batch 时 attn_out 保护,避免 None 传入 o_proj |
| custom_ops/gpu_ops/helper.h | C++ 侧 checkAttentionBackend() 识别 TRITON_MLA_ATTN |
| tl.store( | ||
| O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, | ||
| acc / e_sum, |
| raise ValueError( | ||
| "Invalid attention backend you specified.\n" | ||
| "Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN] in cuda place." | ||
| "Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN, TRITON_MLA_ATTN] in cuda place." | ||
| ) |
| elif selected_backend == _Backend.TRITON_MLA_ATTN: | ||
| logger.info("Using TRITON MLA ATTN backend.") | ||
| return "fastdeploy.model_executor.layers.attention.TritonMLAAttentionBackend" |
CI报告基于以下代码生成(30分钟更新一次): 1 任务总览❌ 2 个 Required 任务失败,需优先处理后方可合并。
2 任务状态汇总2.1 Required任务:8/10 通过
2.2 可选任务 — 29/32 通过
3 失败详情(仅 required)Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage — 覆盖率不达标(置信度: 中)Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage
根因详情: 关键日志: 修复建议:
修复建议摘要: 为 triton_mla 新增文件补充单元测试或申请豁免 关联变更: Approval — 审批检查失败(置信度: 高)Approval
根因详情: 关键日志: 修复建议:
修复建议摘要: 请 xyxinyang 或 zyyzghb 对 PR 进行 Approve 关联变更: |
| attn_out = paddle.zeros( | ||
| [hidden_states.shape[0], self.num_attention_heads_tp * self.v_head_dim], | ||
| dtype=hidden_states.dtype, | ||
| ) |
There was a problem hiding this comment.
这个是None直接报错就好啦,感觉没啥必要新增这段逻辑?
| @@ -0,0 +1,368 @@ | |||
| """ | |||
| @@ -0,0 +1,499 @@ | |||
| """ | |||
| @@ -0,0 +1,147 @@ | |||
| """ | |||
| self.causal: bool = getattr(fd_config.model_config, "causal", True) | ||
|
|
||
| self.num_heads: int = num_heads | ||
| self.head_dim: int = fd_config.model_config.head_dim |
| self.max_kv_splits: int = 32 | ||
|
|
||
| self.rank, self.device_id = init_rank_and_device_id(fd_config) | ||
| self.useless_tensor = paddle.randn([1]).cast("int32") |
| seq_lens_decoder = forward_meta.seq_lens_decoder | ||
| seq_lens_this_time = forward_meta.seq_lens_this_time | ||
| decode_mask = seq_lens_decoder > 0 | ||
| decode_bs = int(decode_mask.sum().item()) | ||
| metadata.decode_bs = decode_bs | ||
|
|
||
| if decode_bs > 0: | ||
| decode_seq_lens = (seq_lens_decoder + seq_lens_this_time)[decode_mask] | ||
| decode_block_tables = forward_meta.block_tables[decode_mask] | ||
| total_kv_len = int(paddle.sum(decode_seq_lens).item()) | ||
|
|
| total_tokens = q.shape[0] | ||
| Lv = self.kv_lora_rank | ||
|
|
||
| # Decode tokens are at positions cu_seqlens_q[i] for sequences with seq_lens_decoder > 0 | ||
| cu_seqlens = forward_meta.cu_seqlens_q | ||
| seq_lens_decoder = forward_meta.seq_lens_decoder | ||
| decode_mask = seq_lens_decoder > 0 | ||
| max_num_seqs = seq_lens_decoder.shape[0] | ||
| seq_indices = paddle.arange(max_num_seqs, dtype="int32") | ||
| decode_seq_indices = seq_indices[decode_mask] | ||
| decode_token_positions = cu_seqlens[decode_seq_indices] | ||
|
|
||
| q_decode = q[decode_token_positions] | ||
| decode_out = self._run_decode_kernel(q_decode, latent_cache, metadata) | ||
|
|
||
| output = paddle.zeros([total_tokens, self.num_heads * Lv], dtype=q.dtype) | ||
| output[decode_token_positions] = decode_out | ||
| return output |
| bs = q.shape[0] | ||
| Lv = self.kv_lora_rank | ||
| latent_dim = self.kv_lora_rank + self.qk_rope_head_dim | ||
| q_reshaped = q.reshape([bs, self.num_heads, latent_dim]) | ||
|
|
||
| attn_logits = paddle.empty([bs, self.num_heads, self.max_kv_splits, Lv], dtype="float32") | ||
| attn_lse = paddle.empty([bs, self.num_heads, self.max_kv_splits], dtype="float32") | ||
| o = paddle.empty([bs, self.num_heads, Lv], dtype=q.dtype) |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.
Comments suppressed due to low confidence (4)
fastdeploy/model_executor/models/deepseek_v3.py:1065
- 早退返回的是单个
hidden_states,但调用方(DeepSeekV3Model.forward 中hidden_states, residual = self.layers[i](...))按二元组解包。当need_do_prefill和need_do_decode都为 False 时这里会直接 ValueError。应当返回(hidden_states, residual)。
if not need_do_prefill and not need_do_decode:
return hidden_states
fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:335
forward_mixed在k is None分支下,对 extend+decode 混合 batch 只会运行 decode kernel,extend token 位置上的 output 始终被填 0(333-334 行),相当于丢弃了 extend token 的注意力输出。如果该分支确实需要支持混合 batch,则需要同时执行 extend;如果约定该 backend 在 mixed 调用时永不出现 extend,至少应当 assertdecode_bs == batch_size,否则将产生静默错误。
# Mixed batch (no CUDAGraph): q has all tokens (extend + decode).
# Extract decode tokens (1 per decode sequence), run kernel, scatter back.
decode_bs = metadata.decode_bs
if decode_bs == 0:
Lv = self.kv_lora_rank
return paddle.zeros([q.shape[0], self.num_heads * Lv], dtype=q.dtype)
total_tokens = q.shape[0]
Lv = self.kv_lora_rank
# Decode tokens are at positions cu_seqlens_q[i] for sequences with seq_lens_decoder > 0
cu_seqlens = forward_meta.cu_seqlens_q
seq_lens_decoder = forward_meta.seq_lens_decoder
decode_mask = seq_lens_decoder > 0
max_num_seqs = seq_lens_decoder.shape[0]
seq_indices = paddle.arange(max_num_seqs, dtype="int32")
decode_seq_indices = seq_indices[decode_mask]
decode_token_positions = cu_seqlens[decode_seq_indices]
q_decode = q[decode_token_positions]
decode_out = self._run_decode_kernel(q_decode, latent_cache, metadata)
output = paddle.zeros([total_tokens, self.num_heads * Lv], dtype=q.dtype)
output[decode_token_positions] = decode_out
return output
fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:351
_run_decode_kernel每次调用都通过paddle.empty新分配attn_logits、attn_lse、o。CUDA Graph 要求 capture 与 replay 使用相同的内存地址,每次新分配会破坏图捕获的稳定性(与本文件 129-134 行刻意预分配_kv_indptr_buf等 buffer 的初衷相矛盾)。建议把这些中间 buffer 也按max_num_seqs预分配,并在使用时按实际 bs 切片。
attn_logits = paddle.empty([bs, self.num_heads, self.max_kv_splits, Lv], dtype="float32")
attn_lse = paddle.empty([bs, self.num_heads, self.max_kv_splits], dtype="float32")
o = paddle.empty([bs, self.num_heads, Lv], dtype=q.dtype)
fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:262
flash_attention_v3_varlen与flash_attn_unpadded的关键字参数名不同(前者通常使用softmax_scale,后者使用scale),且返回结构不一致(v3 返回(out, softmax_lse)风格,而flash_attn_unpadded返回结构也是 tuple,但元素含义不同)。统一通过[0]取第一个元素在两条路径上需要确认是否都对应输出 tensor;另外建议增加对flash_attention_v3_varlen is None但 SM>=90 时的兜底处理(当前 try/except 仅在 import 阶段,运行期若 import 成功但调用失败将抛错)。请确认两套 API 调用确实兼容。
fmha_out = self.flash_attn_func(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
causal=self.causal,
**self.flash_attn_kwargs,
)[0]
| # Pre-compute decode kv_indptr/kv_indices into stable pre-allocated buffers. | ||
| # CUDAGraph requires tensors at the same memory address between capture and replay. | ||
| seq_lens_decoder = forward_meta.seq_lens_decoder | ||
| seq_lens_this_time = forward_meta.seq_lens_this_time | ||
| decode_mask = seq_lens_decoder > 0 | ||
| decode_bs = int(decode_mask.sum().item()) | ||
| metadata.decode_bs = decode_bs | ||
|
|
||
| if decode_bs > 0: | ||
| decode_seq_lens = (seq_lens_decoder + seq_lens_this_time)[decode_mask] | ||
| decode_block_tables = forward_meta.block_tables[decode_mask] | ||
| total_kv_len = int(paddle.sum(decode_seq_lens).item()) |
| def tanh(x): | ||
| return 2 * tl.sigmoid(2 * x) - 1 | ||
|
|
||
|
|
||
| @enable_compat_on_triton_kernel | ||
| @triton.jit |
| need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0 | ||
| need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0 | ||
|
|
||
| if not need_do_prefill and not need_do_decode: | ||
| return hidden_states |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.
Comments suppressed due to low confidence (1)
fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py:335
forward_mixed的 CUDA Graph 分支条件q.shape[0] == metadata.decode_bs依赖于在init_attention_metadata中通过.item()同步出来的decode_bs。这两点共同导致 CUDA Graph 实际无法用于该 backend:(a) capture 期内有 device→host 同步,(b) 不同 step 的decode_bs值通过 Python int 路径影响分支,会破坏 graph 的一致性。建议显式按 capture batch size 走单一路径(例如总是_run_decode_kernel),并通过预分配的 padded buffer 在 kernel 内部用 mask 处理无效 batch。
# CUDAGraph path: q contains exactly the captured batch size of decode tokens.
# Must always take this path during CUDAGraph capture/replay to keep the
# execution trace identical (same kernel launches, same tensor shapes).
if forward_meta.step_use_cudagraph or q.shape[0] == metadata.decode_bs:
return self._run_decode_kernel(q, latent_cache, metadata)
# Mixed batch (no CUDAGraph): q has all tokens (extend + decode).
# Extract decode tokens (1 per decode sequence), run kernel, scatter back.
decode_bs = metadata.decode_bs
if decode_bs == 0:
Lv = self.kv_lora_rank
return paddle.zeros([q.shape[0], self.num_heads * Lv], dtype=q.dtype)
total_tokens = q.shape[0]
Lv = self.kv_lora_rank
# Decode tokens are at positions cu_seqlens_q[i] for sequences with seq_lens_decoder > 0
cu_seqlens = forward_meta.cu_seqlens_q
seq_lens_decoder = forward_meta.seq_lens_decoder
decode_mask = seq_lens_decoder > 0
max_num_seqs = seq_lens_decoder.shape[0]
seq_indices = paddle.arange(max_num_seqs, dtype="int32")
decode_seq_indices = seq_indices[decode_mask]
decode_token_positions = cu_seqlens[decode_seq_indices]
q_decode = q[decode_token_positions]
decode_out = self._run_decode_kernel(q_decode, latent_cache, metadata)
output = paddle.zeros([total_tokens, self.num_heads * Lv], dtype=q.dtype)
output[decode_token_positions] = decode_out
return output
| attn_logits = paddle.empty([bs, self.num_heads, self.max_kv_splits, Lv], dtype="float32") | ||
| attn_lse = paddle.empty([bs, self.num_heads, self.max_kv_splits], dtype="float32") | ||
| o = paddle.empty([bs, self.num_heads, Lv], dtype=q.dtype) |
| decode_bs = int(decode_mask.sum().item()) | ||
| metadata.decode_bs = decode_bs | ||
|
|
||
| if decode_bs > 0: | ||
| decode_seq_lens = (seq_lens_decoder + seq_lens_this_time)[decode_mask] | ||
| decode_block_tables = forward_meta.block_tables[decode_mask] | ||
| total_kv_len = int(paddle.sum(decode_seq_lens).item()) | ||
|
|
||
| build_kv_indices_from_block_tables( | ||
| decode_block_tables, decode_seq_lens, self.block_size, decode_bs, | ||
| total_kv_len=total_kv_len, | ||
| kv_indptr_buf=self._kv_indptr_buf, | ||
| kv_indices_buf=self._kv_indices_buf, | ||
| ) |
| self.max_kv_splits: int = 32 | ||
|
|
||
| self.rank, self.device_id = init_rank_and_device_id(fd_config) | ||
| self.useless_tensor = paddle.randn([1]).cast("int32") |
| fmha_out = self.flash_attn_func( | ||
| q, | ||
| k, | ||
| v, | ||
| forward_meta.cu_seqlens_q, | ||
| forward_meta.cu_seqlens_k, | ||
| metadata.max_enc_len_this_time, | ||
| metadata.max_enc_len_this_time, | ||
| causal=self.causal, | ||
| **self.flash_attn_kwargs, | ||
| )[0] |
| decode_attention_fwd( | ||
| q_reshaped, | ||
| latent_cache, | ||
| latent_cache[:, :, :, :self.kv_lora_rank], | ||
| o, | ||
| metadata.kv_indptr, | ||
| metadata.kv_indices, | ||
| attn_logits, | ||
| attn_lse, | ||
| metadata.num_kv_splits, | ||
| self.max_kv_splits, | ||
| self.attn_softmax_scale, | ||
| self.block_size, | ||
| ) |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## develop #7804 +/- ##
==========================================
Coverage ? 62.76%
==========================================
Files ? 466
Lines ? 64651
Branches ? 9884
==========================================
Hits ? 40581
Misses ? 21309
Partials ? 2761
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
|
||
| self.prefix = prefix | ||
|
|
||
| prop = paddle.device.cuda.get_device_properties() |
CI报告基于以下代码生成(30分钟更新一次): 1 任务总览有 1 个 Required 任务失败(Approval 未通过审批),需优先处理后才能合并。
2 任务状态汇总2.1 Required任务 : 2/3 通过
2.2 可选任务 — 19/23 通过
3 失败详情(仅 required)Approval — 代码规范(置信度: 高)Approval
根因详情: 关键日志: 修复建议:
修复建议摘要: 联系 xyxinyang 或 zyyzghb 在 PR 页面 Approve Review 关联变更: PR 中新增了 Triton MLA 相关的 链接: 查看日志 |
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 Paddle-CI-Agent | pr_review |
2026-05-18 15:10:30
📋 Review 摘要
PR 概述:将 SGLang 的 Triton MLA decode attention 迁移至 FastDeploy,新增 TRITON_MLA_ATTN attention backend,适配 paged KV cache 并支持 CUDAGraph。
变更范围:layers/attention/、triton_ops/、model_executor/models/deepseek_v3.py、config.py、worker/gpu_model_runner.py、platforms/、custom_ops/gpu_ops/helper.h
影响面 Tag:[OP] [Models] [FDConfig]
问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🟡 建议 | triton_mla_attention_backend.py:157 |
init_attention_metadata 每步触发两次 D→H 同步(.item()),影响推理延迟 |
| ❓ 疑问 | triton_mla_attention_backend.py:216 |
forward_extend 的 max_seqlen_k 固定使用 max_enc_len_this_time,疑似不支持 chunked prefill |
| 📝 PR 规范 | — | 标题缺少 [Tag];描述缺少 ## Accuracy Tests 和 ## Checklist 章节 |
📝 PR 规范检查
标题问题:"Triton mla" 缺少官方 Tag,建议修改为:
[Feature] Add Triton MLA attention backend (TRITON_MLA_ATTN)
PR 描述建议(可直接复制,缺少 ## Accuracy Tests 和 ## Checklist,## Usage or Command 代码块未闭合):
## Motivation
将 SGLang 中基于 Triton 实现的 MLA (Multi-head Latent Attention) decode attention 迁移至 FastDeploy,作为新的 attention backend (`TRITON_MLA_ATTN`)。
## Modifications
### 新增文件
| 文件 | 说明 |
|------|------|
| `fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py` | Triton MLA Backend 核心实现,包含 extend/decode/mixed forward 及 CUDA Graph 兼容的 metadata 初始化 |
| `fastdeploy/model_executor/layers/attention/triton_ops/__init__.py` | triton_ops 包导出 |
| `fastdeploy/model_executor/layers/attention/triton_ops/decode_attention.py` | Split-KV 两阶段 decode attention triton kernel(从 SGLang 迁移,适配 FastDeploy paged KV cache 寻址) |
| `fastdeploy/model_executor/layers/attention/triton_ops/mla_cache_kernel.py` | Triton KV cache write kernel,将 `[compressed_kv \|\| k_pe]` 写入分页 cache |
### 修改文件
| 文件 | 说明 |
|------|------|
| `fastdeploy/platforms/base.py` | 新增 `_Backend.TRITON_MLA_ATTN` 枚举值 |
| `fastdeploy/platforms/cuda.py` | 添加 backend 路由至 `TritonMLAAttentionBackend` |
| `fastdeploy/config.py` | `CacheConfig.use_mla_cache` 识别 `TRITON_MLA_ATTN` |
| `fastdeploy/worker/gpu_model_runner.py` | MLA cache 路径识别 + `_apply_position_ids_if_needed` 支持新 backend |
| `fastdeploy/model_executor/layers/attention/__init__.py` | 注册并导出 `TritonMLAAttentionBackend` |
| `fastdeploy/model_executor/models/deepseek_v3.py` | `attn_out` 空值保护,避免空 batch 时 None 传入 `o_proj` |
| `custom_ops/gpu_ops/helper.h` | C++ 层 `checkAttentionBackend()` 识别新 backend |
## Usage or Command
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
export FD_ATTENTION_BACKEND="TRITON_MLA_ATTN"
export FLAGS_flash_attn_version=3
export FD_SAMPLING_CLASS=rejection
python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/GLM-4.7-Flash \
--port 8380 \
--tensor-parallel-size 4 \
--max-model-len 32768 \
--max-num-seqs 32
```
## Accuracy Tests
N/A(Triton kernel 正确性已通过 `tests/deterministic/test_triton_decode_attention.py` 和 `tests/deterministic/test_triton_mla_cache_kernel.py` 与 numpy 参考实现对比验证;端到端精度对比数据待补充)
## Checklist
- [x] Add at least a tag in the PR title.
- Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
- You can add new tags based on the PR content, but the semantics must be clear.
- [x] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.总体评价
整体实现质量较高,CUDAGraph 地址稳定性设计合理(预分配 buffer + 固定 grid dim),Paged KV Cache 适配正确,Triton kernel 的 stage1/stage2 split-KV 逻辑及 MLA 语义(kv_heads=1 等效 MQA)均无误。主要需关注 init_attention_metadata 中的每步 D→H 同步对延迟的影响,以及 chunked prefill 的兼容性问题。
| self.flash_attn_kwargs = {"scale": self.attn_softmax_scale, "training": False} | ||
|
|
||
| def init_attention_metadata(self, forward_meta: ForwardMeta): | ||
| metadata = TritonMLAAttentionMetadata() |
There was a problem hiding this comment.
🟡 建议 init_attention_metadata 中存在两次 Device→Host 同步
decode_bs = int(decode_mask.sum().item()) # D→H sync #1
# ...
total_kv_len = int(paddle.sum(decode_seq_lens).item()) # D→H sync #2init_attention_metadata 在每个推理步骤中都会被调用,这两次 .item() 强制 GPU→CPU 同步,会打断 CUDA 流水线,增加每步延迟(在高并发 decode 场景下尤为明显)。
建议修复:利用 forward_meta 中已有的 CPU 侧信息(如 seq_lens_decoder 的 CPU 副本或调度器传入的 batch 统计)替代 GPU sum + .item(),或将这两个值改为在 CPU 上直接计算(调度器已知 decode_bs 和各序列长度)。
|
|
||
| # Compute num_kv_splits into the pre-allocated buffer | ||
| compute_num_kv_splits(decode_seq_lens, decode_bs, self.max_kv_splits, out_buf=self._num_kv_splits_buf) | ||
| # Padded entries must be >= 1 to avoid division by zero in kernel |
There was a problem hiding this comment.
❓ 疑问 forward_extend 中 max_seqlen_k 固定使用 max_enc_len_this_time
fmha_out = self.flash_attn_func(
q, k, v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time, # max_seqlen_q ✓
metadata.max_enc_len_this_time, # max_seqlen_k ← 此处存疑
causal=self.causal,
**self.flash_attn_kwargs,
)[0]metadata.max_kv_len_this_time 已在 init_attention_metadata 中被赋值但在此处未使用。对于 非 chunked prefill 场景,Q 和 K 长度一致,两处均用 max_enc_len_this_time 是正确的;但若未来启用 chunked prefill,K 序列长度将大于当前 chunk 的 Q 长度,此时应使用 max_kv_len_this_time 作为 max_seqlen_k。
请确认此 backend 是否刻意不支持 chunked prefill(若是,建议在代码中添加注释或在 init_attention_metadata 中增加 assert not enable_chunked_prefill 防御),或改为 metadata.max_kv_len_this_time 以备后续扩展。
Motivation
将 SGLang 中基于 Triton 实现的 MLA (Multi-head Latent Attention) decode attention 迁移至 FastDeploy,作为新的 attention backend (
TRITON_MLA_ATTN)。Modifications
新增文件
fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.pyfastdeploy/model_executor/layers/attention/triton_ops/__init__.pyfastdeploy/model_executor/layers/attention/triton_ops/decode_attention.pyfastdeploy/model_executor/layers/attention/triton_ops/mla_cache_kernel.py修改文件
fastdeploy/platforms/base.py_Backend.TRITON_MLA_ATTN枚举值fastdeploy/platforms/cuda.pyTritonMLAAttentionBackendfastdeploy/config.pyCacheConfig.use_mla_cache识别TRITON_MLA_ATTNfastdeploy/worker/gpu_model_runner.py_apply_position_ids_if_needed支持新 backendfastdeploy/model_executor/layers/attention/__init__.pyTritonMLAAttentionBackendfastdeploy/model_executor/models/deepseek_v3.pyattn_out空值保护,避免空 batch 时 None 传入o_projcustom_ops/gpu_ops/helper.hcheckAttentionBackend()识别新 backend关键设计
kv_loc // block_size和kv_loc % block_size解码 block 寻址。_kv_indptr_buf、_kv_indices_buf、_num_kv_splits_buf),使用 Triton cumsum kernel 替代 thrust(避免 cudaMalloc),padding 填充保持 kernel grid dim 恒定。flash_attention_v3_varlen,SM80 使用flash_attn_unpadded,不重复造轮子。Usage or Command