Skip to content

Commit c3226f7

Browse files
benchislettalhridoy
authored andcommitted
[Spec Decode] Enable efficient speculative decoding with FlashInfer-MLA (vllm-project#25984)
Signed-off-by: Benjamin Chislett <[email protected]>
1 parent aac2fb8 commit c3226f7

File tree

3 files changed

+42
-5
lines changed

3 files changed

+42
-5
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@
190190
import functools
191191
from abc import abstractmethod
192192
from dataclasses import dataclass, field
193-
from typing import Generic, Optional, TypeVar, Union
193+
from typing import ClassVar, Generic, Optional, TypeVar, Union
194194

195195
import torch
196196
from tqdm import tqdm
@@ -454,6 +454,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
454454
understand this class
455455
"""
456456

457+
# Whether the backend supports reordering the batch such that
458+
# short sequences (i.e. verification for speculative decoding) are
459+
# classified as decode requests.
460+
# If True, this will increase `reorder_batch_threshold` (below) when
461+
# speculative decoding is enabled, and set `require_uniform=True` when
462+
# when reordering the batch. Non-uniform decode requests will
463+
# fall back to prefill in this case.
464+
supports_uniform_spec_as_decode: ClassVar[bool] = False
465+
466+
# The threshold for reordering the batch into decode and prefill requests.
467+
# If > 1, the batch will be reordered such that requests with
468+
# query length <= threshold are classified as decode requests.
469+
# Use `supports_uniform_spec_as_decode` (above) to set this automatically
470+
# when speculative decoding is enabled.
457471
reorder_batch_threshold: int = 1
458472

459473
@staticmethod
@@ -503,6 +517,7 @@ def __init__(
503517
self.model_config = vllm_config.model_config
504518
parallel_config = vllm_config.parallel_config
505519
self.compilation_config = vllm_config.compilation_config
520+
self.vllm_config = vllm_config
506521
self.device = device
507522

508523
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
@@ -578,6 +593,11 @@ def __init__(
578593
device=device,
579594
)
580595

596+
supports_spec_as_decode = self.supports_uniform_spec_as_decode
597+
self._init_reorder_batch_threshold(
598+
self.reorder_batch_threshold, supports_spec_as_decode
599+
)
600+
581601
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
582602
qo_indptr = prefill.query_start_loc
583603

@@ -714,7 +734,9 @@ def build(
714734

715735
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
716736
split_decodes_and_prefills(
717-
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
737+
common_attn_metadata,
738+
decode_threshold=self.reorder_batch_threshold,
739+
require_uniform=self.supports_uniform_spec_as_decode,
718740
)
719741
)
720742

vllm/v1/attention/backends/mla/flashinfer_mla.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222

2323

2424
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
25+
# enable spec-as-decode optimization
26+
supports_uniform_spec_as_decode: ClassVar[bool] = True
27+
2528
# enable full CUDA Graph support for decode-only capture
2629
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
2730

@@ -111,7 +114,15 @@ def _forward_decode(
111114
q = torch.cat([q_nope, q_pe], dim=-1)
112115

113116
# trtllm API requires extra dimension q_len_per_request for MTP
114-
q = q.unsqueeze(1)
117+
if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0:
118+
logger.warning_once(
119+
"""FlashInferMLAImpl got a query of uneven length.
120+
This usually indicates an issue in batch reordering
121+
or incorrect setup in dummy_run."""
122+
)
123+
q = q.unsqueeze(1)
124+
else:
125+
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])
115126

116127
if self.bmm1_scale is None:
117128
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
@@ -132,6 +143,9 @@ def _forward_decode(
132143
bmm2_scale=self.bmm2_scale,
133144
)
134145

146+
# Flatten the output for consistent shape
147+
o = o.view(-1, o.shape[-2], o.shape[-1])
148+
135149
# TODO: Return LSE pending support from Flashinfer API:
136150
# https://github.com/flashinfer-ai/flashinfer/pull/1566
137151
return o, None

vllm/v1/attention/backends/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,9 @@ def _init_reorder_batch_threshold(
275275
speculative_config is not None
276276
and speculative_config.num_speculative_tokens is not None
277277
):
278-
self.reorder_batch_threshold = (
279-
1 + speculative_config.num_speculative_tokens
278+
self.reorder_batch_threshold = max(
279+
self.reorder_batch_threshold,
280+
1 + speculative_config.num_speculative_tokens,
280281
)
281282

282283
@abstractmethod

0 commit comments

Comments
 (0)