Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
269f965
[misc] remove deprecated call to `end_forward` in flashinfer backend
abmfy Dec 14, 2024
8c375a3
[flashinfer] upgrade to flashinfer 0.2.0
abmfy Dec 20, 2024
a62b854
[style] fix yapf check
abmfy Dec 20, 2024
b37ff55
[FlashInfer] Pass infered global hyperparameters to `plan`
abmfy Dec 31, 2024
72bdf7e
[FlashInfer] Cache inferred global hyperparameters
abmfy Dec 31, 2024
97dcedc
[Misc] Use `typing.Optional` for Python 3.9 compatability
abmfy Dec 31, 2024
56798c5
[Style] Fix lint errors
abmfy Dec 31, 2024
706a6f6
Merge branch 'main' into flashinfer-0.2
abmfy Jan 22, 2025
dacb6af
[FlashInfer] Cache global hyperparameters in AttentionMetadataBuilder…
abmfy Jan 22, 2025
06fa7cc
[Style] Fix ruff
abmfy Jan 22, 2025
bc480b0
[FlashInfer] Get per layer params from vllm config
abmfy Jan 23, 2025
5a70aac
[FlashInfer] Store vllm config in attention state
abmfy Jan 23, 2025
e0397e9
[CI] Update FlashInfer version
abmfy Jan 23, 2025
ec49257
format
youkaichao Jan 23, 2025
500ff5b
Merge branch 'main' into flashinfer-0.2
abmfy Jan 24, 2025
bde6807
[Misc] Add space in assert message
abmfy Jan 24, 2025
69d7c8d
[FlashInfer] Warn on models with interleaved attention
abmfy Jan 24, 2025
d4d63dc
[Test] Change backend to flash_attn for gemma in compile tests
abmfy Jan 24, 2025
6e7e933
fix inconsistent vllm config
youkaichao Jan 25, 2025
0b47067
Merge branch 'flashinfer-0.2' of github.com:abmfy/vllm-flashinfer int…
abmfy Jan 25, 2025
f6e33a7
[Test] Skip tests for Gemma2 with FlashInfer backend
abmfy Jan 25, 2025
847a4d6
[CI] Build FlashInfer from source
abmfy Jan 25, 2025
5b0fe64
[CI] Fix FlashInfer build command
abmfy Jan 25, 2025
69445cd
[CI] Fix Dockerfile
abmfy Jan 25, 2025
963aff7
[CI] Fix FlashInfer AOT build in Dockerfile
abmfy Jan 25, 2025
ae9da66
fix flashinfer docker build
youkaichao Jan 26, 2025
afa377c
Merge branch 'main' into flashinfer-0.2
youkaichao Jan 26, 2025
269e1eb
fix build command
youkaichao Jan 26, 2025
2e50ab8
move command
youkaichao Jan 26, 2025
0fe979d
unify to use setup.py
youkaichao Jan 26, 2025
3dd209c
fix cd
youkaichao Jan 26, 2025
bcd04fd
fix recursive clone
youkaichao Jan 26, 2025
bb44221
comment
youkaichao Jan 26, 2025
5ca67ae
[CI] Use precompiled FlashInfer AOT wheel
abmfy Jan 26, 2025
3c89bfb
[CI] Temporarily switch to CUDA develop image for vllm-base
abmfy Jan 26, 2025
293fdd6
Merge branch 'main' into flashinfer-0.2
abmfy Jan 26, 2025
5d8ad22
also install jit build dependency
youkaichao Jan 26, 2025
4d57ef9
[FlashInfer] Fix type of k_scale and v_scale
abmfy Jan 26, 2025
33ff07b
Merge branch 'main' into flashinfer-0.2
abmfy Jan 26, 2025
ef15977
Merge branch 'flashinfer-0.2' of github.com:abmfy/vllm-flashinfer int…
abmfy Jan 26, 2025
21efc67
fix reshape_and_cache_flash
youkaichao Jan 27, 2025
a6b6fe8
use new flashinfer
youkaichao Jan 27, 2025
1f13235
Merge branch 'main' into flashinfer-0.2
youkaichao Jan 27, 2025
f17dbc3
update v1 tests
youkaichao Jan 27, 2025
506b641
refactor test
youkaichao Jan 27, 2025
2e476a2
revert
youkaichao Jan 27, 2025
95b5493
add comments
youkaichao Jan 27, 2025
55b55d3
only check compile when loading
youkaichao Jan 27, 2025
1f80aee
test in ci?
youkaichao Jan 27, 2025
5be3783
fix one test
youkaichao Jan 27, 2025
071a68e
fix test_flashinfer_prefill_with_paged_kv
youkaichao Jan 27, 2025
0e0f57f
relax test for prefill
youkaichao Jan 27, 2025
2134e77
fix test_flashinfer_prefill_with_paged_fp8_kv
youkaichao Jan 27, 2025
8e42297
relax test for prefill
youkaichao Jan 27, 2025
b4a7992
fix test_flashinfer_decode_with_paged_fp8_kv
youkaichao Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 160 additions & 21 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
Expand All @@ -13,9 +14,11 @@
from vllm.vllm_flash_attn import flash_attn_varlen_func
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
# Avoid turning these types into variables during type checking
if not TYPE_CHECKING:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0

import torch
Expand All @@ -30,7 +33,9 @@
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)

Expand Down Expand Up @@ -99,6 +104,72 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")


@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""

window_left: int
logits_soft_cap: Optional[float]
sm_scale: float


def get_per_layer_parameters(
vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""

layers = vllm_config.compilation_config.static_forward_context
per_layer_params: Dict[str, PerLayerParameters] = {}

for key, layer in layers.items():
assert isinstance(layer, Attention)

impl = layer.impl
assert isinstance(impl, FlashInferImpl)

# Infer hyperparameters from the attention layer
window_size = impl.sliding_window
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = impl.logits_soft_cap
sm_scale = impl.scale

per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale)

return per_layer_params


def infer_global_hyperparameters(
per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`

So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""

assert len(per_layer_params) > 0, "No attention layers found in the model."

param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
assert params == global_params, (
"FlashInfer backend currently only supports models in which all"
"layers share the same values for the following hyperparameters:"
"`window_left`, `logits_soft_cap`, `sm_scale`.")

return global_params


class FlashInferState(AttentionState):

def __init__(self, runner):
Expand All @@ -108,6 +179,9 @@ def __init__(self, runner):
self._decode_wrapper = None
self._prefill_wrapper = None

# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remember the vllm_config here?


def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
Expand Down Expand Up @@ -215,6 +289,9 @@ def graph_capture_get_metadata_for_batch(
batch_size + 1,
dtype=torch.int32)

global_params = infer_global_hyperparameters(
get_per_layer_parameters(get_current_vllm_config()))

attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size],
Expand All @@ -237,7 +314,9 @@ def graph_capture_get_metadata_for_batch(
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None)
prefill_wrapper=None,
**dataclasses.asdict(global_params),
)
attn_metadata.begin_forward()
return attn_metadata

Expand Down Expand Up @@ -324,9 +403,28 @@ class FlashInferMetadata(AttentionMetadata):
data_type: torch.dtype = None
# The data type of the query
q_data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
# FlashInfer 0.2 encourages passing host tensors
device: torch.device = torch.device("cpu")
is_profile_run: bool = False

# The FlashInfer backend currently supports only models in which all layers
# share the same following hyperparameters:

# The left (inclusive) window size for the attention window, when
# set to `-1`, the window size will be set to the full length of
# the sequence. Defaults to `-1`.
window_left: int = -1
# The attention logits soft capping value (used in Gemini, Grok and
# Gemma-2, etc.), if not provided, will be set to `0`. If greater
# than 0, the logits will be capped according to formula:
# $$\texttt{logits\_soft\_cap} \times
# \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$,
# where $x$ is the input logits.
logits_soft_cap: Optional[float] = None
# The scale used in softmax, if not provided, will be set to
# `1.0 / sqrt(head_dim)`.
sm_scale: Optional[float] = None

def __post_init__(self):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
Expand Down Expand Up @@ -362,14 +460,21 @@ def begin_forward(self):
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.prefill_wrapper.plan(
self.query_start_loc,
self.paged_kv_indptr[:self.num_prefills + 1],
self.paged_kv_indices,
self.paged_kv_last_page_len[:self.num_prefills],
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.data_type)
if self.num_decode_tokens > 0:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
Expand All @@ -385,8 +490,7 @@ def begin_forward(self):
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)

assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
self.decode_wrapper.begin_forward(
self.decode_wrapper.plan(
self.paged_kv_indptr[self.num_prefills:],
self.paged_kv_indices,
self.paged_kv_last_page_len[self.num_prefills:],
Expand All @@ -396,8 +500,11 @@ def begin_forward(self):
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
sm_scale=self.sm_scale,
# kv-cache data type.
data_type=self.data_type,
kv_data_type=self.data_type,
# query data type.
q_data_type=self.q_data_type)

Expand Down Expand Up @@ -495,6 +602,11 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remember the vllm_config here by calling get_current_vllm_config()


# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None

self.vllm_config = get_current_vllm_config()

def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
Expand Down Expand Up @@ -527,6 +639,20 @@ def prepare(self):
self.total_blocks = 0
self.is_profile_run: bool = False

if self.global_hyperparameters is None:
# Infer global hyperparameters, since currently we only support
# models in which all layers share the same values for the
# following hyperparameters:
# - `window_left`
# - `logits_soft_cap`
# - `sm_scale`
inferred_params = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config))
self.global_hyperparameters = inferred_params
self.window_left = inferred_params.window_left
self.logits_soft_cap = inferred_params.logits_soft_cap
self.sm_scale = inferred_params.sm_scale

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
Expand Down Expand Up @@ -754,7 +880,11 @@ def build(self, seq_lens: List[int], query_lens: List[int],
data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=use_captured_graph,
is_profile_run=self.is_profile_run)
is_profile_run=self.is_profile_run,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
sm_scale=self.sm_scale,
)


class FlashInferImpl(AttentionImpl):
Expand Down Expand Up @@ -883,25 +1013,34 @@ def forward(
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
prefill_output = prefill_meta.prefill_wrapper.forward(

assert prefill_meta.prefill_wrapper._causal
assert prefill_meta.prefill_wrapper._window_left == window_left
assert prefill_meta.prefill_wrapper._logits_soft_cap == (
logits_soft_cap or 0.0)
assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale

prefill_output = prefill_meta.prefill_wrapper.run(
query,
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
window_left=window_left)
)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
assert decode_meta.decode_wrapper is not None
decode_output = decode_meta.decode_wrapper.forward(

assert decode_meta.decode_wrapper._window_left == window_left
assert decode_meta.decode_wrapper._logits_soft_cap == (
logits_soft_cap or 0.0)
assert decode_meta.decode_wrapper._sm_scale == softmax_scale

decode_output = decode_meta.decode_wrapper.run(
decode_query,
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
window_left=window_left)
)

if prefill_output is None and decode_output is not None:
# Decode only batch.
Expand Down
16 changes: 10 additions & 6 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_kv_transfer_group, get_pp_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -1498,11 +1498,15 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
) if get_tensor_model_parallel_rank() == 0 else
self.vllm_config.compilation_config.capture_sizes)
for batch_size in capture_sizes:
attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch(
batch_size,
is_encoder_decoder_model=self.model_config.
is_encoder_decoder))
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during
# worker initialization
attn_metadata = (self.attn_state.
graph_capture_get_metadata_for_batch(
batch_size,
is_encoder_decoder_model=self.
model_config.is_encoder_decoder,
))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we don't need this change.


if self.lora_config:
lora_mapping = LoRAMapping(
Expand Down
8 changes: 5 additions & 3 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch.nn as nn

from vllm.config import ObservabilityConfig, VllmConfig
from vllm.config import ObservabilityConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -546,8 +546,10 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
bytes)
worker_class = cloudpickle.loads(
self.vllm_config.parallel_config.worker_cls)
self.worker = worker_class(**kwargs)
assert self.worker is not None
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
assert self.worker is not None

def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try:
Expand Down
Loading