From 06fe87293715e79aac88dfe75963fa1ec600f81f Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 27 Feb 2024 22:55:16 -0800 Subject: [PATCH 01/28] [1/n] Support efficient reshape caching. --- csrc/cache.h | 8 ++++ csrc/cache_kernels.cu | 89 +++++++++++++++++++++++++++++++++++++ csrc/pybind.cpp | 4 ++ tests/kernels/test_cache.py | 85 +++++++++++++++++++++++++++++++++++ vllm/utils.py | 13 +++++- 5 files changed, 197 insertions(+), 2 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 765e231abd26..ce3f96d59e83 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -23,6 +23,14 @@ void reshape_and_cache( torch::Tensor& slot_mapping, const std::string& kv_cache_dtype); +void reshape_and_cache_flash( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + torch::Tensor& num_tokens); + // Just for unittest void convert_fp8_e5m2( torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 7254010b8e3a..21057984e1e3 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -269,6 +269,95 @@ void reshape_and_cache( namespace vllm { +// flash-attention style cache funciton where key/value caches +// has the same shape of [num_blocks, block_size, num_heads, head_size] +template +__global__ void reshape_and_cache_flash_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, head_size] + scalar_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int64_t* __restrict__ num_tokens, // [1] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size) { + const int64_t num_tokens_ = num_tokens[0]; + const int64_t token_idx = blockIdx.x; + if (token_idx >= num_tokens_) { + return; + } + const int64_t slot_idx = slot_mapping[token_idx]; + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + + // Target idx for kv are the same. + const int64_t tgt_idx = block_idx * block_size * num_heads * head_size + + block_offset * num_heads * head_size + + head_idx * head_size + + head_offset; + + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + // TODO(sang): Support ENABLE_FP8_E5M2. + key_cache[tgt_idx] = tgt_key; + value_cache[tgt_idx] = tgt_value; + } +} + +} // namespace vllm + +// TODO(sang): Support kv_cache_dtype +void reshape_and_cache_flash( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + torch::Tensor& num_tokens) // [1] +{ + int num_tokens_padded = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(1); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + dim3 grid(num_tokens_padded); + dim3 block(std::min(num_heads * head_size, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), + "reshape_and_cache_flash_kernel", + [&] { + vllm::reshape_and_cache_flash_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), + num_tokens.data_ptr(), + key_stride, + value_stride, + num_heads, + head_size, + block_size); + }); +} + +namespace vllm { + template __global__ void convert_fp8_e5m2_kernel( const Tin* __restrict__ src_cache, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 5d062bb5700b..3b242327652f 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -79,6 +79,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); + cache_ops.def( + "reshape_and_cache_flash", + &reshape_and_cache_flash, + "Reshape the flash-style key and value tensors and cache them."); cache_ops.def( "convert_fp8_e5m2", &convert_fp8_e5m2, diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index d8dc74bc7b00..1430b76270c5 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -7,6 +7,7 @@ from vllm._C import cache_ops from vllm.utils import is_hip +import torch.nn.functional as F COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -25,6 +26,7 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] +PADDINGS = [8, 16, 0] @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @@ -224,3 +226,86 @@ def test_swap_blocks( dist_key_caches[0][dst].cpu()) assert torch.allclose(src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("padding", PADDINGS) +@torch.inference_mode() +def test_reshape_and_cache_flash( + kv_cache_factory, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + padding: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Create a random slot mapping. + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device='cuda') + _, key, value = qkv.unbind(dim=1) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(num_blocks, + block_size, + 1, + num_heads, + head_size, + dtype, + seed, + flash_style=True) + assert len(key_caches) == 1 and len(value_caches) == 0 + key_cache, value_cache = key_caches[0], value_caches[0] + + # Clone the KV caches. + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + num_tokens = torch.tensor([num_tokens], dtype=torch.long, device="cuda") + + def pad_key_value(key: torch.Tensor, value: torch.Tensor, + pad_size: int) -> Tuple[torch.Tensor, torch.Tensor]: + if pad_size == 0: + return key, value + return F.pad(key, (0, 0, 0, 0, 0, pad_size)),\ + F.pad(value, (0, 0, 0, 0, 0, pad_size)) + + # kv shapes: (num_blocks, block_size, num_heads, head_size) + # pad tokens. + padded_key, padded_value = pad_key_value(key, value, padding) + # Call the reshape_and_cache kernel. + cache_ops.reshape_and_cache_flash(padded_key, padded_value, key_cache, + value_cache, slot_mapping, num_tokens) + + # Run the reference implementation. + block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies[i] + block_offset = block_offsets[i] + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) diff --git a/vllm/utils.py b/vllm/utils.py index c8ac57de6f5f..03ebcc42524f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -245,6 +245,7 @@ def create_kv_caches_with_random( model_dtype: Optional[Union[str, torch.dtype]] = None, seed: Optional[int] = 0, device: Optional[str] = "cuda", + flash_style: bool = False, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -271,7 +272,11 @@ def create_kv_caches_with_random( scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + if flash_style: + key_cache_shape = (num_blocks, block_size, num_heads, head_size) + else: + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, + x) key_caches = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, @@ -286,7 +291,11 @@ def create_kv_caches_with_random( f"Does not support key cache of type {cache_dtype}") key_caches.append(key_cache) - value_cache_shape = (num_blocks, num_heads, head_size, block_size) + if flash_style: + value_cache_shape = (num_blocks, block_size, num_heads, head_size) + else: + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches = [] for _ in range(num_layers): value_cache = torch.empty(size=value_cache_shape, From 9a0b6bea2e88ec37899a2aecc4aee62a0d388f0a Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 27 Feb 2024 23:36:19 -0800 Subject: [PATCH 02/28] [2/n] support flash attention kernel --- tests/kernels/test_flash_attention.py | 460 ++++++++++++++++++++++++ vllm/model_executor/layers/attention.py | 145 ++++++++ 2 files changed, 605 insertions(+) create mode 100644 tests/kernels/test_flash_attention.py diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py new file mode 100644 index 000000000000..6081bfcc0886 --- /dev/null +++ b/tests/kernels/test_flash_attention.py @@ -0,0 +1,460 @@ +import random +from typing import List, Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + +from vllm.model_executor.layers.attention import ( + flash_single_query_cached_kv_attention, + flash_multi_query_cached_kv_attention_varlen, +) +from vllm.utils import get_max_shared_memory_bytes + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +NUM_BLOCKS = 128 # Arbitrary values for testing +PARTITION_SIZE = 512 + +DTYPES = [torch.half, torch.bfloat16] +NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing +NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing +NUM_HEADS_SMALL = NUM_HEADS +HEAD_SIZES = [64, 80, 96, 112, 128, 256] +BLOCK_SIZES = [32] +USE_ALIBI = [False, True] +SEEDS = [0] +PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] + + +def pad_attention_inputs( + pad_config: Tuple[int, int], + block_size: int, + query: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Pad the attention inputs to the specified batch size and context length. + """ + pad_batch_size, pad_max_context_len = pad_config + if pad_batch_size == 0: + return query, block_tables, context_lens, max_context_len + target_batch_size = ( + (query.shape[0] - 1) % pad_batch_size + 1) * pad_batch_size + target_block_size = pad_max_context_len // block_size + 1 + padded_query = F.pad(query, + (0, 0, 0, 0, 0, target_batch_size - query.shape[0])) + padded_block_table = F.pad(block_tables, + (0, target_block_size - block_tables.shape[1], + 0, target_batch_size - block_tables.shape[0])) + padded_context_lens = F.pad(context_lens, + (0, target_batch_size - context_lens.shape[0])) + return padded_query, padded_block_table, padded_context_lens, pad_max_context_len + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], + flash_style: bool = False, +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[-2] + head_size = value_cache.shape[-1] + block_size = value_cache.shape[-3] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + context_lens = context_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + context_len = int(context_lens[i]) + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + if flash_style: + k = key_cache[block_number, block_offset, :, :] + else: + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + if flash_style: + v = value_cache[block_number, block_offset, :, :] + else: + v = value_cache[block_number, :, :, block_offset] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(context_len, device="cuda").int() + alibi_bias = (position_ids - context_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +# @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +# @pytest.mark.parametrize("num_heads", NUM_HEADS) +# @pytest.mark.parametrize("head_size", HEAD_SIZES) +# @pytest.mark.parametrize("use_alibi", [False]) +# @pytest.mark.parametrize("block_size", [32]) +# @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +# @pytest.mark.parametrize("seed", SEEDS) +# @pytest.mark.parametrize("pad_config", PAD_CONFIGS) +@pytest.mark.parametrize("num_seqs", [3]) +@pytest.mark.parametrize("num_heads", [(40, 40), (64, 8)]) +@pytest.mark.parametrize("head_size", [80, 96]) +@pytest.mark.parametrize("use_alibi", [False]) +@pytest.mark.parametrize("block_size", [32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("pad_config", [(0, 0)]) +@torch.inference_mode() +def test_flash_paged_attention( + kv_cache_factory, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + seed: int, + pad_config: Tuple[int, int], +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, + dtype=torch.float, + device="cuda") + + max_seq_len = MAX_SEQ_LEN if not pad_config[0] else (pad_config[1] - 1000) + context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] + context_lens[-1] = max_seq_len + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + dtype, + seed, + flash_style=True) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Call the paged attention kernel. + num_valid_tokens = torch.cuda.IntTensor([num_seqs]) + output = torch.empty_like(query) + + padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ + pad_attention_inputs(pad_config, block_size, query, + block_tables, context_lens, max_context_len) + + flash_single_query_cached_kv_attention( + output, + padded_query, + key_cache, + value_cache, + scale, + padded_block_table, + padded_context_lens, + block_size, + alibi_slopes, + ) + + # Run the reference implementation. + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + context_lens, + scale, + alibi_slopes, + flash_style=True, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +def ref_multi_query_kv_attention_padded( + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + cu_seq_lens: List[int], + context_lens: List[int], + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + block_size = value_cache.shape[-3] + ref_outputs = [] + + for i in range(num_seqs): + q_start_idx = cu_seq_lens[i] + q_end_idx = cu_seq_lens[i + 1] + seq_len = q_end_idx - q_start_idx + + context_len = context_lens[i] + + block_table = block_tables[i] + keys = [] + values = [] + + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, block_offset, :, :] + keys.append(k) + + v = value_cache[block_number, block_offset, :, :] + values.append(v) + + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + q = query[q_start_idx:q_end_idx, :, :] + k = keys[:context_len, :, :] + v = values[:context_len, :, :] + + assert seq_len <= context_len + + # pad q if seq_len is less than context_len + # this is for correct calculation of attention. + if seq_len < context_len: + indices = [i % seq_len for i in range(context_len - seq_len)] + q_left_pad = q[indices, :, :] + q = torch.cat([q_left_pad, q], dim=0) + + # Create attention mask. + attn_mask = torch.triu(torch.ones(context_len, + context_len, + dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype, device="cuda") + + ref_output = ref_masked_attention( + q, + k, + v, + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output[-seq_len:, :, :]) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +def is_a100(): + return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 + + +if not is_a100(): + NUM_HEADS_SMALL = [(16, 16), (16, 8)] + MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) + +NUM_BLOCKS = 1024 +BLOCK_SIZE = 32 + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS_SMALL) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("version", ["flash"]) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_multi_query_kv_attention( + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + version: str, + seed: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. + # As the xformers library is already tested with its own tests, we can use + # a smaller MAX_SEQ_LEN here. + max_len = min(MAX_SEQ_LEN, 4096) + + seq_lens = [random.randint(1, max_len / 2) for i in range(num_seqs)] + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device="cuda") + + context_lens = random.sample(range(max_seq_len, max_len), num_seqs) + max_context_len = max(context_lens) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device="cuda") + + num_tokens = sum(seq_lens) + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + + cu_context_lens = [0] + for context_len in context_lens: + cu_context_lens.append(cu_context_lens[-1] + context_len) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + num_queries_per_kv = num_query_heads // num_kv_heads + + value_cache = torch.empty(NUM_BLOCKS, + BLOCK_SIZE, + num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + key_cache = torch.empty(NUM_BLOCKS, + BLOCK_SIZE, + num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + query = torch.empty(num_tokens, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + value_cache.uniform_(-scale, scale) + key_cache.uniform_(-scale, scale) + query.uniform_(-scale, scale) + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + BLOCK_SIZE - 1) // BLOCK_SIZE + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + output = torch.empty_like(query) + + if version == "flash": + flash_multi_query_cached_kv_attention_varlen( + output, + query, + key_cache, + value_cache, + scale, + block_tables, + torch.cuda.IntTensor(cu_seq_lens), + torch.cuda.IntTensor(cu_context_lens), + BLOCK_SIZE, + max_seq_len, + max_context_len, + None, + ) + else: + assert False, f"{version=} is not supported" + + ref_output = ref_multi_query_kv_attention_padded( + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + cu_seq_lens, + context_lens, + scale, + dtype, + ) + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) \ No newline at end of file diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 2a82325b8021..2cb8bdf0a7c8 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -15,6 +15,13 @@ context_attention_fwd) from vllm.utils import is_hip +try: + from flash_attn import (flash_attn_with_page_attention, + flash_attn_varlen_with_page_attention) +except Exception as e: + flash_attn_with_page_attention = e + flash_attn_varlen_with_page_attention = e + _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. _PARTITION_SIZE = 512 @@ -347,3 +354,141 @@ def _paged_attention( input_metadata.kv_cache_dtype, ) return output + + +def flash_single_query_cached_kv_attention( + output: Optional[torch.Tensor], + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + alibi_slopes: Optional[torch.Tensor], + actual_batch_size: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Similar to vLLM's page attention, caclulates a single token's attention + based on key/value caches. The main difference is this uses flash attention + sytle key-value caches. + Arguments: + output: [num_padded_tokens, num_heads, head_size], output tensor + to write. if None an new output tensor will be created. + query: [num_padded_tokens, num_heads, head_size], query tensor. + key_cache: [num_blocks, block_size, num_heads, head_size], key cache. + value_cache: [num_blocks, block_size, num_heads, head_size], value + cache. + scale: attention scale. + block_tables: [num_padded_tokens, max_context_len / block_size], + block tables. + context_lens: [num_padded_tokens], context lengths. + block_size: block size. + alibi_slopes: unused. + actual_batch_size: [1] actual batch size. + Returns: + output: [num_padded_tokens, num_heads, head_size] + """ + block_size = value_cache.shape[1] + assert block_size == 32, "only support block_size 32 for flash attention" + # TODO: support alibi_slopes + assert alibi_slopes is None, "doesn't support alibi_slopes" + num_tokens, num_heads, head_size = query.shape + out = flash_attn_with_page_attention( + query.view(num_tokens, 1, num_heads, head_size), + key_cache, + value_cache, + block_tables, + None, # key + None, # value + None, # cos + None, # sin + context_lens, + None, # cache_batch_idx + softmax_scale=scale, + causal=True, + window_size=(-1, -1), + rotary_interleaved=False, + num_splits=0, + actual_batch_size=actual_batch_size, + ) + if output is not None: + # in case that output is padded, only copy the valid part. + output[:num_tokens].copy_(out.view(num_tokens, num_heads, head_size)) + return out.view(num_tokens, num_heads, head_size) + + +def flash_multi_query_cached_kv_attention_varlen( + output: Optional[torch.Tensor], + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + scale: float, + block_tables: torch.Tensor, + cum_seqlens_q: torch.Tensor, + cum_context_len: torch.Tensor, + block_size: int, + max_query_len: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + actual_batch_size: Optional[torch.Tensor] = None, +): + """Efficient multi-query paged attention based on flash attention. + It calculates attentions of list of sequences packed in a single batch, + indexed by cum_seqlens_q where the seq_i's index is + [cum_seqlens_q[i], cum_seqlensq[i+1]]. + Similarlly, the length of context is stored in cum_seqlens_k with similar + fashions. + It also supports calculating attention incrementally, where context length + is longer than sequence length. + Arguments: + output: [num_padded_tokens, num_heads, head_size], output tensor to + write to. if None an new output tensor will be created. + query: [num_padded_tokens, num_heads, head_size], query tensor. + key_cache: [num_blocks, block_size, num_heads, head_size], key cache. + value_cache: [num_blocks, block_size, num_heads, head_size], + value cache. + scale: attention scale. + block_tables: [num_padded_tokens, max_context_len / block_size], + block tables. + cum_seqlens_q: [padded_batch_size + 1], cumulative sequence lengths + of query. + cum_context_len: [padded_batch_size + 1], cumulative lengths + of context. + block_size: block size. + max_query_len: max query length. + max_context_len: max context length. + alibi_slopes: unused. + actual_batch_size: [1] actual batch size. + Returns: + output: [num_padded_tokens, num_heads, head_size] + """ + block_size = value_cache.shape[1] + assert block_size == 32, "only support block_size 32 for flash attention" + # TODO: support alibi_slopes + assert alibi_slopes is None, "doesn't support alibi_slopes" + + num_tokens, _, _ = query.shape + out = flash_attn_varlen_with_page_attention( + query, + key_cache, + value_cache, + block_tables, + cum_seqlens_q, + cum_context_len, + max_query_len, + max_context_len, + None, # key + None, # value + None, # cos_cache + None, # sin_cache + None, # cache_batch_idx + softmax_scale=scale, + causal=True, + window_size=(-1, -1), + rotary_interleaved=False, + num_splits=0, + actual_batch_size=actual_batch_size, + ) + if output is not None: + output[:num_tokens].copy_(out) + return out \ No newline at end of file From 6947167ea42f592df37ead195dbfb5c6906609bd Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 28 Feb 2024 01:47:52 -0800 Subject: [PATCH 03/28] oss flash attention works --- tests/kernels/test_flash_attention.py | 250 ++---------------------- vllm/model_executor/layers/attention.py | 135 +++---------- 2 files changed, 43 insertions(+), 342 deletions(-) diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index 6081bfcc0886..82c4350489d5 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -1,16 +1,12 @@ import random -from typing import List, Optional, Tuple +from typing import Optional, Tuple import pytest import torch import torch.nn.functional as F -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from vllm.model_executor.layers.attention import ( - flash_single_query_cached_kv_attention, - flash_multi_query_cached_kv_attention_varlen, -) + flash_attn_with_kvcache_paged, ) from vllm.utils import get_max_shared_memory_bytes FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 @@ -25,11 +21,16 @@ NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS_SMALL = NUM_HEADS -HEAD_SIZES = [64, 80, 96, 112, 128, 256] -BLOCK_SIZES = [32] +# head size should be bigger than or equal to block size. +HEAD_SIZES = [256] +# TODO(sang): https://github.com/Dao-AILab/flash-attention/pull/824 +# should fix the block size. But right now, the block size should be +# divisible by 256. +BLOCK_SIZES = [256] USE_ALIBI = [False, True] SEEDS = [0] -PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] +# PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] +PAD_CONFIGS = [(0, 0)] def pad_attention_inputs( @@ -137,22 +138,14 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) -# @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -# @pytest.mark.parametrize("num_heads", NUM_HEADS) -# @pytest.mark.parametrize("head_size", HEAD_SIZES) -# @pytest.mark.parametrize("use_alibi", [False]) -# @pytest.mark.parametrize("block_size", [32]) -# @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) -# @pytest.mark.parametrize("seed", SEEDS) -# @pytest.mark.parametrize("pad_config", PAD_CONFIGS) -@pytest.mark.parametrize("num_seqs", [3]) -@pytest.mark.parametrize("num_heads", [(40, 40), (64, 8)]) -@pytest.mark.parametrize("head_size", [80, 96]) -@pytest.mark.parametrize("use_alibi", [False]) -@pytest.mark.parametrize("block_size", [32]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", [False, True]) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("pad_config", [(0, 0)]) +@pytest.mark.parametrize("pad_config", PAD_CONFIGS) @torch.inference_mode() def test_flash_paged_attention( kv_cache_factory, @@ -180,9 +173,6 @@ def test_flash_paged_attention( assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), - num_queries_per_kv) alibi_slopes = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, @@ -218,15 +208,13 @@ def test_flash_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. - num_valid_tokens = torch.cuda.IntTensor([num_seqs]) output = torch.empty_like(query) padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ pad_attention_inputs(pad_config, block_size, query, block_tables, context_lens, max_context_len) - flash_single_query_cached_kv_attention( - output, + output = flash_attn_with_kvcache_paged( padded_query, key_cache, value_cache, @@ -256,205 +244,3 @@ def test_flash_paged_attention( # implementations, there is a small numerical difference in the two # outputs. Thus, we use a relaxed tolerance for the test. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) - - -def ref_multi_query_kv_attention_padded( - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - cu_seq_lens: List[int], - context_lens: List[int], - scale: float, - dtype: torch.dtype, -) -> torch.Tensor: - num_seqs = len(cu_seq_lens) - 1 - block_size = value_cache.shape[-3] - ref_outputs = [] - - for i in range(num_seqs): - q_start_idx = cu_seq_lens[i] - q_end_idx = cu_seq_lens[i + 1] - seq_len = q_end_idx - q_start_idx - - context_len = context_lens[i] - - block_table = block_tables[i] - keys = [] - values = [] - - for j in range(context_len): - block_number = int(block_table[j // block_size]) - block_offset = j % block_size - - k = key_cache[block_number, block_offset, :, :] - keys.append(k) - - v = value_cache[block_number, block_offset, :, :] - values.append(v) - - keys = torch.stack(keys, dim=0) - values = torch.stack(values, dim=0) - - if num_queries_per_kv > 1: - # Handle MQA and GQA - keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) - values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - - q = query[q_start_idx:q_end_idx, :, :] - k = keys[:context_len, :, :] - v = values[:context_len, :, :] - - assert seq_len <= context_len - - # pad q if seq_len is less than context_len - # this is for correct calculation of attention. - if seq_len < context_len: - indices = [i % seq_len for i in range(context_len - seq_len)] - q_left_pad = q[indices, :, :] - q = torch.cat([q_left_pad, q], dim=0) - - # Create attention mask. - attn_mask = torch.triu(torch.ones(context_len, - context_len, - dtype=dtype), - diagonal=1) - attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype, device="cuda") - - ref_output = ref_masked_attention( - q, - k, - v, - scale, - attn_mask=attn_mask, - ) - ref_outputs.append(ref_output[-seq_len:, :, :]) - ref_output = torch.cat(ref_outputs, dim=0) - return ref_output - - -def is_a100(): - return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 - - -if not is_a100(): - NUM_HEADS_SMALL = [(16, 16), (16, 8)] - MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) - -NUM_BLOCKS = 1024 -BLOCK_SIZE = 32 - - -@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS_SMALL) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("version", ["flash"]) -@pytest.mark.parametrize("seed", SEEDS) -@torch.inference_mode() -def test_multi_query_kv_attention( - num_seqs: int, - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - version: str, - seed: int, -) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. - # As the xformers library is already tested with its own tests, we can use - # a smaller MAX_SEQ_LEN here. - max_len = min(MAX_SEQ_LEN, 4096) - - seq_lens = [random.randint(1, max_len / 2) for i in range(num_seqs)] - max_seq_len = max(seq_lens) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device="cuda") - - context_lens = random.sample(range(max_seq_len, max_len), num_seqs) - max_context_len = max(context_lens) - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device="cuda") - - num_tokens = sum(seq_lens) - cu_seq_lens = [0] - for seq_len in seq_lens: - cu_seq_lens.append(cu_seq_lens[-1] + seq_len) - - cu_context_lens = [0] - for context_len in context_lens: - cu_context_lens.append(cu_context_lens[-1] + context_len) - - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - num_queries_per_kv = num_query_heads // num_kv_heads - - value_cache = torch.empty(NUM_BLOCKS, - BLOCK_SIZE, - num_kv_heads, - head_size, - dtype=dtype, - device="cuda") - key_cache = torch.empty(NUM_BLOCKS, - BLOCK_SIZE, - num_kv_heads, - head_size, - dtype=dtype, - device="cuda") - query = torch.empty(num_tokens, - num_query_heads, - head_size, - dtype=dtype, - device="cuda") - value_cache.uniform_(-scale, scale) - key_cache.uniform_(-scale, scale) - query.uniform_(-scale, scale) - - # Create the block tables. - max_num_blocks_per_seq = (max_context_len + BLOCK_SIZE - 1) // BLOCK_SIZE - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - - output = torch.empty_like(query) - - if version == "flash": - flash_multi_query_cached_kv_attention_varlen( - output, - query, - key_cache, - value_cache, - scale, - block_tables, - torch.cuda.IntTensor(cu_seq_lens), - torch.cuda.IntTensor(cu_context_lens), - BLOCK_SIZE, - max_seq_len, - max_context_len, - None, - ) - else: - assert False, f"{version=} is not supported" - - ref_output = ref_multi_query_kv_attention_padded( - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - cu_seq_lens, - context_lens, - scale, - dtype, - ) - assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) \ No newline at end of file diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 2cb8bdf0a7c8..31ead4841139 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -15,12 +15,11 @@ context_attention_fwd) from vllm.utils import is_hip +# TODO(sang): Support varlen API. try: - from flash_attn import (flash_attn_with_page_attention, - flash_attn_varlen_with_page_attention) -except Exception as e: - flash_attn_with_page_attention = e - flash_attn_varlen_with_page_attention = e + from flash_attn import flash_attn_with_kvcache +except ImportError: + flash_attn_with_kvcache = None _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -356,7 +355,7 @@ def _paged_attention( return output -def flash_single_query_cached_kv_attention( +def flash_attn_with_kvcache_paged( output: Optional[torch.Tensor], query: torch.Tensor, key_cache: torch.Tensor, @@ -365,130 +364,46 @@ def flash_single_query_cached_kv_attention( block_tables: torch.Tensor, context_lens: torch.Tensor, block_size: int, - alibi_slopes: Optional[torch.Tensor], - actual_batch_size: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Similar to vLLM's page attention, caclulates a single token's attention + """Similar to vLLM's page attention, calculates a single token's attention based on key/value caches. The main difference is this uses flash attention - sytle key-value caches. + style key-value caches. + Arguments: output: [num_padded_tokens, num_heads, head_size], output tensor to write. if None an new output tensor will be created. - query: [num_padded_tokens, num_heads, head_size], query tensor. - key_cache: [num_blocks, block_size, num_heads, head_size], key cache. - value_cache: [num_blocks, block_size, num_heads, head_size], value - cache. - scale: attention scale. - block_tables: [num_padded_tokens, max_context_len / block_size], - block tables. - context_lens: [num_padded_tokens], context lengths. - block_size: block size. - alibi_slopes: unused. - actual_batch_size: [1] actual batch size. + See https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py + for other arguments. + Returns: output: [num_padded_tokens, num_heads, head_size] """ block_size = value_cache.shape[1] - assert block_size == 32, "only support block_size 32 for flash attention" - # TODO: support alibi_slopes - assert alibi_slopes is None, "doesn't support alibi_slopes" + assert block_size % 256 == 0, "only support block_size divisible by 256." num_tokens, num_heads, head_size = query.shape - out = flash_attn_with_page_attention( + out = flash_attn_with_kvcache( query.view(num_tokens, 1, num_heads, head_size), key_cache, value_cache, - block_tables, - None, # key - None, # value - None, # cos - None, # sin - context_lens, - None, # cache_batch_idx + # Inplace update is slow. We don't use it. + # We assume kvcache is already updated before + # calling this API. + None, + None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens=context_lens, + cache_batch_idx=None, + block_table=block_tables, softmax_scale=scale, causal=True, window_size=(-1, -1), rotary_interleaved=False, + alibi_slopes=alibi_slopes, num_splits=0, - actual_batch_size=actual_batch_size, ) if output is not None: # in case that output is padded, only copy the valid part. output[:num_tokens].copy_(out.view(num_tokens, num_heads, head_size)) return out.view(num_tokens, num_heads, head_size) - - -def flash_multi_query_cached_kv_attention_varlen( - output: Optional[torch.Tensor], - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - scale: float, - block_tables: torch.Tensor, - cum_seqlens_q: torch.Tensor, - cum_context_len: torch.Tensor, - block_size: int, - max_query_len: int, - max_context_len: int, - alibi_slopes: Optional[torch.Tensor], - actual_batch_size: Optional[torch.Tensor] = None, -): - """Efficient multi-query paged attention based on flash attention. - It calculates attentions of list of sequences packed in a single batch, - indexed by cum_seqlens_q where the seq_i's index is - [cum_seqlens_q[i], cum_seqlensq[i+1]]. - Similarlly, the length of context is stored in cum_seqlens_k with similar - fashions. - It also supports calculating attention incrementally, where context length - is longer than sequence length. - Arguments: - output: [num_padded_tokens, num_heads, head_size], output tensor to - write to. if None an new output tensor will be created. - query: [num_padded_tokens, num_heads, head_size], query tensor. - key_cache: [num_blocks, block_size, num_heads, head_size], key cache. - value_cache: [num_blocks, block_size, num_heads, head_size], - value cache. - scale: attention scale. - block_tables: [num_padded_tokens, max_context_len / block_size], - block tables. - cum_seqlens_q: [padded_batch_size + 1], cumulative sequence lengths - of query. - cum_context_len: [padded_batch_size + 1], cumulative lengths - of context. - block_size: block size. - max_query_len: max query length. - max_context_len: max context length. - alibi_slopes: unused. - actual_batch_size: [1] actual batch size. - Returns: - output: [num_padded_tokens, num_heads, head_size] - """ - block_size = value_cache.shape[1] - assert block_size == 32, "only support block_size 32 for flash attention" - # TODO: support alibi_slopes - assert alibi_slopes is None, "doesn't support alibi_slopes" - - num_tokens, _, _ = query.shape - out = flash_attn_varlen_with_page_attention( - query, - key_cache, - value_cache, - block_tables, - cum_seqlens_q, - cum_context_len, - max_query_len, - max_context_len, - None, # key - None, # value - None, # cos_cache - None, # sin_cache - None, # cache_batch_idx - softmax_scale=scale, - causal=True, - window_size=(-1, -1), - rotary_interleaved=False, - num_splits=0, - actual_batch_size=actual_batch_size, - ) - if output is not None: - output[:num_tokens].copy_(out) - return out \ No newline at end of file From 4769a2636392d4ac1f25b2af758d008de4533f88 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 28 Feb 2024 06:47:18 -0800 Subject: [PATCH 04/28] in progress --- .buildkite/test-pipeline.yaml | 4 + benchmarks/benchmark_latency.py | 40 ++++++++- csrc/cache_kernels.cu | 9 +- requirements.txt | 1 + tests/chunked_prefill/test_correctness.py | 82 ++++++++++++++++++ tests/conftest.py | 3 + tests/kernels/test_cache.py | 2 +- tests/kernels/test_flash_attention.py | 1 - vllm/config.py | 14 +++ vllm/engine/arg_utils.py | 9 +- vllm/model_executor/input_metadata.py | 46 ++++++++++ vllm/model_executor/layers/attention.py | 101 ++++++++++++++-------- vllm/model_executor/models/llama.py | 13 ++- vllm/worker/cache_engine.py | 36 +++++--- vllm/worker/model_runner.py | 1 + vllm/worker/worker.py | 11 +++ 16 files changed, 310 insertions(+), 63 deletions(-) create mode 100644 tests/chunked_prefill/test_correctness.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index efcc4d2d07a1..c16bb4e3da24 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -43,6 +43,10 @@ steps: commands: - pytest -v -s prefix_caching +- label: Chunked Prefill Test + commands: + - pytest -v -s chunked_prefill + - label: Samplers Test command: pytest -v -s samplers --forked diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 6e3b679cb81b..d8083a826911 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -10,6 +10,13 @@ from vllm import LLM, SamplingParams +SAMPLE_PROMPTS = [ + "The president of the United States is", + "Hello, my name is", + "The capital of France is", + "The future of AI is", +] + def main(args: argparse.Namespace): print(args) @@ -57,10 +64,24 @@ def run_to_completion(profile_dir: Optional[str] = None): print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(prompt_token_ids=dummy_prompt_token_ids, - sampling_params=sampling_params, - use_tqdm=False) + if args.use_sample: + batch = ( + SAMPLE_PROMPTS * + (args.batch_size // len(SAMPLE_PROMPTS) + 1))[:args.batch_size] + outputs = llm.generate(prompts=batch, + sampling_params=sampling_params, + use_tqdm=False) + else: + outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids, + sampling_params=sampling_params, + use_tqdm=False) end_time = time.perf_counter() + if args.verbose: + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print( + f"Prompt: {prompt!r}, Generated text: {generated_text!r}") latency = end_time - start_time return latency @@ -145,5 +166,18 @@ def run_to_completion(profile_dir: Optional[str] = None): default="cuda", choices=["cuda"], help='device type for vLLM execution, supporting CUDA only currently.') + parser.add_argument('--flash-style', + action='store_true', + help='enable flash attention') + parser.add_argument('--block-size', + type=int, + default=16, + help='block size of key/value cache') + parser.add_argument('--use-sample', + action='store_true', + help='use sample input instead of dummy input') + parser.add_argument('--verbose', + action='store_true', + help='print generated text') args = parser.parse_args() main(args) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 21057984e1e3..7c44deafea6a 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -278,17 +278,12 @@ __global__ void reshape_and_cache_flash_kernel( scalar_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, head_size] scalar_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, head_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int64_t* __restrict__ num_tokens, // [1] const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size) { - const int64_t num_tokens_ = num_tokens[0]; const int64_t token_idx = blockIdx.x; - if (token_idx >= num_tokens_) { - return; - } const int64_t slot_idx = slot_mapping[token_idx]; const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; @@ -323,8 +318,7 @@ void reshape_and_cache_flash( torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] - torch::Tensor& num_tokens) // [1] + torch::Tensor& slot_mapping) // [num_tokens] { int num_tokens_padded = key.size(0); int num_heads = key.size(1); @@ -347,7 +341,6 @@ void reshape_and_cache_flash( key_cache.data_ptr(), value_cache.data_ptr(), slot_mapping.data_ptr(), - num_tokens.data_ptr(), key_stride, value_stride, num_heads, diff --git a/requirements.txt b/requirements.txt index d4599ec95d94..384993a4e4c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ prometheus_client >= 0.18.0 pynvml == 11.5.0 triton >= 2.1.0 cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. +flash-attn >= 2.5.0 # Required for chunked prefill. diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py new file mode 100644 index 000000000000..abd0e1f7199a --- /dev/null +++ b/tests/chunked_prefill/test_correctness.py @@ -0,0 +1,82 @@ +import gc + +from typing import List + +import pytest +import torch + +from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel + +MODELS = [ + "JackFram/llama-68m", +] + +# SANG-TODO Read it from example.txt +TEST_PROMPTS = [ + # pylint: disable=line-too-long + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", + "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", + "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", + # Different between page attention and flash attention. + # "Describe the basic components of a neural network and how it can be trained.", + "Write a short story about a robot that dreams for the first time.", + "Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.", + "Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.", + "Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'", +] + + +# TODO(sang): Add chunked prefill parameters. +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_models( + vllm_runner, + model: str, + dtype: str, + max_tokens: int, +) -> None: + """ verify the flash attention has the same output + as page attention """ + print("loading page attention models..") + pg_model = vllm_runner(model, dtype=dtype) + expected_outputs = [] + + print("generating tokens...") + expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) + print("generating tokens finished") + + del pg_model + + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + + flash_attn_model = vllm_runner( + model, + dtype=dtype, + enable_cuda_graph=False, + flash_style=True, + ) + flash_attn_output_by_batchs = [] + for i in range(10): + prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] + flash_attn_output_by_batchs.append( + flash_attn_model.generate_greedy(prompts, max_tokens)) + + del flash_attn_model + + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + + for flash_attn_outputs in flash_attn_output_by_batchs: + for i in range(len(flash_attn_outputs)): + fa_output_ids, fa_output_str = flash_attn_outputs[i] + vllm_output_ids, vllm_output_str = expected_outputs[ + i % len(expected_outputs)] + print() + assert fa_output_ids == vllm_output_ids, ( + f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" + f"Test{i}:\nflash ouput: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" + ) diff --git a/tests/conftest.py b/tests/conftest.py index 30a3df89d9f1..b579bd961bb9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -165,6 +165,7 @@ def __init__( dtype: str = "half", disable_log_stats: bool = True, tensor_parallel_size: int = 1, + flash_style: bool = False, **kwargs, ) -> None: self.model = LLM( @@ -175,6 +176,8 @@ def __init__( swap_space=0, disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, + flash_style=flash_style, + block_size=32, **kwargs, ) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 1430b76270c5..4b7bb67b6ee6 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -294,7 +294,7 @@ def pad_key_value(key: torch.Tensor, value: torch.Tensor, padded_key, padded_value = pad_key_value(key, value, padding) # Call the reshape_and_cache kernel. cache_ops.reshape_and_cache_flash(padded_key, padded_value, key_cache, - value_cache, slot_mapping, num_tokens) + value_cache, slot_mapping) # Run the reference implementation. block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index 82c4350489d5..a9c43e31edc0 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -221,7 +221,6 @@ def test_flash_paged_attention( scale, padded_block_table, padded_context_lens, - block_size, alibi_slopes, ) diff --git a/vllm/config.py b/vllm/config.py index bd0dc89b585f..697801b439bc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -60,6 +60,7 @@ class ModelConfig: max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. + flash_style: Enable flash style page attention. """ def __init__( @@ -79,6 +80,7 @@ def __init__( quantization: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, + flash_style: bool = False, ) -> None: self.model = model self.tokenizer = tokenizer @@ -93,6 +95,7 @@ def __init__( self.quantization = quantization self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture + self.flash_style = flash_style if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, @@ -295,12 +298,14 @@ def __init__( swap_space: int, cache_dtype: str, sliding_window: Optional[int] = None, + flash_style: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB self.cache_dtype = cache_dtype self.sliding_window = sliding_window + self.flash_style = flash_style self._verify_args() self._verify_cache_dtype() @@ -314,6 +319,15 @@ def _verify_args(self) -> None: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + if self.flash_style: + logger.info("Flash attention enabled.") + if self.block_size < 256: + # Flash style attention only supports block size >=256 for now. + # https://github.com/Dao-AILab/flash-attention/pull/824 will fix it. + raise ValueError( + "Flash style attention only supports block size >= 256. Got" + f"{self.block_size }") + def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a4efd171b871..e064f357e069 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -45,6 +45,7 @@ class EngineArgs: lora_dtype = 'auto' max_cpu_loras: Optional[int] = None device: str = 'cuda' + flash_style: bool = False def __post_init__(self): if self.tokenizer is None: @@ -271,6 +272,9 @@ def add_cli_args( choices=["cuda"], help=('Device type for vLLM execution. ' 'Currently, only CUDA-compatible devices are supported.')) + parser.add_argument('--flash-style', + action='store_true', + help='use flash attention.') return parser @classmethod @@ -291,11 +295,12 @@ def create_engine_configs( self.trust_remote_code, self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, - self.enforce_eager, self.max_context_len_to_capture) + self.enforce_eager, self.max_context_len_to_capture, self.flash_style) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, - model_config.get_sliding_window()) + model_config.get_sliding_window(), + self.flash_style) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index f0a88ac8e27f..1beccd380355 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -13,6 +13,10 @@ class InputMetadata: context_lens: the length of attention context for each sequence. block_tables: The block tables. (Seq id -> list of physical block) kv_cache_dtype: Data type to store kv cache. + num_prompt_tokens: The number of tokens in the prompts. This might + include padding. + num_generation_tokens: The number of tokens in the generation sequences. + This might include padding. """ def __init__( @@ -27,6 +31,9 @@ def __init__( block_tables: Optional[torch.Tensor], use_cuda_graph: bool, kv_cache_dtype: str, + # SANG-TODO + # num_prompt_tokens: int, + # num_generation_tokens: int, ) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens @@ -43,6 +50,45 @@ def __init__( # FIXME(woosuk): This is a hack. self.attn_bias = None + # SANG-TODO + # # Prompt related metadata + # # This value might include padding if CudaGraph is enabled. + # self.num_prompts = len(prompt_lens) + # # This value is the source of truth. + # self.num_prompts_tensor = torch.cuda.IntTensor([self.num_prompts]) + # # This value might include padding if CudaGraph is enabled. + # self.num_prompt_tokens = num_prompt_tokens + # self.prompt_lens_tensor = torch.cuda.IntTensor(self.prompt_lens) + # self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 + + # # Cumulative prompt lengths for each prompt in the input + # # tensor. + # self.cum_prompt_query_lens = torch.zeros( + # self.num_prompts + 1, + # device=self.prompt_lens_tensor.device, + # dtype=torch.int32) + # # Cumulative context lengths. + # self.cum_prompt_context_lens = torch.zeros( + # self.num_prompts + 1, + # device=self.prompt_lens_tensor.device, + # dtype=torch.int32) + + # torch.cumsum(self.prompt_lens_tensor, + # dim=0, + # dtype=self.cum_prompt_query_lens.dtype, + # out=self.cum_prompt_query_lens[1:]) + + # # TODO: this will be different once we support chunked prefills. + # self.cum_prompt_context_lens = self.cum_prompt_query_lens + # self.max_context_len = max(self.max_context_len, self.max_prompt_len) + + # # Generation related metadata + # # This value might include padding if CudaGraph is enabled. + # self.num_generation_tokens = num_generation_tokens + # # This is the source of truth for the number of generation tokens. + # self.num_generation_tokens_tensor = torch.cuda.IntTensor( + # [num_generation_tokens]) + def __repr__(self) -> str: return ("InputMetadata(" f"is_prompt={self.is_prompt}, " diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 31ead4841139..3c0a7b54ab2b 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -14,6 +14,8 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( context_attention_fwd) from vllm.utils import is_hip +from vllm.model_executor.layers.attention import ( + flash_attn_with_kvcache_paged, ) # TODO(sang): Support varlen API. try: @@ -47,6 +49,7 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + flash_style: bool = False, ) -> None: super().__init__() self.num_heads = num_heads @@ -57,6 +60,7 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) + self.flash_style = flash_style assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -133,14 +137,23 @@ def forward( # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: - cache_ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - input_metadata.slot_mapping.flatten(), - input_metadata.kv_cache_dtype, - ) + if self.flash_style: + cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping.flatten() + ) + else: + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping.flatten(), + input_metadata.kv_cache_dtype, + ) if input_metadata.is_prompt: # normal attention @@ -215,32 +228,56 @@ def forward( else: # prefix-enabled attention output = torch.empty_like(query) - context_attention_fwd( + if self.flash_attention: + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + output = flash_attn_with_kvcache_paged( + query + key_cache: torch.Tensor, + value_cache: torch.Tensor, + self.scale + input_metadata.block_tables, + input_metadata.context_lens, + self.alibi_slopes, + ) + else: + context_attention_fwd( + query, + key, + value, + output, + key_cache, + value_cache, + input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata.start_loc, + input_metadata.prompt_lens, + input_metadata.context_lens, + input_metadata.max_seq_len, + getattr(self, "alibi_slopes", None), + ) + + else: + # Decoding run. + if self.flash_style: + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + output = flash_attn_with_kvcache_paged( query, - key, - value, - output, key_cache, value_cache, - input_metadata.block_tables, # [BS, max_block_per_request] - input_metadata.start_loc, - input_metadata.prompt_lens, + self.scale, + input_metadata.block_tables, input_metadata.context_lens, - input_metadata.max_seq_len, - getattr(self, "alibi_slopes", None), + self.alibi_slopes + ) + else: + output = _paged_attention( + query, + key_cache, + value_cache, + input_metadata, + self.num_kv_heads, + self.scale, + self.alibi_slopes, ) - - else: - # Decoding run. - output = _paged_attention( - query, - key_cache, - value_cache, - input_metadata, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - ) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) @@ -305,6 +342,7 @@ def _paged_attention( # For context len > 8192, use V2 kernel to avoid shared memory shortage. use_v1 = input_metadata.max_context_len <= 8192 and ( max_num_partitions == 1 or num_seqs * num_heads > 512) + if use_v1: # Run PagedAttention V1. ops.paged_attention_v1( @@ -356,14 +394,12 @@ def _paged_attention( def flash_attn_with_kvcache_paged( - output: Optional[torch.Tensor], query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, scale: float, block_tables: torch.Tensor, context_lens: torch.Tensor, - block_size: int, alibi_slopes: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Similar to vLLM's page attention, calculates a single token's attention @@ -403,7 +439,4 @@ def flash_attn_with_kvcache_paged( alibi_slopes=alibi_slopes, num_splits=0, ) - if output is not None: - # in case that output is padded, only copy the valid part. - output[:num_tokens].copy_(out.view(num_tokens, num_heads, head_size)) return out.view(num_tokens, num_heads, head_size) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b7f6b8f3ec37..9fd4425bca02 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -93,6 +93,7 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, bias: bool = False, sliding_window: Optional[int] = None, + flash_style: bool = False, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -143,7 +144,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + sliding_window=sliding_window, + flash_style=flash_style) def forward( self, @@ -151,6 +153,7 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, + flash_style: bool = False, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -167,6 +170,7 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + flash_style: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -186,6 +190,7 @@ def __init__( linear_method=linear_method, bias=getattr(config, "bias", False), sliding_window=sliding_window, + flash_style=flash_style, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -234,6 +239,7 @@ def __init__( config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, lora_config: Optional[LoRAConfig] = None, + flash_style: bool = False, ) -> None: super().__init__() self.config = config @@ -248,7 +254,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method) + LlamaDecoderLayer(config, linear_method, flash_style) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -308,11 +314,12 @@ def __init__( config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, lora_config: Optional[LoRAConfig] = None, + flash_style: bool = False, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = LlamaModel(config, linear_method, lora_config=lora_config) + self.model = LlamaModel(config, linear_method, lora_config=lora_config, flash_style=flash_style) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index bbe33989fc2a..80d7b68b884b 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -57,19 +57,33 @@ def __init__( def get_key_block_shape(self) -> Tuple[int, int, int, int]: element_size = torch.tensor([], dtype=self.dtype).element_size() x = 16 // element_size - return ( - self.num_heads, - self.head_size // x, - self.block_size, - x, - ) + if self.cache_config.flash_style: + return ( + self.block_size, + self.num_heads, + self.head_size, + ) + else: + return ( + self.num_heads, + self.head_size // x, + self.block_size, + x, + ) def get_value_block_shape(self) -> Tuple[int, int, int]: - return ( - self.num_heads, - self.head_size, - self.block_size, - ) + if self.cache_config.flash_style: + return ( + self.block_size, + self.num_heads, + self.head_size, + ) + else: + return ( + self.num_heads, + self.head_size, + self.block_size, + ) def allocate_gpu_cache(self) -> List[KVCache]: gpu_cache: List[KVCache] = [] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b99a409e02d1..467ad28226b4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -475,6 +475,7 @@ def prepare_input_tensors( # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt + # SANG-TODO set num prompt tokens and generations? # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_metadata, prompt_lens, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9df518d155ec..870c7b7edd40 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -21,6 +21,8 @@ from vllm.lora.request import LoRARequest from vllm.utils import is_hip +MAX_INT_32 = 2**31 - 1 + class Worker: """A worker class that executes (a partition of) the model on a GPU. @@ -143,6 +145,15 @@ def profile_num_available_blocks( self.model_runner.remove_all_loras() gc.collect() torch.cuda.empty_cache() + + # Flash attention only allows max(int32) slots. + # TODO: Remove this once we support int64 in flash-attention. + if self.model_config.flash_style: + num_gpu_blocks = min(num_gpu_blocks, + MAX_INT_32 // cache_block_size) + num_cpu_blocks = min(num_cpu_blocks, + MAX_INT_32 // cache_block_size) + return num_gpu_blocks, num_cpu_blocks def init_cache_engine(self, cache_config: CacheConfig) -> None: From 963db44683aabeb8ac78774b8618b874403cc7cb Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 00:46:13 -0800 Subject: [PATCH 05/28] flash attn enabled. --- benchmarks/benchmark_latency.py | 17 ++++--- csrc/cache.h | 3 +- csrc/cache_kernels.cu | 7 ++- tests/chunked_prefill/test_correctness.py | 14 +++--- tests/conftest.py | 3 +- tests/kernels/test_cache.py | 22 ++------ vllm/engine/arg_utils.py | 3 +- vllm/model_executor/input_metadata.py | 28 ++++------- vllm/model_executor/layers/attention.py | 61 +++++++++-------------- vllm/model_executor/model_loader.py | 2 +- vllm/model_executor/models/__init__.py | 22 +++++++- vllm/model_executor/models/llama.py | 13 ++--- vllm/worker/model_runner.py | 8 +-- 13 files changed, 93 insertions(+), 110 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index d8083a826911..b915f913e7d1 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -65,23 +65,24 @@ def run_to_completion(profile_dir: Optional[str] = None): else: start_time = time.perf_counter() if args.use_sample: - batch = ( - SAMPLE_PROMPTS * - (args.batch_size // len(SAMPLE_PROMPTS) + 1))[:args.batch_size] + batch = (SAMPLE_PROMPTS * + (args.batch_size // len(SAMPLE_PROMPTS) + + 1))[:args.batch_size] outputs = llm.generate(prompts=batch, - sampling_params=sampling_params, - use_tqdm=False) + sampling_params=sampling_params, + use_tqdm=False) else: outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids, - sampling_params=sampling_params, - use_tqdm=False) + sampling_params=sampling_params, + use_tqdm=False) end_time = time.perf_counter() if args.verbose: for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print( - f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + f"Prompt: {prompt!r}, Generated text: {generated_text!r}" + ) latency = end_time - start_time return latency diff --git a/csrc/cache.h b/csrc/cache.h index ce3f96d59e83..1bca7e4e39a9 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -28,8 +28,7 @@ void reshape_and_cache_flash( torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - torch::Tensor& num_tokens); + torch::Tensor& slot_mapping); // Just for unittest void convert_fp8_e5m2( diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 7c44deafea6a..06da455117b0 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -174,6 +174,7 @@ __global__ void reshape_and_cache_kernel( const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size; for (int i = threadIdx.x; i < n; i += blockDim.x) { const int64_t src_key_idx = token_idx * key_stride + i; @@ -285,6 +286,10 @@ __global__ void reshape_and_cache_flash_kernel( const int block_size) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; @@ -318,7 +323,7 @@ void reshape_and_cache_flash( torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping) // [num_tokens] + torch::Tensor& slot_mapping) { int num_tokens_padded = key.size(0); int num_heads = key.size(1); diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index abd0e1f7199a..1cabbbb9e3e1 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -11,7 +11,6 @@ "JackFram/llama-68m", ] -# SANG-TODO Read it from example.txt TEST_PROMPTS = [ # pylint: disable=line-too-long "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", @@ -30,11 +29,13 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("block_size", [256]) def test_models( vllm_runner, model: str, dtype: str, max_tokens: int, + block_size: int, ) -> None: """ verify the flash attention has the same output as page attention """ @@ -52,12 +53,10 @@ def test_models( gc.collect() torch.cuda.empty_cache() - flash_attn_model = vllm_runner( - model, - dtype=dtype, - enable_cuda_graph=False, - flash_style=True, - ) + flash_attn_model = vllm_runner(model, + dtype=dtype, + flash_style=True, + block_size=block_size) flash_attn_output_by_batchs = [] for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] @@ -75,7 +74,6 @@ def test_models( fa_output_ids, fa_output_str = flash_attn_outputs[i] vllm_output_ids, vllm_output_str = expected_outputs[ i % len(expected_outputs)] - print() assert fa_output_ids == vllm_output_ids, ( f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" f"Test{i}:\nflash ouput: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" diff --git a/tests/conftest.py b/tests/conftest.py index b579bd961bb9..4ccac964b999 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -165,6 +165,7 @@ def __init__( dtype: str = "half", disable_log_stats: bool = True, tensor_parallel_size: int = 1, + block_size: int = 32, flash_style: bool = False, **kwargs, ) -> None: @@ -177,7 +178,7 @@ def __init__( disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, flash_style=flash_style, - block_size=32, + block_size=block_size, **kwargs, ) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 4b7bb67b6ee6..23adc10df7cf 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -26,7 +26,6 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] -PADDINGS = [8, 16, 0] @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @@ -235,7 +234,6 @@ def test_swap_blocks( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("padding", PADDINGS) @torch.inference_mode() def test_reshape_and_cache_flash( kv_cache_factory, @@ -246,7 +244,6 @@ def test_reshape_and_cache_flash( num_blocks: int, dtype: torch.dtype, seed: int, - padding: int, ) -> None: random.seed(seed) torch.random.manual_seed(seed) @@ -274,27 +271,16 @@ def test_reshape_and_cache_flash( dtype, seed, flash_style=True) - assert len(key_caches) == 1 and len(value_caches) == 0 + assert len(key_caches) == 1 and len(value_caches) == 1 key_cache, value_cache = key_caches[0], value_caches[0] # Clone the KV caches. cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() - num_tokens = torch.tensor([num_tokens], dtype=torch.long, device="cuda") - - def pad_key_value(key: torch.Tensor, value: torch.Tensor, - pad_size: int) -> Tuple[torch.Tensor, torch.Tensor]: - if pad_size == 0: - return key, value - return F.pad(key, (0, 0, 0, 0, 0, pad_size)),\ - F.pad(value, (0, 0, 0, 0, 0, pad_size)) - - # kv shapes: (num_blocks, block_size, num_heads, head_size) - # pad tokens. - padded_key, padded_value = pad_key_value(key, value, padding) + # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache_flash(padded_key, padded_value, key_cache, - value_cache, slot_mapping) + cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping) # Run the reference implementation. block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e064f357e069..00c9f688e2a2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -295,7 +295,8 @@ def create_engine_configs( self.trust_remote_code, self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, - self.enforce_eager, self.max_context_len_to_capture, self.flash_style) + self.enforce_eager, self.max_context_len_to_capture, + self.flash_style) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 1beccd380355..97c23cdda069 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -9,6 +9,7 @@ class InputMetadata: Args: prompt_lens: Lengths of prompts. slot_mapping: The address to write the new KV to of each token. + index: token_id, value: address within kv_cache. max_context_len: The maximum context length. context_lens: the length of attention context for each sequence. block_tables: The block tables. (Seq id -> list of physical block) @@ -19,22 +20,13 @@ class InputMetadata: This might include padding. """ - def __init__( - self, - is_prompt: bool, - slot_mapping: torch.Tensor, - prompt_lens: Optional[torch.Tensor], - max_seq_len: Optional[int], - start_loc: Optional[torch.Tensor], - max_context_len: Optional[int], - context_lens: Optional[torch.Tensor], - block_tables: Optional[torch.Tensor], - use_cuda_graph: bool, - kv_cache_dtype: str, - # SANG-TODO - # num_prompt_tokens: int, - # num_generation_tokens: int, - ) -> None: + def __init__(self, is_prompt: bool, slot_mapping: torch.Tensor, + prompt_lens: Optional[torch.Tensor], + max_seq_len: Optional[int], start_loc: Optional[torch.Tensor], + max_context_len: Optional[int], + context_lens: Optional[torch.Tensor], + block_tables: Optional[torch.Tensor], use_cuda_graph: bool, + kv_cache_dtype: str, flash_style: bool) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens self.max_seq_len = max_seq_len @@ -45,6 +37,7 @@ def __init__( self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph self.kv_cache_dtype = kv_cache_dtype + self.flash_style = flash_style # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. @@ -97,4 +90,5 @@ def __repr__(self) -> str: f"context_lens={self.context_lens}, " f"block_tables={self.block_tables}, " f"use_cuda_graph={self.use_cuda_graph}, " - f"kv_cache_dtype={self.kv_cache_dtype})") + f"kv_cache_dtype={self.kv_cache_dtype}), " + f"flash_style={self.flash_style}") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 3c0a7b54ab2b..51c47368e831 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -14,10 +14,7 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( context_attention_fwd) from vllm.utils import is_hip -from vllm.model_executor.layers.attention import ( - flash_attn_with_kvcache_paged, ) -# TODO(sang): Support varlen API. try: from flash_attn import flash_attn_with_kvcache except ImportError: @@ -49,7 +46,6 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, - flash_style: bool = False, ) -> None: super().__init__() self.num_heads = num_heads @@ -60,7 +56,6 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) - self.flash_style = flash_style assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -137,14 +132,10 @@ def forward( # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: - if self.flash_style: + if input_metadata.flash_style: cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - input_metadata.slot_mapping.flatten() - ) + key, value, key_cache, value_cache, + input_metadata.slot_mapping.flatten()) else: cache_ops.reshape_and_cache( key, @@ -226,20 +217,20 @@ def forward( ) output = out.view_as(query) else: - # prefix-enabled attention - output = torch.empty_like(query) - if self.flash_attention: - query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + if input_metadata.flash_style: output = flash_attn_with_kvcache_paged( - query - key_cache: torch.Tensor, - value_cache: torch.Tensor, - self.scale + query.view(batch_size, seq_len, self.num_heads, + self.head_size), + key_cache, + value_cache, + self.scale, input_metadata.block_tables, input_metadata.context_lens, self.alibi_slopes, ) else: + # prefix-enabled attention + output = torch.empty_like(query) context_attention_fwd( query, key, @@ -247,7 +238,8 @@ def forward( output, key_cache, value_cache, - input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata. + block_tables, # [BS, max_block_per_request] input_metadata.start_loc, input_metadata.prompt_lens, input_metadata.context_lens, @@ -257,17 +249,12 @@ def forward( else: # Decoding run. - if self.flash_style: - query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + if input_metadata.flash_style: output = flash_attn_with_kvcache_paged( - query, - key_cache, - value_cache, - self.scale, - input_metadata.block_tables, - input_metadata.context_lens, - self.alibi_slopes - ) + query.view(batch_size, seq_len, self.num_heads, + self.head_size), key_cache, value_cache, + self.scale, input_metadata.block_tables, + input_metadata.context_lens, self.alibi_slopes) else: output = _paged_attention( query, @@ -407,19 +394,17 @@ def flash_attn_with_kvcache_paged( style key-value caches. Arguments: - output: [num_padded_tokens, num_heads, head_size], output tensor - to write. if None an new output tensor will be created. See https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py for other arguments. Returns: - output: [num_padded_tokens, num_heads, head_size] + output: [num_tokens, num_heads, head_size] """ block_size = value_cache.shape[1] assert block_size % 256 == 0, "only support block_size divisible by 256." - num_tokens, num_heads, head_size = query.shape + _, _, num_heads, head_size = query.shape out = flash_attn_with_kvcache( - query.view(num_tokens, 1, num_heads, head_size), + query, key_cache, value_cache, # Inplace update is slow. We don't use it. @@ -439,4 +424,6 @@ def flash_attn_with_kvcache_paged( alibi_slopes=alibi_slopes, num_splits=0, ) - return out.view(num_tokens, num_heads, head_size) + + # num_tokens == batch_size * seqlen + return out.view(-1, num_heads, head_size) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index ebe092b5d62b..40a0a9cf394b 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -29,7 +29,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: architectures = ["QuantMixtralForCausalLM"] for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) + model_cls = ModelRegistry.load_model_cls(arch, model_config) if model_cls is not None: return model_cls raise ValueError( diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 66d28207d664..5ed1cb5dbc5c 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -5,6 +5,7 @@ from vllm.logger import init_logger from vllm.utils import is_hip +from vllm.config import ModelConfig logger = init_logger(__name__) @@ -61,11 +62,18 @@ "Sliding window attention is not yet supported in ROCm's flash attention", } +_MODEL_CLASSES_SUPPORT_FLASH_ATTN = { + arch_and_class[1] + for _, arch_and_class in _MODELS.items() + if arch_and_class[1] == "LlamaForCausalLM" +} + class ModelRegistry: @staticmethod - def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + def load_model_cls(model_arch: str, + model_config: ModelConfig) -> Optional[Type[nn.Module]]: if model_arch not in _MODELS: return None if is_hip(): @@ -81,7 +89,17 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: module_name, model_cls_name = _MODELS[model_arch] module = importlib.import_module( f"vllm.model_executor.models.{module_name}") - return getattr(module, model_cls_name, None) + model_cls = getattr(module, model_cls_name, None) + + if (model_config.flash_style + and model_cls_name not in _MODEL_CLASSES_SUPPORT_FLASH_ATTN): + raise ValueError( + f"{model_config.model} doesn't support " + "flash attention in vLLM, but " + "flash_style=True is given. Choose one of models, " + f"{_MODEL_CLASSES_SUPPORT_FLASH_ATTN} to use " + "flash_style=True.") + return model_cls @staticmethod def get_supported_archs() -> List[str]: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 9fd4425bca02..b7f6b8f3ec37 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -93,7 +93,6 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, bias: bool = False, sliding_window: Optional[int] = None, - flash_style: bool = False, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -144,8 +143,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window, - flash_style=flash_style) + sliding_window=sliding_window) def forward( self, @@ -153,7 +151,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - flash_style: bool = False, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -170,7 +167,6 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, - flash_style: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -190,7 +186,6 @@ def __init__( linear_method=linear_method, bias=getattr(config, "bias", False), sliding_window=sliding_window, - flash_style=flash_style, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -239,7 +234,6 @@ def __init__( config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, lora_config: Optional[LoRAConfig] = None, - flash_style: bool = False, ) -> None: super().__init__() self.config = config @@ -254,7 +248,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method, flash_style) + LlamaDecoderLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -314,12 +308,11 @@ def __init__( config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, lora_config: Optional[LoRAConfig] = None, - flash_style: bool = False, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = LlamaModel(config, linear_method, lora_config=lora_config, flash_style=flash_style) + self.model = LlamaModel(config, linear_method, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 467ad28226b4..84470d0fa114 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -249,7 +249,7 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, - ) + flash_style=self.model_config.flash_style) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -377,7 +377,7 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, - ) + flash_style=self.model_config.flash_style) return (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -539,7 +539,7 @@ def prepare_input_tensors( block_tables=metadata_dict["block_tables"], use_cuda_graph=metadata_dict["use_cuda_graph"], kv_cache_dtype=metadata_dict["kv_cache_dtype"], - ) + flash_style=self.model_config.flash_style) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -723,7 +723,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: block_tables=block_tables[:batch_size], use_cuda_graph=True, kv_cache_dtype=self.kv_cache_dtype, - ) + flash_style=self.model_config.flash_style) if self.lora_config: lora_mapping = LoRAMapping( From 2c1bb6c9e6887c0e8efc1cada38c638bc4004529 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 00:51:14 -0800 Subject: [PATCH 06/28] support every model --- tests/chunked_prefill/test_correctness.py | 2 +- vllm/model_executor/models/__init__.py | 15 --------------- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index b79156d2d48a..d63ac9f5429e 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -8,7 +8,7 @@ from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel MODELS = [ - "JackFram/llama-68m", + # "JackFram/llama-68m", "facebook/opt-125m", ] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index b624395c27b5..ea150ef9c6f5 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -62,13 +62,6 @@ "Sliding window attention is not yet supported in ROCm's flash attention", } -# TODO(sang): We may not need this. I think it should work with everything. -_MODEL_CLASSES_SUPPORT_FLASH_ATTN = { - arch_and_class[1] - for _, arch_and_class in _MODELS.items() - if arch_and_class[1] == "LlamaForCausalLM" -} - # Models not supported by Neuron. _NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"} @@ -102,14 +95,6 @@ def load_model_cls(model_arch: str, f"vllm.model_executor.models.{module_name}") model_cls = getattr(module, model_cls_name, None) - if (model_config.flash_style - and model_cls_name not in _MODEL_CLASSES_SUPPORT_FLASH_ATTN): - raise ValueError( - f"{model_config.model} doesn't support " - "flash attention in vLLM, but " - "flash_style=True is given. Choose one of models, " - f"{_MODEL_CLASSES_SUPPORT_FLASH_ATTN} to use " - "flash_style=True.") return model_cls @staticmethod From 2bb5e624a4ed843613b7439ef418eb21dac2fa50 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 01:28:13 -0800 Subject: [PATCH 07/28] Fixed broken tests. --- tests/chunked_prefill/test_correctness.py | 2 +- tests/worker/test_model_runner.py | 2 + vllm/worker/model_runner.py | 53 ++++++++++++----------- 3 files changed, 30 insertions(+), 27 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index d63ac9f5429e..6659b83bfcb6 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -77,5 +77,5 @@ def test_models( i % len(expected_outputs)] assert fa_output_ids == vllm_output_ids, ( f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" - f"Test{i}:\nflash ouput: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" + f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" ) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index f44895a728c7..94c9b45157ec 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -3,6 +3,8 @@ from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + DeviceConfig) def test_prepare_prompt(): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fa337a68f87d..e95b147cb25d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -55,6 +55,9 @@ def __init__( # FIXME(woosuk): This is a hack to make the tests work. Refactor this. self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None) + self.flash_style = (self.model_config.flash_style + if model_config is not None else False) + self.device_config = (device_config if device_config is not None else DeviceConfig()) self.device = self.device_config.device @@ -245,18 +248,17 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - input_metadata = InputMetadata( - is_prompt=True, - slot_mapping=slot_mapping, - prompt_lens=prompt_lens_tensor, - max_seq_len=max_prompt_len, - start_loc=start_loc_tensor, - max_context_len=None, - context_lens=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.model_config.flash_style) + input_metadata = InputMetadata(is_prompt=True, + slot_mapping=slot_mapping, + prompt_lens=prompt_lens_tensor, + max_seq_len=max_prompt_len, + start_loc=start_loc_tensor, + max_context_len=None, + context_lens=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + kv_cache_dtype=self.kv_cache_dtype, + flash_style=self.flash_style) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -373,18 +375,17 @@ def _prepare_decode( _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping ] - input_metadata = InputMetadata( - is_prompt=False, - slot_mapping=slot_mapping, - prompt_lens=None, - max_seq_len=None, - start_loc=None, - max_context_len=max_context_len, - context_lens=context_lens, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.model_config.flash_style) + input_metadata = InputMetadata(is_prompt=False, + slot_mapping=slot_mapping, + prompt_lens=None, + max_seq_len=None, + start_loc=None, + max_context_len=max_context_len, + context_lens=context_lens, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + kv_cache_dtype=self.kv_cache_dtype, + flash_style=self.flash_style) return (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -547,7 +548,7 @@ def prepare_input_tensors( block_tables=metadata_dict["block_tables"], use_cuda_graph=metadata_dict["use_cuda_graph"], kv_cache_dtype=metadata_dict["kv_cache_dtype"], - flash_style=self.model_config.flash_style) + flash_style=self.flash_style) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -731,7 +732,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: block_tables=block_tables[:batch_size], use_cuda_graph=True, kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.model_config.flash_style) + flash_style=self.flash_style) if self.lora_config: lora_mapping = LoRAMapping( From 78bb887bc625c473dc3a5ce5369617a743048c3e Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 08:49:47 -0800 Subject: [PATCH 08/28] ip --- benchmarks/benchmark_latency.py | 2 + .../kernels/benchmark_paged_attention.py | 5 +- tests/chunked_prefill/test_correctness.py | 8 +-- tests/kernels/test_attention.py | 2 + tests/kernels/test_flash_attention.py | 52 +++++++++---------- vllm/model_executor/layers/attention.py | 2 + 6 files changed, 38 insertions(+), 33 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index b915f913e7d1..7edad46a6172 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -33,6 +33,8 @@ def main(args: argparse.Namespace): enforce_eager=args.enforce_eager, kv_cache_dtype=args.kv_cache_dtype, device=args.device, + block_size=args.block_size, + flash_style=args.flash_style, ) sampling_params = SamplingParams( diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index d921dea1220e..14ca569a3d66 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -7,6 +7,7 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random from vllm._C import ops +from vllm.model_executor.layers.attention import flash_attn_with_kvcache_paged NUM_BLOCKS = 1024 PARTITION_SIZE = 512 @@ -131,6 +132,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: alibi_slopes, kv_cache_dtype, ) + elif version == "flash": + else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() @@ -158,7 +161,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: description="Benchmark the paged attention kernel.") parser.add_argument("--version", type=str, - choices=["v1", "v2"], + choices=["v1", "v2", "flash"], default="v2") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--context-len", type=int, default=4096) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index 6659b83bfcb6..e7ed28b86e04 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -8,7 +8,7 @@ from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel MODELS = [ - # "JackFram/llama-68m", + "JackFram/llama-68m", "facebook/opt-125m", ] @@ -58,10 +58,10 @@ def test_models( dtype=dtype, flash_style=True, block_size=block_size) - flash_attn_output_by_batchs = [] + flash_attn_output_by_batches = [] for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] - flash_attn_output_by_batchs.append( + flash_attn_output_by_batches.append( flash_attn_model.generate_greedy(prompts, max_tokens)) del flash_attn_model @@ -70,7 +70,7 @@ def test_models( gc.collect() torch.cuda.empty_cache() - for flash_attn_outputs in flash_attn_output_by_batchs: + for flash_attn_outputs in flash_attn_output_by_batches: for i in range(len(flash_attn_outputs)): fa_output_ids, fa_output_str = flash_attn_outputs[i] vllm_output_ids, vllm_output_str = expected_outputs[ diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index fb571de63d4e..e2f59dff1ca2 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -280,6 +280,7 @@ def ref_multi_query_kv_attention( num_seqs = len(cu_seq_lens) - 1 ref_outputs = [] for i in range(num_seqs): + breakpoint() start_idx = cu_seq_lens[i] end_idx = cu_seq_lens[i + 1] seq_len = end_idx - start_idx @@ -299,6 +300,7 @@ def ref_multi_query_kv_attention( ) ref_outputs.append(ref_output) ref_output = torch.cat(ref_outputs, dim=0) + breakpoint() return ref_output diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index a9c43e31edc0..7a98a0db9ca5 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -1,5 +1,5 @@ import random -from typing import Optional, Tuple +from typing import Optional, Tuple, List import pytest import torch @@ -19,7 +19,7 @@ DTYPES = [torch.half, torch.bfloat16] NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing -NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing +NUM_HEADS = [(1, 40), (40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS_SMALL = NUM_HEADS # head size should be bigger than or equal to block size. HEAD_SIZES = [256] @@ -29,8 +29,7 @@ BLOCK_SIZES = [256] USE_ALIBI = [False, True] SEEDS = [0] -# PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] -PAD_CONFIGS = [(0, 0)] +PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] def pad_attention_inputs( @@ -84,10 +83,8 @@ def ref_single_query_cached_kv_attention( context_lens: torch.Tensor, scale: float, alibi_slopes: Optional[torch.Tensor], - flash_style: bool = False, ) -> None: num_query_heads = query.shape[1] - num_kv_heads = value_cache.shape[-2] head_size = value_cache.shape[-1] block_size = value_cache.shape[-3] num_seqs = query.shape[0] @@ -105,18 +102,11 @@ def ref_single_query_cached_kv_attention( block_number = int(block_table[j // block_size]) block_offset = j % block_size - if flash_style: - k = key_cache[block_number, block_offset, :, :] - else: - k = key_cache[block_number, :, :, block_offset, :] - k = k.reshape(num_kv_heads, head_size) + k = key_cache[block_number, block_offset, :, :] keys.append(k) v = value_cache[block_number, :, :, block_offset] - if flash_style: - v = value_cache[block_number, block_offset, :, :] - else: - v = value_cache[block_number, :, :, block_offset] + v = value_cache[block_number, block_offset, :, :] values.append(v) keys = torch.stack(keys, dim=0) values = torch.stack(values, dim=0) @@ -138,12 +128,20 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) -@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("use_alibi", [False, True]) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +# @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +# @pytest.mark.parametrize("num_heads", NUM_HEADS) +# @pytest.mark.parametrize("head_size", HEAD_SIZES) +# @pytest.mark.parametrize("use_alibi", [False, True]) +# @pytest.mark.parametrize("block_size", BLOCK_SIZES) +# @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +# @pytest.mark.parametrize("seed", SEEDS) +# @pytest.mark.parametrize("pad_config", PAD_CONFIGS) +@pytest.mark.parametrize("num_seqs", [3]) +@pytest.mark.parametrize("num_heads", [(40,40)]) +@pytest.mark.parametrize("head_size", [256]) +@pytest.mark.parametrize("use_alibi", [True]) +@pytest.mark.parametrize("block_size", [256]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("pad_config", PAD_CONFIGS) @torch.inference_mode() @@ -164,14 +162,15 @@ def test_flash_paged_attention( scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, + query = torch.empty(num_seqs, # batch_size + 1, # seqlen num_query_heads, head_size, dtype=dtype, device="cuda") query.uniform_(-scale, scale) - assert num_query_heads % num_kv_heads == 0 + # assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads alibi_slopes = None if use_alibi: @@ -195,6 +194,7 @@ def test_flash_paged_attention( ] block_tables.append(block_table) block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + breakpoint() # Create the KV caches. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, @@ -210,7 +210,7 @@ def test_flash_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) - padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ + padded_query, padded_block_table, padded_context_lens, _ = \ pad_attention_inputs(pad_config, block_size, query, block_tables, context_lens, max_context_len) @@ -236,10 +236,6 @@ def test_flash_paged_attention( context_lens, scale, alibi_slopes, - flash_style=True, ) - # NOTE(woosuk): Due to the kernel-level differences in the two - # implementations, there is a small numerical difference in the two - # outputs. Thus, we use a relaxed tolerance for the test. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 51c47368e831..64b1beaa6205 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -229,6 +229,7 @@ def forward( self.alibi_slopes, ) else: + breakpoint() # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( @@ -360,6 +361,7 @@ def _paged_attention( device=output.device, ) max_logits = torch.empty_like(exp_sums) + breakpoint() ops.paged_attention_v2( output, exp_sums, From 74ac900cbf56dedea8d93d2b948b81e1aeaa5085 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 21:32:59 -0800 Subject: [PATCH 09/28] seems to work. --- .../kernels/benchmark_paged_attention.py | 34 +++++++++++++------ tests/chunked_prefill/test_correctness.py | 2 -- tests/kernels/test_attention.py | 2 -- tests/kernels/test_cache.py | 1 - tests/kernels/test_flash_attention.py | 15 ++++---- tests/worker/test_model_runner.py | 2 -- vllm/model_executor/layers/attention.py | 5 ++- 7 files changed, 33 insertions(+), 28 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 14ca569a3d66..7b3c6405e259 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -65,14 +65,17 @@ def main( block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) # Create the KV cache. - key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, - block_size, - 1, - num_kv_heads, - head_size, - kv_cache_dtype, - dtype, - device=device) + flash_style = version == "flash" + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + flash_style=flash_style) key_cache, value_cache = key_caches[0], value_caches[0] # Prepare for the paged attention kernel. @@ -133,7 +136,15 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: kv_cache_dtype, ) elif version == "flash": - + flash_attn_with_kvcache_paged( + query.view(num_seqs, 1, num_query_heads, head_size), + key_cache, + value_cache, + scale, + block_tables, + context_lens, + alibi_slopes=alibi_slopes, + ) else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() @@ -171,7 +182,10 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: type=int, choices=[64, 80, 96, 112, 128, 256], default=128) - parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--block-size", + type=int, + choices=[16, 32, 256], + default=16) parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--dtype", type=str, diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index e7ed28b86e04..72fb9f73766c 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -1,7 +1,5 @@ import gc -from typing import List - import pytest import torch diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index e2f59dff1ca2..fb571de63d4e 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -280,7 +280,6 @@ def ref_multi_query_kv_attention( num_seqs = len(cu_seq_lens) - 1 ref_outputs = [] for i in range(num_seqs): - breakpoint() start_idx = cu_seq_lens[i] end_idx = cu_seq_lens[i + 1] seq_len = end_idx - start_idx @@ -300,7 +299,6 @@ def ref_multi_query_kv_attention( ) ref_outputs.append(ref_output) ref_output = torch.cat(ref_outputs, dim=0) - breakpoint() return ref_output diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 23adc10df7cf..d2de4105b7f3 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -7,7 +7,6 @@ from vllm._C import cache_ops from vllm.utils import is_hip -import torch.nn.functional as F COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index 7a98a0db9ca5..74c9d3008f17 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -1,5 +1,5 @@ import random -from typing import Optional, Tuple, List +from typing import Optional, Tuple import pytest import torch @@ -125,7 +125,8 @@ def ref_single_query_cached_kv_attention( out = ref_masked_attention(q, keys, values, scale, alibi_bias) out = out.view(num_query_heads, head_size) - output[i].copy_(out, non_blocking=True) + # output[i].copy_(out, non_blocking=True) + output[i].copy_(out) # @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @@ -137,13 +138,13 @@ def ref_single_query_cached_kv_attention( # @pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("pad_config", PAD_CONFIGS) @pytest.mark.parametrize("num_seqs", [3]) -@pytest.mark.parametrize("num_heads", [(40,40)]) +@pytest.mark.parametrize("num_heads", [(40, 40)]) @pytest.mark.parametrize("head_size", [256]) @pytest.mark.parametrize("use_alibi", [True]) @pytest.mark.parametrize("block_size", [256]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("pad_config", PAD_CONFIGS) +@pytest.mark.parametrize("pad_config", [(0, 0)]) @torch.inference_mode() def test_flash_paged_attention( kv_cache_factory, @@ -162,8 +163,7 @@ def test_flash_paged_attention( scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, # batch_size - 1, # seqlen + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype, @@ -194,7 +194,6 @@ def test_flash_paged_attention( ] block_tables.append(block_table) block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - breakpoint() # Create the KV caches. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, @@ -215,7 +214,7 @@ def test_flash_paged_attention( block_tables, context_lens, max_context_len) output = flash_attn_with_kvcache_paged( - padded_query, + padded_query.view(num_seqs, 1, num_query_heads, head_size), key_cache, value_cache, scale, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 94c9b45157ec..f44895a728c7 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -3,8 +3,6 @@ from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner -from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, - DeviceConfig) def test_prepare_prompt(): diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 64b1beaa6205..e04732d67d13 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -229,7 +229,6 @@ def forward( self.alibi_slopes, ) else: - breakpoint() # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( @@ -361,7 +360,6 @@ def _paged_attention( device=output.device, ) max_logits = torch.empty_like(exp_sums) - breakpoint() ops.paged_attention_v2( output, exp_sums, @@ -403,7 +401,8 @@ def flash_attn_with_kvcache_paged( output: [num_tokens, num_heads, head_size] """ block_size = value_cache.shape[1] - assert block_size % 256 == 0, "only support block_size divisible by 256." + assert block_size % 256 == 0, ("only support block_size divisible by 256. " + f"Current block size: {block_size}") _, _, num_heads, head_size = query.shape out = flash_attn_with_kvcache( query, From 71bdada54fa2294e1d6483d88e4990a943c6817d Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 1 Mar 2024 06:38:28 -0800 Subject: [PATCH 10/28] . --- tests/chunked_prefill/test_correctness.py | 4 +-- tests/models/test_models.py | 26 +++++++++--------- vllm/engine/llm_engine.py | 2 ++ vllm/entrypoints/llm.py | 1 + vllm/model_executor/layers/attention.py | 32 ++++++++++++++++++++++- vllm/worker/model_runner.py | 14 +++++++++- vllm/worker/worker.py | 3 +++ 7 files changed, 65 insertions(+), 17 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index 72fb9f73766c..b9b6353e04e6 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -7,7 +7,7 @@ MODELS = [ "JackFram/llama-68m", - "facebook/opt-125m", + # "facebook/opt-125m", ] TEST_PROMPTS = [ @@ -57,7 +57,7 @@ def test_models( flash_style=True, block_size=block_size) flash_attn_output_by_batches = [] - for i in range(10): + for i in range(2): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] flash_attn_output_by_batches.append( flash_attn_model.generate_greedy(prompts, max_tokens)) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index e44452e9893c..02afd3ac9868 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -6,19 +6,19 @@ MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", - "mistralai/Mistral-7B-v0.1", - "Deci/DeciLM-7b", - "tiiuae/falcon-7b", - "gpt2", - "bigcode/tiny_starcoder_py", - "EleutherAI/gpt-j-6b", - "EleutherAI/pythia-70m", - "bigscience/bloom-560m", - "mosaicml/mpt-7b", - "microsoft/phi-2", - "stabilityai/stablelm-3b-4e1t", - "allenai/OLMo-1B", + # "meta-llama/Llama-2-7b-hf", + # "mistralai/Mistral-7B-v0.1", + # "Deci/DeciLM-7b", + # "tiiuae/falcon-7b", + # "gpt2", + # "bigcode/tiny_starcoder_py", + # "EleutherAI/gpt-j-6b", + # "EleutherAI/pythia-70m", + # "bigscience/bloom-560m", + # "mosaicml/mpt-7b", + # "microsoft/phi-2", + # "stabilityai/stablelm-3b-4e1t", + # "allenai/OLMo-1B", ] diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6f5af71426d7..9f303e4b8209 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -831,6 +831,8 @@ def step(self) -> List[RequestOutput]: >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + print("SANG-TODO step seq_group_metadata_list length: ", + len(seq_group_metadata_list)) if not scheduler_outputs.is_empty(): # Execute the model. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fc82018d18eb..d29fabb3f47a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -152,6 +152,7 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ + print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index e04732d67d13..b9b60cd8c9ad 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -122,6 +122,7 @@ def forward( shape = [batch_size, seq_len, num_heads * head_size] """ batch_size, seq_len, hidden_size = query.shape + print("SANG-TODO query size: ", query.size()) # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -133,10 +134,12 @@ def forward( # profiling run. if key_cache is not None and value_cache is not None: if input_metadata.flash_style: + print("SANG-TODO reshape cache flash.") cache_ops.reshape_and_cache_flash( key, value, key_cache, value_cache, input_metadata.slot_mapping.flatten()) else: + print("SANG-TODO reshape cache.") cache_ops.reshape_and_cache( key, value, @@ -150,6 +153,33 @@ def forward( # normal attention if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): + print("SANG-TODO flash attn is used.") + print( + "SANG-TODO query size: ", + query.view(batch_size, seq_len, self.num_heads, + self.head_size).size()) + # if key_cache is not None and value_cache is not None: + # output2 = flash_attn_with_kvcache_paged( + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size), + # key_cache, + # value_cache, + # self.scale, + # input_metadata.block_tables, + # input_metadata.context_lens + seq_len, + # self.alibi_slopes, + # ) + # from flash_attn import flash_attn_func + # breakpoint() + # output3 = flash_attn_func( + # q=query.view(batch_size, seq_len, self.num_heads, + # self.head_size), + # k=key.view(batch_size, seq_len, self.num_kv_heads, self.head_size), + # v=value.view(batch_size, seq_len, self.num_kv_heads, self.head_size), + # softmax_scale=self.scale, + # causal=True, + # alibi_slopes=self.alibi_slopes, + # ) if self.num_kv_heads != self.num_heads: # As of Nov 2023, xformers only supports MHA. For MQA/GQA, # project the key and value tensors to the desired number of @@ -217,6 +247,7 @@ def forward( ) output = out.view_as(query) else: + # prefix-enabled attention if input_metadata.flash_style: output = flash_attn_with_kvcache_paged( query.view(batch_size, seq_len, self.num_heads, @@ -229,7 +260,6 @@ def forward( self.alibi_slopes, ) else: - # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( query, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e95b147cb25d..855daeddd1e2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -138,6 +138,8 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] + print("SANG-TODO # of requests (seq_group_metadata_list): ", + len(seq_group_metadata_list)) for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -150,14 +152,22 @@ def _prepare_prompt( prompt_lens.append(prompt_len) prefix_len = 0 prefix = seq_group_metadata.prefix + print("SANG-TODO prefix, ", prefix) if prefix is not None and prefix.computed: prefix_len = prefix.get_length() prompt_tokens = prompt_tokens[prefix_len:] prefix_block_tables.append(prefix.get_block_numbers()) + context_len = prefix_len else: prefix_block_tables.append([]) + # if seq_group_metadata.block_tables is None: + # prefix_block_tables.append([]) + # else: + # prefix_block_tables.append( + # seq_group_metadata.block_tables[seq_id]) + context_len = prefix_len # actual prompt lens - context_lens.append(prefix_len) + context_lens.append(context_len) subquery_lens.append(prompt_len - prefix_len) input_tokens.append(prompt_tokens) @@ -487,10 +497,12 @@ def prepare_input_tensors( # SANG-TODO set num prompt tokens and generations? # Prepare input tensors. if is_prompt: + print("SANG-TODO execute model prompt.") (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_prompt(seq_group_metadata_list) else: + print("SANG-TODO execute model decode.") (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_decode(seq_group_metadata_list) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 870c7b7edd40..a2c5f8e0b1e0 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -117,6 +117,7 @@ def profile_num_available_blocks( gpu_memory_utilization: The fraction of the total GPU memory to use. cpu_swap_space: The size of the CPU swap space in bytes. """ + print("SANG-TODO profile_num_available_blocks") # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -153,6 +154,7 @@ def profile_num_available_blocks( MAX_INT_32 // cache_block_size) num_cpu_blocks = min(num_cpu_blocks, MAX_INT_32 // cache_block_size) + print("SANG-TODO profile_num_available_blocks done") return num_gpu_blocks, num_cpu_blocks @@ -205,6 +207,7 @@ def execute_model( blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, ) -> Optional[SamplerOutput]: + print("SANG-TODO execute model.") if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups = len(seq_group_metadata_list) From d4c3b5ddb1a66eae256aee8257b4a4bdeb1dad54 Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 1 Mar 2024 06:46:29 -0800 Subject: [PATCH 11/28] ip? --- vllm/worker/model_runner.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 855daeddd1e2..406e110550fa 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -159,13 +159,12 @@ def _prepare_prompt( prefix_block_tables.append(prefix.get_block_numbers()) context_len = prefix_len else: - prefix_block_tables.append([]) - # if seq_group_metadata.block_tables is None: - # prefix_block_tables.append([]) - # else: - # prefix_block_tables.append( - # seq_group_metadata.block_tables[seq_id]) - context_len = prefix_len + if seq_group_metadata.block_tables is None: + prefix_block_tables.append([]) + else: + prefix_block_tables.append( + seq_group_metadata.block_tables[seq_id]) + context_len = prompt_len # actual prompt lens context_lens.append(context_len) subquery_lens.append(prompt_len - prefix_len) From baef7c66e7c22698d3b86f0a8029cac073eb1bee Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 1 Mar 2024 07:10:20 -0800 Subject: [PATCH 12/28] block tables updated correctly --- tests/chunked_prefill/test_correctness.py | 2 +- vllm/model_executor/input_metadata.py | 4 +++- vllm/model_executor/layers/attention.py | 6 ++++-- vllm/worker/model_runner.py | 15 +++++++++++---- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index b9b6353e04e6..561ef24713dd 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -57,7 +57,7 @@ def test_models( flash_style=True, block_size=block_size) flash_attn_output_by_batches = [] - for i in range(2): + for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] flash_attn_output_by_batches.append( flash_attn_model.generate_greedy(prompts, max_tokens)) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 97c23cdda069..00acf71a10b5 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -26,7 +26,8 @@ def __init__(self, is_prompt: bool, slot_mapping: torch.Tensor, max_context_len: Optional[int], context_lens: Optional[torch.Tensor], block_tables: Optional[torch.Tensor], use_cuda_graph: bool, - kv_cache_dtype: str, flash_style: bool) -> None: + kv_cache_dtype: str, flash_style: bool, + prefix_enabled: bool) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens self.max_seq_len = max_seq_len @@ -38,6 +39,7 @@ def __init__(self, is_prompt: bool, slot_mapping: torch.Tensor, self.use_cuda_graph = use_cuda_graph self.kv_cache_dtype = kv_cache_dtype self.flash_style = flash_style + self.prefix_enabled = prefix_enabled # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index b9b60cd8c9ad..051472a9d09b 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -152,7 +152,8 @@ def forward( if input_metadata.is_prompt: # normal attention if (key_cache is None or value_cache is None - or input_metadata.block_tables.numel() == 0): + # or input_metadata.block_tables.numel() == 0): + or not input_metadata.prefix_enabled): print("SANG-TODO flash attn is used.") print( "SANG-TODO query size: ", @@ -256,7 +257,8 @@ def forward( value_cache, self.scale, input_metadata.block_tables, - input_metadata.context_lens, + # input_metadata.context_lens, + # seq_len, self.alibi_slopes, ) else: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 406e110550fa..db4165141b2a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -159,6 +159,7 @@ def _prepare_prompt( prefix_block_tables.append(prefix.get_block_numbers()) context_len = prefix_len else: + prefix_block_tables.append([]) if seq_group_metadata.block_tables is None: prefix_block_tables.append([]) else: @@ -267,7 +268,9 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.flash_style) + flash_style=self.flash_style, + prefix_enabled=prefix is not None + and prefix.computed) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -394,7 +397,8 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.flash_style) + flash_style=self.flash_style, + prefix_enabled=False) return (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -540,6 +544,7 @@ def prepare_input_tensors( sampling_metadata.selected_token_indices, "lora_requests": lora_requests, "lora_mapping": lora_mapping, + "prefix_enabled": input_metadata.prefix_enabled } broadcast_tensor_dict(metadata_dict, src=0) else: @@ -559,7 +564,8 @@ def prepare_input_tensors( block_tables=metadata_dict["block_tables"], use_cuda_graph=metadata_dict["use_cuda_graph"], kv_cache_dtype=metadata_dict["kv_cache_dtype"], - flash_style=self.flash_style) + flash_style=self.flash_style, + prefix_enabled=metadata_dict["prefix_enabled"]) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -743,7 +749,8 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: block_tables=block_tables[:batch_size], use_cuda_graph=True, kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.flash_style) + flash_style=self.flash_style, + prefix_enabled=False) if self.lora_config: lora_mapping = LoRAMapping( From a12ec689d410dc15970cda7ad922dfeb528b1957 Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 1 Mar 2024 07:12:15 -0800 Subject: [PATCH 13/28] hopefully tests pass --- vllm/engine/llm_engine.py | 4 +-- vllm/entrypoints/llm.py | 2 +- vllm/model_executor/layers/attention.py | 38 ++++++------------------- vllm/worker/model_runner.py | 10 +++---- vllm/worker/worker.py | 6 ++-- 5 files changed, 19 insertions(+), 41 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9f303e4b8209..a7784a2ca443 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -831,8 +831,8 @@ def step(self) -> List[RequestOutput]: >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - print("SANG-TODO step seq_group_metadata_list length: ", - len(seq_group_metadata_list)) + # print("SANG-TODO step seq_group_metadata_list length: ", + # len(seq_group_metadata_list)) if not scheduler_outputs.is_empty(): # Execute the model. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d29fabb3f47a..2af5a47f66e8 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -152,7 +152,7 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - print("SANG-TODO generate: ", prompts, prompt_token_ids) + # print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 051472a9d09b..b99efdcc160a 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -122,7 +122,7 @@ def forward( shape = [batch_size, seq_len, num_heads * head_size] """ batch_size, seq_len, hidden_size = query.shape - print("SANG-TODO query size: ", query.size()) + # print("SANG-TODO query size: ", query.size()) # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -134,12 +134,12 @@ def forward( # profiling run. if key_cache is not None and value_cache is not None: if input_metadata.flash_style: - print("SANG-TODO reshape cache flash.") + # print("SANG-TODO reshape cache flash.") cache_ops.reshape_and_cache_flash( key, value, key_cache, value_cache, input_metadata.slot_mapping.flatten()) else: - print("SANG-TODO reshape cache.") + # print("SANG-TODO reshape cache.") cache_ops.reshape_and_cache( key, value, @@ -154,33 +154,11 @@ def forward( if (key_cache is None or value_cache is None # or input_metadata.block_tables.numel() == 0): or not input_metadata.prefix_enabled): - print("SANG-TODO flash attn is used.") - print( - "SANG-TODO query size: ", - query.view(batch_size, seq_len, self.num_heads, - self.head_size).size()) - # if key_cache is not None and value_cache is not None: - # output2 = flash_attn_with_kvcache_paged( - # query.view(batch_size, seq_len, self.num_heads, - # self.head_size), - # key_cache, - # value_cache, - # self.scale, - # input_metadata.block_tables, - # input_metadata.context_lens + seq_len, - # self.alibi_slopes, - # ) - # from flash_attn import flash_attn_func - # breakpoint() - # output3 = flash_attn_func( - # q=query.view(batch_size, seq_len, self.num_heads, - # self.head_size), - # k=key.view(batch_size, seq_len, self.num_kv_heads, self.head_size), - # v=value.view(batch_size, seq_len, self.num_kv_heads, self.head_size), - # softmax_scale=self.scale, - # causal=True, - # alibi_slopes=self.alibi_slopes, - # ) + # print("SANG-TODO flash attn is used.") + # print( + # "SANG-TODO query size: ", + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size).size()) if self.num_kv_heads != self.num_heads: # As of Nov 2023, xformers only supports MHA. For MQA/GQA, # project the key and value tensors to the desired number of diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index db4165141b2a..4251a5a6b61f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -138,8 +138,8 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] - print("SANG-TODO # of requests (seq_group_metadata_list): ", - len(seq_group_metadata_list)) + # print("SANG-TODO # of requests (seq_group_metadata_list): ", + # len(seq_group_metadata_list)) for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -152,7 +152,7 @@ def _prepare_prompt( prompt_lens.append(prompt_len) prefix_len = 0 prefix = seq_group_metadata.prefix - print("SANG-TODO prefix, ", prefix) + # print("SANG-TODO prefix, ", prefix) if prefix is not None and prefix.computed: prefix_len = prefix.get_length() prompt_tokens = prompt_tokens[prefix_len:] @@ -500,12 +500,12 @@ def prepare_input_tensors( # SANG-TODO set num prompt tokens and generations? # Prepare input tensors. if is_prompt: - print("SANG-TODO execute model prompt.") + # print("SANG-TODO execute model prompt.") (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_prompt(seq_group_metadata_list) else: - print("SANG-TODO execute model decode.") + # print("SANG-TODO execute model decode.") (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_decode(seq_group_metadata_list) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index a2c5f8e0b1e0..d8a1c67bb123 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -117,7 +117,7 @@ def profile_num_available_blocks( gpu_memory_utilization: The fraction of the total GPU memory to use. cpu_swap_space: The size of the CPU swap space in bytes. """ - print("SANG-TODO profile_num_available_blocks") + # print("SANG-TODO profile_num_available_blocks") # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -154,7 +154,7 @@ def profile_num_available_blocks( MAX_INT_32 // cache_block_size) num_cpu_blocks = min(num_cpu_blocks, MAX_INT_32 // cache_block_size) - print("SANG-TODO profile_num_available_blocks done") + # print("SANG-TODO profile_num_available_blocks done") return num_gpu_blocks, num_cpu_blocks @@ -207,7 +207,7 @@ def execute_model( blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, ) -> Optional[SamplerOutput]: - print("SANG-TODO execute model.") + # print("SANG-TODO execute model.") if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups = len(seq_group_metadata_list) From 08c8541ca04622c2e00811ef4f86caa993342631 Mon Sep 17 00:00:00 2001 From: sang Date: Sun, 3 Mar 2024 08:37:25 -0800 Subject: [PATCH 14/28] . --- tests/chunked_prefill/test_correctness.py | 6 ++--- vllm/core/block_manager.py | 2 ++ vllm/entrypoints/llm.py | 2 +- vllm/model_executor/layers/attention.py | 28 ++++++++++++++++------- vllm/worker/worker.py | 4 ++-- 5 files changed, 28 insertions(+), 14 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index 561ef24713dd..17ce58149fab 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -53,9 +53,9 @@ def test_models( torch.cuda.empty_cache() flash_attn_model = vllm_runner(model, - dtype=dtype, - flash_style=True, - block_size=block_size) + dtype=dtype,) + # flash_style=True, + # block_size=block_size) flash_attn_output_by_batches = [] for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 08d519ab767a..b5ae90e217ba 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -200,6 +200,8 @@ def allocate(self, seq_group: SequenceGroup) -> None: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) + print("allcate a new block for seq_group", seq_group) + print("block", block) block_table.append(block) # Assign the block table for each sequence. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 93adc98adda5..e1932bb761ef 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -146,7 +146,7 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - # print("SANG-TODO generate: ", prompts, prompt_token_ids) + print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 17a0bedf9e0c..f80a47565e95 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -153,6 +153,7 @@ def forward( # normal attention if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): + # or not input_metadata.prefix_enabled): # print("SANG-TODO flash attn is used.") # print( # "SANG-TODO query size: ", @@ -212,7 +213,6 @@ def forward( query = query.unflatten(0, (batch_size, seq_len)) key = key.unflatten(0, (batch_size, seq_len)) value = value.unflatten(0, (batch_size, seq_len)) - out = xops.memory_efficient_attention_forward( query, key, @@ -224,6 +224,19 @@ def forward( (is_hip()) else None, ) output = out.view_as(query) + # if key_cache is not None and value_cache is not None: + # breakpoint() + # output2 = flash_attn_with_kvcache_paged( + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size), + # key_cache, + # value_cache, + # self.scale, + # input_metadata.block_tables, + # input_metadata.context_lens, + # self.alibi_slopes, + # ) + # breakpoint() else: if input_metadata.flash_style: output = flash_attn_with_kvcache_paged( @@ -237,6 +250,8 @@ def forward( self.alibi_slopes, ) else: + print("SANG-TODO context attention") + breakpoint() # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( @@ -257,6 +272,8 @@ def forward( else: # Decoding run. + # breakpoint() + print("SANG-TODO decoding") if input_metadata.flash_style: output = flash_attn_with_kvcache_paged( query.view(batch_size, seq_len, self.num_heads, @@ -419,17 +436,12 @@ def flash_attn_with_kvcache_paged( # Inplace update is slow. We don't use it. # We assume kvcache is already updated before # calling this API. - None, - None, - rotary_cos=None, - rotary_sin=None, + None, # key + None, # value cache_seqlens=context_lens, - cache_batch_idx=None, block_table=block_tables, softmax_scale=scale, causal=True, - window_size=(-1, -1), - rotary_interleaved=False, alibi_slopes=alibi_slopes, num_splits=0, ) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d8a1c67bb123..02975bcd5326 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -117,7 +117,7 @@ def profile_num_available_blocks( gpu_memory_utilization: The fraction of the total GPU memory to use. cpu_swap_space: The size of the CPU swap space in bytes. """ - # print("SANG-TODO profile_num_available_blocks") + print("SANG-TODO profile_num_available_blocks") # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -154,7 +154,7 @@ def profile_num_available_blocks( MAX_INT_32 // cache_block_size) num_cpu_blocks = min(num_cpu_blocks, MAX_INT_32 // cache_block_size) - # print("SANG-TODO profile_num_available_blocks done") + print("SANG-TODO profile_num_available_blocks done") return num_gpu_blocks, num_cpu_blocks From 3bac9af4f4166b0c059a355c9ece5bfe56e9613d Mon Sep 17 00:00:00 2001 From: sang Date: Sun, 3 Mar 2024 12:02:03 -0800 Subject: [PATCH 15/28] ip --- tests/chunked_prefill/test_correctness.py | 9 +- tests/kernels/test_flash_attention.py | 252 +++++++++++++++++- vllm/entrypoints/llm.py | 2 +- vllm/model_executor/layers/attention.py | 150 ++++++----- .../layers/triton_kernel/prefix_prefill.py | 2 +- vllm/worker/worker.py | 4 +- 6 files changed, 344 insertions(+), 75 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index 17ce58149fab..d6cb0a4dc0ac 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -53,10 +53,12 @@ def test_models( torch.cuda.empty_cache() flash_attn_model = vllm_runner(model, - dtype=dtype,) + dtype=dtype) # flash_style=True, # block_size=block_size) flash_attn_output_by_batches = [] + # flash_attn_output_by_batches.extend(flash_attn_model.generate_greedy(TEST_PROMPTS, max_tokens)) + flash_attn_output_by_batches = [] for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] flash_attn_output_by_batches.append( @@ -64,6 +66,11 @@ def test_models( del flash_attn_model + # for e, f in zip(expected_outputs, flash_attn_output_by_batches): + # # print("expected: ", e[1]) + # # print("flash: ", f[1]) + # assert e[1] == f[1] + destroy_model_parallel() gc.collect() torch.cuda.empty_cache() diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index 74c9d3008f17..2938dff7cac8 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -1,5 +1,5 @@ import random -from typing import Optional, Tuple +from typing import Optional, Tuple, List import pytest import torch @@ -238,3 +238,253 @@ def test_flash_paged_attention( ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +def ref_multi_query_kv_attention( + cu_seq_lens: List[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + ref_outputs = [] + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask. + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype, device="cuda") + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +def ref_multi_query_kv_attention_padded( + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + cu_seq_lens: List[int], + context_lens: List[int], + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + block_size = value_cache.shape[-3] + ref_outputs = [] + + for i in range(num_seqs): + q_start_idx = cu_seq_lens[i] + q_end_idx = cu_seq_lens[i + 1] + seq_len = q_end_idx - q_start_idx + + context_len = context_lens[i] + + block_table = block_tables[i] + keys = [] + values = [] + + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, block_offset, :, :] + keys.append(k) + + v = value_cache[block_number, block_offset, :, :] + values.append(v) + + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + q = query[q_start_idx:q_end_idx, :, :] + k = keys[:context_len, :, :] + v = values[:context_len, :, :] + + assert seq_len <= context_len + + # pad q if seq_len is less than context_len + # this is for correct calculation of attention. + if seq_len < context_len: + indices = [i % seq_len for i in range(context_len - seq_len)] + q_left_pad = q[indices, :, :] + q = torch.cat([q_left_pad, q], dim=0) + + # Create attention mask. + attn_mask = torch.triu(torch.ones(context_len, + context_len, + dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype, device="cuda") + ref_output = ref_masked_attention( + q, + k, + v, + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output[-seq_len:, :, :]) + breakpoint + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +def is_a100(): + return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 + + +if not is_a100(): + NUM_HEADS_SMALL = [(16, 16), (16, 8)] + MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) + + +@pytest.mark.parametrize("num_seqs", [17]) +@pytest.mark.parametrize("num_heads", [(40, 40)]) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("version", ["flash"]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("block_size", [256]) +@torch.inference_mode() +def test_multi_query_kv_attention( + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + version: str, + seed: int, + block_size: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. + # As the xformers library is already tested with its own tests, we can use + # a smaller MAX_SEQ_LEN here. + max_len = min(MAX_SEQ_LEN, 4096) + + seq_lens = [random.randint(1, max_len // 2) for i in range(num_seqs)] + max_seq_len = max(seq_lens) + context_lens = seq_lens + max_context_len = max(context_lens) + + num_tokens = sum(seq_lens) + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + + cu_context_lens = [0] + for context_len in context_lens: + cu_context_lens.append(cu_context_lens[-1] + context_len) + + print(f"cu_seq_lens={cu_seq_lens}, cu_context_lens={cu_context_lens}") + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + num_queries_per_kv = num_query_heads // num_kv_heads + + value_cache = torch.empty(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + key_cache = torch.empty(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + query = torch.empty(max_seq_len * num_seqs, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + value_cache.uniform_(-scale, scale) + key_cache.uniform_(-scale, scale) + query.uniform_(-scale, scale) + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + output = torch.empty_like(query) + + if version == "flash": + # flash_multi_query_cached_kv_attention_varlen( + # output, + # query, + # key_cache, + # value_cache, + # scale, + # block_tables, + # torch.cuda.IntTensor(cu_seq_lens), + # torch.cuda.IntTensor(cu_context_lens), + # block_size, + # max_seq_len, + # max_context_len, + # None, + # ) + from flash_attn import flash_attn_func + breakpoint() + # output = flash_attn_func( + # query.unsqueeze(0), + # k=key, + # v=value, + # softmax_scale=scale, + # causal=True, + # alibi_slopes=alibi_slopes, + # ) + output = flash_attn_with_kvcache_paged( + query.view(num_seqs, max_seq_len, num_query_heads, head_size), + key_cache, + value_cache, + scale, + block_tables, + torch.tensor(context_lens, dtype=torch.int, device="cuda"), + None, + ) + else: + assert False, f"{version=} is not supported" + + ref_output = ref_multi_query_kv_attention_padded( + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + cu_seq_lens, + context_lens, + scale, + dtype, + ) + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e1932bb761ef..93adc98adda5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -146,7 +146,7 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - print("SANG-TODO generate: ", prompts, prompt_token_ids) + # print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index f80a47565e95..db5e5f44ad42 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -154,76 +154,90 @@ def forward( if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # or not input_metadata.prefix_enabled): - # print("SANG-TODO flash attn is used.") - # print( - # "SANG-TODO query size: ", - # query.view(batch_size, seq_len, self.num_heads, - # self.head_size).size()) - if self.num_kv_heads != self.num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # TODO(woosuk): Use MQA/GQA kernels for higher performance. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - - # Set attention bias if not provided. This typically happens at - # the very attention layer of every iteration. - # FIXME(woosuk): This is a hack. - if input_metadata.attn_bias is None: + # if key_cache is not None and value_cache is not None and input_metadata.flash_style: + if False: + print("SANG-TODO flash attention.") + output = flash_attn_with_kvcache_paged( + query.view(batch_size, seq_len, self.num_heads, + self.head_size), + key_cache, + value_cache, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + self.alibi_slopes, + ) + else: + # print("SANG-TODO flash attn is used.") + # print( + # "SANG-TODO query size: ", + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size).size()) + if self.num_kv_heads != self.num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # TODO(woosuk): Use MQA/GQA kernels for higher performance. + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], + self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + if input_metadata.attn_bias is None: + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * batch_size) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = attn_bias + else: + input_metadata.attn_bias = _make_alibi_bias( + self.alibi_slopes, self.num_kv_heads, batch_size, + seq_len, query.dtype) + + if self.use_ref_attention: + output = self.ref_masked_attention( + query, + key, + value, + ) + # Using view got RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use reshape instead + return output.reshape(batch_size, seq_len, hidden_size) + + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. if self.alibi_slopes is None: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size) - if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention( - self.sliding_window) - input_metadata.attn_bias = attn_bias + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) else: - input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, batch_size, - seq_len, query.dtype) - - if self.use_ref_attention: - output = self.ref_masked_attention( + query = query.unflatten(0, (batch_size, seq_len)) + key = key.unflatten(0, (batch_size, seq_len)) + value = value.unflatten(0, (batch_size, seq_len)) + out = xops.memory_efficient_attention_forward( query, key, value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, ) - # Using view got RuntimeError: view size is not compatible with input tensor's size and stride - # (at least one dimension spans across two contiguous subspaces). Use reshape instead - return output.reshape(batch_size, seq_len, hidden_size) - - # TODO(woosuk): Too many view operations. Let's try to reduce - # them in the future for code readability. - if self.alibi_slopes is None: - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - else: - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) - out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ) - output = out.view_as(query) + output = out.view_as(query) # if key_cache is not None and value_cache is not None: # breakpoint() # output2 = flash_attn_with_kvcache_paged( @@ -239,6 +253,7 @@ def forward( # breakpoint() else: if input_metadata.flash_style: + assert False output = flash_attn_with_kvcache_paged( query.view(batch_size, seq_len, self.num_heads, self.head_size), @@ -250,8 +265,7 @@ def forward( self.alibi_slopes, ) else: - print("SANG-TODO context attention") - breakpoint() + # print("SANG-TODO context attention") # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( @@ -261,8 +275,7 @@ def forward( output, key_cache, value_cache, - input_metadata. - block_tables, # [BS, max_block_per_request] + input_metadata.block_tables, # [BS, max_block_per_request] input_metadata.start_loc, input_metadata.prompt_lens, input_metadata.context_lens, @@ -273,7 +286,6 @@ def forward( else: # Decoding run. # breakpoint() - print("SANG-TODO decoding") if input_metadata.flash_style: output = flash_attn_with_kvcache_paged( query.view(batch_size, seq_len, self.num_heads, diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py index 70f09224f1cf..c6054de2b718 100644 --- a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -696,7 +696,7 @@ def context_attention_fwd(q, ) return - _fwd_kernel[grid]( + _fwd_kernel_flash_attn_v2[grid]( q, k, v, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 02975bcd5326..d8a1c67bb123 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -117,7 +117,7 @@ def profile_num_available_blocks( gpu_memory_utilization: The fraction of the total GPU memory to use. cpu_swap_space: The size of the CPU swap space in bytes. """ - print("SANG-TODO profile_num_available_blocks") + # print("SANG-TODO profile_num_available_blocks") # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -154,7 +154,7 @@ def profile_num_available_blocks( MAX_INT_32 // cache_block_size) num_cpu_blocks = min(num_cpu_blocks, MAX_INT_32 // cache_block_size) - print("SANG-TODO profile_num_available_blocks done") + # print("SANG-TODO profile_num_available_blocks done") return num_gpu_blocks, num_cpu_blocks From 31aa92052d43ce31aa93faa545db01a639c206c4 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 4 Mar 2024 00:41:48 -0800 Subject: [PATCH 16/28] ip --- tests/chunked_prefill/test_correctness.py | 49 +- tests/kernels/test_flash_attention.py | 779 ++++++++++++++++++++-- tests/models/test_models.py | 28 +- vllm/config.py | 4 +- vllm/core/block_manager.py | 4 +- vllm/engine/llm_engine.py | 4 +- vllm/entrypoints/llm.py | 2 +- vllm/model_executor/layers/attention.py | 447 +++++++++---- vllm/worker/model_runner.py | 1 + 9 files changed, 1085 insertions(+), 233 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index d6cb0a4dc0ac..efd1d5610f2c 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -28,49 +28,51 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("block_size", [256]) +@pytest.mark.parametrize("block_size", [32]) +@pytest.mark.parametrize("max_num_prompt_seqs", [1]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) def test_models( vllm_runner, model: str, dtype: str, max_tokens: int, block_size: int, + max_num_prompt_seqs: int, + tensor_parallel_size: int, ) -> None: """ verify the flash attention has the same output as page attention """ - print("loading page attention models..") - pg_model = vllm_runner(model, dtype=dtype) - expected_outputs = [] + if torch.cuda.device_count() < tensor_parallel_size: + pytest.skip( + f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}" + ) + # print("loading page attention models..") + # pg_model = vllm_runner(model, dtype=dtype) + # expected_outputs = [] - print("generating tokens...") - expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) - print("generating tokens finished") + # print("generating tokens...") + # expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) + # print("generating tokens finished") - del pg_model + # del pg_model - destroy_model_parallel() - gc.collect() - torch.cuda.empty_cache() + # destroy_model_parallel() + # gc.collect() + # torch.cuda.empty_cache() - flash_attn_model = vllm_runner(model, - dtype=dtype) - # flash_style=True, - # block_size=block_size) - flash_attn_output_by_batches = [] - # flash_attn_output_by_batches.extend(flash_attn_model.generate_greedy(TEST_PROMPTS, max_tokens)) flash_attn_output_by_batches = [] + flash_attn_model = vllm_runner( + model, + dtype=dtype, + block_size=block_size, + flash_style=True, + tensor_parallel_size=tensor_parallel_size) for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] flash_attn_output_by_batches.append( flash_attn_model.generate_greedy(prompts, max_tokens)) del flash_attn_model - - # for e, f in zip(expected_outputs, flash_attn_output_by_batches): - # # print("expected: ", e[1]) - # # print("flash: ", f[1]) - # assert e[1] == f[1] - destroy_model_parallel() gc.collect() torch.cuda.empty_cache() @@ -80,6 +82,7 @@ def test_models( fa_output_ids, fa_output_str = flash_attn_outputs[i] vllm_output_ids, vllm_output_str = expected_outputs[ i % len(expected_outputs)] + print(vllm_output_str) assert fa_output_ids == vllm_output_ids, ( f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index 2938dff7cac8..52f9ad3ccc9f 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -1,12 +1,505 @@ +# import random +# from typing import Optional, Tuple, List + +# import pytest +# import torch +# import torch.nn.functional as F + +# from vllm.model_executor.layers.attention import ( +# flash_attn_with_kvcache_paged, ) +# from vllm.utils import get_max_shared_memory_bytes + +# FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# # This will change depending on the compute capability. +# # - 512 as a buffer +# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# NUM_BLOCKS = 128 # Arbitrary values for testing +# PARTITION_SIZE = 512 + +# DTYPES = [torch.half, torch.bfloat16] +# NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing +# NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing +# NUM_HEADS = [(1, 40), (40, 40), (64, 8)] # Arbitrary values for testing +# NUM_HEADS_SMALL = NUM_HEADS +# # head size should be bigger than or equal to block size. +# HEAD_SIZES = [256] +# # TODO(sang): https://github.com/Dao-AILab/flash-attention/pull/824 +# # should fix the block size. But right now, the block size should be +# # divisible by 256. +# BLOCK_SIZES = [256] +# USE_ALIBI = [False, True] +# SEEDS = [0] +# PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] + + +# def pad_attention_inputs( +# pad_config: Tuple[int, int], +# block_size: int, +# query: torch.Tensor, +# block_tables: torch.Tensor, +# context_lens: torch.Tensor, +# max_context_len: int, +# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: +# """Pad the attention inputs to the specified batch size and context length. +# """ +# pad_batch_size, pad_max_context_len = pad_config +# if pad_batch_size == 0: +# return query, block_tables, context_lens, max_context_len +# target_batch_size = ( +# (query.shape[0] - 1) % pad_batch_size + 1) * pad_batch_size +# target_block_size = pad_max_context_len // block_size + 1 +# padded_query = F.pad(query, +# (0, 0, 0, 0, 0, target_batch_size - query.shape[0])) +# padded_block_table = F.pad(block_tables, +# (0, target_block_size - block_tables.shape[1], +# 0, target_batch_size - block_tables.shape[0])) +# padded_context_lens = F.pad(context_lens, +# (0, target_batch_size - context_lens.shape[0])) +# return padded_query, padded_block_table, padded_context_lens, pad_max_context_len + + +# def ref_masked_attention( +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# scale: float, +# attn_mask: Optional[torch.Tensor] = None, +# ) -> torch.Tensor: +# attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() +# if attn_mask is not None: +# attn_weights = attn_weights + attn_mask.float() +# attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) +# out = torch.einsum("hqk,khd->qhd", attn_weights, value) +# return out + + +# def ref_single_query_cached_kv_attention( +# output: torch.Tensor, +# query: torch.Tensor, +# num_queries_per_kv: int, +# key_cache: torch.Tensor, +# value_cache: torch.Tensor, +# block_tables: torch.Tensor, +# context_lens: torch.Tensor, +# scale: float, +# alibi_slopes: Optional[torch.Tensor], +# ) -> None: +# num_query_heads = query.shape[1] +# head_size = value_cache.shape[-1] +# block_size = value_cache.shape[-3] +# num_seqs = query.shape[0] + +# block_tables = block_tables.cpu().tolist() +# context_lens = context_lens.cpu().tolist() +# for i in range(num_seqs): +# q = query[i].unsqueeze(0) +# block_table = block_tables[i] +# context_len = int(context_lens[i]) + +# keys = [] +# values = [] +# for j in range(context_len): +# block_number = int(block_table[j // block_size]) +# block_offset = j % block_size + +# k = key_cache[block_number, block_offset, :, :] +# keys.append(k) + +# v = value_cache[block_number, :, :, block_offset] +# v = value_cache[block_number, block_offset, :, :] +# values.append(v) +# keys = torch.stack(keys, dim=0) +# values = torch.stack(values, dim=0) +# if num_queries_per_kv > 1: +# # Handle MQA and GQA +# keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) +# values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + +# alibi_bias = None +# if alibi_slopes is not None: +# # Create the ALiBi bias used in the paged attention kernel. +# position_ids = torch.arange(context_len, device="cuda").int() +# alibi_bias = (position_ids - context_len + 1).float() +# alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( +# 1, 1, -1) + +# out = ref_masked_attention(q, keys, values, scale, alibi_bias) +# out = out.view(num_query_heads, head_size) +# # output[i].copy_(out, non_blocking=True) +# output[i].copy_(out) + + +# # @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +# # @pytest.mark.parametrize("num_heads", NUM_HEADS) +# # @pytest.mark.parametrize("head_size", HEAD_SIZES) +# # @pytest.mark.parametrize("use_alibi", [False, True]) +# # @pytest.mark.parametrize("block_size", BLOCK_SIZES) +# # @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +# # @pytest.mark.parametrize("seed", SEEDS) +# # @pytest.mark.parametrize("pad_config", PAD_CONFIGS) +# @pytest.mark.parametrize("num_seqs", [3]) +# @pytest.mark.parametrize("num_heads", [(40, 40)]) +# @pytest.mark.parametrize("head_size", [256]) +# @pytest.mark.parametrize("use_alibi", [True]) +# @pytest.mark.parametrize("block_size", [256]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("seed", SEEDS) +# @pytest.mark.parametrize("pad_config", [(0, 0)]) +# @torch.inference_mode() +# def test_flash_paged_attention( +# kv_cache_factory, +# num_seqs: int, +# num_heads: Tuple[int, int], +# head_size: int, +# use_alibi: bool, +# block_size: int, +# dtype: torch.dtype, +# seed: int, +# pad_config: Tuple[int, int], +# ) -> None: +# random.seed(seed) +# torch.random.manual_seed(seed) +# torch.cuda.manual_seed(seed) + +# scale = float(1.0 / (head_size**0.5)) +# num_query_heads, num_kv_heads = num_heads +# query = torch.empty(num_seqs, +# num_query_heads, +# head_size, +# dtype=dtype, +# device="cuda") +# query.uniform_(-scale, scale) + +# # assert num_query_heads % num_kv_heads == 0 +# num_queries_per_kv = num_query_heads // num_kv_heads +# alibi_slopes = None +# if use_alibi: +# alibi_slopes = torch.randn(num_query_heads, +# dtype=torch.float, +# device="cuda") + +# max_seq_len = MAX_SEQ_LEN if not pad_config[0] else (pad_config[1] - 1000) +# context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] +# context_lens[-1] = max_seq_len +# max_context_len = max(context_lens) +# context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + +# # Create the block tables. +# max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size +# block_tables = [] +# for _ in range(num_seqs): +# block_table = [ +# random.randint(0, NUM_BLOCKS - 1) +# for _ in range(max_num_blocks_per_seq) +# ] +# block_tables.append(block_table) +# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + +# # Create the KV caches. +# key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, +# block_size, +# 1, +# num_kv_heads, +# head_size, +# dtype, +# seed, +# flash_style=True) +# key_cache, value_cache = key_caches[0], value_caches[0] + +# # Call the paged attention kernel. +# output = torch.empty_like(query) + +# padded_query, padded_block_table, padded_context_lens, _ = \ +# pad_attention_inputs(pad_config, block_size, query, +# block_tables, context_lens, max_context_len) + +# output = flash_attn_with_kvcache_paged( +# padded_query.view(num_seqs, 1, num_query_heads, head_size), +# key_cache, +# value_cache, +# scale, +# padded_block_table, +# padded_context_lens, +# alibi_slopes, +# ) + +# # Run the reference implementation. +# ref_output = torch.empty_like(query) +# ref_single_query_cached_kv_attention( +# ref_output, +# query, +# num_queries_per_kv, +# key_cache, +# value_cache, +# block_tables, +# context_lens, +# scale, +# alibi_slopes, +# ) + +# assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +# def ref_multi_query_kv_attention( +# cu_seq_lens: List[int], +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# scale: float, +# dtype: torch.dtype, +# ) -> torch.Tensor: +# num_seqs = len(cu_seq_lens) - 1 +# ref_outputs = [] +# for i in range(num_seqs): +# start_idx = cu_seq_lens[i] +# end_idx = cu_seq_lens[i + 1] +# seq_len = end_idx - start_idx + +# # Create attention mask. +# attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), +# diagonal=1) +# attn_mask = attn_mask * torch.finfo(dtype).min +# attn_mask = attn_mask.to(dtype=dtype, device="cuda") + +# ref_output = ref_masked_attention( +# query[start_idx:end_idx], +# key[start_idx:end_idx], +# value[start_idx:end_idx], +# scale, +# attn_mask=attn_mask, +# ) +# ref_outputs.append(ref_output) +# ref_output = torch.cat(ref_outputs, dim=0) +# return ref_output + + +# def ref_multi_query_kv_attention_padded( +# query: torch.Tensor, +# num_queries_per_kv: int, +# key_cache: torch.Tensor, +# value_cache: torch.Tensor, +# block_tables: torch.Tensor, +# cu_seq_lens: List[int], +# context_lens: List[int], +# scale: float, +# dtype: torch.dtype, +# ) -> torch.Tensor: +# num_seqs = len(cu_seq_lens) - 1 +# block_size = value_cache.shape[-3] +# ref_outputs = [] + +# for i in range(num_seqs): +# q_start_idx = cu_seq_lens[i] +# q_end_idx = cu_seq_lens[i + 1] +# seq_len = q_end_idx - q_start_idx + +# context_len = context_lens[i] + +# block_table = block_tables[i] +# keys = [] +# values = [] + +# for j in range(context_len): +# block_number = int(block_table[j // block_size]) +# block_offset = j % block_size + +# k = key_cache[block_number, block_offset, :, :] +# keys.append(k) + +# v = value_cache[block_number, block_offset, :, :] +# values.append(v) + +# keys = torch.stack(keys, dim=0) +# values = torch.stack(values, dim=0) + +# if num_queries_per_kv > 1: +# # Handle MQA and GQA +# keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) +# values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + +# q = query[q_start_idx:q_end_idx, :, :] +# k = keys[:context_len, :, :] +# v = values[:context_len, :, :] + +# assert seq_len <= context_len + +# # pad q if seq_len is less than context_len +# # this is for correct calculation of attention. +# if seq_len < context_len: +# indices = [i % seq_len for i in range(context_len - seq_len)] +# q_left_pad = q[indices, :, :] +# q = torch.cat([q_left_pad, q], dim=0) + +# # Create attention mask. +# attn_mask = torch.triu(torch.ones(context_len, +# context_len, +# dtype=dtype), +# diagonal=1) +# attn_mask = attn_mask * torch.finfo(dtype).min +# attn_mask = attn_mask.to(dtype=dtype, device="cuda") +# ref_output = ref_masked_attention( +# q, +# k, +# v, +# scale, +# attn_mask=attn_mask, +# ) +# ref_outputs.append(ref_output[-seq_len:, :, :]) +# breakpoint +# ref_output = torch.cat(ref_outputs, dim=0) +# return ref_output + + +# def is_a100(): +# return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 + + +# if not is_a100(): +# NUM_HEADS_SMALL = [(16, 16), (16, 8)] +# MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) + + +# @pytest.mark.parametrize("num_seqs", [17]) +# @pytest.mark.parametrize("num_heads", [(40, 40)]) +# @pytest.mark.parametrize("head_size", HEAD_SIZES) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("version", ["flash"]) +# @pytest.mark.parametrize("seed", SEEDS) +# @pytest.mark.parametrize("block_size", [256]) +# @torch.inference_mode() +# def test_multi_query_kv_attention( +# num_seqs: int, +# num_heads: Tuple[int, int], +# head_size: int, +# dtype: torch.dtype, +# version: str, +# seed: int, +# block_size: int, +# ) -> None: +# random.seed(seed) +# torch.random.manual_seed(seed) +# torch.cuda.manual_seed(seed) + +# # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. +# # As the xformers library is already tested with its own tests, we can use +# # a smaller MAX_SEQ_LEN here. +# max_len = min(MAX_SEQ_LEN, 4096) + +# seq_lens = [random.randint(1, max_len // 2) for i in range(num_seqs)] +# max_seq_len = max(seq_lens) +# context_lens = seq_lens +# max_context_len = max(context_lens) + +# num_tokens = sum(seq_lens) +# cu_seq_lens = [0] +# for seq_len in seq_lens: +# cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + +# cu_context_lens = [0] +# for context_len in context_lens: +# cu_context_lens.append(cu_context_lens[-1] + context_len) + +# print(f"cu_seq_lens={cu_seq_lens}, cu_context_lens={cu_context_lens}") + +# scale = float(1.0 / (head_size**0.5)) +# num_query_heads, num_kv_heads = num_heads +# num_queries_per_kv = num_query_heads // num_kv_heads + +# value_cache = torch.empty(NUM_BLOCKS, +# block_size, +# num_kv_heads, +# head_size, +# dtype=dtype, +# device="cuda") +# key_cache = torch.empty(NUM_BLOCKS, +# block_size, +# num_kv_heads, +# head_size, +# dtype=dtype, +# device="cuda") +# query = torch.empty(max_seq_len * num_seqs, +# num_query_heads, +# head_size, +# dtype=dtype, +# device="cuda") +# value_cache.uniform_(-scale, scale) +# key_cache.uniform_(-scale, scale) +# query.uniform_(-scale, scale) + +# # Create the block tables. +# max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size +# block_tables = [] +# for _ in range(num_seqs): +# block_table = [ +# random.randint(0, NUM_BLOCKS - 1) +# for _ in range(max_num_blocks_per_seq) +# ] +# block_tables.append(block_table) +# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + +# output = torch.empty_like(query) + +# if version == "flash": +# # flash_multi_query_cached_kv_attention_varlen( +# # output, +# # query, +# # key_cache, +# # value_cache, +# # scale, +# # block_tables, +# # torch.cuda.IntTensor(cu_seq_lens), +# # torch.cuda.IntTensor(cu_context_lens), +# # block_size, +# # max_seq_len, +# # max_context_len, +# # None, +# # ) +# from flash_attn import flash_attn_func +# breakpoint() +# # output = flash_attn_func( +# # query.unsqueeze(0), +# # k=key, +# # v=value, +# # softmax_scale=scale, +# # causal=True, +# # alibi_slopes=alibi_slopes, +# # ) +# output = flash_attn_with_kvcache_paged( +# query.view(num_seqs, max_seq_len, num_query_heads, head_size), +# key_cache, +# value_cache, +# scale, +# block_tables, +# torch.tensor(context_lens, dtype=torch.int, device="cuda"), +# None, +# ) +# else: +# assert False, f"{version=} is not supported" + +# ref_output = ref_multi_query_kv_attention_padded( +# query, +# num_queries_per_kv, +# key_cache, +# value_cache, +# block_tables, +# cu_seq_lens, +# context_lens, +# scale, +# dtype, +# ) +# assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + import random -from typing import Optional, Tuple, List +from typing import List, Optional, Tuple import pytest import torch import torch.nn.functional as F from vllm.model_executor.layers.attention import ( - flash_attn_with_kvcache_paged, ) + flash_single_query_cached_kv_attention, + flash_multi_query_cached_kv_attention_varlen, +) from vllm.utils import get_max_shared_memory_bytes FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 @@ -19,14 +512,10 @@ DTYPES = [torch.half, torch.bfloat16] NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing -NUM_HEADS = [(1, 40), (40, 40), (64, 8)] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS_SMALL = NUM_HEADS -# head size should be bigger than or equal to block size. -HEAD_SIZES = [256] -# TODO(sang): https://github.com/Dao-AILab/flash-attention/pull/824 -# should fix the block size. But right now, the block size should be -# divisible by 256. -BLOCK_SIZES = [256] +HEAD_SIZES = [32, 64, 128, 256] +BLOCK_SIZES = [32, 64, 256] USE_ALIBI = [False, True] SEEDS = [0] PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] @@ -83,8 +572,10 @@ def ref_single_query_cached_kv_attention( context_lens: torch.Tensor, scale: float, alibi_slopes: Optional[torch.Tensor], + flash_style: bool = False, ) -> None: num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[-2] head_size = value_cache.shape[-1] block_size = value_cache.shape[-3] num_seqs = query.shape[0] @@ -102,11 +593,17 @@ def ref_single_query_cached_kv_attention( block_number = int(block_table[j // block_size]) block_offset = j % block_size - k = key_cache[block_number, block_offset, :, :] + if flash_style: + k = key_cache[block_number, block_offset, :, :] + else: + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) keys.append(k) - v = value_cache[block_number, :, :, block_offset] - v = value_cache[block_number, block_offset, :, :] + if flash_style: + v = value_cache[block_number, block_offset, :, :] + else: + v = value_cache[block_number, :, :, block_offset] values.append(v) keys = torch.stack(keys, dim=0) values = torch.stack(values, dim=0) @@ -125,24 +622,15 @@ def ref_single_query_cached_kv_attention( out = ref_masked_attention(q, keys, values, scale, alibi_bias) out = out.view(num_query_heads, head_size) - # output[i].copy_(out, non_blocking=True) - output[i].copy_(out) + output[i].copy_(out, non_blocking=True) -# @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -# @pytest.mark.parametrize("num_heads", NUM_HEADS) -# @pytest.mark.parametrize("head_size", HEAD_SIZES) -# @pytest.mark.parametrize("use_alibi", [False, True]) -# @pytest.mark.parametrize("block_size", BLOCK_SIZES) -# @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) -# @pytest.mark.parametrize("seed", SEEDS) -# @pytest.mark.parametrize("pad_config", PAD_CONFIGS) -@pytest.mark.parametrize("num_seqs", [3]) -@pytest.mark.parametrize("num_heads", [(40, 40)]) -@pytest.mark.parametrize("head_size", [256]) -@pytest.mark.parametrize("use_alibi", [True]) -@pytest.mark.parametrize("block_size", [256]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", [False]) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("pad_config", [(0, 0)]) @torch.inference_mode() @@ -170,8 +658,11 @@ def test_flash_paged_attention( device="cuda") query.uniform_(-scale, scale) - # assert num_query_heads % num_kv_heads == 0 + assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) alibi_slopes = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, @@ -207,14 +698,16 @@ def test_flash_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. + num_valid_tokens = torch.cuda.IntTensor([num_seqs]) output = torch.empty_like(query) - padded_query, padded_block_table, padded_context_lens, _ = \ + padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ pad_attention_inputs(pad_config, block_size, query, block_tables, context_lens, max_context_len) - output = flash_attn_with_kvcache_paged( - padded_query.view(num_seqs, 1, num_query_heads, head_size), + flash_single_query_cached_kv_attention( + output, + padded_query, key_cache, value_cache, scale, @@ -235,8 +728,12 @@ def test_flash_paged_attention( context_lens, scale, alibi_slopes, + flash_style=True, ) + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) @@ -337,6 +834,7 @@ def ref_multi_query_kv_attention_padded( diagonal=1) attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask.to(dtype=dtype, device="cuda") + ref_output = ref_masked_attention( q, k, @@ -345,7 +843,6 @@ def ref_multi_query_kv_attention_padded( attn_mask=attn_mask, ) ref_outputs.append(ref_output[-seq_len:, :, :]) - breakpoint ref_output = torch.cat(ref_outputs, dim=0) return ref_output @@ -359,13 +856,14 @@ def is_a100(): MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) -@pytest.mark.parametrize("num_seqs", [17]) -@pytest.mark.parametrize("num_heads", [(40, 40)]) +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS_SMALL) @pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("version", ["flash"]) +@pytest.mark.parametrize("chunked_prefill", [False, True]) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("block_size", [256]) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, @@ -374,6 +872,7 @@ def test_multi_query_kv_attention( dtype: torch.dtype, version: str, seed: int, + chunked_prefill: bool, block_size: int, ) -> None: random.seed(seed) @@ -387,8 +886,18 @@ def test_multi_query_kv_attention( seq_lens = [random.randint(1, max_len // 2) for i in range(num_seqs)] max_seq_len = max(seq_lens) - context_lens = seq_lens + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device="cuda") + + if chunked_prefill: + # context length will be different from seq_len if chunked_prefill is + # true. + context_lens = random.sample(range(max_seq_len, max_len), num_seqs) + else: + context_lens = seq_lens max_context_len = max(context_lens) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device="cuda") num_tokens = sum(seq_lens) cu_seq_lens = [0] @@ -417,7 +926,7 @@ def test_multi_query_kv_attention( head_size, dtype=dtype, device="cuda") - query = torch.empty(max_seq_len * num_seqs, + query = torch.empty(num_tokens, num_query_heads, head_size, dtype=dtype, @@ -440,37 +949,18 @@ def test_multi_query_kv_attention( output = torch.empty_like(query) if version == "flash": - # flash_multi_query_cached_kv_attention_varlen( - # output, - # query, - # key_cache, - # value_cache, - # scale, - # block_tables, - # torch.cuda.IntTensor(cu_seq_lens), - # torch.cuda.IntTensor(cu_context_lens), - # block_size, - # max_seq_len, - # max_context_len, - # None, - # ) - from flash_attn import flash_attn_func - breakpoint() - # output = flash_attn_func( - # query.unsqueeze(0), - # k=key, - # v=value, - # softmax_scale=scale, - # causal=True, - # alibi_slopes=alibi_slopes, - # ) - output = flash_attn_with_kvcache_paged( - query.view(num_seqs, max_seq_len, num_query_heads, head_size), + flash_multi_query_cached_kv_attention_varlen( + output, + query, key_cache, value_cache, scale, block_tables, - torch.tensor(context_lens, dtype=torch.int, device="cuda"), + torch.cuda.IntTensor(cu_seq_lens), + torch.cuda.IntTensor(cu_context_lens), + block_size, + max_seq_len, + max_context_len, None, ) else: @@ -488,3 +978,160 @@ def test_multi_query_kv_attention( dtype, ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +@pytest.mark.parametrize("num_heads", [40]) +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("block_size", [32]) +@pytest.mark.parametrize("num_blocks", [128]) +@pytest.mark.parametrize("dtype", [torch.half]) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_e2e( + kv_cache_factory, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, +) -> None: + from vllm._C import cache_ops + batch_size = 2 + seqlen = 29 + num_tokens = batch_size * seqlen + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + block_tables = [] + for _ in range(batch_size): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + # Create a random slot mapping. + slot_mapping = [] + for i in range(0, 29): + block_number = int(block_tables[i // block_size]) + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + # for _ in range(23, 29): + # slot_mapping.append(-1) + for i in range(0, 29): + block_number = int(block_tables[1]) + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + # slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device='cuda') + query, key, value = qkv.unbind(dim=1) + # query[23:29] = 0 + # key[23:29] = 0 + # value[23:29] = 0 + # Create the KV caches. + + key_caches, value_caches = kv_cache_factory(num_blocks, + block_size, + 1, + num_heads, + head_size, + dtype, + seed, + flash_style=True) + assert len(key_caches) == 1 and len(value_caches) == 1 + key_cache, value_cache = key_caches[0], value_caches[0] + # Call the reshape_and_cache kernel. + cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping) + + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Run the reference implementation. + block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies[i] + block_offset = block_offsets[i] + print("block_idx", block_idx) + print("block_offset", block_offset) + if block_idx != -1: + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) + + for i in range(58): + print(i) + block_idx = block_indicies[i] + block_offset = block_offsets[i] + torch.allclose(key[i], key_cache[block_idx][block_offset]) + + from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + from xformers import ops as xops + seqlen = query.shape[1] + attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seqlen] * batch_size) + scale = float(1.0 / (head_size**0.5)) + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=None, + p=0.0, + scale=scale, + ) + + num_tokens, num_heads, head_size = query.shape + from flash_attn import flash_attn_func + output1 = flash_single_query_cached_kv_attention( + None, + query.view(batch_size, num_tokens // batch_size, num_heads, head_size), + key_cache, + value_cache, + scale, + block_tables, + torch.tensor([23, 29], dtype=torch.int, device='cuda'), + alibi_slopes=None, + ) + output2 = flash_attn_func( + # query.view(batch_size, num_tokens // batch_size, num_heads, head_size), + # key.view(batch_size, num_tokens // batch_size, num_heads, head_size), + # value.view(batch_size, num_tokens // batch_size, num_heads, head_size), + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + # key_cache, + # value_cache, + softmax_scale=scale, + # block_tables, + # torch.tensor([23, 29], dtype=torch.int, device='cuda'), + # alibi_slopes=None, + causal=True, + ) + output3 = flash_attn_func( + query.view(batch_size, num_tokens // batch_size, num_heads, head_size), + key.view(batch_size, num_tokens // batch_size, num_heads, head_size), + value.view(batch_size, num_tokens // batch_size, num_heads, head_size), + # key_cache, + # value_cache, + softmax_scale=scale, + # block_tables, + # torch.tensor([23, 29], dtype=torch.int, device='cuda'), + # alibi_slopes=None, + causal=True, + ) + + breakpoint() diff --git a/tests/models/test_models.py b/tests/models/test_models.py index fb567e837d28..b8ee03759284 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -6,20 +6,20 @@ MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", - "mistralai/Mistral-7B-v0.1", - "Deci/DeciLM-7b", - "tiiuae/falcon-7b", - "gpt2", - "bigcode/tiny_starcoder_py", - "EleutherAI/gpt-j-6b", - "EleutherAI/pythia-70m", - "bigscience/bloom-560m", - "mosaicml/mpt-7b", - "microsoft/phi-2", - "stabilityai/stablelm-3b-4e1t", - "allenai/OLMo-1B", - "bigcode/starcoder2-3b", + # "meta-llama/Llama-2-7b-hf", + # "mistralai/Mistral-7B-v0.1", + # "Deci/DeciLM-7b", + # "tiiuae/falcon-7b", + # "gpt2", + # "bigcode/tiny_starcoder_py", + # "EleutherAI/gpt-j-6b", + # "EleutherAI/pythia-70m", + # "bigscience/bloom-560m", + # "mosaicml/mpt-7b", + # "microsoft/phi-2", + # "stabilityai/stablelm-3b-4e1t", + # "allenai/OLMo-1B", + # "bigcode/starcoder2-3b", ] diff --git a/vllm/config.py b/vllm/config.py index 859a4ee734ad..bbafd4c06de3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -335,11 +335,11 @@ def _verify_args(self) -> None: if self.flash_style: logger.info("Flash attention enabled.") - if self.block_size < 256: + if self.block_size > 32: # Flash style attention only supports block size >=256 for now. # https://github.com/Dao-AILab/flash-attention/pull/824 will fix it. raise ValueError( - "Flash style attention only supports block size >= 256. Got" + "Flash style attention only supports block size <= 32. Got" f"{self.block_size }") def _verify_cache_dtype(self) -> None: diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index b5ae90e217ba..465704f9a1c3 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -200,8 +200,8 @@ def allocate(self, seq_group: SequenceGroup) -> None: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) - print("allcate a new block for seq_group", seq_group) - print("block", block) + # print("allcate a new block for seq_group", seq_group) + # print("block", block) block_table.append(block) # Assign the block table for each sequence. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3f1b5eca5d05..f59a5605dcd2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -823,8 +823,8 @@ def step(self) -> List[RequestOutput]: >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - # print("SANG-TODO step seq_group_metadata_list length: ", - # len(seq_group_metadata_list)) + print("SANG-TODO step seq_group_metadata_list length: ", + len(seq_group_metadata_list)) if not scheduler_outputs.is_empty(): # Execute the model. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 93adc98adda5..e1932bb761ef 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -146,7 +146,7 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - # print("SANG-TODO generate: ", prompts, prompt_token_ids) + print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index db5e5f44ad42..1a5fd49cf025 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -14,11 +14,18 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( context_attention_fwd) from vllm.utils import is_hip +from flash_attn import flash_attn_func +# try: +# from flash_attn import flash_attn_with_kvcache +# except ImportError: +# flash_attn_with_kvcache = None try: - from flash_attn import flash_attn_with_kvcache -except ImportError: - flash_attn_with_kvcache = None + from flash_attn import (flash_attn_with_page_attention, + flash_attn_varlen_with_page_attention) +except Exception as e: + flash_attn_with_page_attention = e + flash_attn_varlen_with_page_attention = e _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -152,120 +159,145 @@ def forward( if input_metadata.is_prompt: # normal attention if (key_cache is None or value_cache is None - or input_metadata.block_tables.numel() == 0): - # or not input_metadata.prefix_enabled): - # if key_cache is not None and value_cache is not None and input_metadata.flash_style: - if False: - print("SANG-TODO flash attention.") - output = flash_attn_with_kvcache_paged( - query.view(batch_size, seq_len, self.num_heads, - self.head_size), + # or input_metadata.block_tables.numel() == 0): + or not input_metadata.prefix_enabled): + # print("SANG-TODO flash attn is used.") + # print( + # "SANG-TODO query size: ", + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size).size()) + if self.num_kv_heads != self.num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # TODO(woosuk): Use MQA/GQA kernels for higher performance. + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], + self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + if input_metadata.attn_bias is None: + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * batch_size) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = attn_bias + else: + input_metadata.attn_bias = _make_alibi_bias( + self.alibi_slopes, self.num_kv_heads, batch_size, + seq_len, query.dtype) + + if self.use_ref_attention: + output = self.ref_masked_attention( + query, + key, + value, + ) + # Using view got RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use reshape instead + return output.reshape(batch_size, seq_len, hidden_size) + + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + if self.alibi_slopes is None: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + else: + query = query.unflatten(0, (batch_size, seq_len)) + key = key.unflatten(0, (batch_size, seq_len)) + value = value.unflatten(0, (batch_size, seq_len)) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + output = out.view_as(query) + if key_cache is not None and value_cache is not None: + output2 = flash_single_query_cached_kv_attention( + None, + query.view(batch_size, seq_len, self.num_heads, self.head_size), key_cache, value_cache, self.scale, input_metadata.block_tables, input_metadata.context_lens, - self.alibi_slopes, + alibi_slopes=self.alibi_slopes, ) - else: - # print("SANG-TODO flash attn is used.") - # print( - # "SANG-TODO query size: ", - # query.view(batch_size, seq_len, self.num_heads, - # self.head_size).size()) - if self.num_kv_heads != self.num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # TODO(woosuk): Use MQA/GQA kernels for higher performance. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - - # Set attention bias if not provided. This typically happens at - # the very attention layer of every iteration. - # FIXME(woosuk): This is a hack. - if input_metadata.attn_bias is None: - if self.alibi_slopes is None: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size) - if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention( - self.sliding_window) - input_metadata.attn_bias = attn_bias - else: - input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, batch_size, - seq_len, query.dtype) - - if self.use_ref_attention: - output = self.ref_masked_attention( - query, - key, - value, - ) - # Using view got RuntimeError: view size is not compatible with input tensor's size and stride - # (at least one dimension spans across two contiguous subspaces). Use reshape instead - return output.reshape(batch_size, seq_len, hidden_size) - - # TODO(woosuk): Too many view operations. Let's try to reduce - # them in the future for code readability. - if self.alibi_slopes is None: - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - else: - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) - out = xops.memory_efficient_attention_forward( + output3 = flash_attn_func( + query.view(batch_size, seq_len, self.num_heads, self.head_size), + key.view(batch_size, seq_len, self.num_heads, self.head_size), + value.view(batch_size, seq_len, self.num_heads, self.head_size), + softmax_scale=self.scale, + causal=True, + ).view_as(query) + output4 = flash_attn_func( query, key, value, - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ) - output = out.view_as(query) - # if key_cache is not None and value_cache is not None: - # breakpoint() - # output2 = flash_attn_with_kvcache_paged( - # query.view(batch_size, seq_len, self.num_heads, - # self.head_size), - # key_cache, - # value_cache, - # self.scale, - # input_metadata.block_tables, - # input_metadata.context_lens, - # self.alibi_slopes, - # ) - # breakpoint() + softmax_scale=self.scale, + causal=True, + ).view_as(query) + breakpoint() + # if key_cache is not None and value_cache is not None: + # breakpoint() + # output2 = flash_attn_with_kvcache_paged( + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size), + # key_cache, + # value_cache, + # self.scale, + # input_metadata.block_tables, + # input_metadata.context_lens, + # self.alibi_slopes, + # ) + # breakpoint() else: if input_metadata.flash_style: - assert False - output = flash_attn_with_kvcache_paged( - query.view(batch_size, seq_len, self.num_heads, - self.head_size), + print("SANG-TODO flash prefill") + output = flash_single_query_cached_kv_attention( + None, + query.view(batch_size, seq_len, self.num_heads, self.head_size), key_cache, value_cache, self.scale, input_metadata.block_tables, input_metadata.context_lens, - self.alibi_slopes, + alibi_slopes=self.alibi_slopes, ) + # output = flash_multi_query_cached_kv_attention_varlen( + # None, + # query, + # key_cache, + # value_cache, + # self.scale, + # self.block_tables, + # input_metadata.start_loc, + # input_metadata.start_loc, + # max_query_len, + # max_context_len, + # alibi_slopes=self.alibi_slopes, + # ) else: - # print("SANG-TODO context attention") + print("SANG-TODO context attention") # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( @@ -287,11 +319,21 @@ def forward( # Decoding run. # breakpoint() if input_metadata.flash_style: - output = flash_attn_with_kvcache_paged( - query.view(batch_size, seq_len, self.num_heads, - self.head_size), key_cache, value_cache, - self.scale, input_metadata.block_tables, - input_metadata.context_lens, self.alibi_slopes) + # output = flash_attn_with_kvcache_paged( + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size), key_cache, value_cache, + # self.scale, input_metadata.block_tables, + # input_metadata.context_lens, self.alibi_slopes) + output = flash_single_query_cached_kv_attention( + None, + query.view(batch_size, 1, self.num_heads, self.head_size), + key_cache, + value_cache, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + alibi_slopes=self.alibi_slopes, + ) else: output = _paged_attention( query, @@ -417,46 +459,205 @@ def _paged_attention( return output -def flash_attn_with_kvcache_paged( +# def flash_attn_with_kvcache_paged( +# query: torch.Tensor, +# key_cache: torch.Tensor, +# value_cache: torch.Tensor, +# scale: float, +# block_tables: torch.Tensor, +# context_lens: torch.Tensor, +# alibi_slopes: Optional[torch.Tensor] = None, +# ) -> torch.Tensor: +# """Similar to vLLM's page attention, calculates a single token's attention +# based on key/value caches. The main difference is this uses flash attention +# style key-value caches. + +# Arguments: +# See https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py +# for other arguments. + +# Returns: +# output: [num_tokens, num_heads, head_size] +# """ +# block_size = value_cache.shape[1] +# assert block_size % 256 == 0, ("only support block_size divisible by 256. " +# f"Current block size: {block_size}") +# _, _, num_heads, head_size = query.shape +# out = flash_attn_with_kvcache( +# query, +# key_cache, +# value_cache, +# # Inplace update is slow. We don't use it. +# # We assume kvcache is already updated before +# # calling this API. +# None, # key +# None, # value +# cache_seqlens=context_lens, +# block_table=block_tables, +# softmax_scale=scale, +# causal=True, +# alibi_slopes=alibi_slopes, +# num_splits=0, +# ) + +# # num_tokens == batch_size * seqlen +# return out.view(-1, num_heads, head_size) + + +def flash_single_query_cached_kv_attention( + output: Optional[torch.Tensor], query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, scale: float, block_tables: torch.Tensor, context_lens: torch.Tensor, - alibi_slopes: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor], + actual_batch_size: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Similar to vLLM's page attention, calculates a single token's attention + """Similar to vLLM's page attention, caclulates a single token's attention based on key/value caches. The main difference is this uses flash attention - style key-value caches. + sytle key-value caches. Arguments: - See https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py - for other arguments. + output: [num_padded_tokens, num_heads, head_size], output tensor + to write. if None an new output tensor will be created. + query: [batch_size, num_padded_tokens, num_heads, head_size], query tensor. + key_cache: [num_blocks, block_size, num_heads, head_size], key cache. + value_cache: [num_blocks, block_size, num_heads, head_size], value + cache. + scale: attention scale. + block_tables: [num_padded_tokens, max_context_len / block_size], + block tables. + context_lens: [num_padded_tokens], context lengths. + block_size: block size. + alibi_slopes: unused. + actual_batch_size: [1] actual batch size. Returns: - output: [num_tokens, num_heads, head_size] + output: [num_padded_tokens, num_heads, head_size] """ block_size = value_cache.shape[1] - assert block_size % 256 == 0, ("only support block_size divisible by 256. " - f"Current block size: {block_size}") - _, _, num_heads, head_size = query.shape - out = flash_attn_with_kvcache( + assert block_size >= 32, "only support block_size >= 32 for flash attention" + # TODO: support alibi_slopes + assert alibi_slopes is None, "doesn't support alibi_slopes" + batch_size, seqlen, num_heads, head_size = query.shape + num_tokens = batch_size * seqlen + out = flash_attn_with_page_attention( + query, + key_cache, + value_cache, + block_tables, + None, # key + None, # value + None, # cos + None, # sin + context_lens, + None, # cache_batch_idx + softmax_scale=scale, + causal=True, + window_size=(-1, -1), + rotary_interleaved=False, + num_splits=0, + actual_batch_size=actual_batch_size, + ) + return out + # if output is not None: + # # in case that output is padded, only copy the valid part. + # output[:num_tokens].copy_(out.view(num_tokens, num_heads, head_size)) + # return out.view(num_tokens, num_heads, head_size) + + +def flash_multi_query_cached_kv_attention_varlen( + output: Optional[torch.Tensor], + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + scale: float, + block_tables: torch.Tensor, + cum_seqlens_q: torch.Tensor, + cum_context_len: torch.Tensor, + max_query_len: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + actual_batch_size: Optional[torch.Tensor] = None, +): + """Efficient multi-query paged attention based on flash attention. + It calculates attentions of list of sequences packed in a single batch, + indexed by cum_seqlens_q where the seq_i's index is + [cum_seqlens_q[i], cum_seqlensq[i+1]]. + Similarlly, the length of context is stored in cum_seqlens_k with similar + fashions. + It also supports calculating attention incrementally, where context length + is longer than sequence length. + + Arguments: + output: [num_padded_tokens, num_heads, head_size], output tensor to + write to. if None an new output tensor will be created. + + prefill -> always 1 batch + query, head_size, head_dim + + varlen -> provides cumulative lengths of queries + + decoding -> always 1 batch + 1 * number_of_batch, head_size, head_dim + + query: [num_padded_tokens, num_heads, head_size], query tensor. + key_cache: [num_blocks, block_size, num_heads, head_size], key cache. + value_cache: [num_blocks, block_size, num_heads, head_size], + value cache. + scale: attention scale. + block_tables: [num_padded_tokens, max_context_len / block_size], + block tables. + - these two are the same if no chunked prefill (for prefill) + - If you do chunked prefill, it may be different + - when? see attention mask setting + - Each iteration, it can have smaller # of queries than (because it is chunked) + tokens we should attend to (context len). + - cum_seqlens_q: actual query length (actual prompt length). + - cum_context_len: context len that needs to be attended (subquery). + cum_seqlens_q: [padded_batch_size + 1], cumulative sequence lengths + of query. + cum_context_len: [padded_batch_size + 1], cumulative lengths + of context. + block_size: block size. + max_query_len: max query length. + max_context_len: max context length. + alibi_slopes: unused. + actual_batch_size: [1] actual batch size. + + Returns: + output: [num_padded_tokens, num_heads, head_size] + """ + block_size = value_cache.shape[1] + assert block_size >= 32, "only support block_size >= 32 for flash attention" + # TODO: support alibi_slopes + assert alibi_slopes is None, "doesn't support alibi_slopes" + + num_tokens, _, _ = query.shape + out = flash_attn_varlen_with_page_attention( query, key_cache, value_cache, - # Inplace update is slow. We don't use it. - # We assume kvcache is already updated before - # calling this API. + block_tables, + cum_seqlens_q, + cum_context_len, + max_query_len, + max_context_len, None, # key None, # value - cache_seqlens=context_lens, - block_table=block_tables, + None, # cos_cache + None, # sin_cache + None, # cache_batch_idx softmax_scale=scale, causal=True, - alibi_slopes=alibi_slopes, + window_size=(-1, -1), + rotary_interleaved=False, num_splits=0, + actual_batch_size=actual_batch_size, ) - # num_tokens == batch_size * seqlen - return out.view(-1, num_heads, head_size) + if output is not None: + output[:num_tokens].copy_(out) + return out \ No newline at end of file diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 84fc0a0e5528..5cddeed35538 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -224,6 +224,7 @@ def _prepare_prompt( slot_mapping[-1].append(slot) max_prompt_len = max(subquery_lens) + breakpoint() input_tokens = _make_tensor_with_pad(input_tokens, max_prompt_len, pad=0, From ef679d74a2296f941571c0c30dfaa0fd338783de Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 4 Mar 2024 05:09:04 -0800 Subject: [PATCH 17/28] . --- tests/kernels/test_flash_attention.py | 297 ++++++++++++------------ vllm/model_executor/layers/attention.py | 54 ++--- vllm/worker/model_runner.py | 1 - 3 files changed, 173 insertions(+), 179 deletions(-) diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index 52f9ad3ccc9f..eb999504690a 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -980,158 +980,157 @@ def test_multi_query_kv_attention( assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) -@pytest.mark.parametrize("num_heads", [40]) -@pytest.mark.parametrize("head_size", [64]) -@pytest.mark.parametrize("block_size", [32]) -@pytest.mark.parametrize("num_blocks", [128]) -@pytest.mark.parametrize("dtype", [torch.half]) -@pytest.mark.parametrize("seed", SEEDS) -@torch.inference_mode() -def test_e2e( - kv_cache_factory, - num_heads: int, - head_size: int, - block_size: int, - num_blocks: int, - dtype: torch.dtype, - seed: int, -) -> None: - from vllm._C import cache_ops - batch_size = 2 - seqlen = 29 - num_tokens = batch_size * seqlen - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) +# @pytest.mark.parametrize("num_heads", [40]) +# @pytest.mark.parametrize("head_size", [64]) +# @pytest.mark.parametrize("block_size", [32]) +# @pytest.mark.parametrize("num_blocks", [128]) +# @pytest.mark.parametrize("dtype", [torch.half]) +# @pytest.mark.parametrize("seed", SEEDS) +# @torch.inference_mode() +# def test_e2e( +# kv_cache_factory, +# num_heads: int, +# head_size: int, +# block_size: int, +# num_blocks: int, +# dtype: torch.dtype, +# seed: int, +# ) -> None: +# from vllm._C import cache_ops +# batch_size = 2 +# seqlen = 29 +# num_tokens = batch_size * seqlen +# random.seed(seed) +# torch.random.manual_seed(seed) +# torch.cuda.manual_seed(seed) - block_tables = [] - for _ in range(batch_size): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - # Create a random slot mapping. - slot_mapping = [] - for i in range(0, 29): - block_number = int(block_tables[i // block_size]) - block_offset = i % block_size - slot = block_number * block_size + block_offset - slot_mapping.append(slot) - # for _ in range(23, 29): - # slot_mapping.append(-1) - for i in range(0, 29): - block_number = int(block_tables[1]) - block_offset = i % block_size - slot = block_number * block_size + block_offset - slot_mapping.append(slot) - # slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') - - qkv = torch.randn(num_tokens, - 3, - num_heads, - head_size, - dtype=dtype, - device='cuda') - query, key, value = qkv.unbind(dim=1) - # query[23:29] = 0 - # key[23:29] = 0 - # value[23:29] = 0 - # Create the KV caches. +# block_tables = [] +# for _ in range(batch_size): +# block_table = [ +# random.randint(0, NUM_BLOCKS - 1) +# ] +# block_tables.append(block_table) +# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") +# # Create a random slot mapping. +# slot_mapping = [] +# for i in range(0, 29): +# block_number = int(block_tables[i // block_size]) +# block_offset = i % block_size +# slot = block_number * block_size + block_offset +# slot_mapping.append(slot) +# # for _ in range(23, 29): +# # slot_mapping.append(-1) +# for i in range(0, 29): +# block_number = int(block_tables[1]) +# block_offset = i % block_size +# slot = block_number * block_size + block_offset +# slot_mapping.append(slot) +# # slot_mapping = random.sample(range(num_slots), num_tokens) +# slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + +# qkv = torch.randn(num_tokens, +# 3, +# num_heads, +# head_size, +# dtype=dtype, +# device='cuda') +# query, key, value = qkv.unbind(dim=1) +# # query[23:29] = 0 +# # key[23:29] = 0 +# # value[23:29] = 0 +# # Create the KV caches. - key_caches, value_caches = kv_cache_factory(num_blocks, - block_size, - 1, - num_heads, - head_size, - dtype, - seed, - flash_style=True) - assert len(key_caches) == 1 and len(value_caches) == 1 - key_cache, value_cache = key_caches[0], value_caches[0] - # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping) +# key_caches, value_caches = kv_cache_factory(num_blocks, +# block_size, +# 1, +# num_heads, +# head_size, +# dtype, +# seed, +# flash_style=True) +# assert len(key_caches) == 1 and len(value_caches) == 1 +# key_cache, value_cache = key_caches[0], value_caches[0] +# # Call the reshape_and_cache kernel. +# cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, +# slot_mapping) - cloned_key_cache = key_cache.clone() - cloned_value_cache = value_cache.clone() +# cloned_key_cache = key_cache.clone() +# cloned_value_cache = value_cache.clone() - # Run the reference implementation. - block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') - block_indicies = block_indicies.cpu().tolist() - block_offsets = slot_mapping % block_size - block_offsets = block_offsets.cpu().tolist() - for i in range(num_tokens): - block_idx = block_indicies[i] - block_offset = block_offsets[i] - print("block_idx", block_idx) - print("block_offset", block_offset) - if block_idx != -1: - cloned_key_cache[block_idx, block_offset, :, :] = key[i] - cloned_value_cache[block_idx, block_offset, :, :] = value[i] - assert torch.allclose(key_cache, cloned_key_cache) - assert torch.allclose(value_cache, cloned_value_cache) - - for i in range(58): - print(i) - block_idx = block_indicies[i] - block_offset = block_offsets[i] - torch.allclose(key[i], key_cache[block_idx][block_offset]) - - from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask - from xformers import ops as xops - seqlen = query.shape[1] - attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seqlen] * batch_size) - scale = float(1.0 / (head_size**0.5)) - output = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=None, - p=0.0, - scale=scale, - ) +# # Run the reference implementation. +# block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') +# block_indicies = block_indicies.cpu().tolist() +# block_offsets = slot_mapping % block_size +# block_offsets = block_offsets.cpu().tolist() +# for i in range(num_tokens): +# block_idx = block_indicies[i] +# block_offset = block_offsets[i] +# print("block_idx", block_idx) +# print("block_offset", block_offset) +# if block_idx != -1: +# cloned_key_cache[block_idx, block_offset, :, :] = key[i] +# cloned_value_cache[block_idx, block_offset, :, :] = value[i] +# assert torch.allclose(key_cache, cloned_key_cache) +# assert torch.allclose(value_cache, cloned_value_cache) + +# for i in range(58): +# print(i) +# block_idx = block_indicies[i] +# block_offset = block_offsets[i] +# torch.allclose(key[i], key_cache[block_idx][block_offset]) + +# from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask +# from xformers import ops as xops +# seqlen = query.shape[1] +# attn_bias = BlockDiagonalCausalMask.from_seqlens( +# [seqlen] * batch_size) +# scale = float(1.0 / (head_size**0.5)) +# output = xops.memory_efficient_attention_forward( +# query.unsqueeze(0), +# key.unsqueeze(0), +# value.unsqueeze(0), +# attn_bias=None, +# p=0.0, +# scale=scale, +# ) - num_tokens, num_heads, head_size = query.shape - from flash_attn import flash_attn_func - output1 = flash_single_query_cached_kv_attention( - None, - query.view(batch_size, num_tokens // batch_size, num_heads, head_size), - key_cache, - value_cache, - scale, - block_tables, - torch.tensor([23, 29], dtype=torch.int, device='cuda'), - alibi_slopes=None, - ) - output2 = flash_attn_func( - # query.view(batch_size, num_tokens // batch_size, num_heads, head_size), - # key.view(batch_size, num_tokens // batch_size, num_heads, head_size), - # value.view(batch_size, num_tokens // batch_size, num_heads, head_size), - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - # key_cache, - # value_cache, - softmax_scale=scale, - # block_tables, - # torch.tensor([23, 29], dtype=torch.int, device='cuda'), - # alibi_slopes=None, - causal=True, - ) - output3 = flash_attn_func( - query.view(batch_size, num_tokens // batch_size, num_heads, head_size), - key.view(batch_size, num_tokens // batch_size, num_heads, head_size), - value.view(batch_size, num_tokens // batch_size, num_heads, head_size), - # key_cache, - # value_cache, - softmax_scale=scale, - # block_tables, - # torch.tensor([23, 29], dtype=torch.int, device='cuda'), - # alibi_slopes=None, - causal=True, - ) +# num_tokens, num_heads, head_size = query.shape +# from flash_attn import flash_attn_func +# output1 = flash_single_query_cached_kv_attention( +# None, +# query.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# key_cache, +# value_cache, +# scale, +# block_tables, +# torch.tensor([23, 29], dtype=torch.int, device='cuda'), +# alibi_slopes=None, +# ) +# output2 = flash_attn_func( +# # query.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# # key.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# # value.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# query.unsqueeze(0), +# key.unsqueeze(0), +# value.unsqueeze(0), +# # key_cache, +# # value_cache, +# softmax_scale=scale, +# # block_tables, +# # torch.tensor([23, 29], dtype=torch.int, device='cuda'), +# # alibi_slopes=None, +# causal=True, +# ) +# output3 = flash_attn_func( +# query.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# key.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# value.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# # key_cache, +# # value_cache, +# softmax_scale=scale, +# # block_tables, +# # torch.tensor([23, 29], dtype=torch.int, device='cuda'), +# # alibi_slopes=None, +# causal=True, +# ) - breakpoint() diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 1a5fd49cf025..74c574e2903f 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -231,34 +231,32 @@ def forward( (is_hip()) else None, ) output = out.view_as(query) - if key_cache is not None and value_cache is not None: - output2 = flash_single_query_cached_kv_attention( - None, - query.view(batch_size, seq_len, self.num_heads, self.head_size), - key_cache, - value_cache, - self.scale, - input_metadata.block_tables, - input_metadata.context_lens, - alibi_slopes=self.alibi_slopes, - ) - output3 = flash_attn_func( - query.view(batch_size, seq_len, self.num_heads, self.head_size), - key.view(batch_size, seq_len, self.num_heads, self.head_size), - value.view(batch_size, seq_len, self.num_heads, self.head_size), - softmax_scale=self.scale, - causal=True, - ).view_as(query) - output4 = flash_attn_func( - query, - key, - value, - softmax_scale=self.scale, - causal=True, - ).view_as(query) - breakpoint() + # if key_cache is not None and value_cache is not None: + # output2 = flash_single_query_cached_kv_attention( + # None, + # query.view(batch_size, seq_len, self.num_heads, self.head_size), + # key_cache, + # value_cache, + # self.scale, + # input_metadata.block_tables, + # input_metadata.context_lens, + # alibi_slopes=self.alibi_slopes, + # ) + # output3 = flash_attn_func( + # query.view(batch_size, seq_len, self.num_heads, self.head_size), + # key.view(batch_size, seq_len, self.num_heads, self.head_size), + # value.view(batch_size, seq_len, self.num_heads, self.head_size), + # softmax_scale=self.scale, + # causal=True, + # ).view_as(query) + # output4 = flash_attn_func( + # query, + # key, + # value, + # softmax_scale=self.scale, + # causal=True, + # ).view_as(query) # if key_cache is not None and value_cache is not None: - # breakpoint() # output2 = flash_attn_with_kvcache_paged( # query.view(batch_size, seq_len, self.num_heads, # self.head_size), @@ -269,7 +267,6 @@ def forward( # input_metadata.context_lens, # self.alibi_slopes, # ) - # breakpoint() else: if input_metadata.flash_style: print("SANG-TODO flash prefill") @@ -317,7 +314,6 @@ def forward( else: # Decoding run. - # breakpoint() if input_metadata.flash_style: # output = flash_attn_with_kvcache_paged( # query.view(batch_size, seq_len, self.num_heads, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5cddeed35538..84fc0a0e5528 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -224,7 +224,6 @@ def _prepare_prompt( slot_mapping[-1].append(slot) max_prompt_len = max(subquery_lens) - breakpoint() input_tokens = _make_tensor_with_pad(input_tokens, max_prompt_len, pad=0, From 71bda972692059492c56223e5d713ea8fc9f0759 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 4 Mar 2024 05:18:34 -0800 Subject: [PATCH 18/28] . --- tests/chunked_prefill/test_correctness.py | 33 +- tests/kernels/test_flash_attention.py | 503 +--------------------- vllm/core/block_manager.py | 2 - vllm/engine/llm_engine.py | 4 +- vllm/entrypoints/llm.py | 2 +- vllm/model_executor/layers/attention.py | 47 +- 6 files changed, 45 insertions(+), 546 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index efd1d5610f2c..6b38bb97b41d 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -29,7 +29,6 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("block_size", [32]) -@pytest.mark.parametrize("max_num_prompt_seqs", [1]) @pytest.mark.parametrize("tensor_parallel_size", [1]) def test_models( vllm_runner, @@ -37,7 +36,6 @@ def test_models( dtype: str, max_tokens: int, block_size: int, - max_num_prompt_seqs: int, tensor_parallel_size: int, ) -> None: """ verify the flash attention has the same output @@ -46,27 +44,26 @@ def test_models( pytest.skip( f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}" ) - # print("loading page attention models..") - # pg_model = vllm_runner(model, dtype=dtype) - # expected_outputs = [] + print("loading page attention models..") + pg_model = vllm_runner(model, dtype=dtype) + expected_outputs = [] - # print("generating tokens...") - # expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) - # print("generating tokens finished") + print("generating tokens...") + expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) + print("generating tokens finished") - # del pg_model + del pg_model - # destroy_model_parallel() - # gc.collect() - # torch.cuda.empty_cache() + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() flash_attn_output_by_batches = [] - flash_attn_model = vllm_runner( - model, - dtype=dtype, - block_size=block_size, - flash_style=True, - tensor_parallel_size=tensor_parallel_size) + flash_attn_model = vllm_runner(model, + dtype=dtype, + block_size=block_size, + flash_style=True, + tensor_parallel_size=tensor_parallel_size) for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] flash_attn_output_by_batches.append( diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index eb999504690a..b5b43be50a64 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -1,494 +1,3 @@ -# import random -# from typing import Optional, Tuple, List - -# import pytest -# import torch -# import torch.nn.functional as F - -# from vllm.model_executor.layers.attention import ( -# flash_attn_with_kvcache_paged, ) -# from vllm.utils import get_max_shared_memory_bytes - -# FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 -# # This will change depending on the compute capability. -# # - 512 as a buffer -# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -# NUM_BLOCKS = 128 # Arbitrary values for testing -# PARTITION_SIZE = 512 - -# DTYPES = [torch.half, torch.bfloat16] -# NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing -# NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing -# NUM_HEADS = [(1, 40), (40, 40), (64, 8)] # Arbitrary values for testing -# NUM_HEADS_SMALL = NUM_HEADS -# # head size should be bigger than or equal to block size. -# HEAD_SIZES = [256] -# # TODO(sang): https://github.com/Dao-AILab/flash-attention/pull/824 -# # should fix the block size. But right now, the block size should be -# # divisible by 256. -# BLOCK_SIZES = [256] -# USE_ALIBI = [False, True] -# SEEDS = [0] -# PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] - - -# def pad_attention_inputs( -# pad_config: Tuple[int, int], -# block_size: int, -# query: torch.Tensor, -# block_tables: torch.Tensor, -# context_lens: torch.Tensor, -# max_context_len: int, -# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: -# """Pad the attention inputs to the specified batch size and context length. -# """ -# pad_batch_size, pad_max_context_len = pad_config -# if pad_batch_size == 0: -# return query, block_tables, context_lens, max_context_len -# target_batch_size = ( -# (query.shape[0] - 1) % pad_batch_size + 1) * pad_batch_size -# target_block_size = pad_max_context_len // block_size + 1 -# padded_query = F.pad(query, -# (0, 0, 0, 0, 0, target_batch_size - query.shape[0])) -# padded_block_table = F.pad(block_tables, -# (0, target_block_size - block_tables.shape[1], -# 0, target_batch_size - block_tables.shape[0])) -# padded_context_lens = F.pad(context_lens, -# (0, target_batch_size - context_lens.shape[0])) -# return padded_query, padded_block_table, padded_context_lens, pad_max_context_len - - -# def ref_masked_attention( -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# scale: float, -# attn_mask: Optional[torch.Tensor] = None, -# ) -> torch.Tensor: -# attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() -# if attn_mask is not None: -# attn_weights = attn_weights + attn_mask.float() -# attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) -# out = torch.einsum("hqk,khd->qhd", attn_weights, value) -# return out - - -# def ref_single_query_cached_kv_attention( -# output: torch.Tensor, -# query: torch.Tensor, -# num_queries_per_kv: int, -# key_cache: torch.Tensor, -# value_cache: torch.Tensor, -# block_tables: torch.Tensor, -# context_lens: torch.Tensor, -# scale: float, -# alibi_slopes: Optional[torch.Tensor], -# ) -> None: -# num_query_heads = query.shape[1] -# head_size = value_cache.shape[-1] -# block_size = value_cache.shape[-3] -# num_seqs = query.shape[0] - -# block_tables = block_tables.cpu().tolist() -# context_lens = context_lens.cpu().tolist() -# for i in range(num_seqs): -# q = query[i].unsqueeze(0) -# block_table = block_tables[i] -# context_len = int(context_lens[i]) - -# keys = [] -# values = [] -# for j in range(context_len): -# block_number = int(block_table[j // block_size]) -# block_offset = j % block_size - -# k = key_cache[block_number, block_offset, :, :] -# keys.append(k) - -# v = value_cache[block_number, :, :, block_offset] -# v = value_cache[block_number, block_offset, :, :] -# values.append(v) -# keys = torch.stack(keys, dim=0) -# values = torch.stack(values, dim=0) -# if num_queries_per_kv > 1: -# # Handle MQA and GQA -# keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) -# values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - -# alibi_bias = None -# if alibi_slopes is not None: -# # Create the ALiBi bias used in the paged attention kernel. -# position_ids = torch.arange(context_len, device="cuda").int() -# alibi_bias = (position_ids - context_len + 1).float() -# alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( -# 1, 1, -1) - -# out = ref_masked_attention(q, keys, values, scale, alibi_bias) -# out = out.view(num_query_heads, head_size) -# # output[i].copy_(out, non_blocking=True) -# output[i].copy_(out) - - -# # @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -# # @pytest.mark.parametrize("num_heads", NUM_HEADS) -# # @pytest.mark.parametrize("head_size", HEAD_SIZES) -# # @pytest.mark.parametrize("use_alibi", [False, True]) -# # @pytest.mark.parametrize("block_size", BLOCK_SIZES) -# # @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) -# # @pytest.mark.parametrize("seed", SEEDS) -# # @pytest.mark.parametrize("pad_config", PAD_CONFIGS) -# @pytest.mark.parametrize("num_seqs", [3]) -# @pytest.mark.parametrize("num_heads", [(40, 40)]) -# @pytest.mark.parametrize("head_size", [256]) -# @pytest.mark.parametrize("use_alibi", [True]) -# @pytest.mark.parametrize("block_size", [256]) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) -# @pytest.mark.parametrize("seed", SEEDS) -# @pytest.mark.parametrize("pad_config", [(0, 0)]) -# @torch.inference_mode() -# def test_flash_paged_attention( -# kv_cache_factory, -# num_seqs: int, -# num_heads: Tuple[int, int], -# head_size: int, -# use_alibi: bool, -# block_size: int, -# dtype: torch.dtype, -# seed: int, -# pad_config: Tuple[int, int], -# ) -> None: -# random.seed(seed) -# torch.random.manual_seed(seed) -# torch.cuda.manual_seed(seed) - -# scale = float(1.0 / (head_size**0.5)) -# num_query_heads, num_kv_heads = num_heads -# query = torch.empty(num_seqs, -# num_query_heads, -# head_size, -# dtype=dtype, -# device="cuda") -# query.uniform_(-scale, scale) - -# # assert num_query_heads % num_kv_heads == 0 -# num_queries_per_kv = num_query_heads // num_kv_heads -# alibi_slopes = None -# if use_alibi: -# alibi_slopes = torch.randn(num_query_heads, -# dtype=torch.float, -# device="cuda") - -# max_seq_len = MAX_SEQ_LEN if not pad_config[0] else (pad_config[1] - 1000) -# context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] -# context_lens[-1] = max_seq_len -# max_context_len = max(context_lens) -# context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") - -# # Create the block tables. -# max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size -# block_tables = [] -# for _ in range(num_seqs): -# block_table = [ -# random.randint(0, NUM_BLOCKS - 1) -# for _ in range(max_num_blocks_per_seq) -# ] -# block_tables.append(block_table) -# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - -# # Create the KV caches. -# key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, -# block_size, -# 1, -# num_kv_heads, -# head_size, -# dtype, -# seed, -# flash_style=True) -# key_cache, value_cache = key_caches[0], value_caches[0] - -# # Call the paged attention kernel. -# output = torch.empty_like(query) - -# padded_query, padded_block_table, padded_context_lens, _ = \ -# pad_attention_inputs(pad_config, block_size, query, -# block_tables, context_lens, max_context_len) - -# output = flash_attn_with_kvcache_paged( -# padded_query.view(num_seqs, 1, num_query_heads, head_size), -# key_cache, -# value_cache, -# scale, -# padded_block_table, -# padded_context_lens, -# alibi_slopes, -# ) - -# # Run the reference implementation. -# ref_output = torch.empty_like(query) -# ref_single_query_cached_kv_attention( -# ref_output, -# query, -# num_queries_per_kv, -# key_cache, -# value_cache, -# block_tables, -# context_lens, -# scale, -# alibi_slopes, -# ) - -# assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) - - -# def ref_multi_query_kv_attention( -# cu_seq_lens: List[int], -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# scale: float, -# dtype: torch.dtype, -# ) -> torch.Tensor: -# num_seqs = len(cu_seq_lens) - 1 -# ref_outputs = [] -# for i in range(num_seqs): -# start_idx = cu_seq_lens[i] -# end_idx = cu_seq_lens[i + 1] -# seq_len = end_idx - start_idx - -# # Create attention mask. -# attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), -# diagonal=1) -# attn_mask = attn_mask * torch.finfo(dtype).min -# attn_mask = attn_mask.to(dtype=dtype, device="cuda") - -# ref_output = ref_masked_attention( -# query[start_idx:end_idx], -# key[start_idx:end_idx], -# value[start_idx:end_idx], -# scale, -# attn_mask=attn_mask, -# ) -# ref_outputs.append(ref_output) -# ref_output = torch.cat(ref_outputs, dim=0) -# return ref_output - - -# def ref_multi_query_kv_attention_padded( -# query: torch.Tensor, -# num_queries_per_kv: int, -# key_cache: torch.Tensor, -# value_cache: torch.Tensor, -# block_tables: torch.Tensor, -# cu_seq_lens: List[int], -# context_lens: List[int], -# scale: float, -# dtype: torch.dtype, -# ) -> torch.Tensor: -# num_seqs = len(cu_seq_lens) - 1 -# block_size = value_cache.shape[-3] -# ref_outputs = [] - -# for i in range(num_seqs): -# q_start_idx = cu_seq_lens[i] -# q_end_idx = cu_seq_lens[i + 1] -# seq_len = q_end_idx - q_start_idx - -# context_len = context_lens[i] - -# block_table = block_tables[i] -# keys = [] -# values = [] - -# for j in range(context_len): -# block_number = int(block_table[j // block_size]) -# block_offset = j % block_size - -# k = key_cache[block_number, block_offset, :, :] -# keys.append(k) - -# v = value_cache[block_number, block_offset, :, :] -# values.append(v) - -# keys = torch.stack(keys, dim=0) -# values = torch.stack(values, dim=0) - -# if num_queries_per_kv > 1: -# # Handle MQA and GQA -# keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) -# values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - -# q = query[q_start_idx:q_end_idx, :, :] -# k = keys[:context_len, :, :] -# v = values[:context_len, :, :] - -# assert seq_len <= context_len - -# # pad q if seq_len is less than context_len -# # this is for correct calculation of attention. -# if seq_len < context_len: -# indices = [i % seq_len for i in range(context_len - seq_len)] -# q_left_pad = q[indices, :, :] -# q = torch.cat([q_left_pad, q], dim=0) - -# # Create attention mask. -# attn_mask = torch.triu(torch.ones(context_len, -# context_len, -# dtype=dtype), -# diagonal=1) -# attn_mask = attn_mask * torch.finfo(dtype).min -# attn_mask = attn_mask.to(dtype=dtype, device="cuda") -# ref_output = ref_masked_attention( -# q, -# k, -# v, -# scale, -# attn_mask=attn_mask, -# ) -# ref_outputs.append(ref_output[-seq_len:, :, :]) -# breakpoint -# ref_output = torch.cat(ref_outputs, dim=0) -# return ref_output - - -# def is_a100(): -# return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 - - -# if not is_a100(): -# NUM_HEADS_SMALL = [(16, 16), (16, 8)] -# MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) - - -# @pytest.mark.parametrize("num_seqs", [17]) -# @pytest.mark.parametrize("num_heads", [(40, 40)]) -# @pytest.mark.parametrize("head_size", HEAD_SIZES) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) -# @pytest.mark.parametrize("version", ["flash"]) -# @pytest.mark.parametrize("seed", SEEDS) -# @pytest.mark.parametrize("block_size", [256]) -# @torch.inference_mode() -# def test_multi_query_kv_attention( -# num_seqs: int, -# num_heads: Tuple[int, int], -# head_size: int, -# dtype: torch.dtype, -# version: str, -# seed: int, -# block_size: int, -# ) -> None: -# random.seed(seed) -# torch.random.manual_seed(seed) -# torch.cuda.manual_seed(seed) - -# # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. -# # As the xformers library is already tested with its own tests, we can use -# # a smaller MAX_SEQ_LEN here. -# max_len = min(MAX_SEQ_LEN, 4096) - -# seq_lens = [random.randint(1, max_len // 2) for i in range(num_seqs)] -# max_seq_len = max(seq_lens) -# context_lens = seq_lens -# max_context_len = max(context_lens) - -# num_tokens = sum(seq_lens) -# cu_seq_lens = [0] -# for seq_len in seq_lens: -# cu_seq_lens.append(cu_seq_lens[-1] + seq_len) - -# cu_context_lens = [0] -# for context_len in context_lens: -# cu_context_lens.append(cu_context_lens[-1] + context_len) - -# print(f"cu_seq_lens={cu_seq_lens}, cu_context_lens={cu_context_lens}") - -# scale = float(1.0 / (head_size**0.5)) -# num_query_heads, num_kv_heads = num_heads -# num_queries_per_kv = num_query_heads // num_kv_heads - -# value_cache = torch.empty(NUM_BLOCKS, -# block_size, -# num_kv_heads, -# head_size, -# dtype=dtype, -# device="cuda") -# key_cache = torch.empty(NUM_BLOCKS, -# block_size, -# num_kv_heads, -# head_size, -# dtype=dtype, -# device="cuda") -# query = torch.empty(max_seq_len * num_seqs, -# num_query_heads, -# head_size, -# dtype=dtype, -# device="cuda") -# value_cache.uniform_(-scale, scale) -# key_cache.uniform_(-scale, scale) -# query.uniform_(-scale, scale) - -# # Create the block tables. -# max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size -# block_tables = [] -# for _ in range(num_seqs): -# block_table = [ -# random.randint(0, NUM_BLOCKS - 1) -# for _ in range(max_num_blocks_per_seq) -# ] -# block_tables.append(block_table) -# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - -# output = torch.empty_like(query) - -# if version == "flash": -# # flash_multi_query_cached_kv_attention_varlen( -# # output, -# # query, -# # key_cache, -# # value_cache, -# # scale, -# # block_tables, -# # torch.cuda.IntTensor(cu_seq_lens), -# # torch.cuda.IntTensor(cu_context_lens), -# # block_size, -# # max_seq_len, -# # max_context_len, -# # None, -# # ) -# from flash_attn import flash_attn_func -# breakpoint() -# # output = flash_attn_func( -# # query.unsqueeze(0), -# # k=key, -# # v=value, -# # softmax_scale=scale, -# # causal=True, -# # alibi_slopes=alibi_slopes, -# # ) -# output = flash_attn_with_kvcache_paged( -# query.view(num_seqs, max_seq_len, num_query_heads, head_size), -# key_cache, -# value_cache, -# scale, -# block_tables, -# torch.tensor(context_lens, dtype=torch.int, device="cuda"), -# None, -# ) -# else: -# assert False, f"{version=} is not supported" - -# ref_output = ref_multi_query_kv_attention_padded( -# query, -# num_queries_per_kv, -# key_cache, -# value_cache, -# block_tables, -# cu_seq_lens, -# context_lens, -# scale, -# dtype, -# ) -# assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) - import random from typing import List, Optional, Tuple @@ -660,9 +169,6 @@ def test_flash_paged_attention( assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), - num_queries_per_kv) alibi_slopes = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, @@ -698,7 +204,6 @@ def test_flash_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. - num_valid_tokens = torch.cuda.IntTensor([num_seqs]) output = torch.empty_like(query) padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ @@ -707,7 +212,7 @@ def test_flash_paged_attention( flash_single_query_cached_kv_attention( output, - padded_query, + padded_query.view(num_seqs, 1, num_query_heads, head_size), key_cache, value_cache, scale, @@ -731,9 +236,6 @@ def test_flash_paged_attention( flash_style=True, ) - # NOTE(woosuk): Due to the kernel-level differences in the two - # implementations, there is a small numerical difference in the two - # outputs. Thus, we use a relaxed tolerance for the test. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) @@ -1053,7 +555,7 @@ def test_multi_query_kv_attention( # # Call the reshape_and_cache kernel. # cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, # slot_mapping) - + # cloned_key_cache = key_cache.clone() # cloned_value_cache = value_cache.clone() @@ -1133,4 +635,3 @@ def test_multi_query_kv_attention( # # alibi_slopes=None, # causal=True, # ) - diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index a7c38fc4c0f2..daf83827a7e5 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -200,8 +200,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) - # print("allcate a new block for seq_group", seq_group) - # print("block", block) block_table.append(block) # Assign the block table for each sequence. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5283654d1a34..50bedbe89cd1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -836,8 +836,8 @@ def step(self) -> List[RequestOutput]: >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - print("SANG-TODO step seq_group_metadata_list length: ", - len(seq_group_metadata_list)) + # print("SANG-TODO step seq_group_metadata_list length: ", + # len(seq_group_metadata_list)) if not scheduler_outputs.is_empty(): # Execute the model. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e1932bb761ef..93adc98adda5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -146,7 +146,7 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - print("SANG-TODO generate: ", prompts, prompt_token_ids) + # print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 74c574e2903f..03941f9f6c9a 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -172,17 +172,17 @@ def forward( # heads. # TODO(woosuk): Use MQA/GQA kernels for higher performance. query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - query.shape[-1]) + self.num_queries_per_kv, + query.shape[-1]) key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + key.shape[-1]) value = value[:, :, - None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) + None, :].expand(value.shape[0], + self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. @@ -269,10 +269,10 @@ def forward( # ) else: if input_metadata.flash_style: - print("SANG-TODO flash prefill") output = flash_single_query_cached_kv_attention( None, - query.view(batch_size, seq_len, self.num_heads, self.head_size), + query.view(batch_size, seq_len, self.num_heads, + self.head_size), key_cache, value_cache, self.scale, @@ -294,7 +294,7 @@ def forward( # alibi_slopes=self.alibi_slopes, # ) else: - print("SANG-TODO context attention") + # print("SANG-TODO context attention") # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( @@ -304,7 +304,8 @@ def forward( output, key_cache, value_cache, - input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata. + block_tables, # [BS, max_block_per_request] input_metadata.start_loc, input_metadata.prompt_lens, input_metadata.context_lens, @@ -455,6 +456,7 @@ def _paged_attention( return output +# OSS version. # def flash_attn_with_kvcache_paged( # query: torch.Tensor, # key_cache: torch.Tensor, @@ -501,7 +503,7 @@ def _paged_attention( def flash_single_query_cached_kv_attention( - output: Optional[torch.Tensor], + output: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -511,9 +513,9 @@ def flash_single_query_cached_kv_attention( alibi_slopes: Optional[torch.Tensor], actual_batch_size: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Similar to vLLM's page attention, caclulates a single token's attention + """Similar to vLLM's page attention, calculates a single token's attention based on key/value caches. The main difference is this uses flash attention - sytle key-value caches. + style key-value caches. Arguments: output: [num_padded_tokens, num_heads, head_size], output tensor @@ -538,6 +540,8 @@ def flash_single_query_cached_kv_attention( # TODO: support alibi_slopes assert alibi_slopes is None, "doesn't support alibi_slopes" batch_size, seqlen, num_heads, head_size = query.shape + assert seqlen == 1, ( + "Single query attention can be only used for decoding phase.") num_tokens = batch_size * seqlen out = flash_attn_with_page_attention( query, @@ -557,11 +561,10 @@ def flash_single_query_cached_kv_attention( num_splits=0, actual_batch_size=actual_batch_size, ) - return out - # if output is not None: - # # in case that output is padded, only copy the valid part. - # output[:num_tokens].copy_(out.view(num_tokens, num_heads, head_size)) - # return out.view(num_tokens, num_heads, head_size) + if output is not None: + # in case that output is padded, only copy the valid part. + output[:num_tokens].copy_(out.view(num_tokens, num_heads, head_size)) + return out.view(num_tokens, num_heads, head_size) def flash_multi_query_cached_kv_attention_varlen( @@ -656,4 +659,4 @@ def flash_multi_query_cached_kv_attention_varlen( if output is not None: output[:num_tokens].copy_(out) - return out \ No newline at end of file + return out From 4e00e7f00f2278045ad8970a70d50c4ac3b4241a Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 4 Mar 2024 05:18:44 -0800 Subject: [PATCH 19/28] done? --- tests/kernels/test_flash_attention.py | 6 +----- vllm/model_executor/layers/attention.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index b5b43be50a64..66a668cb7dd5 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -388,7 +388,6 @@ def test_multi_query_kv_attention( seq_lens = [random.randint(1, max_len // 2) for i in range(num_seqs)] max_seq_len = max(seq_lens) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device="cuda") if chunked_prefill: # context length will be different from seq_len if chunked_prefill is @@ -397,9 +396,6 @@ def test_multi_query_kv_attention( else: context_lens = seq_lens max_context_len = max(context_lens) - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device="cuda") num_tokens = sum(seq_lens) cu_seq_lens = [0] @@ -466,7 +462,7 @@ def test_multi_query_kv_attention( None, ) else: - assert False, f"{version=} is not supported" + raise AssertionError(f"{version=} is not supported") ref_output = ref_multi_query_kv_attention_padded( query, diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 03941f9f6c9a..be4343fde453 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( context_attention_fwd) from vllm.utils import is_hip -from flash_attn import flash_attn_func # try: # from flash_attn import flash_attn_with_kvcache From c0384a42dc2ed266ff269f946b6482221a5b746a Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 6 Mar 2024 05:37:42 -0800 Subject: [PATCH 20/28] Refactor 2d query to 1d query --- tests/core/test_scheduler.py | 18 +-- tests/models/test_models.py | 32 ++-- tests/worker/test_model_runner.py | 146 ++++++++++++++++- vllm/config.py | 3 - vllm/core/scheduler.py | 13 +- vllm/engine/arg_utils.py | 8 +- vllm/model_executor/input_metadata.py | 19 ++- vllm/model_executor/layers/activation.py | 4 +- vllm/model_executor/layers/attention.py | 152 ++++++++++-------- vllm/model_executor/layers/sampler.py | 1 - vllm/worker/model_runner.py | 189 +++++++++++++++-------- 11 files changed, 403 insertions(+), 182 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index ebfeb8ba0481..397101fa8610 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,7 +10,7 @@ def test_scheduler_add_seq_group(): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1, 256) + scheduler_config = SchedulerConfig(100, 64, 1) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -26,7 +26,7 @@ def test_scheduler_add_seq_group(): def test_scheduler_abort_seq_group(): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1, 256) + scheduler_config = SchedulerConfig(100, 64, 1) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -50,7 +50,7 @@ def test_scheduler_schedule_simple(): block_size = 4 num_seq_group = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len, 256) + scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -64,10 +64,10 @@ def test_scheduler_schedule_simple(): running.append(seq_group) # Schedule seq groups prompts. + num_tokens = block_size * num_seq_group seq_group_meta, out = scheduler.schedule() assert set(out.scheduled_seq_groups) == set(running) - assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs( - )[0].get_len() + assert out.num_batched_tokens == num_tokens assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == num_seq_group @@ -84,7 +84,7 @@ def test_scheduler_schedule_simple(): def test_scheduler_schedule_preempt_abort(): block_size = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, 2, max_model_len, 256) + scheduler_config = SchedulerConfig(64, 2, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 @@ -99,7 +99,7 @@ def test_scheduler_schedule_preempt_abort(): # Schedule seq groups prompts. seq_group_meta, out = scheduler.schedule() assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] - assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2 + assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == 2 @@ -124,7 +124,7 @@ def test_scheduler_schedule_preempt_abort(): scheduler.abort_seq_group("1") seq_group_meta, out = scheduler.schedule() assert out.scheduled_seq_groups == [seq_group_b] - assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len() + assert out.num_batched_tokens == 5 # 4 prompt + 1 generation. assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == 1 @@ -136,7 +136,7 @@ def test_scheduler_max_seqs(): num_seq_group = 4 max_seq_group = 2 max_model_len = 16 - scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len, 256) + scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 diff --git a/tests/models/test_models.py b/tests/models/test_models.py index fb567e837d28..c268b6fd4868 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -6,26 +6,27 @@ MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", - "mistralai/Mistral-7B-v0.1", - "Deci/DeciLM-7b", - "tiiuae/falcon-7b", - "gpt2", - "bigcode/tiny_starcoder_py", - "EleutherAI/gpt-j-6b", - "EleutherAI/pythia-70m", - "bigscience/bloom-560m", - "mosaicml/mpt-7b", - "microsoft/phi-2", - "stabilityai/stablelm-3b-4e1t", - "allenai/OLMo-1B", - "bigcode/starcoder2-3b", + # "meta-llama/Llama-2-7b-hf", + # "mistralai/Mistral-7B-v0.1", + # "Deci/DeciLM-7b", + # "tiiuae/falcon-7b", + # "gpt2", + # "bigcode/tiny_starcoder_py", + # "EleutherAI/gpt-j-6b", + # "EleutherAI/pythia-70m", + # "bigscience/bloom-560m", + # "mosaicml/mpt-7b", + # "microsoft/phi-2", + # "stabilityai/stablelm-3b-4e1t", + # "allenai/OLMo-1B", + # "bigcode/starcoder2-3b", ] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("enforce_eager", [False, True]) def test_models( hf_runner, vllm_runner, @@ -33,12 +34,13 @@ def test_models( model: str, dtype: str, max_tokens: int, + enforce_eager: bool, ) -> None: hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype) + vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index f44895a728c7..7fed21bc2aa9 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -2,7 +2,14 @@ import torch from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import ModelRunner +from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT +from vllm.config import ModelConfig + + +# Make sure the result is aligned. +def round_up_to_next_multiple_of_batch_size(n): + batch_size = _BATCH_SIZE_ALIGNMENT + return ((n + 7) // batch_size) * batch_size def test_prepare_prompt(): @@ -28,21 +35,148 @@ def test_prepare_prompt(): expected_selected_token_indices = [] selected_token_start_idx = 0 - max_seq_len = max(prompt_lens) for prompt_len in prompt_lens: expected_selected_token_indices.append(selected_token_start_idx + prompt_len - 1) - selected_token_start_idx += max_seq_len - input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = ( + selected_token_start_idx += prompt_len + input_tokens, input_positions, input_metadata, return_prompt_lens, _, _, _, _ = ( model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens + + # Verify input metadata is correct for prompts. + device = model_runner.device + assert input_metadata.is_prompt is True + assert torch.allclose(input_metadata.prompt_lens, + torch.tensor(prompt_lens, device=device)) + assert input_metadata.num_prompt_tokens == sum(prompt_lens) + assert input_metadata.num_generation_tokens == 0 + assert input_metadata.max_seq_len == max(prompt_lens) + # build start_loc + start_idx = 0 + start_loc = [start_idx] + # start_loc is padded. + for prompt_len in prompt_lens: + start_idx += prompt_len + start_loc.append(start_idx) + assert torch.allclose( + input_metadata.start_loc, + torch.tensor(start_loc, dtype=torch.long, device=device)) + assert input_metadata.max_context_len == max(prompt_lens) + assert torch.allclose( + input_metadata.context_lens, + torch.tensor(prompt_lens, dtype=torch.int, device=device)) + # assert input_metadata.slot_mapping == max(prompt_lens) + # block_tables + # Cuda graph should not be used for prerill. + assert input_metadata.use_cuda_graph is False + assert input_metadata.kv_cache_dtype == "auto" + assert input_metadata.num_valid_tokens == round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)) + + assert input_tokens.shape == (round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)), ) + assert input_positions.shape == (round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)), ) + torch.testing.assert_close(input_tokens, input_positions) + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - assert input_tokens.shape == (batch_size, max_seq_len) - assert input_positions.shape == (batch_size, max_seq_len) + assert input_tokens.shape == (round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)), ) + assert input_positions.shape == (round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)), ) + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) + torch.testing.assert_close(input_tokens, input_positions) + + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) + + +def test_prepare_decode_cuda_graph(): + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=False, + ) + model_runner = ModelRunner(model_config, None, None, None, None) + model_runner.set_block_size(16) + + # Make sure the result is aligned. + def round_up_to_next_multiple_of_batch_size(n): + batch_size = _BATCH_SIZE_ALIGNMENT + return ((n + 7) // batch_size) * batch_size + + batch_size = random.randint(1, 256) + prompt_lens = [] + seq_group_metadata_list = [] + for i in range(batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_lens.append(prompt_len) + seq_data = list(range(prompt_len)) + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: SequenceData(seq_data)}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + )) + + input_tokens, input_positions, input_metadata, _, _, _ = ( + model_runner._prepare_decode(seq_group_metadata_list)) + + # Verify input metadata is correct for prompts. + device = model_runner.device + assert input_metadata.is_prompt is False + assert input_metadata.prompt_lens is None + assert input_metadata.num_prompt_tokens == 0 + assert input_metadata.num_generation_tokens == ( + round_up_to_next_multiple_of_batch_size(len(seq_group_metadata_list))) + assert input_metadata.max_seq_len is None + assert input_metadata.start_loc is None + assert input_metadata.max_context_len == max(prompt_lens) + assert torch.allclose( + input_metadata.context_lens[:len(prompt_lens)], + torch.tensor(prompt_lens, dtype=torch.int, device=device)) + # assert input_metadata.slot_mapping == max(prompt_lens) + # block_tables + # Cuda graph should not be used for prerill. + assert input_metadata.use_cuda_graph is True + assert input_metadata.kv_cache_dtype == "auto" + assert input_metadata.num_valid_tokens == ( + round_up_to_next_multiple_of_batch_size(len(seq_group_metadata_list))) + + assert input_tokens.shape == (round_up_to_next_multiple_of_batch_size( + len(seq_group_metadata_list)), ) + assert input_positions.shape == (round_up_to_next_multiple_of_batch_size( + len(seq_group_metadata_list)), ) torch.testing.assert_close(input_tokens, input_positions) + # Verify Sampling + expected_selected_token_indices = [] + selected_token_start_idx = 0 + for prompt_len in prompt_lens: + expected_selected_token_indices.append(selected_token_start_idx) + selected_token_start_idx += 1 + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, diff --git a/vllm/config.py b/vllm/config.py index ef9a920f29c2..9867ca9a2c1e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -456,7 +456,6 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). - max_paddings: Maximum number of paddings to be added to a batch. """ def __init__( @@ -464,7 +463,6 @@ def __init__( max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, - max_paddings: int, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -474,7 +472,6 @@ def __init__( self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len - self.max_paddings = max_paddings self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c96c6d62ef19..3ec5216a0f18 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -173,12 +173,12 @@ def _schedule(self) -> SchedulerOutputs: curr_loras = set( seq_group.lora_int_id for seq_group in self.running) if self.lora_enabled else None - seq_lens: List[int] = [] # Optimization: We do not sort the waiting queue since the preempted # sequence groups are added to the front and the new sequence groups # are added to the back. leftover_waiting_sequences = deque() + num_batched_tokens = 0 while self.waiting: seq_group = self.waiting[0] waiting_seqs = seq_group.get_seqs( @@ -223,8 +223,7 @@ def _schedule(self) -> SchedulerOutputs: continue # If the number of batched tokens exceeds the limit, stop. - new_seq_lens = seq_lens + [num_prompt_tokens] - num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) + num_batched_tokens += num_prompt_tokens if (num_batched_tokens > self.scheduler_config.max_num_batched_tokens): break @@ -236,11 +235,6 @@ def _schedule(self) -> SchedulerOutputs: self.scheduler_config.max_num_seqs): break - num_paddings = num_batched_tokens - sum(new_seq_lens) - if num_paddings > self.scheduler_config.max_paddings: - break - seq_lens = new_seq_lens - if lora_int_id > 0: curr_loras.add(lora_int_id) self.waiting.popleft() @@ -255,8 +249,7 @@ def _schedule(self) -> SchedulerOutputs: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, prompt_run=True, - num_batched_tokens=len(seq_lens) * - max(seq_lens) if seq_lens else 0, + num_batched_tokens=num_batched_tokens, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c3dccdd5bb50..fb4462c0adff 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -30,7 +30,6 @@ class EngineArgs: gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 - max_paddings: int = 256 max_logprobs: int = 5 # OpenAI default value disable_log_stats: bool = False revision: Optional[str] = None @@ -209,10 +208,6 @@ def add_cli_args( type=int, default=EngineArgs.max_num_seqs, help='maximum number of sequences per iteration') - parser.add_argument('--max-paddings', - type=int, - default=EngineArgs.max_paddings, - help='maximum number of paddings in a batch') parser.add_argument( '--max-logprobs', type=int, @@ -322,8 +317,7 @@ def create_engine_configs( self.ray_workers_use_nsight) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, - model_config.max_model_len, - self.max_paddings) + model_config.max_model_len) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index f0a88ac8e27f..e7d45a6cbf3f 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -8,9 +8,16 @@ class InputMetadata: Args: prompt_lens: Lengths of prompts. - slot_mapping: The address to write the new KV to of each token. + slot_mapping: The index of each token mapped into a physical block + in block tables. E.g., if block_size is 32, 35 means it is in + the block number 1, 3rd index. + num_prompt_tokens: The number of tokens in the prompts. This might + include padding. + num_generation_tokens: The number of tokens in the generation sequences. + This might include padding. max_context_len: The maximum context length. context_lens: the length of attention context for each sequence. + I.e., the number of tokens that have attended so far. block_tables: The block tables. (Seq id -> list of physical block) kv_cache_dtype: Data type to store kv cache. """ @@ -20,6 +27,8 @@ def __init__( is_prompt: bool, slot_mapping: torch.Tensor, prompt_lens: Optional[torch.Tensor], + num_prompt_tokens: int, + num_generation_tokens: int, max_seq_len: Optional[int], start_loc: Optional[torch.Tensor], max_context_len: Optional[int], @@ -30,6 +39,8 @@ def __init__( ) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens + self.num_prompt_tokens = num_prompt_tokens + self.num_generation_tokens = num_generation_tokens self.max_seq_len = max_seq_len self.start_loc = start_loc self.max_context_len = max_context_len @@ -42,13 +53,17 @@ def __init__( # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. self.attn_bias = None + self.num_valid_tokens = slot_mapping.shape[0] def __repr__(self) -> str: return ("InputMetadata(" f"is_prompt={self.is_prompt}, " f"max_context_len={self.max_context_len}, " + f"num_generation_tokens={self.num_generation_tokens}, " + f"num_prompt_tokens={self.num_prompt_tokens}, " f"slot_mapping={self.slot_mapping}, " f"context_lens={self.context_lens}, " f"block_tables={self.block_tables}, " f"use_cuda_graph={self.use_cuda_graph}, " - f"kv_cache_dtype={self.kv_cache_dtype})") + f"kv_cache_dtype={self.kv_cache_dtype}) " + f"num_valid_tokens={self.num_valid_tokens}") diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 5a3a7b2dbaee..6d829395883a 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -20,8 +20,8 @@ class SiluAndMul(nn.Module): The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. Shapes: - x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) - return: (batch_size, seq_len, d) or (num_tokens, d) + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) """ def _forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 2a82325b8021..bfcdecc8876b 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -23,14 +23,28 @@ class PagedAttention(nn.Module): """MHA/MQA/GQA layer with PagedAttention. - This class takes query, key, and value tensors as input. The input tensors + This class takes flattened 1D query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + + If the input tensors contain prompt tokens, the layout is as follows: + |<---------------------- num_valid_tokens ---------------------->| + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->| + Otherwise, the layout is as follows: + |<------------------ num_valid_tokens ------------------->| + |<------- num_generation_tokens (M) ------->| + |<--generation_0-->|...|<--generation_M-1-->|<--padding-->| + The prompts might have different lengths, while the generation tokens always + have length 1. The paddings are appended to make the input length a multiple + of 8, which is desirable for Tensor Cores. + The class does the following: 1. Reshape and store the input key and value tensors in the KV cache. 2. Perform (multi-head/multi-query/grouped-query) attention using either xformers or the PagedAttention custom op. - 3. Return the output tensor. + 3. Output a flattened 1D tensor. """ def __init__( @@ -105,38 +119,50 @@ def forward( """PagedAttention forward pass. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] + query: shape = [num_tokens num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] input_metadata: metadata for the inputs. Returns: - shape = [batch_size, seq_len, num_heads * head_size] + shape = [num_tokens, num_heads * head_size] """ - batch_size, seq_len, hidden_size = query.shape + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) + output = torch.empty_like(query) # Reshape the keys and values and store them in the cache. # If key_cache and value_cache are not provided, the new key and value # vectors will not be cached. This happens during the initial memory # profiling run. - if key_cache is not None and value_cache is not None: + num_valid_tokens = input_metadata.num_valid_tokens + if (num_valid_tokens > 0 and key_cache is not None + and value_cache is not None): + key_to_cache = key[:num_valid_tokens] + value_to_cache = value[:num_valid_tokens] cache_ops.reshape_and_cache( - key, - value, + key_to_cache, + value_to_cache, key_cache, value_cache, input_metadata.slot_mapping.flatten(), input_metadata.kv_cache_dtype, ) - if input_metadata.is_prompt: + num_prompt_tokens = input_metadata.num_prompt_tokens + num_generation_tokens = input_metadata.num_generation_tokens + + if num_prompt_tokens > 0: + assert num_generation_tokens == 0 + query = query[:num_prompt_tokens] + key = key[:num_prompt_tokens] + value = value[:num_prompt_tokens] # normal attention if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): @@ -164,25 +190,25 @@ def forward( if input_metadata.attn_bias is None: if self.alibi_slopes is None: attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size) + input_metadata.prompt_lens.tolist()) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) input_metadata.attn_bias = attn_bias else: input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, batch_size, - seq_len, query.dtype) + self.alibi_slopes, self.num_kv_heads, query.dtype, + input_metadata) if self.use_ref_attention: - output = self.ref_masked_attention( + output[:num_prompt_tokens] = self.ref_masked_attention( query, key, value, ) # Using view got RuntimeError: view size is not compatible with input tensor's size and stride # (at least one dimension spans across two contiguous subspaces). Use reshape instead - return output.reshape(batch_size, seq_len, hidden_size) + return output.reshape(num_tokens, hidden_size) # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. @@ -191,21 +217,22 @@ def forward( key = key.unsqueeze(0) value = value.unsqueeze(0) else: - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) + query = query.unflatten(0, (num_tokens)) + key = key.unflatten(0, (num_tokens)) + value = value.unflatten(0, (num_tokens)) - out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ) - output = out.view_as(query) + output[: + num_prompt_tokens] = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha. + MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ).view_as(query) else: # prefix-enabled attention output = torch.empty_like(query) @@ -224,10 +251,12 @@ def forward( getattr(self, "alibi_slopes", None), ) - else: + if num_generation_tokens > 0: + assert num_prompt_tokens == 0 # Decoding run. output = _paged_attention( - query, + output[num_prompt_tokens:num_valid_tokens], + query[num_prompt_tokens:num_valid_tokens], key_cache, value_cache, input_metadata, @@ -237,45 +266,44 @@ def forward( ) # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + return output.view(-1, self.num_heads * self.head_size) def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, - batch_size: int, - seq_len: int, dtype: torch.dtype, + input_metadata: InputMetadata, ) -> LowerTriangularMaskWithTensorBias: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - # When using custom attention bias, xformers requires the bias to - # be sliced from a tensor whose length is a multiple of 8. - padded_len = (seq_len + 7) // 8 * 8 - num_heads = alibi_slopes.shape[0] - bias = torch.empty( - batch_size, - num_heads, - seq_len, - padded_len, - device=alibi_slopes.device, - dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) - if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) - attn_bias = LowerTriangularMaskWithTensorBias(bias) - return attn_bias + for prompt_len in input_metadata.prompt_lens: + bias = torch.arange(prompt_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + padded_len = (prompt_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + prompt_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :prompt_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + attn_bias = LowerTriangularMaskWithTensorBias(bias) + return attn_bias def _paged_attention( - query: torch.Tensor, + output: torch.Tensor, # [num_tokens, num_heads, head_size] + query: torch.Tensor, # [num_tokens, num_heads, head_size] key_cache: torch.Tensor, value_cache: torch.Tensor, input_metadata: InputMetadata, @@ -283,8 +311,6 @@ def _paged_attention( scale: float, alibi_slopes: Optional[torch.Tensor], ) -> torch.Tensor: - output = torch.empty_like(query) - block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = ( diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 320cb443524c..bb5aa281568f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -125,7 +125,6 @@ def _prune_hidden_states( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) return hidden_states.index_select(0, sampling_metadata.selected_token_indices) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index aff8ebc90362..6c207a1abbe1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -28,9 +28,14 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] _PAD_SLOT_ID = -1 LORA_WARMUP_RANK = 8 -# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. +_BATCH_SIZE_ALIGNMENT = 8 +# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] +# Note that cuda graph is only used for decoding because it speeds up +# the performance when num_tokens < 200. Batch here means a single token. +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) +] class ModelRunner: @@ -124,9 +129,9 @@ def _prepare_prompt( ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], List[int], List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() @@ -155,16 +160,18 @@ def _prepare_prompt( computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) + context_len = computed_len else: prefix_block_tables.append([]) + context_len = prompt_len # actual prompt lens - context_lens.append(computed_len) + context_lens.append(context_len) subquery_lens.append(prompt_len - computed_len) - input_tokens.append(prompt_tokens) + input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.append( + input_positions.extend( list(range(computed_len, computed_len + len(prompt_tokens)))) lora_id = seq_group_metadata.lora_int_id @@ -181,11 +188,10 @@ def _prepare_prompt( if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. - slot_mapping.append([_PAD_SLOT_ID] * prompt_len) + slot_mapping.extend([_PAD_SLOT_ID] * prompt_len) continue # Compute the slot mapping. - slot_mapping.append([]) block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, prompt_len - sliding_window). @@ -200,30 +206,26 @@ def _prepare_prompt( start_idx = max(0, prompt_len - self.sliding_window) for i in range(computed_len, prompt_len): if i < start_idx: - slot_mapping[-1].append(_PAD_SLOT_ID) + slot_mapping.append(_PAD_SLOT_ID) continue block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping[-1].append(slot) + slot_mapping.append(slot) max_prompt_len = max(subquery_lens) - input_tokens = _make_tensor_with_pad(input_tokens, - max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - input_positions = _make_tensor_with_pad(input_positions, - max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) + num_prompt_tokens = len(input_tokens) + input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, + pad=0, + dtype=torch.long, + device=self.device) + input_positions = _make_tensor_with_pad_for_alignment( + input_positions, pad=0, dtype=torch.long, device=self.device) + slot_mapping = _make_tensor_with_pad_for_alignment(slot_mapping, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) lora_index_mapping = [ _pad_to_max(mapping, max_prompt_len, pad=0) for mapping in lora_index_mapping @@ -240,22 +242,30 @@ def _prepare_prompt( dtype=torch.int, device=self.device, ) - start_loc_tensor = torch.arange(0, - len(prompt_lens) * max_prompt_len, - max_prompt_len, - dtype=torch.long, - device=self.device) prompt_lens_tensor = torch.tensor(prompt_lens, dtype=torch.long, device=self.device) + # Cumulative index of each prompt. [prompt_lens + 1] + # [0, 0+1th, 0+1th+2nd, ...] + start_loc_tensor = torch.zeros(prompt_lens_tensor.shape[0] + 1, + dtype=torch.long, + device=self.device) + + torch.cumsum(prompt_lens_tensor, + dim=0, + dtype=start_loc_tensor.dtype, + out=start_loc_tensor[1:]) + input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping, prompt_lens=prompt_lens_tensor, + num_prompt_tokens=num_prompt_tokens, + num_generation_tokens=0, max_seq_len=max_prompt_len, start_loc=start_loc_tensor, - max_context_len=None, + max_context_len=max(context_lens), context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -271,9 +281,9 @@ def _prepare_decode( ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] context_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] @@ -292,11 +302,11 @@ def _prepare_decode( for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) + input_tokens.append(generation_token) seq_len = seq_data.get_len() position = seq_len - 1 - input_positions.append([position]) + input_positions.append(position) context_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) @@ -306,7 +316,7 @@ def _prepare_decode( block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping.append([slot]) + slot_mapping.append(slot) lora_index_mapping.append([lora_id]) lora_prompt_mapping.append(lora_id) @@ -316,6 +326,9 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + # vLLM uses cuda graph only for decoding requests. + # See `capture_model` API for more details. + # For decoding requests, batch_size == input_tokens. batch_size = len(input_tokens) max_context_len = max(context_lens) use_captured_graph = ( @@ -328,32 +341,32 @@ def _prepare_decode( graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size for _ in range(graph_batch_size - batch_size): - input_tokens.append([]) - input_positions.append([]) - slot_mapping.append([]) + input_tokens.append(0) + input_positions.append(0) + slot_mapping.append(_PAD_SLOT_ID) context_lens.append(1) block_tables.append([]) batch_size = graph_batch_size - input_tokens = _make_tensor_with_pad(input_tokens, - max_len=1, - pad=0, - dtype=torch.long, - device=self.device) - input_positions = _make_tensor_with_pad(input_positions, - max_len=1, - pad=0, - dtype=torch.long, - device=self.device) - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_len=1, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) + input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, + pad=0, + dtype=torch.long, + device=self.device) + input_positions = _make_tensor_with_pad_for_alignment( + input_positions, pad=0, dtype=torch.long, device=self.device) + slot_mapping = _make_tensor_with_pad_for_alignment(slot_mapping, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) + # The first dimension corresponds to number of tokens. + assert context_lens.shape[0] == input_tokens.shape[0] + assert context_lens.shape[0] == input_positions.shape[0] + assert context_lens.shape[0] == slot_mapping.shape[0] + if use_captured_graph: # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -381,6 +394,8 @@ def _prepare_decode( is_prompt=False, slot_mapping=slot_mapping, prompt_lens=None, + num_prompt_tokens=0, + num_generation_tokens=len(input_tokens), max_seq_len=None, start_loc=None, max_context_len=max_context_len, @@ -406,7 +421,6 @@ def _prepare_sample( categorized_sample_indices_start_idx = 0 pin_memory = not self.in_wsl and not self.device_config.is_neuron - max_subquery_len = max(subquery_lens) if subquery_lens else 1 for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params @@ -431,7 +445,7 @@ def _prepare_sample( selected_token_start_idx + subquery_len - 1)) selected_token_indices.append(selected_token_start_idx + subquery_len - 1) - selected_token_start_idx += max_subquery_len + selected_token_start_idx += subquery_len if sampling_params.seed is not None: seq_group_metadata.state.generator = torch.Generator( @@ -520,6 +534,8 @@ def prepare_input_tensors( "is_prompt": input_metadata.is_prompt, "slot_mapping": input_metadata.slot_mapping, "prompt_lens": input_metadata.prompt_lens, + "num_prompt_tokens": input_metadata.num_prompt_tokens, + "num_generation_tokens": input_metadata.num_generation_tokens, "max_seq_len": input_metadata.max_seq_len, "start_loc": input_metadata.start_loc, "max_context_len": input_metadata.max_context_len, @@ -543,6 +559,8 @@ def prepare_input_tensors( is_prompt=metadata_dict["is_prompt"], slot_mapping=metadata_dict["slot_mapping"], prompt_lens=metadata_dict["prompt_lens"], + num_prompt_tokens=metadata_dict["num_prompt_tokens"], + num_generation_tokens=metadata_dict["num_generation_tokens"], max_seq_len=metadata_dict["max_seq_len"], start_loc=metadata_dict["start_loc"], max_context_len=metadata_dict["max_context_len"], @@ -579,6 +597,9 @@ def execute_model( # Execute the model. if input_metadata.use_cuda_graph: + # NOTE: We use cuda graph only when there are only + # decoding requests, which means the number of batch + # size is equivalent to number of input tokens. graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: @@ -680,6 +701,18 @@ def list_loras(self) -> Set[int]: @torch.inference_mode() def capture_model(self, kv_caches: List[KVCache]) -> None: + """Cuda graph capture a model. + + Note that cuda graph performance gain is negligible if number + of batched tokens are less than 200. And since Cuda graph + requires fixed sized tensors, supporting large/variable batch + size requires high GPU memory overhead. Thus, vLLM only captures + decoding requests. Mixed batch (chunked prefill + decoding) or + prefill requests are not captured. + + this API assumes the shape is N batches of tokens flattened to 1D + tensor, where is token's seqlen is 1. + """ # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never # deleted before the CUDA graphs. self.cupy_nccl_backend = cupy_utils.get_nccl_backend() @@ -698,10 +731,9 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() - input_positions = torch.zeros(max_batch_size, 1, - dtype=torch.long).cuda() - slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() + input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() + input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() + slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() @@ -727,6 +759,8 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: is_prompt=False, slot_mapping=slot_mapping[:batch_size], prompt_lens=None, + num_prompt_tokens=0, + num_generation_tokens=0, max_seq_len=None, start_loc=None, max_context_len=self.max_context_len_to_capture, @@ -862,11 +896,32 @@ def _maybe_cupy_nccl(): yield +def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]: + return x + [pad] * ((-len(x)) % multiple_of) + + def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: assert len(x) <= max_len return x + [pad] * (max_len - len(x)) +def _make_tensor_with_pad_for_alignment( + x: List[int], + pad: int, + dtype: torch.dtype, + device: Optional[Union[str, torch.device]], +) -> torch.Tensor: + """Create a tensor of a given list x with padding. + It adds paddings to align with graph batch size. See + _get_graph_batch_size for more details. + # NOTE: This API is only for decoding requests. + """ + batch_size = len(x) + batch_size = _get_graph_batch_size(batch_size) + padded_x = _pad_to_alignment(x, batch_size, pad) + return torch.tensor(padded_x, dtype=dtype, device=device) + + def _make_tensor_with_pad( x: List[List[int]], max_len: int, @@ -879,12 +934,18 @@ def _make_tensor_with_pad( def _get_graph_batch_size(batch_size: int) -> int: + """Returns the padded batch size given actual batch size. + + Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, + 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... + """ if batch_size <= 2: return batch_size elif batch_size <= 4: return 4 else: - return (batch_size + 7) // 8 * 8 + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) def _async_h2d( From 6032edf76fd7c31c3fd5617c13bd7c0b70d227fc Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 6 Mar 2024 08:01:21 -0800 Subject: [PATCH 21/28] ., --- tests/models/test_models.py | 2 +- tests/prompts/example.txt | 9 +-------- vllm/engine/llm_engine.py | 2 +- vllm/model_executor/layers/attention.py | 2 ++ vllm/worker/model_runner.py | 12 +++++++----- vllm/worker/worker.py | 4 ++-- 6 files changed, 14 insertions(+), 17 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index c268b6fd4868..f3f416e63838 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -26,7 +26,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("enforce_eager", [False]) def test_models( hf_runner, vllm_runner, diff --git a/tests/prompts/example.txt b/tests/prompts/example.txt index e1b97bc6eee7..6e8c45b673ee 100644 --- a/tests/prompts/example.txt +++ b/tests/prompts/example.txt @@ -1,8 +1 @@ -vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. -Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020. -Compare and contrast artificial intelligence with human intelligence in terms of processing information. -Describe the basic components of a neural network and how it can be trained. -Write a short story about a robot that dreams for the first time. -Analyze the impact of the COVID-19 pandemic on global economic structures and future business models. -Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies. -Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.' +vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. \ No newline at end of file diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8484014c9a13..4c25602c6e96 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -798,7 +798,7 @@ def _process_model_outputs( # Log stats. if self.log_stats: self.stat_logger.log(self._get_stats(scheduler_outputs)) - + # breakpoint() return request_outputs def step(self) -> List[RequestOutput]: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index bfcdecc8876b..db67577fe881 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -157,6 +157,7 @@ def forward( num_prompt_tokens = input_metadata.num_prompt_tokens num_generation_tokens = input_metadata.num_generation_tokens + print(num_generation_tokens) if num_prompt_tokens > 0: assert num_generation_tokens == 0 @@ -252,6 +253,7 @@ def forward( ) if num_generation_tokens > 0: + breakpoint() assert num_prompt_tokens == 0 # Decoding run. output = _paged_attention( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6c207a1abbe1..fa27a62ceee8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -265,7 +265,7 @@ def _prepare_prompt( num_generation_tokens=0, max_seq_len=max_prompt_len, start_loc=start_loc_tensor, - max_context_len=max(context_lens), + max_context_len=None, context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -344,10 +344,11 @@ def _prepare_decode( input_tokens.append(0) input_positions.append(0) slot_mapping.append(_PAD_SLOT_ID) - context_lens.append(1) + context_lens.append(0) block_tables.append([]) batch_size = graph_batch_size + # Q: should we not pad when cuda graph is disabled? input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, pad=0, dtype=torch.long, @@ -610,6 +611,7 @@ def execute_model( kv_caches=kv_caches, input_metadata=input_metadata, ) + breakpoint() # Sample the next token. output = self.model.sample( @@ -717,7 +719,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # deleted before the CUDA graphs. self.cupy_nccl_backend = cupy_utils.get_nccl_backend() - assert not self.model_config.enforce_eager + # assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " "unexpected consequences if the model is not static. To " "run the model in eager mode, set 'enforce_eager=True' or " @@ -735,7 +737,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) - context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + context_lens = torch.zeros(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() graph_batch_size = _get_graph_batch_size( @@ -876,7 +878,7 @@ def forward( non_blocking=True) self.input_buffers["block_tables"].copy_(input_metadata.block_tables, non_blocking=True) - + breakpoint() # Run the graph. self.graph.replay() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 157e8c45836b..90ab9f421b7f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -153,8 +153,8 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: self.model_runner.set_block_size(self.cache_engine.block_size) def warm_up_model(self) -> None: - if not self.model_config.enforce_eager: - self.model_runner.capture_model(self.gpu_cache) + # if not self.model_config.enforce_eager: + self.model_runner.capture_model(self.gpu_cache) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) From c1ab0b0bedf0e25f3d35c998f5216eb33b4275d1 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 6 Mar 2024 08:06:53 -0800 Subject: [PATCH 22/28] done --- tests/models/test_models.py | 2 +- vllm/engine/llm_engine.py | 1 - vllm/model_executor/layers/attention.py | 53 +++++++++---------------- vllm/worker/model_runner.py | 2 - 4 files changed, 20 insertions(+), 38 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index f3f416e63838..c268b6fd4868 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -26,7 +26,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("enforce_eager", [False, True]) def test_models( hf_runner, vllm_runner, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4c25602c6e96..ea10f3e18897 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -798,7 +798,6 @@ def _process_model_outputs( # Log stats. if self.log_stats: self.stat_logger.log(self._get_stats(scheduler_outputs)) - # breakpoint() return request_outputs def step(self) -> List[RequestOutput]: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index db67577fe881..8c4619a11f1d 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -141,29 +141,17 @@ def forward( # If key_cache and value_cache are not provided, the new key and value # vectors will not be cached. This happens during the initial memory # profiling run. - num_valid_tokens = input_metadata.num_valid_tokens - if (num_valid_tokens > 0 and key_cache is not None - and value_cache is not None): - key_to_cache = key[:num_valid_tokens] - value_to_cache = value[:num_valid_tokens] + if (key_cache is not None and value_cache is not None): cache_ops.reshape_and_cache( - key_to_cache, - value_to_cache, + key, + value, key_cache, value_cache, input_metadata.slot_mapping.flatten(), input_metadata.kv_cache_dtype, ) - num_prompt_tokens = input_metadata.num_prompt_tokens - num_generation_tokens = input_metadata.num_generation_tokens - print(num_generation_tokens) - - if num_prompt_tokens > 0: - assert num_generation_tokens == 0 - query = query[:num_prompt_tokens] - key = key[:num_prompt_tokens] - value = value[:num_prompt_tokens] + if input_metadata.is_prompt: # normal attention if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): @@ -202,7 +190,7 @@ def forward( input_metadata) if self.use_ref_attention: - output[:num_prompt_tokens] = self.ref_masked_attention( + output = self.ref_masked_attention( query, key, value, @@ -222,18 +210,17 @@ def forward( key = key.unflatten(0, (num_tokens)) value = value.unflatten(0, (num_tokens)) - output[: - num_prompt_tokens] = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - op=xops.fmha. - MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ).view_as(query) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + output = out.view_as(query) else: # prefix-enabled attention output = torch.empty_like(query) @@ -252,13 +239,11 @@ def forward( getattr(self, "alibi_slopes", None), ) - if num_generation_tokens > 0: - breakpoint() - assert num_prompt_tokens == 0 + else: # Decoding run. output = _paged_attention( - output[num_prompt_tokens:num_valid_tokens], - query[num_prompt_tokens:num_valid_tokens], + output, + query, key_cache, value_cache, input_metadata, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fa27a62ceee8..2ad4ee534dd5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -611,7 +611,6 @@ def execute_model( kv_caches=kv_caches, input_metadata=input_metadata, ) - breakpoint() # Sample the next token. output = self.model.sample( @@ -878,7 +877,6 @@ def forward( non_blocking=True) self.input_buffers["block_tables"].copy_(input_metadata.block_tables, non_blocking=True) - breakpoint() # Run the graph. self.graph.replay() From f48dc72f5058f661cee56a1f34ccae314f9606d1 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 7 Mar 2024 01:01:39 -0800 Subject: [PATCH 23/28] Addressed code review. --- tests/kernels/test_prefix_prefill.py | 1 - tests/lora/test_worker.py | 2 +- tests/models/test_models.py | 28 ++++++++++++++-------------- tests/prompts/example.txt | 9 ++++++++- tests/worker/test_model_runner.py | 13 +++++++++++-- vllm/worker/model_runner.py | 15 ++++++++------- vllm/worker/worker.py | 4 ++-- 7 files changed, 44 insertions(+), 28 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index c068b38a6691..76dd9bdf3a19 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -113,7 +113,6 @@ def test_contexted_kv_attention( # to V_cache[num_blocks, num_kv_heads, head_size, block_size] v_cache = v_cache.view(-1, block_size, num_kv_heads, head_size).permute(0, 2, 3, 1).contiguous() - # Warm up the Triton kernel by calling it once before actually measuring generation time context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, b_start_loc, b_seq_len, b_ctx_len, max_input_len) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 31a7c716afbf..e4538de35169 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -25,7 +25,7 @@ def test_worker_apply_lora(sql_lora_files): revision=None, ), parallel_config=ParallelConfig(1, 1, False), - scheduler_config=SchedulerConfig(32, 32, 32, 256), + scheduler_config=SchedulerConfig(32, 32, 32), device_config=DeviceConfig("cuda"), local_rank=0, rank=0, diff --git a/tests/models/test_models.py b/tests/models/test_models.py index c268b6fd4868..7bb93a519bf0 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -6,20 +6,20 @@ MODELS = [ "facebook/opt-125m", - # "meta-llama/Llama-2-7b-hf", - # "mistralai/Mistral-7B-v0.1", - # "Deci/DeciLM-7b", - # "tiiuae/falcon-7b", - # "gpt2", - # "bigcode/tiny_starcoder_py", - # "EleutherAI/gpt-j-6b", - # "EleutherAI/pythia-70m", - # "bigscience/bloom-560m", - # "mosaicml/mpt-7b", - # "microsoft/phi-2", - # "stabilityai/stablelm-3b-4e1t", - # "allenai/OLMo-1B", - # "bigcode/starcoder2-3b", + "meta-llama/Llama-2-7b-hf", + "mistralai/Mistral-7B-v0.1", + "Deci/DeciLM-7b", + "tiiuae/falcon-7b", + "gpt2", + "bigcode/tiny_starcoder_py", + "EleutherAI/gpt-j-6b", + "EleutherAI/pythia-70m", + "bigscience/bloom-560m", + "mosaicml/mpt-7b", + "microsoft/phi-2", + "stabilityai/stablelm-3b-4e1t", + "allenai/OLMo-1B", + "bigcode/starcoder2-3b", ] diff --git a/tests/prompts/example.txt b/tests/prompts/example.txt index 6e8c45b673ee..cef4d1d76873 100644 --- a/tests/prompts/example.txt +++ b/tests/prompts/example.txt @@ -1 +1,8 @@ -vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. \ No newline at end of file +vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. +Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020. +Compare and contrast artificial intelligence with human intelligence in terms of processing information. +Describe the basic components of a neural network and how it can be trained. +Write a short story about a robot that dreams for the first time. +Analyze the impact of the COVID-19 pandemic on global economic structures and future business models. +Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies. +Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.' \ No newline at end of file diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 7fed21bc2aa9..6ae0c5eddd04 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -61,10 +61,17 @@ def test_prepare_prompt(): assert torch.allclose( input_metadata.start_loc, torch.tensor(start_loc, dtype=torch.long, device=device)) - assert input_metadata.max_context_len == max(prompt_lens) + assert input_metadata.max_context_len is None + # TODO(sang): The current definition of context_lens is the + # number of k/v that are already cached (before this run). + # It is inconsistent with decoding. assert torch.allclose( input_metadata.context_lens, - torch.tensor(prompt_lens, dtype=torch.int, device=device)) + torch.zeros(input_metadata.context_lens.shape[0], + dtype=torch.int, + device=device)) + + # SANG-TODO # assert input_metadata.slot_mapping == max(prompt_lens) # block_tables # Cuda graph should not be used for prerill. @@ -154,6 +161,8 @@ def round_up_to_next_multiple_of_batch_size(n): assert torch.allclose( input_metadata.context_lens[:len(prompt_lens)], torch.tensor(prompt_lens, dtype=torch.int, device=device)) + + # SANG-TODO # assert input_metadata.slot_mapping == max(prompt_lens) # block_tables # Cuda graph should not be used for prerill. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2ad4ee534dd5..04713914497f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -109,8 +109,7 @@ def load_model(self) -> None: ), "Model does not have embedding_padding_modules" self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens + - self.scheduler_config.max_paddings, vocab_size, + self.scheduler_config.max_num_batched_tokens, vocab_size, self.lora_config, self.device, self.model.embedding_modules, self.model.embedding_padding_modules) self.model = self.lora_manager.create_lora_manager(self.model) @@ -163,7 +162,7 @@ def _prepare_prompt( context_len = computed_len else: prefix_block_tables.append([]) - context_len = prompt_len + context_len = 0 # actual prompt lens context_lens.append(context_len) subquery_lens.append(prompt_len - computed_len) @@ -363,10 +362,12 @@ def _prepare_decode( dtype=torch.int, device=self.device) - # The first dimension corresponds to number of tokens. - assert context_lens.shape[0] == input_tokens.shape[0] - assert context_lens.shape[0] == input_positions.shape[0] - assert context_lens.shape[0] == slot_mapping.shape[0] + if use_captured_graph: + # When using cuda-graph all these tensors should be + # padded. + assert context_lens.shape[0] == input_tokens.shape[0] + assert context_lens.shape[0] == input_positions.shape[0] + assert context_lens.shape[0] == slot_mapping.shape[0] if use_captured_graph: # The shape of graph_block_tables is diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 90ab9f421b7f..157e8c45836b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -153,8 +153,8 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: self.model_runner.set_block_size(self.cache_engine.block_size) def warm_up_model(self) -> None: - # if not self.model_config.enforce_eager: - self.model_runner.capture_model(self.gpu_cache) + if not self.model_config.enforce_eager: + self.model_runner.capture_model(self.gpu_cache) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) From 769b2b491939f9e461486fecd2ac97e80e15eb0c Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 7 Mar 2024 01:14:49 -0800 Subject: [PATCH 24/28] working --- vllm/model_executor/input_metadata.py | 9 +++++++++ vllm/model_executor/layers/attention.py | 9 +++++---- vllm/worker/model_runner.py | 14 ++++++-------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index e7d45a6cbf3f..be4f71445276 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -45,6 +45,13 @@ def __init__( self.start_loc = start_loc self.max_context_len = max_context_len self.slot_mapping = slot_mapping + # Index: The batched sequence's index. + # Value: The length of attention context. + # NOTE(sang): When it is prefill/decoding, + # the definition is different. For prefill, + # it means the the length of KV that are cached + # excluding the new KVs. In decoding, this + # includes a new KV. self.context_lens = context_lens self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph @@ -53,6 +60,8 @@ def __init__( # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. self.attn_bias = None + # Number of valid tokens. It includes paddings. + # See attention.py for precise definition. self.num_valid_tokens = slot_mapping.shape[0] def __repr__(self) -> str: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 8c4619a11f1d..a6fea54e9815 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -242,7 +242,6 @@ def forward( else: # Decoding run. output = _paged_attention( - output, query, key_cache, value_cache, @@ -289,15 +288,17 @@ def _make_alibi_bias( def _paged_attention( - output: torch.Tensor, # [num_tokens, num_heads, head_size] query: torch.Tensor, # [num_tokens, num_heads, head_size] - key_cache: torch.Tensor, - value_cache: torch.Tensor, + key_cache: torch. + Tensor, # [num_total_blocks, block_size, num_heads, head_size] + value_cache: torch. + Tensor, # [num_total_blocks, block_size, num_heads, head_size] input_metadata: InputMetadata, num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], ) -> torch.Tensor: + output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = ( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 04713914497f..99a5602ca967 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -31,8 +31,6 @@ _BATCH_SIZE_ALIGNMENT = 8 # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. -# Note that cuda graph is only used for decoding because it speeds up -# the performance when num_tokens < 200. Batch here means a single token. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) ] @@ -215,6 +213,9 @@ def _prepare_prompt( max_prompt_len = max(subquery_lens) num_prompt_tokens = len(input_tokens) + + # Pad tokens to better utilize tensor cores although + # cuda graph is not enabled. input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, pad=0, dtype=torch.long, @@ -347,7 +348,8 @@ def _prepare_decode( block_tables.append([]) batch_size = graph_batch_size - # Q: should we not pad when cuda graph is disabled? + # Pad tokens to better utilize tensor cores although + # cuda graph is not enabled. input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, pad=0, dtype=torch.long, @@ -599,9 +601,6 @@ def execute_model( # Execute the model. if input_metadata.use_cuda_graph: - # NOTE: We use cuda graph only when there are only - # decoding requests, which means the number of batch - # size is equivalent to number of input tokens. graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: @@ -719,7 +718,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # deleted before the CUDA graphs. self.cupy_nccl_backend = cupy_utils.get_nccl_backend() - # assert not self.model_config.enforce_eager + assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " "unexpected consequences if the model is not static. To " "run the model in eager mode, set 'enforce_eager=True' or " @@ -915,7 +914,6 @@ def _make_tensor_with_pad_for_alignment( """Create a tensor of a given list x with padding. It adds paddings to align with graph batch size. See _get_graph_batch_size for more details. - # NOTE: This API is only for decoding requests. """ batch_size = len(x) batch_size = _get_graph_batch_size(batch_size) From f7347b8c704a587040873496c5398d0683b9decd Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 7 Mar 2024 06:42:08 -0800 Subject: [PATCH 25/28] working --- vllm/model_executor/input_metadata.py | 5 +++ .../layers/attention/attention.py | 7 ++-- .../layers/attention/backends/flash_attn.py | 42 +++++++++---------- .../layers/attention/backends/xformers.py | 1 + 4 files changed, 29 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index be4f71445276..d992f6b7d207 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -6,6 +6,11 @@ class InputMetadata: """Metadata for input sequences. Used in PagedAttention. + NOTE: Any python object stored here is not updated when it is + cuda-graph replays. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + Args: prompt_lens: Lengths of prompts. slot_mapping: The index of each token mapped into a physical block diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 4d288dbb81e7..58f1bbd99aad 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -5,7 +5,7 @@ import torch.nn as nn from vllm.model_executor.input_metadata import InputMetadata -from vllm.utils import is_hip +# from vllm.utils import is_hip class Attention(nn.Module): @@ -46,8 +46,9 @@ def __init__( sliding_window: Optional[int] = None, ) -> None: super().__init__() - if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and - torch.get_default_dtype() in (torch.float16, torch.bfloat16)): + # if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and + # torch.get_default_dtype() in (torch.float16, torch.bfloat16)): + if False: # Ampere or later NVIDIA GPUs. # NOTE(woosuk): FlashAttention does not support FP32. from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index 512f4e49c7eb..20a3a072bb3b 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -53,18 +53,18 @@ def forward( """Forward pass with FlashAttention and PagedAttention. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] input_metadata: metadata for the inputs. Returns: - shape = [batch_size, seq_len, num_heads * head_size] + shape = [num_tokens, num_heads * head_size] """ - batch_size, seq_len, hidden_size = query.shape + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -74,27 +74,25 @@ def forward( # If key_cache and value_cache are not provided, the new key and value # vectors will not be cached. This happens during the initial memory # profiling run. - if key_cache is not None and value_cache is not None: - PagedAttentionImpl.reshape_and_cache(key, value, key_cache, - value_cache, input_metadata) + PagedAttentionImpl.reshape_and_cache(key, value, key_cache, + value_cache, input_metadata) if input_metadata.is_prompt: # Prompt run. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # normal attention - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) - output = flash_attn_func( - query, - key, - value, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) + output = torch.empty_like(query) + query = query.unflatten(0, (num_tokens, )) + key = key.unflatten(0, (num_tokens, )) + value = value.unflatten(0, (num_tokens, )) + output = flash_attn_func(query, + key, + value, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes) else: # prefix-enabled attention output = PagedAttentionImpl.forward_prefix( @@ -104,8 +102,6 @@ def forward( key_cache, value_cache, input_metadata, - self.num_heads, - self.num_kv_heads, self.alibi_slopes, ) else: @@ -121,4 +117,4 @@ def forward( ) # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + return output.view(num_tokens, hidden_size) diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/model_executor/layers/attention/backends/xformers.py index 5206959de333..269b73303fb0 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/model_executor/layers/attention/backends/xformers.py @@ -83,6 +83,7 @@ def forward( if input_metadata.is_prompt: # Prompt run. + # Unless there's a prefix, context lens is all 0 for prefill. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # normal attention From f91d73e40e4e60143eeddc904ef080d88aea4b07 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 7 Mar 2024 23:04:29 -0800 Subject: [PATCH 26/28] fix lora --- vllm/worker/model_runner.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d6be22ec89fb..060828346c3f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -182,7 +182,7 @@ def _prepare_prompt( if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping.append([lora_id] * (prompt_len - computed_len)) + lora_index_mapping += [lora_id] * (prompt_len - computed_len) lora_prompt_mapping.extend( [lora_id] * (prompt_len - computed_len @@ -232,10 +232,11 @@ def _prepare_prompt( pad=_PAD_SLOT_ID, dtype=torch.long, device=self.device) - lora_index_mapping = [ - _pad_to_max(mapping, max_prompt_len, pad=0) - for mapping in lora_index_mapping - ] + lora_index_mapping = _pad_to_alignment(lora_index_mapping, + _get_graph_batch_size( + len(lora_index_mapping)), + pad=0) + context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -323,7 +324,7 @@ def _prepare_decode( block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - lora_index_mapping.append([lora_id]) + lora_index_mapping.append(lora_id) lora_prompt_mapping.append(lora_id) if self.sliding_window is not None: @@ -396,9 +397,10 @@ def _prepare_decode( device=self.device, ) - lora_index_mapping = [ - _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping - ] + lora_index_mapping = _pad_to_alignment(lora_index_mapping, + _get_graph_batch_size( + len(lora_index_mapping)), + pad=0) input_metadata = InputMetadata( is_prompt=False, @@ -527,11 +529,8 @@ def prepare_input_tensors( subquery_lens) if self.lora_config: - flat_lora_index_mapping = [ - item for sublist in lora_index_mapping for item in sublist - ] lora_mapping = LoRAMapping( - flat_lora_index_mapping, + lora_index_mapping, lora_prompt_mapping, ) else: From f7d79dad9b61de77f0d1df5f20f86752adbb7f58 Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 8 Mar 2024 00:06:10 -0800 Subject: [PATCH 27/28] fixed --- .../spec_decode/test_multi_step_worker.py | 6 +++--- tests/worker/test_model_runner.py | 18 ++++++++++-------- vllm/worker/model_runner.py | 9 ++++++--- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/tests/worker/spec_decode/test_multi_step_worker.py b/tests/worker/spec_decode/test_multi_step_worker.py index ea5480290357..15a0546813eb 100644 --- a/tests/worker/spec_decode/test_multi_step_worker.py +++ b/tests/worker/spec_decode/test_multi_step_worker.py @@ -90,8 +90,8 @@ def test_same_output_for_single_step(): num_gpu_blocks, seed, ) - multi_step_worker.model_runner = worker.model_runner - multi_step_worker.cache_engine = worker.cache_engine + # multi_step_worker.model_runner = worker.model_runner + # multi_step_worker.cache_engine = worker.cache_engine num_steps = 1 @@ -258,4 +258,4 @@ def test_same_output_for_multi_step(): for multi_step_logprobs, single_step_logprobs in zip( multi_step_output_logprobs, single_step_output_logprobs): assert_logprobs_dict_allclose(multi_step_logprobs, - single_step_logprobs) + single_step_logprobs) \ No newline at end of file diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 6ae0c5eddd04..514eedd9a415 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -19,6 +19,7 @@ def test_prepare_prompt(): batch_size = random.randint(1, 256) prompt_lens = [] seq_group_metadata_list = [] + block_tables = {0: [1]} for i in range(batch_size): # make sure all tokens fit into one block prompt_len = i % (model_runner.block_size - 1) + 1 @@ -30,7 +31,7 @@ def test_prepare_prompt(): is_prompt=True, seq_data={0: SequenceData(seq_data)}, sampling_params=SamplingParams(temperature=0), - block_tables={0: [1]}, + block_tables=block_tables, )) expected_selected_token_indices = [] @@ -70,10 +71,7 @@ def test_prepare_prompt(): torch.zeros(input_metadata.context_lens.shape[0], dtype=torch.int, device=device)) - - # SANG-TODO - # assert input_metadata.slot_mapping == max(prompt_lens) - # block_tables + assert input_metadata.block_tables is None # Cuda graph should not be used for prerill. assert input_metadata.use_cuda_graph is False assert input_metadata.kv_cache_dtype == "auto" @@ -162,9 +160,13 @@ def round_up_to_next_multiple_of_batch_size(n): input_metadata.context_lens[:len(prompt_lens)], torch.tensor(prompt_lens, dtype=torch.int, device=device)) - # SANG-TODO - # assert input_metadata.slot_mapping == max(prompt_lens) - # block_tables + # block table's first index corresponds to each batch, meaning in + # decoding it is each token. + assert input_metadata.block_tables.shape[0] == len(input_tokens) + # Block table's second dim correspondsd to each token's block number. + # It is padded up to + assert input_metadata.block_tables.shape[1] == ( + model_runner.get_max_block_per_batch()) # Cuda graph should not be used for prerill. assert input_metadata.use_cuda_graph is True assert input_metadata.kv_cache_dtype == "auto" diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 060828346c3f..2edb21c4900f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -121,10 +121,13 @@ def load_model(self) -> None: def set_block_size(self, block_size: int) -> None: self.block_size = block_size - max_num_blocks = (self.max_context_len_to_capture + block_size - - 1) // block_size self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + dtype=np.int32) + + def get_max_block_per_batch(self): + block_size = self.block_size + return (self.max_context_len_to_capture + block_size - 1) // block_size def _prepare_prompt( self, From 406f1d40cde7069233c2cde0ac8cfffc50b2c81e Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 8 Mar 2024 00:38:30 -0800 Subject: [PATCH 28/28] fix --- tests/worker/test_model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 514eedd9a415..55a078230b46 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -71,7 +71,11 @@ def test_prepare_prompt(): torch.zeros(input_metadata.context_lens.shape[0], dtype=torch.int, device=device)) - assert input_metadata.block_tables is None + + expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.allclose(input_metadata.block_tables, expected) # Cuda graph should not be used for prerill. assert input_metadata.use_cuda_graph is False assert input_metadata.kv_cache_dtype == "auto"