Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@
import functools
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Generic, Optional, TypeVar, Union
from typing import ClassVar, Generic, Optional, TypeVar, Union

import torch
from tqdm import tqdm
Expand Down Expand Up @@ -454,6 +454,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
understand this class
"""

# Whether the backend supports reordering the batch such that
# short sequences (i.e. verification for speculative decoding) are
# classified as decode requests.
# If True, this will increase `reorder_batch_threshold` (below) when
# speculative decoding is enabled, and set `require_uniform=True` when
# when reordering the batch. Non-uniform decode requests will
# fall back to prefill in this case.
supports_uniform_spec_as_decode: ClassVar[bool] = False

# The threshold for reordering the batch into decode and prefill requests.
# If > 1, the batch will be reordered such that requests with
# query length <= threshold are classified as decode requests.
# Use `supports_uniform_spec_as_decode` (above) to set this automatically
# when speculative decoding is enabled.
reorder_batch_threshold: int = 1

@staticmethod
Expand Down Expand Up @@ -503,6 +517,7 @@ def __init__(
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
self.device = device

self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
Expand Down Expand Up @@ -578,6 +593,11 @@ def __init__(
device=device,
)

supports_spec_as_decode = self.supports_uniform_spec_as_decode
self._init_reorder_batch_threshold(
self.reorder_batch_threshold, supports_spec_as_decode
)

def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc

Expand Down Expand Up @@ -714,7 +734,9 @@ def build(

num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=self.supports_uniform_spec_as_decode,
)
)

Expand Down
16 changes: 15 additions & 1 deletion vllm/v1/attention/backends/mla/flashinfer_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@


class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable spec-as-decode optimization
supports_uniform_spec_as_decode: ClassVar[bool] = True

# enable full CUDA Graph support for decode-only capture
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH

Expand Down Expand Up @@ -111,7 +114,15 @@ def _forward_decode(
q = torch.cat([q_nope, q_pe], dim=-1)

# trtllm API requires extra dimension q_len_per_request for MTP
q = q.unsqueeze(1)
if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0:
logger.warning_once(
"""FlashInferMLAImpl got a query of uneven length.
This usually indicates an issue in batch reordering
or incorrect setup in dummy_run."""
)
q = q.unsqueeze(1)
else:
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])

if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
Expand All @@ -132,6 +143,9 @@ def _forward_decode(
bmm2_scale=self.bmm2_scale,
)

# Flatten the output for consistent shape
o = o.view(-1, o.shape[-2], o.shape[-1])

# TODO: Return LSE pending support from Flashinfer API:
# https://github.com/flashinfer-ai/flashinfer/pull/1566
return o, None
5 changes: 3 additions & 2 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,9 @@ def _init_reorder_batch_threshold(
speculative_config is not None
and speculative_config.num_speculative_tokens is not None
):
self.reorder_batch_threshold = (
1 + speculative_config.num_speculative_tokens
self.reorder_batch_threshold = max(
self.reorder_batch_threshold,
1 + speculative_config.num_speculative_tokens,
)

@abstractmethod
Expand Down