Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 46 additions & 35 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

import numpy as np
import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
from flashinfer import (BatchPrefillWithPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper)
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
Expand Down Expand Up @@ -228,8 +227,10 @@ class FlashInferMetadata:
# For cascade attention (CPU for planning).
use_cascade: bool

# Use Prefill Wrapper for both prefill and decode
# as backend all dispatches to the prefill kernels (w/ different tile sizes)
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None

qo_indptr_gpu: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -262,7 +263,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
self._decode_wrappers_cudagraph: dict[
int, BatchDecodeWithPagedKVCacheWrapper] = {}
int, BatchPrefillWithPagedKVCacheWrapper] = {}
self._decode_cudagraph_max_bs = min(
max_num_reqs, self.compilation_config.max_capture_size)

Expand Down Expand Up @@ -314,6 +315,11 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=self.device)
# this is a placeholder buffer for initializing
# the prefill wrapper w/ graph
self.qo_indptr_buf = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
# host-side buffer
pin_memory = is_pin_memory_available()
self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1,
Expand Down Expand Up @@ -359,25 +365,26 @@ def _get_decode_wrapper(self,

if decode_wrapper is None:
if use_cudagraph:
qo_indptr = self.qo_indptr_buf[:batch_size + 1]
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
paged_kv_indices = self.paged_kv_indices
paged_kv_last_page_len = self.paged_kv_last_page_len[:
batch_size]
else:
qo_indptr = None
paged_kv_indptr = None
paged_kv_indices = None
paged_kv_last_page_len = None
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
get_kv_cache_layout(),

decode_wrapper = BatchPrefillWithPagedKVCacheWrapper(
float_workspace_buffer=self._get_workspace_buffer(),
kv_layout=get_kv_cache_layout(),
use_cuda_graph=use_cudagraph,
paged_kv_indptr_buffer=paged_kv_indptr,
paged_kv_indices_buffer=paged_kv_indices,
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
# Tensor cores are enabled by default because the perf would be
# at least as good as cuda cores for all attention ops in latest
# gpus.
use_tensor_cores=True,
qo_indptr_buf=qo_indptr,
paged_kv_indptr_buf=paged_kv_indptr,
paged_kv_indices_buf=paged_kv_indices,
paged_kv_last_page_len_buf=paged_kv_last_page_len,
backend="fa2" # use FA2 for backend
)

# save the decode wrapper
Expand Down Expand Up @@ -1009,36 +1016,41 @@ def fast_plan_decode(
Part of the code get inspiration from the original plan from FlashInfer repo
and the implementation of fast_decode_plan for FlashInfer in SGlang repo.
"""
# Get placeholder buffer for decode, where all `qo-length` are 1
batch_size = len(last_page_len_cpu)
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")

# Warm up with the original plan if it is first call, and always run the
# original plan if we run for dynamic shape. For fixed shape (cudagraph),
# this warm up is to generate the _cached_module for the decode wrapper.
if not self.is_cuda_graph_enabled or \
getattr(self, "vllm_first_call", True):
self.plan(
indptr_cpu,
indices,
last_page_len_cpu,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
pos_encoding_mode,
window_left,
logits_soft_cap,
q_data_type,
kv_data_type,
data_type,
sm_scale,
rope_scale,
rope_theta,
non_blocking,
qo_indptr=qo_indptr_host,
paged_kv_indptr=indptr_cpu,
paged_kv_indices=indices,
paged_kv_last_page_len=last_page_len_cpu,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_dim,
page_size=page_size,
head_dim_vo=head_dim,
causal=False,
pos_encoding_mode=pos_encoding_mode,
window_left=window_left,
logits_soft_cap=logits_soft_cap,
q_data_type=q_data_type,
kv_data_type=kv_data_type,
sm_scale=sm_scale,
rope_scale=rope_scale,
rope_theta=rope_theta,
non_blocking=non_blocking,
)
self.vllm_first_call = False
return

assert self.is_cuda_graph_enabled, "Should be cudagraph only here"

batch_size = len(last_page_len_cpu)
if logits_soft_cap is None:
logits_soft_cap = 0.0

Expand Down Expand Up @@ -1074,10 +1086,8 @@ def fast_plan_decode(
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu,
non_blocking=True)

qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")

try:
# Make sure we pass exactly 15 arguments for tensor core version
# PrefillPlan parameters for batched decode
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
Expand All @@ -1098,6 +1108,7 @@ def fast_plan_decode(
except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e

self._causal = False
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
Expand Down