From bef36294bce40e54a4e7c8c4f9cc4219ec1a5f8d Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 22 Apr 2024 19:22:53 -0700 Subject: [PATCH 01/19] draft --- csrc/cache.h | 8 ++ csrc/cache_kernels.cu | 79 +++++++++++++++ csrc/pybind.cpp | 4 + vllm/attention/backends/flashinfer.py | 135 ++++++++++++++++++++++++++ vllm/config.py | 3 + vllm/sequence.py | 3 +- vllm/worker/model_runner.py | 35 +++++++ 7 files changed, 266 insertions(+), 1 deletion(-) create mode 100644 vllm/attention/backends/flashinfer.py diff --git a/csrc/cache.h b/csrc/cache.h index 718a5f6cfd7f..4c142ce17f1b 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -24,6 +24,14 @@ void reshape_and_cache( const std::string& kv_cache_dtype, const float kv_scale); +void reshape_and_cache_flash( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype); + // Just for unittest void convert_fp8( torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 24aaa2ff3e26..a34a6466469c 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel( } } +template +__global__ void reshape_and_cache_flash_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] + scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int64_t tgt_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + + head_offset; + k_cache[tgt_value_idx] = key[src_key_idx]; + v_cache[tgt_value_idx] = value[src_value_idx]; + } +} } // namespace vllm #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ @@ -275,6 +310,50 @@ void reshape_and_cache( } } +void reshape_and_cache_flash( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype) +{ + if (kv_cache_dtype != "auto") { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = k_cache.size(1); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + int block_stride = k_cache.stride(0); + TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), + "reshape_and_cache_flash", + [&] { + vllm::reshape_and_cache_flash_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + k_cache.data_ptr(), + v_cache.data_ptr(), + slot_mapping.data_ptr(), + block_stride, + key_stride, + value_stride, + num_heads, + head_size, + block_size); + }); +} + namespace vllm { template diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index de02afc16211..d35a2a08bbdc 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -90,6 +90,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); + cache_ops.def( + "reshape_and_cache_flash", + &reshape_and_cache_flash, + "Reshape the key and value tensors and cache them"); cache_ops.def( "convert_fp8", &convert_fp8, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py new file mode 100644 index 000000000000..004cb9e3a9dc --- /dev/null +++ b/vllm/attention/backends/flashinfer.py @@ -0,0 +1,135 @@ +from typing import Type, Tuple, List, Dict, Optional +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) + +import torch +import flashinfer +from vllm._C import cache_ops + +class FlashInferBackend(AttentionBackend): + @staticmethod + def get_impl_cls() -> Type["FlashInferImpl"]: + return FlashInferImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "FlashInferMetadata": + return FlashInferMetadata(*args, **kwargs) + + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + raise NotImplementedError + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + raise NotImplementedError + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + raise NotImplementedError + +@dataclass +class FlashInferMetadata(AttentionMetadata): + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + + # The indptr of the paged kv-cache, shape: [batch_size + 1]. + # Please follow the definition in the FlashInfer documentation: https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout. + paged_kv_indptr: List[int] + + # The indices of the paged kv-cache of all sequences. + paged_kv_indices: List[int] + + # The last page length of the paged kv-cache of all sequences, shape: [batch_size]. + paged_kv_last_page_len: List[int] + + # The number of query/output heads. + num_qo_heads: int + + # The number of key/value heads. + num_kv_heads: int + + # The dimension of the heads + head_dim: int + + # The wrapper for the prefill or decode operation. + wrapper = None + + # The indptr of the query/output sequence, shape: [batch_size + 1]. + # This is only used for the prefill operation. + subquery_start_loc: Optional[torch.Tensor] = None + + # The block size for the decode operation. + block_size: Optional[int] = None + + use_cuda_graph: bool = False + + + def __post_init__(self): + assert not self.use_cuda_graph, "CUDA graph is not supported yet." + # Allocate 16MB workspace buffer + # Follow the example: https://docs.flashinfer.ai/api/python/prefill.html#batch-prefill-append-attention + workspace_buffer = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") + if self.is_prompt: + self.wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") + self.wrapper.begin_forward( + self.subquery_start_loc, + self.paged_kv_indptr, + self.paged_kv_indices, + self.paged_kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim + ) + else: + self.wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") + self.wrapper.begin_forward( + self.paged_kv_indptr, + self.paged_kv_indices, + self.paged_kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.block_size, + pos_encoding_mode="NONE", # FIXME: Add support for pos_encoding_mode + data_type=torch.float16 # FIXME: Add support for data_type + ) + +class FlashInferImpl(AttentionImpl): + def __init__(self, metadata: FlashInferMetadata): + self.prefill_wrapper = metadata.prefill_wrapper + + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata[FlashInferMetadata]): + if kv_cache is not None: + # Use the same reshape and cache kernel as flash attention. + cache_ops.reshape_and_cache_flash(key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + attn_metadata.kv_cache_dtype,) + + if attn_metadata.is_prompt: + assert kv_cache is None, "Does not support prefix caching yet." + attn_metadata.prefill_metadata.wrapper.forward(query, kv_cache, causal=True) + + else: + attn_metadata.decode_metadata.wrapper.forward(query, kv_cache) diff --git a/vllm/config.py b/vllm/config.py index 97ede0faa21a..276fa3882c5a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -271,6 +271,9 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) + def get_num_attention_heads(self) -> int: + return self.hf_text_config.num_attention_heads + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size diff --git a/vllm/sequence.py b/vllm/sequence.py index 7dcacab6f2ab..c36b9567b549 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -556,8 +556,9 @@ class SequenceGroupMetadata: numbers) token_chunk_size: The number of tokens to be processed (per sequence). None if chunking is not required. - state: Internal state tied to this sequence group. lora_request: LoRA request. + computed_block_nums: The block numbers that are already computed, used in prefix caching. + state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. """ diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 31e08789dfd1..b83201a09699 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -236,6 +236,9 @@ def _prepare_prompt( subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] + paged_kv_indices: List[int] = [] + paged_kv_indptr: List[int] = [0] + paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: return PreparePromptMetadata.empty() @@ -318,6 +321,13 @@ def _prepare_prompt( # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] + paged_kv_indices.extend(block_table) + paged_kv_indptr.append(len(block_table)) + last_len = seq_data.get_len() % self.block_size + if last_len == 0: + last_len = self.block_size + paged_kv_last_page_len.append(last_len) + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, prompt_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and @@ -405,6 +415,13 @@ def _prepare_prompt( context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, + paged_kv_indices = paged_kv_indices, + paged_kv_indptr = paged_kv_indptr, + paged_kv_last_page_len = paged_kv_last_page_len, + num_qo_heads = self.model_config.get_num_attention_heads(), + num_kv_heads = self.model_config.get_num_kv_heads(), + head_dim = self.model_config.get_head_size(), + block_size = self.block_size ) return PreparePromptMetadata( @@ -432,6 +449,9 @@ def _prepare_decode( lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + paged_kv_indices: List[int] = [] + paged_kv_indptr: List[int] = [0] + paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: return PrepareDecodeMetadata.empty() @@ -473,6 +493,14 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + paged_kv_indices.extend(block_table) + paged_kv_indptr.append(len(block_table)) + last_len = seq_data.get_len() % self.block_size + if last_len == 0: + last_len = self.block_size + paged_kv_last_page_len.append(last_len) + + # vLLM uses cuda graph only for decoding requests. # See `capture_model` API for more details. # For decoding requests, batch_size == input_tokens. @@ -535,6 +563,13 @@ def _prepare_decode( context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, + paged_kv_indices = paged_kv_indices, + paged_kv_indptr = paged_kv_indptr, + paged_kv_last_page_len = paged_kv_last_page_len, + num_qo_heads = self.model_config.get_num_attention_heads(), + num_kv_heads = self.model_config.get_num_kv_heads(), + head_dim = self.model_config.get_head_size(), + block_size = self.block_size ) return PrepareDecodeMetadata( input_tokens=input_tokens, From c1360e63e2e0baf8ca9b68ba89f8bc6cc3aaf367 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 23 Apr 2024 06:28:12 +0000 Subject: [PATCH 02/19] fix --- vllm/attention/backends/abstract.py | 8 ++ vllm/attention/backends/flash_attn.py | 4 +- vllm/attention/backends/flashinfer.py | 90 +++++++++++++--------- vllm/attention/backends/rocm_flash_attn.py | 4 +- vllm/attention/backends/torch_sdpa.py | 4 +- vllm/attention/backends/xformers.py | 4 +- vllm/attention/selector.py | 12 +++ vllm/config.py | 2 +- vllm/worker/model_runner.py | 31 ++++---- 9 files changed, 97 insertions(+), 62 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 7a4ccecf702f..96e3bb98fb6f 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -59,6 +59,14 @@ def asdict_zerocopy(self) -> Dict[str, Any]: for field in fields(self) } + @classmethod + def new(cls, **kwargs) -> "AttentionMetadataPerStage": + """Create a new instance with updated attributes.""" + # filtering + cls_fields = [field.name for field in fields(cls)] + filtered_kwargs = {k: v for k, v in kwargs.items() if k in cls_fields} + return cls(**filtered_kwargs) + T = TypeVar("T", bound=AttentionMetadataPerStage) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 12e8c4404b94..862fbe83de2e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -24,8 +24,8 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]: return FlashAttentionImpl @staticmethod - def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata": - return FlashAttentionMetadata(*args, **kwargs) + def make_metadata(**kwargs) -> "FlashAttentionMetadata": + return FlashAttentionMetadata.new(**kwargs) @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 004cb9e3a9dc..e33162eaa09f 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,20 +1,23 @@ from typing import Type, Tuple, List, Dict, Optional from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) import torch import flashinfer from vllm._C import cache_ops +from dataclasses import dataclass + class FlashInferBackend(AttentionBackend): + @staticmethod def get_impl_cls() -> Type["FlashInferImpl"]: return FlashInferImpl @staticmethod - def make_metadata(*args, **kwargs) -> "FlashInferMetadata": - return FlashInferMetadata(*args, **kwargs) - + def make_metadata(**kwargs) -> "FlashInferMetadata": + return FlashInferMetadata.new(**kwargs) @staticmethod def get_kv_cache_shape( @@ -24,7 +27,7 @@ def get_kv_cache_shape( head_size: int, ) -> Tuple[int, ...]: raise NotImplementedError - + @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, @@ -32,16 +35,17 @@ def swap_blocks( src_to_dst: Dict[int, int], ) -> None: raise NotImplementedError - + @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: Dict[int, List[int]], ) -> None: raise NotImplementedError - + + @dataclass -class FlashInferMetadata(AttentionMetadata): +class FlashInferMetadata(AttentionMetadataPerStage): # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool @@ -77,25 +81,25 @@ class FlashInferMetadata(AttentionMetadata): use_cuda_graph: bool = False - def __post_init__(self): assert not self.use_cuda_graph, "CUDA graph is not supported yet." # Allocate 16MB workspace buffer # Follow the example: https://docs.flashinfer.ai/api/python/prefill.html#batch-prefill-append-attention - workspace_buffer = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") + workspace_buffer = torch.empty(16 * 1024 * 1024, + dtype=torch.uint8, + device="cuda:0") if self.is_prompt: - self.wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") - self.wrapper.begin_forward( - self.subquery_start_loc, - self.paged_kv_indptr, - self.paged_kv_indices, - self.paged_kv_last_page_len, - self.num_qo_heads, - self.num_kv_heads, - self.head_dim - ) + self.wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, "NHD") + self.wrapper.begin_forward(self.subquery_start_loc, + self.paged_kv_indptr, + self.paged_kv_indices, + self.paged_kv_last_page_len, + self.num_qo_heads, self.num_kv_heads, + self.head_dim) else: - self.wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") + self.wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, "NHD") self.wrapper.begin_forward( self.paged_kv_indptr, self.paged_kv_indices, @@ -104,32 +108,44 @@ def __post_init__(self): self.num_kv_heads, self.head_dim, self.block_size, - pos_encoding_mode="NONE", # FIXME: Add support for pos_encoding_mode - data_type=torch.float16 # FIXME: Add support for data_type + pos_encoding_mode= + "NONE", # FIXME: Add support for pos_encoding_mode + data_type=torch.float16 # FIXME: Add support for data_type ) - + + class FlashInferImpl(AttentionImpl): - def __init__(self, metadata: FlashInferMetadata): - self.prefill_wrapper = metadata.prefill_wrapper - - def forward(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + pass + + def forward(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata[FlashInferMetadata]): if kv_cache is not None: # Use the same reshape and cache kernel as flash attention. - cache_ops.reshape_and_cache_flash(key, + cache_ops.reshape_and_cache_flash( + key, value, kv_cache[:, 0], kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), - attn_metadata.kv_cache_dtype,) - + attn_metadata.kv_cache_dtype, + ) + if attn_metadata.is_prompt: assert kv_cache is None, "Does not support prefix caching yet." - attn_metadata.prefill_metadata.wrapper.forward(query, kv_cache, causal=True) - + attn_metadata.prefill_metadata.wrapper.forward(query, + kv_cache, + causal=True) + else: attn_metadata.decode_metadata.wrapper.forward(query, kv_cache) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index dbaa71fd16ad..182dc855890b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -22,8 +22,8 @@ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: return ROCmFlashAttentionImpl @staticmethod - def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata": - return ROCmFlashAttentionMetadata(*args, **kwargs) + def make_metadata(**kwargs) -> "ROCmFlashAttentionMetadata": + return ROCmFlashAttentionMetadata.new(**kwargs) @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index d21b54b16db4..dc0dcc690cc6 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -20,8 +20,8 @@ def get_impl_cls() -> Type["TorchSDPABackendImpl"]: return TorchSDPABackendImpl @staticmethod - def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata": - return TorchSDPAMetadata(*args, **kwargs) + def make_metadata(**kwargs) -> "TorchSDPAMetadata": + return TorchSDPAMetadata.new(**kwargs) @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b745a04a143b..b2c59f38b55b 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -25,8 +25,8 @@ def get_impl_cls() -> Type["XFormersImpl"]: return XFormersImpl @staticmethod - def make_metadata(*args, **kwargs) -> "XFormersMetadata": - return XFormersMetadata(*args, **kwargs) + def make_metadata(**kwargs) -> "XFormersMetadata": + return XFormersMetadata.new(**kwargs) @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 554e802cd551..d625b2756f97 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -19,6 +19,7 @@ class _Backend(enum.Enum): XFORMERS = enum.auto() ROCM_FLASH = enum.auto() TORCH_SDPA = enum.auto() + FLASHINER = enum.auto() @lru_cache(maxsize=None) @@ -43,6 +44,10 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: logger.info("Using Torch SDPA backend.") from vllm.attention.backends.torch_sdpa import TorchSDPABackend return TorchSDPABackend + elif backend == _Backend.FLASHINER: + logger.info("Using Flashinfer backend.") + from vllm.attention.backends.flashinfer import FlashInferBackend + return FlashInferBackend else: raise ValueError("Invalid attention backend.") @@ -71,6 +76,13 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS + try: + import flashinfer + return _Backend.FLASHINER + except ImportError: + logger.info( + "Cannot use Flashinfer backend because the flashinfer package " + "is not found. Please install it for better performance.") try: import flash_attn # noqa: F401 except ImportError: diff --git a/vllm/config.py b/vllm/config.py index 276fa3882c5a..c3bdf23e3bf7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -273,7 +273,7 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: def get_num_attention_heads(self) -> int: return self.hf_text_config.num_attention_heads - + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b83201a09699..b6d361f7c624 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -415,13 +415,14 @@ def _prepare_prompt( context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - paged_kv_indices = paged_kv_indices, - paged_kv_indptr = paged_kv_indptr, - paged_kv_last_page_len = paged_kv_last_page_len, - num_qo_heads = self.model_config.get_num_attention_heads(), - num_kv_heads = self.model_config.get_num_kv_heads(), - head_dim = self.model_config.get_head_size(), - block_size = self.block_size + paged_kv_indices=paged_kv_indices, + paged_kv_indptr=paged_kv_indptr, + paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=self.model_config.get_num_attention_heads(), + num_kv_heads=self.model_config.get_num_kv_heads( + self.parallel_config), + head_dim=self.model_config.get_head_size(), + # block_size = self.block_size ) return PreparePromptMetadata( @@ -500,7 +501,6 @@ def _prepare_decode( last_len = self.block_size paged_kv_last_page_len.append(last_len) - # vLLM uses cuda graph only for decoding requests. # See `capture_model` API for more details. # For decoding requests, batch_size == input_tokens. @@ -563,14 +563,13 @@ def _prepare_decode( context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, - paged_kv_indices = paged_kv_indices, - paged_kv_indptr = paged_kv_indptr, - paged_kv_last_page_len = paged_kv_last_page_len, - num_qo_heads = self.model_config.get_num_attention_heads(), - num_kv_heads = self.model_config.get_num_kv_heads(), - head_dim = self.model_config.get_head_size(), - block_size = self.block_size - ) + paged_kv_indices=paged_kv_indices, + paged_kv_indptr=paged_kv_indptr, + paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=self.model_config.get_num_attention_heads(), + num_kv_heads=self.model_config.get_num_kv_heads(), + head_dim=self.model_config.get_head_size(), + block_size=self.block_size) return PrepareDecodeMetadata( input_tokens=input_tokens, input_positions=input_positions, From ad9189c4743311c759d36e696549df2adbcec5a9 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 25 Apr 2024 05:45:27 +0000 Subject: [PATCH 03/19] draft, pass simple decode tests --- tests/kernels/conftest.py | 8 +- tests/kernels/test_cache.py | 77 ++++++++++++++++++ vllm/attention/backends/flashinfer.py | 113 ++++++++++---------------- vllm/attention/selector.py | 16 ++-- vllm/utils.py | 67 +++++++++++---- vllm/worker/model_runner.py | 70 ++++++++++------ 6 files changed, 232 insertions(+), 119 deletions(-) diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index d26da2c7fe4e..4fb04cc7e4f3 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -1,8 +1,14 @@ import pytest -from vllm.utils import create_kv_caches_with_random +from vllm.utils import (create_kv_caches_with_random, + create_kv_caches_with_random_flashinfer) @pytest.fixture() def kv_cache_factory(): return create_kv_caches_with_random + + +@pytest.fixture() +def kv_cache_factory_flashinfer(): + return create_kv_caches_with_random_flashinfer diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index d1051fd7e2f4..ca215bb75837 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -5,6 +5,7 @@ import torch from vllm import _custom_ops as ops +from vllm._C import cache_ops from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] @@ -191,6 +192,82 @@ def test_reshape_and_cache( assert torch.allclose(value_cache, cloned_value_cache) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@torch.inference_mode() +def test_reshape_and_cache_flash( + kv_cache_factory_flashinfer, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, + kv_cache_dtype: str, +) -> None: + if kv_cache_dtype == "fp8": + pytest.skip() + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Create a random slot mapping. + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device=device) + _, key, value = qkv.unbind(dim=1) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory_flashinfer( + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Clone the KV caches. + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Call the reshape_and_cache kernel. + cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) + + # Run the reference implementation. + block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies[i] + block_offset = block_offsets[i] + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) + + @pytest.mark.parametrize("direction", COPYING_DIRECTION) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_heads", NUM_HEADS) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e33162eaa09f..3a841a5045cd 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -7,6 +7,7 @@ import flashinfer from vllm._C import cache_ops from dataclasses import dataclass +from flash_attn import flash_attn_varlen_func class FlashInferBackend(AttentionBackend): @@ -26,7 +27,7 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - raise NotImplementedError + return (num_blocks, 2, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -46,72 +47,16 @@ def copy_blocks( @dataclass class FlashInferMetadata(AttentionMetadataPerStage): + # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # The indptr of the paged kv-cache, shape: [batch_size + 1]. - # Please follow the definition in the FlashInfer documentation: https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout. - paged_kv_indptr: List[int] - - # The indices of the paged kv-cache of all sequences. - paged_kv_indices: List[int] - - # The last page length of the paged kv-cache of all sequences, shape: [batch_size]. - paged_kv_last_page_len: List[int] - - # The number of query/output heads. - num_qo_heads: int - - # The number of key/value heads. - num_kv_heads: int - - # The dimension of the heads - head_dim: int - - # The wrapper for the prefill or decode operation. - wrapper = None - - # The indptr of the query/output sequence, shape: [batch_size + 1]. - # This is only used for the prefill operation. - subquery_start_loc: Optional[torch.Tensor] = None - - # The block size for the decode operation. - block_size: Optional[int] = None - use_cuda_graph: bool = False - def __post_init__(self): - assert not self.use_cuda_graph, "CUDA graph is not supported yet." - # Allocate 16MB workspace buffer - # Follow the example: https://docs.flashinfer.ai/api/python/prefill.html#batch-prefill-append-attention - workspace_buffer = torch.empty(16 * 1024 * 1024, - dtype=torch.uint8, - device="cuda:0") - if self.is_prompt: - self.wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, "NHD") - self.wrapper.begin_forward(self.subquery_start_loc, - self.paged_kv_indptr, - self.paged_kv_indices, - self.paged_kv_last_page_len, - self.num_qo_heads, self.num_kv_heads, - self.head_dim) - else: - self.wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, "NHD") - self.wrapper.begin_forward( - self.paged_kv_indptr, - self.paged_kv_indices, - self.paged_kv_last_page_len, - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.block_size, - pos_encoding_mode= - "NONE", # FIXME: Add support for pos_encoding_mode - data_type=torch.float16 # FIXME: Add support for data_type - ) + wrapper: Optional[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = None + seq_start_loc: Optional[torch.Tensor] = None + max_prompt_len: Optional[int] = None class FlashInferImpl(AttentionImpl): @@ -125,11 +70,23 @@ def __init__( alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, ) -> None: - pass + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.alibi_slopes = alibi_slopes + self.scale = scale + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[FlashInferMetadata]): + attn_metadata: AttentionMetadata[FlashInferMetadata], + kv_scale: float): + num_tokens, hidden_size = query.shape + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + if kv_cache is not None: # Use the same reshape and cache kernel as flash attention. cache_ops.reshape_and_cache_flash( @@ -141,11 +98,27 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, attn_metadata.kv_cache_dtype, ) - if attn_metadata.is_prompt: - assert kv_cache is None, "Does not support prefix caching yet." - attn_metadata.prefill_metadata.wrapper.forward(query, - kv_cache, - causal=True) - + if attn_metadata.prefill_metadata: + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.prefill_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.prefill_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.prefill_metadata.max_prompt_len, + max_seqlen_k=attn_metadata.prefill_metadata.max_prompt_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) else: - attn_metadata.decode_metadata.wrapper.forward(query, kv_cache) + assert attn_metadata.decode_metadata is not None + query = query.contiguous( + ) # Flashinfer requires query to be contiguous + output = attn_metadata.decode_metadata.wrapper.forward( + query, + kv_cache, + sm_scale=self.scale, + ) + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index d625b2756f97..802bcfbded72 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -77,19 +77,23 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: return _Backend.XFORMERS try: - import flashinfer - return _Backend.FLASHINER + import flash_attn # noqa: F401 except ImportError: logger.info( - "Cannot use Flashinfer backend because the flashinfer package " + "Cannot use FlashAttention backend because the flash_attn package " "is not found. Please install it for better performance.") + return _Backend.XFORMERS + + # Move flahsinfer below flash_attn for now to avoid breaking tests try: - import flash_attn # noqa: F401 + import flashinfer + # Still use flash attention for the prefill stage. + import flash_attn + return _Backend.FLASHINER except ImportError: logger.info( - "Cannot use FlashAttention backend because the flash_attn package " + "Cannot use Flashinfer backend because the flashinfer package " "is not found. Please install it for better performance.") - return _Backend.XFORMERS backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) if backend_by_env_var is not None: diff --git a/vllm/utils.py b/vllm/utils.py index fbe86dacaeb9..93495be20bcc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -341,21 +341,9 @@ def _generate_random_fp8( del tensor_tmp -def create_kv_caches_with_random( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: int = 0, - device: Optional[str] = "cuda", -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - +def _str_to_torch_dtype( + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": if isinstance(model_dtype, str): @@ -374,6 +362,55 @@ def create_kv_caches_with_random( torch_dtype = cache_dtype else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + + +def create_kv_caches_with_random_flashinfer( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + assert cache_dtype != "fp8" + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + torch_dtype = _str_to_torch_dtype(cache_dtype, model_dtype) + key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + scale = head_size**-0.5 + key_caches, value_caches = [], [] + for _ in range(num_layers): + key_value_cache = torch.empty(size=key_value_cache_shape, + dtype=torch_dtype, + device=device) + key_value_cache.uniform_(-scale, scale) + key_caches.append(key_value_cache[:, 0]) + value_caches.append(key_value_cache[:, 1]) + return key_caches, value_caches + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + torch_dtype = _str_to_torch_dtype(cache_dtype, model_dtype) scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b6d361f7c624..0e5b0b05f727 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,6 +9,7 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) +from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce @@ -157,6 +158,9 @@ def __init__( # (max batch size to capture, max context len to capture / block size). self.graph_block_tables: torch.Tensor # Set after initial profiling. + # Set if the backend is flashinfer. + self.workspace_buffer = None + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -236,9 +240,6 @@ def _prepare_prompt( subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] - paged_kv_indices: List[int] = [] - paged_kv_indptr: List[int] = [0] - paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: return PreparePromptMetadata.empty() @@ -321,12 +322,6 @@ def _prepare_prompt( # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] - paged_kv_indices.extend(block_table) - paged_kv_indptr.append(len(block_table)) - last_len = seq_data.get_len() % self.block_size - if last_len == 0: - last_len = self.block_size - paged_kv_last_page_len.append(last_len) # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, prompt_len - sliding_window). @@ -414,16 +409,7 @@ def _prepare_prompt( seq_start_loc=seq_start_loc, context_lens=context_lens_tensor, block_tables=block_tables, - use_cuda_graph=False, - paged_kv_indices=paged_kv_indices, - paged_kv_indptr=paged_kv_indptr, - paged_kv_last_page_len=paged_kv_last_page_len, - num_qo_heads=self.model_config.get_num_attention_heads(), - num_kv_heads=self.model_config.get_num_kv_heads( - self.parallel_config), - head_dim=self.model_config.get_head_size(), - # block_size = self.block_size - ) + use_cuda_graph=False) return PreparePromptMetadata( input_tokens=input_tokens, @@ -495,7 +481,7 @@ def _prepare_decode( block_tables.append(block_table) paged_kv_indices.extend(block_table) - paged_kv_indptr.append(len(block_table)) + paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table)) last_len = seq_data.get_len() % self.block_size if last_len == 0: last_len = self.block_size @@ -551,6 +537,42 @@ def _prepare_decode( device=self.device, ) + flashinfer_wrapper = None + if self.attn_backend is FlashInferBackend: + # Lazy import to avoid repetitive import + try: + flashinfer + except NameError: + import flashinfer + + if self.workspace_buffer is None: + self.workspace_buffer = torch.empty(16 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) + flashinfer_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, "NHD") + paged_kv_indptr = torch.tensor(paged_kv_indptr, + dtype=torch.int, + device=self.device) + paged_kv_indices = torch.tensor(paged_kv_indices, + dtype=torch.int, + device=self.device) + paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len, + dtype=torch.int, + device=self.device) + flashinfer_wrapper.begin_forward( + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + self.model_config.get_num_attention_heads(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_head_size(), + self.block_size, + pos_encoding_mode= + "NONE", # FIXME: Add support for pos_encoding_mode + data_type=torch.float16 # FIXME: Add support for data_type + ) + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, prompt_lens=None, @@ -563,13 +585,7 @@ def _prepare_decode( context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, - paged_kv_indices=paged_kv_indices, - paged_kv_indptr=paged_kv_indptr, - paged_kv_last_page_len=paged_kv_last_page_len, - num_qo_heads=self.model_config.get_num_attention_heads(), - num_kv_heads=self.model_config.get_num_kv_heads(), - head_dim=self.model_config.get_head_size(), - block_size=self.block_size) + wrapper=flashinfer_wrapper) return PrepareDecodeMetadata( input_tokens=input_tokens, input_positions=input_positions, From 6b26d3b874b32460fbd0ef4d39d1ce61c058c1db Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 25 Apr 2024 06:00:48 +0000 Subject: [PATCH 04/19] minor --- vllm/attention/selector.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 802bcfbded72..906743d9e409 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -76,15 +76,6 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS - try: - import flash_attn # noqa: F401 - except ImportError: - logger.info( - "Cannot use FlashAttention backend because the flash_attn package " - "is not found. Please install it for better performance.") - return _Backend.XFORMERS - - # Move flahsinfer below flash_attn for now to avoid breaking tests try: import flashinfer # Still use flash attention for the prefill stage. @@ -95,6 +86,14 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: "Cannot use Flashinfer backend because the flashinfer package " "is not found. Please install it for better performance.") + try: + import flash_attn # noqa: F401 + except ImportError: + logger.info( + "Cannot use FlashAttention backend because the flash_attn package " + "is not found. Please install it for better performance.") + return _Backend.XFORMERS + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) if backend_by_env_var is not None: return _Backend[backend_by_env_var] From e7817989c3e9627073820fc817d930419c355c3b Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 25 Apr 2024 22:55:24 +0000 Subject: [PATCH 05/19] remove attn backend interface change --- vllm/attention/backends/abstract.py | 8 --- vllm/attention/backends/flash_attn.py | 4 +- vllm/attention/backends/flashinfer.py | 21 +++--- vllm/attention/backends/rocm_flash_attn.py | 4 +- vllm/attention/backends/torch_sdpa.py | 4 +- vllm/attention/backends/xformers.py | 4 +- vllm/attention/selector.py | 4 +- vllm/sequence.py | 3 +- vllm/utils.py | 6 +- vllm/worker/model_runner.py | 78 +++++++++++++--------- 10 files changed, 73 insertions(+), 63 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 96e3bb98fb6f..7a4ccecf702f 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -59,14 +59,6 @@ def asdict_zerocopy(self) -> Dict[str, Any]: for field in fields(self) } - @classmethod - def new(cls, **kwargs) -> "AttentionMetadataPerStage": - """Create a new instance with updated attributes.""" - # filtering - cls_fields = [field.name for field in fields(cls)] - filtered_kwargs = {k: v for k, v in kwargs.items() if k in cls_fields} - return cls(**filtered_kwargs) - T = TypeVar("T", bound=AttentionMetadataPerStage) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 862fbe83de2e..12e8c4404b94 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -24,8 +24,8 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]: return FlashAttentionImpl @staticmethod - def make_metadata(**kwargs) -> "FlashAttentionMetadata": - return FlashAttentionMetadata.new(**kwargs) + def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata": + return FlashAttentionMetadata(*args, **kwargs) @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 3a841a5045cd..2ccb41073c4c 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,14 +1,15 @@ -from typing import Type, Tuple, List, Dict, Optional -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type -import torch import flashinfer -from vllm._C import cache_ops -from dataclasses import dataclass +import torch from flash_attn import flash_attn_varlen_func +from vllm._C import cache_ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataPerStage) + class FlashInferBackend(AttentionBackend): @@ -17,8 +18,8 @@ def get_impl_cls() -> Type["FlashInferImpl"]: return FlashInferImpl @staticmethod - def make_metadata(**kwargs) -> "FlashInferMetadata": - return FlashInferMetadata.new(**kwargs) + def make_metadata(*args, **kwargs) -> "FlashInferMetadata": + return FlashInferMetadata(*args, **kwargs) @staticmethod def get_kv_cache_shape( @@ -55,6 +56,8 @@ class FlashInferMetadata(AttentionMetadataPerStage): use_cuda_graph: bool = False wrapper: Optional[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = None + + # Metadata for prefill stage since we still use flash attention for prefill. seq_start_loc: Optional[torch.Tensor] = None max_prompt_len: Optional[int] = None diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 182dc855890b..dbaa71fd16ad 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -22,8 +22,8 @@ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: return ROCmFlashAttentionImpl @staticmethod - def make_metadata(**kwargs) -> "ROCmFlashAttentionMetadata": - return ROCmFlashAttentionMetadata.new(**kwargs) + def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata": + return ROCmFlashAttentionMetadata(*args, **kwargs) @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index dc0dcc690cc6..d21b54b16db4 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -20,8 +20,8 @@ def get_impl_cls() -> Type["TorchSDPABackendImpl"]: return TorchSDPABackendImpl @staticmethod - def make_metadata(**kwargs) -> "TorchSDPAMetadata": - return TorchSDPAMetadata.new(**kwargs) + def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata": + return TorchSDPAMetadata(*args, **kwargs) @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b2c59f38b55b..b745a04a143b 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -25,8 +25,8 @@ def get_impl_cls() -> Type["XFormersImpl"]: return XFormersImpl @staticmethod - def make_metadata(**kwargs) -> "XFormersMetadata": - return XFormersMetadata.new(**kwargs) + def make_metadata(*args, **kwargs) -> "XFormersMetadata": + return XFormersMetadata(*args, **kwargs) @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 906743d9e409..c0ab3e49154a 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -77,9 +77,9 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: return _Backend.XFORMERS try: - import flashinfer # Still use flash attention for the prefill stage. - import flash_attn + import flash_attn # noqa: F401 + import flashinfer # noqa: F401 return _Backend.FLASHINER except ImportError: logger.info( diff --git a/vllm/sequence.py b/vllm/sequence.py index c36b9567b549..df2038d15413 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -557,7 +557,8 @@ class SequenceGroupMetadata: token_chunk_size: The number of tokens to be processed (per sequence). None if chunking is not required. lora_request: LoRA request. - computed_block_nums: The block numbers that are already computed, used in prefix caching. + computed_block_nums: The block numbers that are already computed, + used in prefix caching. state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. """ diff --git a/vllm/utils.py b/vllm/utils.py index 93495be20bcc..e162ad856bfc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -341,7 +341,7 @@ def _generate_random_fp8( del tensor_tmp -def _str_to_torch_dtype( +def get_kv_cache_torch_dtype( cache_dtype: Optional[Union[str, torch.dtype]], model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: if isinstance(cache_dtype, str): @@ -381,7 +381,7 @@ def create_kv_caches_with_random_flashinfer( if torch.cuda.is_available(): torch.cuda.manual_seed(seed) - torch_dtype = _str_to_torch_dtype(cache_dtype, model_dtype) + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) scale = head_size**-0.5 key_caches, value_caches = [], [] @@ -410,7 +410,7 @@ def create_kv_caches_with_random( if torch.cuda.is_available(): torch.cuda.manual_seed(seed) - torch_dtype = _str_to_torch_dtype(cache_dtype, model_dtype) + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0e5b0b05f727..c643b87900f4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -24,7 +24,8 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip, +from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, + get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad, maybe_expand_dim) @@ -398,18 +399,26 @@ def _prepare_prompt( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - prompt_lens=prompt_lens, - prompt_lens_tensor=prompt_lens_tensor, - max_subquery_len=max_subquery_len, - max_context_len=None, - max_prompt_len=max_prompt_len, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False) + if self.attn_backend is FlashInferBackend: + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + use_cuda_graph=False, + wrapper=None, + seq_start_loc=seq_start_loc, + max_prompt_len=max_prompt_len) + else: + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + prompt_lens=prompt_lens, + prompt_lens_tensor=prompt_lens_tensor, + max_subquery_len=max_subquery_len, + max_context_len=None, + max_prompt_len=max_prompt_len, + subquery_start_loc=subquery_start_loc, + seq_start_loc=seq_start_loc, + context_lens=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False) return PreparePromptMetadata( input_tokens=input_tokens, @@ -541,7 +550,7 @@ def _prepare_decode( if self.attn_backend is FlashInferBackend: # Lazy import to avoid repetitive import try: - flashinfer + flashinfer # noqa: B018 except NameError: import flashinfer @@ -560,6 +569,8 @@ def _prepare_decode( paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len, dtype=torch.int, device=self.device) + kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, + self.model_config.dtype) flashinfer_wrapper.begin_forward( paged_kv_indptr, paged_kv_indices, @@ -568,24 +579,27 @@ def _prepare_decode( self.model_config.get_num_kv_heads(self.parallel_config), self.model_config.get_head_size(), self.block_size, - pos_encoding_mode= - "NONE", # FIXME: Add support for pos_encoding_mode - data_type=torch.float16 # FIXME: Add support for data_type - ) - - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - prompt_lens=None, - prompt_lens_tensor=None, - max_subquery_len=None, - max_context_len=max_context_len, - max_prompt_len=None, - subquery_start_loc=None, - seq_start_loc=None, - context_lens=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - wrapper=flashinfer_wrapper) + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + data_type=kv_cache_dtype) + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + use_cuda_graph=use_captured_graph, + wrapper=flashinfer_wrapper) + else: + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + prompt_lens=None, + prompt_lens_tensor=None, + max_subquery_len=None, + max_context_len=max_context_len, + max_prompt_len=None, + subquery_start_loc=None, + seq_start_loc=None, + context_lens=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + wrapper=flashinfer_wrapper) return PrepareDecodeMetadata( input_tokens=input_tokens, input_positions=input_positions, From 5d469b7e1e25ab3eb09df4d4564cd47e6067d6c9 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 25 Apr 2024 23:40:23 +0000 Subject: [PATCH 06/19] basic test --- tests/basic_correctness/test_flashinfer.py | 49 ++++++++++++++++++++++ vllm/attention/backends/flashinfer.py | 1 + vllm/attention/selector.py | 3 +- 3 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 tests/basic_correctness/test_flashinfer.py diff --git a/tests/basic_correctness/test_flashinfer.py b/tests/basic_correctness/test_flashinfer.py new file mode 100644 index 000000000000..9a2fd5a82949 --- /dev/null +++ b/tests/basic_correctness/test_flashinfer.py @@ -0,0 +1,49 @@ +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +Run `pytest tests/basic_correctness/test_flashinfer.py`. +""" +import pytest +import os + +MODELS = [ + "facebook/opt-125m", + "meta-llama/Llama-2-7b-hf", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [True]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + enforce_eager: bool, +) -> None: + try: + import flash_attn # noqa: F401 + import flashinfer # noqa: F401 + except ImportError: + pytest.skip("Cannot use Flashinfer backend because the flashinfer package " + "is not found. Please install both flashinfer and flash attention " + "for running the test.") + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 2ccb41073c4c..7dd61ef96e19 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -117,6 +117,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, ) else: assert attn_metadata.decode_metadata is not None + assert attn_metadata.decode_metadata.wrapper is not None query = query.contiguous( ) # Flashinfer requires query to be contiguous output = attn_metadata.decode_metadata.wrapper.forward( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index c0ab3e49154a..bf6b94b6a10d 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -84,7 +84,8 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: except ImportError: logger.info( "Cannot use Flashinfer backend because the flashinfer package " - "is not found. Please install it for better performance.") + "is not found. Please install both flashinfer and flash attention " + "for better performance.") try: import flash_attn # noqa: F401 From 883be3b7b5f050cda5c65143c67495703bb88511 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sun, 28 Apr 2024 06:09:41 +0000 Subject: [PATCH 07/19] fix distributed tests --- tests/basic_correctness/test_flashinfer.py | 8 +-- .../test_flashinfer_distributed.py | 65 +++++++++++++++++++ vllm/attention/backends/abstract.py | 11 ++-- vllm/attention/backends/flashinfer.py | 53 +++++++++++++-- vllm/config.py | 6 +- vllm/worker/model_runner.py | 35 ++++------ 6 files changed, 141 insertions(+), 37 deletions(-) create mode 100644 tests/distributed/test_flashinfer_distributed.py diff --git a/tests/basic_correctness/test_flashinfer.py b/tests/basic_correctness/test_flashinfer.py index 9a2fd5a82949..ed65d062c89c 100644 --- a/tests/basic_correctness/test_flashinfer.py +++ b/tests/basic_correctness/test_flashinfer.py @@ -3,7 +3,6 @@ Run `pytest tests/basic_correctness/test_flashinfer.py`. """ import pytest -import os MODELS = [ "facebook/opt-125m", @@ -28,9 +27,10 @@ def test_models( import flash_attn # noqa: F401 import flashinfer # noqa: F401 except ImportError: - pytest.skip("Cannot use Flashinfer backend because the flashinfer package " - "is not found. Please install both flashinfer and flash attention " - "for running the test.") + pytest.skip( + "Cannot use Flashinfer backend because the flashinfer package " + "is not found. Please install both flashinfer and flash attention " + "for running the test.") hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/distributed/test_flashinfer_distributed.py b/tests/distributed/test_flashinfer_distributed.py new file mode 100644 index 000000000000..c6bdccded07b --- /dev/null +++ b/tests/distributed/test_flashinfer_distributed.py @@ -0,0 +1,65 @@ +"""Compare the outputs of HF and distributed vLLM when using greedy sampling. +vLLM will allocate all the available memory, so we need to run the tests one +by one. The solution is to pass arguments (model name) by environment +variables. +Run: +```sh +TEST_DIST_MODEL=facebook/opt-125m pytest \ + test_flashinfer_distributed.py +TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ + test_flashinfer_distributed.py +``` +""" +import os + +import pytest +import torch + +MODELS = [ + os.environ["TEST_DIST_MODEL"], +] + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [True]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + enforce_eager: bool +) -> None: + try: + import flash_attn # noqa: F401 + import flashinfer # noqa: F401 + except ImportError: + pytest.skip("Cannot use Flashinfer backend because the flashinfer package " + "is not found. Please install both flashinfer and flash attention " + "for running the test.") + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + enforce_eager=enforce_eager + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 7a4ccecf702f..83b9350e58e5 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar +from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, + TypeVar) import torch @@ -15,7 +16,7 @@ def get_impl_cls() -> Type["AttentionImpl"]: @staticmethod @abstractmethod - def make_metadata(*args, **kwargs) -> "AttentionMetadata": + def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage": raise NotImplementedError @staticmethod @@ -50,13 +51,15 @@ def copy_blocks( class AttentionMetadataPerStage: """Attention metadata for a specific stage. I.e., prefill or decode.""" - def asdict_zerocopy(self) -> Dict[str, Any]: + def asdict_zerocopy(self, skip_fileds: Set[str] = None) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" + if skip_fileds is None: + skip_fileds = set() # Note that if we add dataclasses as fields, they will need # similar handling. return { field.name: getattr(self, field.name) - for field in fields(self) + for field in fields(self) if field.name not in skip_fileds } diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 7dd61ef96e19..e4314c849fb2 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import flashinfer import torch @@ -55,12 +55,55 @@ class FlashInferMetadata(AttentionMetadataPerStage): use_cuda_graph: bool = False - wrapper: Optional[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = None + decode_wrapper: Optional[ + flashinfer.BatchDecodeWithPagedKVCacheWrapper] = None - # Metadata for prefill stage since we still use flash attention for prefill. + # Metadata for the prefill stage since we still + # use flash attention for prefill. seq_start_loc: Optional[torch.Tensor] = None max_prompt_len: Optional[int] = None + # Metadata for the decode stage + # Workspace buffer required by the kernel, the buffer should not + # be allocated/deacollated by the FalshInfermetadata + workspace_buffer: Optional[torch.Tensor] = None + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads + num_qo_heads: Optional[int] = None + # The number of key/value heads + num_kv_heads: Optional[int] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + # Block size of vllm + page_size: Optional[int] = None + # The data type of the paged kv cache + data_type: torch.dtype = None + + def __post_init__(self): + if not self.is_prompt: + self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, "NHD") + self.decode_wrapper.begin_forward( + self.paged_kv_indptr, + self.paged_kv_indices, + self.paged_kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + data_type=self.data_type) + + def asdict_zerocopy(self) -> Dict[str, Any]: + return super().asdict_zerocopy({'decode_wrapper'}) + class FlashInferImpl(AttentionImpl): @@ -117,10 +160,10 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, ) else: assert attn_metadata.decode_metadata is not None - assert attn_metadata.decode_metadata.wrapper is not None + assert attn_metadata.decode_metadata.decode_wrapper is not None query = query.contiguous( ) # Flashinfer requires query to be contiguous - output = attn_metadata.decode_metadata.wrapper.forward( + output = attn_metadata.decode_metadata.decode_wrapper.forward( query, kv_cache, sm_scale=self.scale, diff --git a/vllm/config.py b/vllm/config.py index c3bdf23e3bf7..67c01f257214 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -271,8 +271,10 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) - def get_num_attention_heads(self) -> int: - return self.hf_text_config.num_attention_heads + def get_num_attention_heads(self, + parallel_config: "ParallelConfig") -> int: + return self.hf_text_config.num_attention_heads // \ + parallel_config.tensor_parallel_size def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c643b87900f4..6e04f0f3c4de 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -403,7 +403,6 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, use_cuda_graph=False, - wrapper=None, seq_start_loc=seq_start_loc, max_prompt_len=max_prompt_len) else: @@ -548,18 +547,10 @@ def _prepare_decode( flashinfer_wrapper = None if self.attn_backend is FlashInferBackend: - # Lazy import to avoid repetitive import - try: - flashinfer # noqa: B018 - except NameError: - import flashinfer - if self.workspace_buffer is None: self.workspace_buffer = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device=self.device) - flashinfer_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, "NHD") paged_kv_indptr = torch.tensor(paged_kv_indptr, dtype=torch.int, device=self.device) @@ -571,21 +562,21 @@ def _prepare_decode( device=self.device) kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, self.model_config.dtype) - flashinfer_wrapper.begin_forward( - paged_kv_indptr, - paged_kv_indices, - paged_kv_last_page_len, - self.model_config.get_num_attention_heads(), - self.model_config.get_num_kv_heads(self.parallel_config), - self.model_config.get_head_size(), - self.block_size, - # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", - data_type=kv_cache_dtype) + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - use_cuda_graph=use_captured_graph, - wrapper=flashinfer_wrapper) + use_cuda_graph=False, + workspace_buffer=self.workspace_buffer, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=self.model_config.get_num_attention_heads( + self.parallel_config), + num_kv_heads=self.model_config.get_num_kv_heads( + self.parallel_config), + head_dim=self.model_config.get_head_size(), + page_size=self.block_size, + data_type=kv_cache_dtype) else: attn_metadata = self.attn_backend.make_metadata( is_prompt=False, From 710962211e8fdd4cd3d1a6653741be74f3a52170 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 30 Apr 2024 01:02:09 +0000 Subject: [PATCH 08/19] try to fix ci --- csrc/cache_kernels.cu | 2 +- requirements-cuda.txt | 2 ++ tests/kernels/conftest.py | 4 +-- vllm/_custom_ops.py | 12 +++++++++ vllm/attention/backends/flashinfer.py | 30 +++++++++++++-------- vllm/utils.py | 2 +- vllm/worker/model_runner.py | 39 ++++++++++++++++++--------- 7 files changed, 63 insertions(+), 28 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index a34a6466469c..5e7dc2db909f 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -318,7 +318,7 @@ void reshape_and_cache_flash( torch::Tensor& slot_mapping, // [num_tokens] const std::string& kv_cache_dtype) { - if (kv_cache_dtype != "auto") { + if (kv_cache_dtype != "auto" || kv_cache_dtype == "fp8" ) { TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); } int num_tokens = key.size(0); diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 1bddae4c6f40..534fd8c51265 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,3 +7,5 @@ nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 xformers == 0.0.25 # Requires PyTorch 2.2.1 +--extra-index-url https://flashinfer.ai/whl/cu121/torch2.2/ +flashinfer \ No newline at end of file diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index 4fb04cc7e4f3..4f2f9cc3dac7 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -1,7 +1,7 @@ import pytest from vllm.utils import (create_kv_caches_with_random, - create_kv_caches_with_random_flashinfer) + create_kv_caches_with_random_flash) @pytest.fixture() @@ -11,4 +11,4 @@ def kv_cache_factory(): @pytest.fixture() def kv_cache_factory_flashinfer(): - return create_kv_caches_with_random_flashinfer + return create_kv_caches_with_random_flash diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 5ba104bada7a..a79e0fc486eb 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -204,6 +204,18 @@ def reshape_and_cache( slot_mapping, kv_cache_dtype, kv_scale) +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, +) -> None: + vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) + + def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, block_mapping: torch.Tensor) -> None: vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 41d982077c35..bc672644b3d9 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -5,7 +5,7 @@ import torch from flash_attn import flash_attn_varlen_func -from vllm._C import cache_ops +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataPerStage) @@ -49,8 +49,6 @@ def copy_blocks( @dataclass class FlashInferMetadata(AttentionMetadataPerStage): - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. is_prompt: bool use_cuda_graph: bool = False @@ -106,6 +104,8 @@ def asdict_zerocopy(self, ) -> Dict[str, Any]: if skip_fields is None: skip_fields = set() + # We need to skip the decode_wrapper field since it cannot be + # broadcasted with nccl when TP is enabled. skip_fields.add('decode_wrapper') return super().asdict_zerocopy(skip_fields) @@ -121,8 +121,9 @@ def __init__( alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, ) -> None: - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) + if sliding_window is not None: + raise ValueError("Sliding window is not supported in FlashInfer.") + self.sliding_window = (-1, -1) self.alibi_slopes = alibi_slopes self.scale = scale self.num_heads = num_heads @@ -138,9 +139,16 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) + if attn_metadata.num_prefill_tokens > 0: + assert attn_metadata.num_decode_tokens == 0, ( + "chunked prefill is not supported with flash infer yet") + if attn_metadata.num_decode_tokens > 0: + assert attn_metadata.num_prefill_tokens == 0, ( + "chunked prefill is not supported with flash infer yet") + if kv_cache is not None: # Use the same reshape and cache kernel as flash attention. - cache_ops.reshape_and_cache_flash( + ops.reshape_and_cache_flash( key, value, kv_cache[:, 0], @@ -149,15 +157,15 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, attn_metadata.kv_cache_dtype, ) - if attn_metadata.prefill_metadata: + if prefill_metadata := attn_metadata.prefill_metadata: output = flash_attn_varlen_func( q=query, k=key, v=value, - cu_seqlens_q=attn_metadata.prefill_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.prefill_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.prefill_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.prefill_metadata.max_prompt_len, + cu_seqlens_q=prefill_metadata.seq_start_loc, + cu_seqlens_k=prefill_metadata.seq_start_loc, + max_seqlen_q=prefill_metadata.max_prompt_len, + max_seqlen_k=prefill_metadata.max_prompt_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, diff --git a/vllm/utils.py b/vllm/utils.py index 62abf0d23354..7284250c7f63 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -377,7 +377,7 @@ def get_kv_cache_torch_dtype( return torch_dtype -def create_kv_caches_with_random_flashinfer( +def create_kv_caches_with_random_flash( num_blocks: int, block_size: int, num_layers: int, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c74befe67803..aeb70fc49c53 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -158,7 +158,7 @@ def __init__( self.graph_block_tables: torch.Tensor # Set after initial profiling. # Set if the backend is flashinfer. - self.workspace_buffer = None + self.flashinfer_workspace_buffer: torch.Tensor def load_model(self) -> None: with CudaMemoryProfiler() as m: @@ -443,7 +443,21 @@ def _prepare_decode( lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + + # The following fields are only for flashinfer + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # paged_kv_last_page_len is the length of the last page of each request paged_kv_indices: List[int] = [] + # 0 is always in paged_kv_indptr because the paged_kv_indptr: List[int] = [0] paged_kv_last_page_len: List[int] = [] @@ -489,10 +503,10 @@ def _prepare_decode( paged_kv_indices.extend(block_table) paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table)) - last_len = seq_data.get_len() % self.block_size - if last_len == 0: - last_len = self.block_size - paged_kv_last_page_len.append(last_len) + last_page_len = seq_data.get_len() % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + paged_kv_last_page_len.append(last_page_len) # vLLM uses cuda graph only for decoding requests. # See `capture_model` API for more details. @@ -544,12 +558,12 @@ def _prepare_decode( device=self.device, ) - flashinfer_wrapper = None if self.attn_backend is FlashInferBackend: - if self.workspace_buffer is None: - self.workspace_buffer = torch.empty(16 * 1024 * 1024, - dtype=torch.uint8, - device=self.device) + if not hasattr(self, "flashinfer_workspace_buffer"): + # Allocate 16MB workspace buffer + # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html + self.flashinfer_workspace_buffer = torch.empty( + 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) paged_kv_indptr = torch.tensor(paged_kv_indptr, dtype=torch.int, device=self.device) @@ -565,7 +579,7 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, use_cuda_graph=False, - workspace_buffer=self.workspace_buffer, + workspace_buffer=self.flashinfer_workspace_buffer, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, @@ -588,8 +602,7 @@ def _prepare_decode( seq_start_loc=None, context_lens=context_lens_tensor, block_tables=block_tables, - use_cuda_graph=use_captured_graph, - wrapper=flashinfer_wrapper) + use_cuda_graph=use_captured_graph) return PrepareDecodeMetadata( input_tokens=input_tokens, input_positions=input_positions, From ab00582d34ea5e425b88c39f9c20d35fc497d6fc Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 30 Apr 2024 16:53:44 +0000 Subject: [PATCH 09/19] fix comments --- .buildkite/test-pipeline.yaml | 3 + csrc/cache_kernels.cu | 3 +- .../test_basic_correctness.py | 12 +++- tests/basic_correctness/test_flashinfer.py | 49 ---------------- .../test_basic_distributed_correctness.py | 14 +++-- .../test_flashinfer_distributed.py | 57 ------------------- vllm/attention/selector.py | 15 +---- 7 files changed, 27 insertions(+), 126 deletions(-) delete mode 100644 tests/basic_correctness/test_flashinfer.py delete mode 100644 tests/distributed/test_flashinfer_distributed.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 11cda053260e..937b6728f602 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -15,6 +15,7 @@ steps: commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py @@ -34,6 +35,8 @@ steps: - pytest -v -s test_pynccl_library.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 5e7dc2db909f..1b44871b0051 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -318,7 +318,8 @@ void reshape_and_cache_flash( torch::Tensor& slot_mapping, // [num_tokens] const std::string& kv_cache_dtype) { - if (kv_cache_dtype != "auto" || kv_cache_dtype == "fp8" ) { + // FIXME: only suport auto datatype, does not support fp8 + if (kv_cache_dtype != "auto") { TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); } int num_tokens = key.size(0); diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 97cff623c5e1..d75279dd9cfa 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -2,12 +2,15 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ +import os + import pytest MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", ] +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @pytest.mark.parametrize("model", MODELS) @@ -23,11 +26,18 @@ def test_models( max_tokens: int, enforce_eager: bool, ) -> None: + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var == "FLASHINFER" and enforce_eager is False: + pytest.skip("Skipping non-eager test for FlashInferBackend.") + hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) + vllm_model = vllm_runner(model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/basic_correctness/test_flashinfer.py b/tests/basic_correctness/test_flashinfer.py deleted file mode 100644 index ed65d062c89c..000000000000 --- a/tests/basic_correctness/test_flashinfer.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Compare the short outputs of HF and vLLM when using greedy sampling. - -Run `pytest tests/basic_correctness/test_flashinfer.py`. -""" -import pytest - -MODELS = [ - "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", -] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [True]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - enforce_eager: bool, -) -> None: - try: - import flash_attn # noqa: F401 - import flashinfer # noqa: F401 - except ImportError: - pytest.skip( - "Cannot use Flashinfer backend because the flashinfer package " - "is not found. Please install both flashinfer and flash attention " - "for running the test.") - - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - del hf_model - - vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - del vllm_model - - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 77aa90b12bf8..527452630c9f 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -18,6 +18,7 @@ MODELS = [ os.environ["TEST_DIST_MODEL"], ] +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -33,16 +34,19 @@ def test_models( dtype: str, max_tokens: int, ) -> None: + enforce_eager = False + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var == "FLASHINFER": + enforce_eager = True hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - ) + vllm_model = vllm_runner(model, + dtype=dtype, + tensor_parallel_size=2, + enforce_eager=enforce_eager) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_flashinfer_distributed.py b/tests/distributed/test_flashinfer_distributed.py deleted file mode 100644 index 5552841acef2..000000000000 --- a/tests/distributed/test_flashinfer_distributed.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Compare the outputs of HF and distributed vLLM when using greedy sampling. -vLLM will allocate all the available memory, so we need to run the tests one -by one. The solution is to pass arguments (model name) by environment -variables. -Run: -```sh -TEST_DIST_MODEL=facebook/opt-125m pytest \ - test_flashinfer_distributed.py -TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ - test_flashinfer_distributed.py -``` -""" -import os - -import pytest -import torch - -MODELS = [ - os.environ["TEST_DIST_MODEL"], -] - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [True]) -def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, enforce_eager: bool) -> None: - try: - import flash_attn # noqa: F401 - import flashinfer # noqa: F401 - except ImportError: - pytest.skip( - "Cannot use Flashinfer backend because the flashinfer package " - "is not found. Please install both flashinfer and flash attention " - "for running the test.") - - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - del hf_model - - vllm_model = vllm_runner(model, - dtype=dtype, - tensor_parallel_size=2, - enforce_eager=enforce_eager) - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - del vllm_model - - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index c0d3c43e45f0..ca9d16711f8b 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -19,7 +19,7 @@ class _Backend(enum.Enum): XFORMERS = enum.auto() ROCM_FLASH = enum.auto() TORCH_SDPA = enum.auto() - FLASHINER = enum.auto() + FLASHINFER = enum.auto() @lru_cache(maxsize=None) @@ -44,7 +44,7 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: logger.info("Using Torch SDPA backend.") from vllm.attention.backends.torch_sdpa import TorchSDPABackend return TorchSDPABackend - elif backend == _Backend.FLASHINER: + elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend @@ -76,17 +76,6 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS - try: - # Still use flash attention for the prefill stage. - import flash_attn # noqa: F401 - import flashinfer # noqa: F401 - return _Backend.FLASHINER - except ImportError: - logger.info( - "Cannot use Flashinfer backend because the flashinfer package " - "is not found. Please install both flashinfer and flash attention " - "for better performance.") - try: import flash_attn # noqa: F401 except ImportError: From 67fe4fd2dfd4661dae093d84408d13a592d65703 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 30 Apr 2024 16:55:55 +0000 Subject: [PATCH 10/19] typo --- csrc/cache_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1b44871b0051..42f884c76c62 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -318,7 +318,7 @@ void reshape_and_cache_flash( torch::Tensor& slot_mapping, // [num_tokens] const std::string& kv_cache_dtype) { - // FIXME: only suport auto datatype, does not support fp8 + // FIXME: only support auto datatype, does not support fp8 if (kv_cache_dtype != "auto") { TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); } From 9ed5be0b39ab6c8d2e54971bb1c663df9f597a37 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 30 Apr 2024 17:53:10 +0000 Subject: [PATCH 11/19] raise exception for prefix caching --- vllm/attention/backends/flashinfer.py | 54 ++++++++++++++++++--------- vllm/worker/model_runner.py | 8 ++-- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index bc672644b3d9..8db8e558b6bc 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -60,12 +60,21 @@ class FlashInferMetadata(AttentionMetadataPerStage): # use flash attention for prefill. seq_start_loc: Optional[torch.Tensor] = None max_prompt_len: Optional[int] = None + block_tables: Optional[torch.Tensor] = None # Metadata for the decode stage # Workspace buffer required by the kernel, the buffer should not # be allocated/deacollated by the FalshInfermetadata workspace_buffer: Optional[torch.Tensor] = None - # The indptr of the paged kv cache, shape: [batch_size + 1] + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] paged_kv_indptr: Optional[torch.Tensor] = None # The page indices of the paged kv cache paged_kv_indices: Optional[torch.Tensor] = None @@ -84,6 +93,12 @@ class FlashInferMetadata(AttentionMetadataPerStage): data_type: torch.dtype = None def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + if self.head_dim is not None and self.head_dim not in [64, 128, 256]: + raise ValueError("Only [64, 128, 256] are supported for head_dim,", + f"received {self.head_dim}.") + if not self.is_prompt: self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, "NHD") @@ -141,10 +156,10 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, if attn_metadata.num_prefill_tokens > 0: assert attn_metadata.num_decode_tokens == 0, ( - "chunked prefill is not supported with flash infer yet") + "Chunked prefill is not supported with flashinfer yet.") if attn_metadata.num_decode_tokens > 0: assert attn_metadata.num_prefill_tokens == 0, ( - "chunked prefill is not supported with flash infer yet") + "Chunked prefill is not supported with flashinfer yet.") if kv_cache is not None: # Use the same reshape and cache kernel as flash attention. @@ -157,20 +172,25 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, attn_metadata.kv_cache_dtype, ) - if prefill_metadata := attn_metadata.prefill_metadata: - output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_metadata.seq_start_loc, - cu_seqlens_k=prefill_metadata.seq_start_loc, - max_seqlen_q=prefill_metadata.max_prompt_len, - max_seqlen_k=prefill_metadata.max_prompt_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) + if prefill_meta := attn_metadata.prefill_metadata: + assert prefill_meta.block_tables is not None + if kv_cache is None or prefill_meta.block_tables.numel() == 0: + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prompt_len, + max_seqlen_k=prefill_meta.max_prompt_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + else: + raise NotImplementedError( + "Prefix caching is not supported with flashinfer yet.") else: assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata.decode_wrapper is not None diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index aeb70fc49c53..cfaa72bd03aa 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -403,7 +403,8 @@ def _prepare_prompt( is_prompt=True, use_cuda_graph=False, seq_start_loc=seq_start_loc, - max_prompt_len=max_prompt_len) + max_prompt_len=max_prompt_len, + block_tables=block_tables) else: attn_metadata = self.attn_backend.make_metadata( is_prompt=True, @@ -455,10 +456,11 @@ def _prepare_decode( # [0, 5, 8, 1, 6, 7, 3, 4] # paged_kv_indptr is used to index into paged_kv_indices: # [0, 3, 6, 8] - # paged_kv_last_page_len is the length of the last page of each request paged_kv_indices: List[int] = [] - # 0 is always in paged_kv_indptr because the + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: From 0ffddd7edf6a5ce755113f2259a6e731ea1c8790 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 30 Apr 2024 18:48:09 +0000 Subject: [PATCH 12/19] try requirements --- requirements-cuda.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 60dd4326d46a..b20f571d7be1 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,5 +7,6 @@ nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 + --extra-index-url https://flashinfer.ai/whl/cu121/torch2.2/ flashinfer From 7f261df2695f49f3a60a53d936b958e5e77bd48d Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 1 May 2024 00:11:40 +0000 Subject: [PATCH 13/19] manual install flashinfer --- Dockerfile | 2 ++ requirements-cuda.txt | 3 --- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index e471a6e93b96..19bd483661d0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -94,6 +94,8 @@ WORKDIR /usr/src/flash-attention-v2 RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ --no-build-isolation --no-deps --no-cache-dir +# Flashinfer backend +RUN pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.2/ #################### FLASH_ATTENTION Build IMAGE #################### #################### vLLM installation IMAGE #################### diff --git a/requirements-cuda.txt b/requirements-cuda.txt index b20f571d7be1..6548d7a6684b 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,6 +7,3 @@ nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 - ---extra-index-url https://flashinfer.ai/whl/cu121/torch2.2/ -flashinfer From 5b7f02942475a73e2a59022d79321b3944dcc635 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 1 May 2024 04:30:12 +0000 Subject: [PATCH 14/19] change dockerfile package --- Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 19bd483661d0..a55a06bcb892 100644 --- a/Dockerfile +++ b/Dockerfile @@ -93,9 +93,6 @@ WORKDIR /usr/src/flash-attention-v2 # Download the wheel or build it if a pre-compiled release doesn't exist RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ --no-build-isolation --no-deps --no-cache-dir - -# Flashinfer backend -RUN pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.2/ #################### FLASH_ATTENTION Build IMAGE #################### #################### vLLM installation IMAGE #################### @@ -120,6 +117,9 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ --mount=type=cache,target=/root/.cache/pip \ pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir + +# Flashinfer backend +RUN pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.2/ #################### vLLM installation IMAGE #################### From 470df94e22c912bc7864bf4199e3a85fb561f08e Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 2 May 2024 18:08:37 +0000 Subject: [PATCH 15/19] remove test --- .buildkite/test-pipeline.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 937b6728f602..11cda053260e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -15,7 +15,6 @@ steps: commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py @@ -35,8 +34,6 @@ steps: - pytest -v -s test_pynccl_library.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py From 46ccb6310270c6003ec0eede8022157e4d7f1a90 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 2 May 2024 18:17:15 +0000 Subject: [PATCH 16/19] remove flashinfer from docker --- Dockerfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index a55a06bcb892..c9738ab97724 100644 --- a/Dockerfile +++ b/Dockerfile @@ -118,8 +118,6 @@ RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,ta --mount=type=cache,target=/root/.cache/pip \ pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir -# Flashinfer backend -RUN pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.2/ #################### vLLM installation IMAGE #################### From 14f7cefec49f12ccc2706c0f344a5916d7897de2 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 2 May 2024 18:24:54 +0000 Subject: [PATCH 17/19] revert docker change --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index c9738ab97724..e471a6e93b96 100644 --- a/Dockerfile +++ b/Dockerfile @@ -93,6 +93,7 @@ WORKDIR /usr/src/flash-attention-v2 # Download the wheel or build it if a pre-compiled release doesn't exist RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ --no-build-isolation --no-deps --no-cache-dir + #################### FLASH_ATTENTION Build IMAGE #################### #################### vLLM installation IMAGE #################### @@ -117,7 +118,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ --mount=type=cache,target=/root/.cache/pip \ pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir - #################### vLLM installation IMAGE #################### From e5245c06da29ef5bf324b9683f0d7510f718f821 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Fri, 3 May 2024 05:14:52 +0000 Subject: [PATCH 18/19] fix comments and import error --- vllm/attention/backends/flashinfer.py | 34 ++++++++++++++++++++------- vllm/attention/selector.py | 1 + 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 8db8e558b6bc..4e4015049af9 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,9 +1,16 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Type -import flashinfer +try: + import flashinfer + from flash_attn import flash_attn_varlen_func + from flashinfer import BatchDecodeWithPagedKVCacheWrapper +except ImportError: + flashinfer = None + flash_attn_varlen_func = None + BatchDecodeWithPagedKVCacheWrapper = None + import torch -from flash_attn import flash_attn_varlen_func from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -45,6 +52,10 @@ def copy_blocks( ) -> None: raise NotImplementedError + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 128, 256] + @dataclass class FlashInferMetadata(AttentionMetadataPerStage): @@ -53,8 +64,7 @@ class FlashInferMetadata(AttentionMetadataPerStage): use_cuda_graph: bool = False - decode_wrapper: Optional[ - flashinfer.BatchDecodeWithPagedKVCacheWrapper] = None + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None # Metadata for the prefill stage since we still # use flash attention for prefill. @@ -64,7 +74,7 @@ class FlashInferMetadata(AttentionMetadataPerStage): # Metadata for the decode stage # Workspace buffer required by the kernel, the buffer should not - # be allocated/deacollated by the FalshInfermetadata + # be allocated/deacollated by the FalshInfermetadata object. workspace_buffer: Optional[torch.Tensor] = None # An example for paged_kv_indices, paged_kv_indptr: # request 1, page indices [0, 5, 8] @@ -95,10 +105,16 @@ class FlashInferMetadata(AttentionMetadataPerStage): def __post_init__(self): # Refer to # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 - if self.head_dim is not None and self.head_dim not in [64, 128, 256]: - raise ValueError("Only [64, 128, 256] are supported for head_dim,", - f"received {self.head_dim}.") - + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + # When using flashinfer, we are also creating the FlashInferMetadata, + # which will also call post_init by default, here we want to skip the + # post_init if it's the prefill phase. if not self.is_prompt: self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, "NHD") diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index ca9d16711f8b..e0190fd56685 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -46,6 +46,7 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: return TorchSDPABackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") + logger.warning("Eager mode is enforced for the Flashinfer backend. ") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend else: From fb7615dfbaac2c1c77600e90b99d50d0759a41b6 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Fri, 3 May 2024 20:06:52 +0000 Subject: [PATCH 19/19] minor --- vllm/attention/backends/flashinfer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 4e4015049af9..8ab4b1f12ee3 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -69,7 +69,7 @@ class FlashInferMetadata(AttentionMetadataPerStage): # Metadata for the prefill stage since we still # use flash attention for prefill. seq_start_loc: Optional[torch.Tensor] = None - max_prompt_len: Optional[int] = None + max_seq_len: Optional[int] = None block_tables: Optional[torch.Tensor] = None # Metadata for the decode stage @@ -197,8 +197,8 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prompt_len, - max_seqlen_k=prefill_meta.max_prompt_len, + max_seqlen_q=prefill_meta.max_seq_len, + max_seqlen_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window,