From 06fe87293715e79aac88dfe75963fa1ec600f81f Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 27 Feb 2024 22:55:16 -0800 Subject: [PATCH 01/88] [1/n] Support efficient reshape caching. --- csrc/cache.h | 8 ++++ csrc/cache_kernels.cu | 89 +++++++++++++++++++++++++++++++++++++ csrc/pybind.cpp | 4 ++ tests/kernels/test_cache.py | 85 +++++++++++++++++++++++++++++++++++ vllm/utils.py | 13 +++++- 5 files changed, 197 insertions(+), 2 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 765e231abd26..ce3f96d59e83 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -23,6 +23,14 @@ void reshape_and_cache( torch::Tensor& slot_mapping, const std::string& kv_cache_dtype); +void reshape_and_cache_flash( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + torch::Tensor& num_tokens); + // Just for unittest void convert_fp8_e5m2( torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 7254010b8e3a..21057984e1e3 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -269,6 +269,95 @@ void reshape_and_cache( namespace vllm { +// flash-attention style cache funciton where key/value caches +// has the same shape of [num_blocks, block_size, num_heads, head_size] +template +__global__ void reshape_and_cache_flash_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, head_size] + scalar_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int64_t* __restrict__ num_tokens, // [1] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size) { + const int64_t num_tokens_ = num_tokens[0]; + const int64_t token_idx = blockIdx.x; + if (token_idx >= num_tokens_) { + return; + } + const int64_t slot_idx = slot_mapping[token_idx]; + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + + // Target idx for kv are the same. + const int64_t tgt_idx = block_idx * block_size * num_heads * head_size + + block_offset * num_heads * head_size + + head_idx * head_size + + head_offset; + + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + // TODO(sang): Support ENABLE_FP8_E5M2. + key_cache[tgt_idx] = tgt_key; + value_cache[tgt_idx] = tgt_value; + } +} + +} // namespace vllm + +// TODO(sang): Support kv_cache_dtype +void reshape_and_cache_flash( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + torch::Tensor& num_tokens) // [1] +{ + int num_tokens_padded = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(1); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + dim3 grid(num_tokens_padded); + dim3 block(std::min(num_heads * head_size, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), + "reshape_and_cache_flash_kernel", + [&] { + vllm::reshape_and_cache_flash_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), + num_tokens.data_ptr(), + key_stride, + value_stride, + num_heads, + head_size, + block_size); + }); +} + +namespace vllm { + template __global__ void convert_fp8_e5m2_kernel( const Tin* __restrict__ src_cache, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 5d062bb5700b..3b242327652f 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -79,6 +79,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); + cache_ops.def( + "reshape_and_cache_flash", + &reshape_and_cache_flash, + "Reshape the flash-style key and value tensors and cache them."); cache_ops.def( "convert_fp8_e5m2", &convert_fp8_e5m2, diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index d8dc74bc7b00..1430b76270c5 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -7,6 +7,7 @@ from vllm._C import cache_ops from vllm.utils import is_hip +import torch.nn.functional as F COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -25,6 +26,7 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] +PADDINGS = [8, 16, 0] @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @@ -224,3 +226,86 @@ def test_swap_blocks( dist_key_caches[0][dst].cpu()) assert torch.allclose(src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("padding", PADDINGS) +@torch.inference_mode() +def test_reshape_and_cache_flash( + kv_cache_factory, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + padding: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Create a random slot mapping. + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device='cuda') + _, key, value = qkv.unbind(dim=1) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(num_blocks, + block_size, + 1, + num_heads, + head_size, + dtype, + seed, + flash_style=True) + assert len(key_caches) == 1 and len(value_caches) == 0 + key_cache, value_cache = key_caches[0], value_caches[0] + + # Clone the KV caches. + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + num_tokens = torch.tensor([num_tokens], dtype=torch.long, device="cuda") + + def pad_key_value(key: torch.Tensor, value: torch.Tensor, + pad_size: int) -> Tuple[torch.Tensor, torch.Tensor]: + if pad_size == 0: + return key, value + return F.pad(key, (0, 0, 0, 0, 0, pad_size)),\ + F.pad(value, (0, 0, 0, 0, 0, pad_size)) + + # kv shapes: (num_blocks, block_size, num_heads, head_size) + # pad tokens. + padded_key, padded_value = pad_key_value(key, value, padding) + # Call the reshape_and_cache kernel. + cache_ops.reshape_and_cache_flash(padded_key, padded_value, key_cache, + value_cache, slot_mapping, num_tokens) + + # Run the reference implementation. + block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies[i] + block_offset = block_offsets[i] + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) diff --git a/vllm/utils.py b/vllm/utils.py index c8ac57de6f5f..03ebcc42524f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -245,6 +245,7 @@ def create_kv_caches_with_random( model_dtype: Optional[Union[str, torch.dtype]] = None, seed: Optional[int] = 0, device: Optional[str] = "cuda", + flash_style: bool = False, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -271,7 +272,11 @@ def create_kv_caches_with_random( scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + if flash_style: + key_cache_shape = (num_blocks, block_size, num_heads, head_size) + else: + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, + x) key_caches = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, @@ -286,7 +291,11 @@ def create_kv_caches_with_random( f"Does not support key cache of type {cache_dtype}") key_caches.append(key_cache) - value_cache_shape = (num_blocks, num_heads, head_size, block_size) + if flash_style: + value_cache_shape = (num_blocks, block_size, num_heads, head_size) + else: + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches = [] for _ in range(num_layers): value_cache = torch.empty(size=value_cache_shape, From 9a0b6bea2e88ec37899a2aecc4aee62a0d388f0a Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 27 Feb 2024 23:36:19 -0800 Subject: [PATCH 02/88] [2/n] support flash attention kernel --- tests/kernels/test_flash_attention.py | 460 ++++++++++++++++++++++++ vllm/model_executor/layers/attention.py | 145 ++++++++ 2 files changed, 605 insertions(+) create mode 100644 tests/kernels/test_flash_attention.py diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py new file mode 100644 index 000000000000..6081bfcc0886 --- /dev/null +++ b/tests/kernels/test_flash_attention.py @@ -0,0 +1,460 @@ +import random +from typing import List, Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + +from vllm.model_executor.layers.attention import ( + flash_single_query_cached_kv_attention, + flash_multi_query_cached_kv_attention_varlen, +) +from vllm.utils import get_max_shared_memory_bytes + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +NUM_BLOCKS = 128 # Arbitrary values for testing +PARTITION_SIZE = 512 + +DTYPES = [torch.half, torch.bfloat16] +NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing +NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing +NUM_HEADS_SMALL = NUM_HEADS +HEAD_SIZES = [64, 80, 96, 112, 128, 256] +BLOCK_SIZES = [32] +USE_ALIBI = [False, True] +SEEDS = [0] +PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] + + +def pad_attention_inputs( + pad_config: Tuple[int, int], + block_size: int, + query: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Pad the attention inputs to the specified batch size and context length. + """ + pad_batch_size, pad_max_context_len = pad_config + if pad_batch_size == 0: + return query, block_tables, context_lens, max_context_len + target_batch_size = ( + (query.shape[0] - 1) % pad_batch_size + 1) * pad_batch_size + target_block_size = pad_max_context_len // block_size + 1 + padded_query = F.pad(query, + (0, 0, 0, 0, 0, target_batch_size - query.shape[0])) + padded_block_table = F.pad(block_tables, + (0, target_block_size - block_tables.shape[1], + 0, target_batch_size - block_tables.shape[0])) + padded_context_lens = F.pad(context_lens, + (0, target_batch_size - context_lens.shape[0])) + return padded_query, padded_block_table, padded_context_lens, pad_max_context_len + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], + flash_style: bool = False, +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[-2] + head_size = value_cache.shape[-1] + block_size = value_cache.shape[-3] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + context_lens = context_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + context_len = int(context_lens[i]) + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + if flash_style: + k = key_cache[block_number, block_offset, :, :] + else: + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + if flash_style: + v = value_cache[block_number, block_offset, :, :] + else: + v = value_cache[block_number, :, :, block_offset] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(context_len, device="cuda").int() + alibi_bias = (position_ids - context_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +# @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +# @pytest.mark.parametrize("num_heads", NUM_HEADS) +# @pytest.mark.parametrize("head_size", HEAD_SIZES) +# @pytest.mark.parametrize("use_alibi", [False]) +# @pytest.mark.parametrize("block_size", [32]) +# @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +# @pytest.mark.parametrize("seed", SEEDS) +# @pytest.mark.parametrize("pad_config", PAD_CONFIGS) +@pytest.mark.parametrize("num_seqs", [3]) +@pytest.mark.parametrize("num_heads", [(40, 40), (64, 8)]) +@pytest.mark.parametrize("head_size", [80, 96]) +@pytest.mark.parametrize("use_alibi", [False]) +@pytest.mark.parametrize("block_size", [32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("pad_config", [(0, 0)]) +@torch.inference_mode() +def test_flash_paged_attention( + kv_cache_factory, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + seed: int, + pad_config: Tuple[int, int], +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, + dtype=torch.float, + device="cuda") + + max_seq_len = MAX_SEQ_LEN if not pad_config[0] else (pad_config[1] - 1000) + context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] + context_lens[-1] = max_seq_len + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + dtype, + seed, + flash_style=True) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Call the paged attention kernel. + num_valid_tokens = torch.cuda.IntTensor([num_seqs]) + output = torch.empty_like(query) + + padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ + pad_attention_inputs(pad_config, block_size, query, + block_tables, context_lens, max_context_len) + + flash_single_query_cached_kv_attention( + output, + padded_query, + key_cache, + value_cache, + scale, + padded_block_table, + padded_context_lens, + block_size, + alibi_slopes, + ) + + # Run the reference implementation. + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + context_lens, + scale, + alibi_slopes, + flash_style=True, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +def ref_multi_query_kv_attention_padded( + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + cu_seq_lens: List[int], + context_lens: List[int], + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + block_size = value_cache.shape[-3] + ref_outputs = [] + + for i in range(num_seqs): + q_start_idx = cu_seq_lens[i] + q_end_idx = cu_seq_lens[i + 1] + seq_len = q_end_idx - q_start_idx + + context_len = context_lens[i] + + block_table = block_tables[i] + keys = [] + values = [] + + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, block_offset, :, :] + keys.append(k) + + v = value_cache[block_number, block_offset, :, :] + values.append(v) + + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + q = query[q_start_idx:q_end_idx, :, :] + k = keys[:context_len, :, :] + v = values[:context_len, :, :] + + assert seq_len <= context_len + + # pad q if seq_len is less than context_len + # this is for correct calculation of attention. + if seq_len < context_len: + indices = [i % seq_len for i in range(context_len - seq_len)] + q_left_pad = q[indices, :, :] + q = torch.cat([q_left_pad, q], dim=0) + + # Create attention mask. + attn_mask = torch.triu(torch.ones(context_len, + context_len, + dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype, device="cuda") + + ref_output = ref_masked_attention( + q, + k, + v, + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output[-seq_len:, :, :]) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +def is_a100(): + return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 + + +if not is_a100(): + NUM_HEADS_SMALL = [(16, 16), (16, 8)] + MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) + +NUM_BLOCKS = 1024 +BLOCK_SIZE = 32 + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS_SMALL) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("version", ["flash"]) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_multi_query_kv_attention( + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + version: str, + seed: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. + # As the xformers library is already tested with its own tests, we can use + # a smaller MAX_SEQ_LEN here. + max_len = min(MAX_SEQ_LEN, 4096) + + seq_lens = [random.randint(1, max_len / 2) for i in range(num_seqs)] + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device="cuda") + + context_lens = random.sample(range(max_seq_len, max_len), num_seqs) + max_context_len = max(context_lens) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device="cuda") + + num_tokens = sum(seq_lens) + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + + cu_context_lens = [0] + for context_len in context_lens: + cu_context_lens.append(cu_context_lens[-1] + context_len) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + num_queries_per_kv = num_query_heads // num_kv_heads + + value_cache = torch.empty(NUM_BLOCKS, + BLOCK_SIZE, + num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + key_cache = torch.empty(NUM_BLOCKS, + BLOCK_SIZE, + num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + query = torch.empty(num_tokens, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + value_cache.uniform_(-scale, scale) + key_cache.uniform_(-scale, scale) + query.uniform_(-scale, scale) + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + BLOCK_SIZE - 1) // BLOCK_SIZE + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + output = torch.empty_like(query) + + if version == "flash": + flash_multi_query_cached_kv_attention_varlen( + output, + query, + key_cache, + value_cache, + scale, + block_tables, + torch.cuda.IntTensor(cu_seq_lens), + torch.cuda.IntTensor(cu_context_lens), + BLOCK_SIZE, + max_seq_len, + max_context_len, + None, + ) + else: + assert False, f"{version=} is not supported" + + ref_output = ref_multi_query_kv_attention_padded( + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + cu_seq_lens, + context_lens, + scale, + dtype, + ) + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) \ No newline at end of file diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 2a82325b8021..2cb8bdf0a7c8 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -15,6 +15,13 @@ context_attention_fwd) from vllm.utils import is_hip +try: + from flash_attn import (flash_attn_with_page_attention, + flash_attn_varlen_with_page_attention) +except Exception as e: + flash_attn_with_page_attention = e + flash_attn_varlen_with_page_attention = e + _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. _PARTITION_SIZE = 512 @@ -347,3 +354,141 @@ def _paged_attention( input_metadata.kv_cache_dtype, ) return output + + +def flash_single_query_cached_kv_attention( + output: Optional[torch.Tensor], + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + alibi_slopes: Optional[torch.Tensor], + actual_batch_size: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Similar to vLLM's page attention, caclulates a single token's attention + based on key/value caches. The main difference is this uses flash attention + sytle key-value caches. + Arguments: + output: [num_padded_tokens, num_heads, head_size], output tensor + to write. if None an new output tensor will be created. + query: [num_padded_tokens, num_heads, head_size], query tensor. + key_cache: [num_blocks, block_size, num_heads, head_size], key cache. + value_cache: [num_blocks, block_size, num_heads, head_size], value + cache. + scale: attention scale. + block_tables: [num_padded_tokens, max_context_len / block_size], + block tables. + context_lens: [num_padded_tokens], context lengths. + block_size: block size. + alibi_slopes: unused. + actual_batch_size: [1] actual batch size. + Returns: + output: [num_padded_tokens, num_heads, head_size] + """ + block_size = value_cache.shape[1] + assert block_size == 32, "only support block_size 32 for flash attention" + # TODO: support alibi_slopes + assert alibi_slopes is None, "doesn't support alibi_slopes" + num_tokens, num_heads, head_size = query.shape + out = flash_attn_with_page_attention( + query.view(num_tokens, 1, num_heads, head_size), + key_cache, + value_cache, + block_tables, + None, # key + None, # value + None, # cos + None, # sin + context_lens, + None, # cache_batch_idx + softmax_scale=scale, + causal=True, + window_size=(-1, -1), + rotary_interleaved=False, + num_splits=0, + actual_batch_size=actual_batch_size, + ) + if output is not None: + # in case that output is padded, only copy the valid part. + output[:num_tokens].copy_(out.view(num_tokens, num_heads, head_size)) + return out.view(num_tokens, num_heads, head_size) + + +def flash_multi_query_cached_kv_attention_varlen( + output: Optional[torch.Tensor], + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + scale: float, + block_tables: torch.Tensor, + cum_seqlens_q: torch.Tensor, + cum_context_len: torch.Tensor, + block_size: int, + max_query_len: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + actual_batch_size: Optional[torch.Tensor] = None, +): + """Efficient multi-query paged attention based on flash attention. + It calculates attentions of list of sequences packed in a single batch, + indexed by cum_seqlens_q where the seq_i's index is + [cum_seqlens_q[i], cum_seqlensq[i+1]]. + Similarlly, the length of context is stored in cum_seqlens_k with similar + fashions. + It also supports calculating attention incrementally, where context length + is longer than sequence length. + Arguments: + output: [num_padded_tokens, num_heads, head_size], output tensor to + write to. if None an new output tensor will be created. + query: [num_padded_tokens, num_heads, head_size], query tensor. + key_cache: [num_blocks, block_size, num_heads, head_size], key cache. + value_cache: [num_blocks, block_size, num_heads, head_size], + value cache. + scale: attention scale. + block_tables: [num_padded_tokens, max_context_len / block_size], + block tables. + cum_seqlens_q: [padded_batch_size + 1], cumulative sequence lengths + of query. + cum_context_len: [padded_batch_size + 1], cumulative lengths + of context. + block_size: block size. + max_query_len: max query length. + max_context_len: max context length. + alibi_slopes: unused. + actual_batch_size: [1] actual batch size. + Returns: + output: [num_padded_tokens, num_heads, head_size] + """ + block_size = value_cache.shape[1] + assert block_size == 32, "only support block_size 32 for flash attention" + # TODO: support alibi_slopes + assert alibi_slopes is None, "doesn't support alibi_slopes" + + num_tokens, _, _ = query.shape + out = flash_attn_varlen_with_page_attention( + query, + key_cache, + value_cache, + block_tables, + cum_seqlens_q, + cum_context_len, + max_query_len, + max_context_len, + None, # key + None, # value + None, # cos_cache + None, # sin_cache + None, # cache_batch_idx + softmax_scale=scale, + causal=True, + window_size=(-1, -1), + rotary_interleaved=False, + num_splits=0, + actual_batch_size=actual_batch_size, + ) + if output is not None: + output[:num_tokens].copy_(out) + return out \ No newline at end of file From 6947167ea42f592df37ead195dbfb5c6906609bd Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 28 Feb 2024 01:47:52 -0800 Subject: [PATCH 03/88] oss flash attention works --- tests/kernels/test_flash_attention.py | 250 ++---------------------- vllm/model_executor/layers/attention.py | 135 +++---------- 2 files changed, 43 insertions(+), 342 deletions(-) diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index 6081bfcc0886..82c4350489d5 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -1,16 +1,12 @@ import random -from typing import List, Optional, Tuple +from typing import Optional, Tuple import pytest import torch import torch.nn.functional as F -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from vllm.model_executor.layers.attention import ( - flash_single_query_cached_kv_attention, - flash_multi_query_cached_kv_attention_varlen, -) + flash_attn_with_kvcache_paged, ) from vllm.utils import get_max_shared_memory_bytes FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 @@ -25,11 +21,16 @@ NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS_SMALL = NUM_HEADS -HEAD_SIZES = [64, 80, 96, 112, 128, 256] -BLOCK_SIZES = [32] +# head size should be bigger than or equal to block size. +HEAD_SIZES = [256] +# TODO(sang): https://github.com/Dao-AILab/flash-attention/pull/824 +# should fix the block size. But right now, the block size should be +# divisible by 256. +BLOCK_SIZES = [256] USE_ALIBI = [False, True] SEEDS = [0] -PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] +# PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] +PAD_CONFIGS = [(0, 0)] def pad_attention_inputs( @@ -137,22 +138,14 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) -# @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -# @pytest.mark.parametrize("num_heads", NUM_HEADS) -# @pytest.mark.parametrize("head_size", HEAD_SIZES) -# @pytest.mark.parametrize("use_alibi", [False]) -# @pytest.mark.parametrize("block_size", [32]) -# @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) -# @pytest.mark.parametrize("seed", SEEDS) -# @pytest.mark.parametrize("pad_config", PAD_CONFIGS) -@pytest.mark.parametrize("num_seqs", [3]) -@pytest.mark.parametrize("num_heads", [(40, 40), (64, 8)]) -@pytest.mark.parametrize("head_size", [80, 96]) -@pytest.mark.parametrize("use_alibi", [False]) -@pytest.mark.parametrize("block_size", [32]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", [False, True]) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("pad_config", [(0, 0)]) +@pytest.mark.parametrize("pad_config", PAD_CONFIGS) @torch.inference_mode() def test_flash_paged_attention( kv_cache_factory, @@ -180,9 +173,6 @@ def test_flash_paged_attention( assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), - num_queries_per_kv) alibi_slopes = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, @@ -218,15 +208,13 @@ def test_flash_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. - num_valid_tokens = torch.cuda.IntTensor([num_seqs]) output = torch.empty_like(query) padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ pad_attention_inputs(pad_config, block_size, query, block_tables, context_lens, max_context_len) - flash_single_query_cached_kv_attention( - output, + output = flash_attn_with_kvcache_paged( padded_query, key_cache, value_cache, @@ -256,205 +244,3 @@ def test_flash_paged_attention( # implementations, there is a small numerical difference in the two # outputs. Thus, we use a relaxed tolerance for the test. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) - - -def ref_multi_query_kv_attention_padded( - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - cu_seq_lens: List[int], - context_lens: List[int], - scale: float, - dtype: torch.dtype, -) -> torch.Tensor: - num_seqs = len(cu_seq_lens) - 1 - block_size = value_cache.shape[-3] - ref_outputs = [] - - for i in range(num_seqs): - q_start_idx = cu_seq_lens[i] - q_end_idx = cu_seq_lens[i + 1] - seq_len = q_end_idx - q_start_idx - - context_len = context_lens[i] - - block_table = block_tables[i] - keys = [] - values = [] - - for j in range(context_len): - block_number = int(block_table[j // block_size]) - block_offset = j % block_size - - k = key_cache[block_number, block_offset, :, :] - keys.append(k) - - v = value_cache[block_number, block_offset, :, :] - values.append(v) - - keys = torch.stack(keys, dim=0) - values = torch.stack(values, dim=0) - - if num_queries_per_kv > 1: - # Handle MQA and GQA - keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) - values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - - q = query[q_start_idx:q_end_idx, :, :] - k = keys[:context_len, :, :] - v = values[:context_len, :, :] - - assert seq_len <= context_len - - # pad q if seq_len is less than context_len - # this is for correct calculation of attention. - if seq_len < context_len: - indices = [i % seq_len for i in range(context_len - seq_len)] - q_left_pad = q[indices, :, :] - q = torch.cat([q_left_pad, q], dim=0) - - # Create attention mask. - attn_mask = torch.triu(torch.ones(context_len, - context_len, - dtype=dtype), - diagonal=1) - attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype, device="cuda") - - ref_output = ref_masked_attention( - q, - k, - v, - scale, - attn_mask=attn_mask, - ) - ref_outputs.append(ref_output[-seq_len:, :, :]) - ref_output = torch.cat(ref_outputs, dim=0) - return ref_output - - -def is_a100(): - return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 - - -if not is_a100(): - NUM_HEADS_SMALL = [(16, 16), (16, 8)] - MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) - -NUM_BLOCKS = 1024 -BLOCK_SIZE = 32 - - -@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS_SMALL) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("version", ["flash"]) -@pytest.mark.parametrize("seed", SEEDS) -@torch.inference_mode() -def test_multi_query_kv_attention( - num_seqs: int, - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - version: str, - seed: int, -) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. - # As the xformers library is already tested with its own tests, we can use - # a smaller MAX_SEQ_LEN here. - max_len = min(MAX_SEQ_LEN, 4096) - - seq_lens = [random.randint(1, max_len / 2) for i in range(num_seqs)] - max_seq_len = max(seq_lens) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device="cuda") - - context_lens = random.sample(range(max_seq_len, max_len), num_seqs) - max_context_len = max(context_lens) - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device="cuda") - - num_tokens = sum(seq_lens) - cu_seq_lens = [0] - for seq_len in seq_lens: - cu_seq_lens.append(cu_seq_lens[-1] + seq_len) - - cu_context_lens = [0] - for context_len in context_lens: - cu_context_lens.append(cu_context_lens[-1] + context_len) - - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - num_queries_per_kv = num_query_heads // num_kv_heads - - value_cache = torch.empty(NUM_BLOCKS, - BLOCK_SIZE, - num_kv_heads, - head_size, - dtype=dtype, - device="cuda") - key_cache = torch.empty(NUM_BLOCKS, - BLOCK_SIZE, - num_kv_heads, - head_size, - dtype=dtype, - device="cuda") - query = torch.empty(num_tokens, - num_query_heads, - head_size, - dtype=dtype, - device="cuda") - value_cache.uniform_(-scale, scale) - key_cache.uniform_(-scale, scale) - query.uniform_(-scale, scale) - - # Create the block tables. - max_num_blocks_per_seq = (max_context_len + BLOCK_SIZE - 1) // BLOCK_SIZE - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - - output = torch.empty_like(query) - - if version == "flash": - flash_multi_query_cached_kv_attention_varlen( - output, - query, - key_cache, - value_cache, - scale, - block_tables, - torch.cuda.IntTensor(cu_seq_lens), - torch.cuda.IntTensor(cu_context_lens), - BLOCK_SIZE, - max_seq_len, - max_context_len, - None, - ) - else: - assert False, f"{version=} is not supported" - - ref_output = ref_multi_query_kv_attention_padded( - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - cu_seq_lens, - context_lens, - scale, - dtype, - ) - assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) \ No newline at end of file diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 2cb8bdf0a7c8..31ead4841139 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -15,12 +15,11 @@ context_attention_fwd) from vllm.utils import is_hip +# TODO(sang): Support varlen API. try: - from flash_attn import (flash_attn_with_page_attention, - flash_attn_varlen_with_page_attention) -except Exception as e: - flash_attn_with_page_attention = e - flash_attn_varlen_with_page_attention = e + from flash_attn import flash_attn_with_kvcache +except ImportError: + flash_attn_with_kvcache = None _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -356,7 +355,7 @@ def _paged_attention( return output -def flash_single_query_cached_kv_attention( +def flash_attn_with_kvcache_paged( output: Optional[torch.Tensor], query: torch.Tensor, key_cache: torch.Tensor, @@ -365,130 +364,46 @@ def flash_single_query_cached_kv_attention( block_tables: torch.Tensor, context_lens: torch.Tensor, block_size: int, - alibi_slopes: Optional[torch.Tensor], - actual_batch_size: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Similar to vLLM's page attention, caclulates a single token's attention + """Similar to vLLM's page attention, calculates a single token's attention based on key/value caches. The main difference is this uses flash attention - sytle key-value caches. + style key-value caches. + Arguments: output: [num_padded_tokens, num_heads, head_size], output tensor to write. if None an new output tensor will be created. - query: [num_padded_tokens, num_heads, head_size], query tensor. - key_cache: [num_blocks, block_size, num_heads, head_size], key cache. - value_cache: [num_blocks, block_size, num_heads, head_size], value - cache. - scale: attention scale. - block_tables: [num_padded_tokens, max_context_len / block_size], - block tables. - context_lens: [num_padded_tokens], context lengths. - block_size: block size. - alibi_slopes: unused. - actual_batch_size: [1] actual batch size. + See https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py + for other arguments. + Returns: output: [num_padded_tokens, num_heads, head_size] """ block_size = value_cache.shape[1] - assert block_size == 32, "only support block_size 32 for flash attention" - # TODO: support alibi_slopes - assert alibi_slopes is None, "doesn't support alibi_slopes" + assert block_size % 256 == 0, "only support block_size divisible by 256." num_tokens, num_heads, head_size = query.shape - out = flash_attn_with_page_attention( + out = flash_attn_with_kvcache( query.view(num_tokens, 1, num_heads, head_size), key_cache, value_cache, - block_tables, - None, # key - None, # value - None, # cos - None, # sin - context_lens, - None, # cache_batch_idx + # Inplace update is slow. We don't use it. + # We assume kvcache is already updated before + # calling this API. + None, + None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens=context_lens, + cache_batch_idx=None, + block_table=block_tables, softmax_scale=scale, causal=True, window_size=(-1, -1), rotary_interleaved=False, + alibi_slopes=alibi_slopes, num_splits=0, - actual_batch_size=actual_batch_size, ) if output is not None: # in case that output is padded, only copy the valid part. output[:num_tokens].copy_(out.view(num_tokens, num_heads, head_size)) return out.view(num_tokens, num_heads, head_size) - - -def flash_multi_query_cached_kv_attention_varlen( - output: Optional[torch.Tensor], - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - scale: float, - block_tables: torch.Tensor, - cum_seqlens_q: torch.Tensor, - cum_context_len: torch.Tensor, - block_size: int, - max_query_len: int, - max_context_len: int, - alibi_slopes: Optional[torch.Tensor], - actual_batch_size: Optional[torch.Tensor] = None, -): - """Efficient multi-query paged attention based on flash attention. - It calculates attentions of list of sequences packed in a single batch, - indexed by cum_seqlens_q where the seq_i's index is - [cum_seqlens_q[i], cum_seqlensq[i+1]]. - Similarlly, the length of context is stored in cum_seqlens_k with similar - fashions. - It also supports calculating attention incrementally, where context length - is longer than sequence length. - Arguments: - output: [num_padded_tokens, num_heads, head_size], output tensor to - write to. if None an new output tensor will be created. - query: [num_padded_tokens, num_heads, head_size], query tensor. - key_cache: [num_blocks, block_size, num_heads, head_size], key cache. - value_cache: [num_blocks, block_size, num_heads, head_size], - value cache. - scale: attention scale. - block_tables: [num_padded_tokens, max_context_len / block_size], - block tables. - cum_seqlens_q: [padded_batch_size + 1], cumulative sequence lengths - of query. - cum_context_len: [padded_batch_size + 1], cumulative lengths - of context. - block_size: block size. - max_query_len: max query length. - max_context_len: max context length. - alibi_slopes: unused. - actual_batch_size: [1] actual batch size. - Returns: - output: [num_padded_tokens, num_heads, head_size] - """ - block_size = value_cache.shape[1] - assert block_size == 32, "only support block_size 32 for flash attention" - # TODO: support alibi_slopes - assert alibi_slopes is None, "doesn't support alibi_slopes" - - num_tokens, _, _ = query.shape - out = flash_attn_varlen_with_page_attention( - query, - key_cache, - value_cache, - block_tables, - cum_seqlens_q, - cum_context_len, - max_query_len, - max_context_len, - None, # key - None, # value - None, # cos_cache - None, # sin_cache - None, # cache_batch_idx - softmax_scale=scale, - causal=True, - window_size=(-1, -1), - rotary_interleaved=False, - num_splits=0, - actual_batch_size=actual_batch_size, - ) - if output is not None: - output[:num_tokens].copy_(out) - return out \ No newline at end of file From 4769a2636392d4ac1f25b2af758d008de4533f88 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 28 Feb 2024 06:47:18 -0800 Subject: [PATCH 04/88] in progress --- .buildkite/test-pipeline.yaml | 4 + benchmarks/benchmark_latency.py | 40 ++++++++- csrc/cache_kernels.cu | 9 +- requirements.txt | 1 + tests/chunked_prefill/test_correctness.py | 82 ++++++++++++++++++ tests/conftest.py | 3 + tests/kernels/test_cache.py | 2 +- tests/kernels/test_flash_attention.py | 1 - vllm/config.py | 14 +++ vllm/engine/arg_utils.py | 9 +- vllm/model_executor/input_metadata.py | 46 ++++++++++ vllm/model_executor/layers/attention.py | 101 ++++++++++++++-------- vllm/model_executor/models/llama.py | 13 ++- vllm/worker/cache_engine.py | 36 +++++--- vllm/worker/model_runner.py | 1 + vllm/worker/worker.py | 11 +++ 16 files changed, 310 insertions(+), 63 deletions(-) create mode 100644 tests/chunked_prefill/test_correctness.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index efcc4d2d07a1..c16bb4e3da24 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -43,6 +43,10 @@ steps: commands: - pytest -v -s prefix_caching +- label: Chunked Prefill Test + commands: + - pytest -v -s chunked_prefill + - label: Samplers Test command: pytest -v -s samplers --forked diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 6e3b679cb81b..d8083a826911 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -10,6 +10,13 @@ from vllm import LLM, SamplingParams +SAMPLE_PROMPTS = [ + "The president of the United States is", + "Hello, my name is", + "The capital of France is", + "The future of AI is", +] + def main(args: argparse.Namespace): print(args) @@ -57,10 +64,24 @@ def run_to_completion(profile_dir: Optional[str] = None): print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(prompt_token_ids=dummy_prompt_token_ids, - sampling_params=sampling_params, - use_tqdm=False) + if args.use_sample: + batch = ( + SAMPLE_PROMPTS * + (args.batch_size // len(SAMPLE_PROMPTS) + 1))[:args.batch_size] + outputs = llm.generate(prompts=batch, + sampling_params=sampling_params, + use_tqdm=False) + else: + outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids, + sampling_params=sampling_params, + use_tqdm=False) end_time = time.perf_counter() + if args.verbose: + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print( + f"Prompt: {prompt!r}, Generated text: {generated_text!r}") latency = end_time - start_time return latency @@ -145,5 +166,18 @@ def run_to_completion(profile_dir: Optional[str] = None): default="cuda", choices=["cuda"], help='device type for vLLM execution, supporting CUDA only currently.') + parser.add_argument('--flash-style', + action='store_true', + help='enable flash attention') + parser.add_argument('--block-size', + type=int, + default=16, + help='block size of key/value cache') + parser.add_argument('--use-sample', + action='store_true', + help='use sample input instead of dummy input') + parser.add_argument('--verbose', + action='store_true', + help='print generated text') args = parser.parse_args() main(args) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 21057984e1e3..7c44deafea6a 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -278,17 +278,12 @@ __global__ void reshape_and_cache_flash_kernel( scalar_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, head_size] scalar_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, head_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int64_t* __restrict__ num_tokens, // [1] const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size) { - const int64_t num_tokens_ = num_tokens[0]; const int64_t token_idx = blockIdx.x; - if (token_idx >= num_tokens_) { - return; - } const int64_t slot_idx = slot_mapping[token_idx]; const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; @@ -323,8 +318,7 @@ void reshape_and_cache_flash( torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] - torch::Tensor& num_tokens) // [1] + torch::Tensor& slot_mapping) // [num_tokens] { int num_tokens_padded = key.size(0); int num_heads = key.size(1); @@ -347,7 +341,6 @@ void reshape_and_cache_flash( key_cache.data_ptr(), value_cache.data_ptr(), slot_mapping.data_ptr(), - num_tokens.data_ptr(), key_stride, value_stride, num_heads, diff --git a/requirements.txt b/requirements.txt index d4599ec95d94..384993a4e4c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ prometheus_client >= 0.18.0 pynvml == 11.5.0 triton >= 2.1.0 cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. +flash-attn >= 2.5.0 # Required for chunked prefill. diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py new file mode 100644 index 000000000000..abd0e1f7199a --- /dev/null +++ b/tests/chunked_prefill/test_correctness.py @@ -0,0 +1,82 @@ +import gc + +from typing import List + +import pytest +import torch + +from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel + +MODELS = [ + "JackFram/llama-68m", +] + +# SANG-TODO Read it from example.txt +TEST_PROMPTS = [ + # pylint: disable=line-too-long + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", + "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", + "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", + # Different between page attention and flash attention. + # "Describe the basic components of a neural network and how it can be trained.", + "Write a short story about a robot that dreams for the first time.", + "Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.", + "Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.", + "Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'", +] + + +# TODO(sang): Add chunked prefill parameters. +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_models( + vllm_runner, + model: str, + dtype: str, + max_tokens: int, +) -> None: + """ verify the flash attention has the same output + as page attention """ + print("loading page attention models..") + pg_model = vllm_runner(model, dtype=dtype) + expected_outputs = [] + + print("generating tokens...") + expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) + print("generating tokens finished") + + del pg_model + + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + + flash_attn_model = vllm_runner( + model, + dtype=dtype, + enable_cuda_graph=False, + flash_style=True, + ) + flash_attn_output_by_batchs = [] + for i in range(10): + prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] + flash_attn_output_by_batchs.append( + flash_attn_model.generate_greedy(prompts, max_tokens)) + + del flash_attn_model + + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + + for flash_attn_outputs in flash_attn_output_by_batchs: + for i in range(len(flash_attn_outputs)): + fa_output_ids, fa_output_str = flash_attn_outputs[i] + vllm_output_ids, vllm_output_str = expected_outputs[ + i % len(expected_outputs)] + print() + assert fa_output_ids == vllm_output_ids, ( + f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" + f"Test{i}:\nflash ouput: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" + ) diff --git a/tests/conftest.py b/tests/conftest.py index 30a3df89d9f1..b579bd961bb9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -165,6 +165,7 @@ def __init__( dtype: str = "half", disable_log_stats: bool = True, tensor_parallel_size: int = 1, + flash_style: bool = False, **kwargs, ) -> None: self.model = LLM( @@ -175,6 +176,8 @@ def __init__( swap_space=0, disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, + flash_style=flash_style, + block_size=32, **kwargs, ) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 1430b76270c5..4b7bb67b6ee6 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -294,7 +294,7 @@ def pad_key_value(key: torch.Tensor, value: torch.Tensor, padded_key, padded_value = pad_key_value(key, value, padding) # Call the reshape_and_cache kernel. cache_ops.reshape_and_cache_flash(padded_key, padded_value, key_cache, - value_cache, slot_mapping, num_tokens) + value_cache, slot_mapping) # Run the reference implementation. block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index 82c4350489d5..a9c43e31edc0 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -221,7 +221,6 @@ def test_flash_paged_attention( scale, padded_block_table, padded_context_lens, - block_size, alibi_slopes, ) diff --git a/vllm/config.py b/vllm/config.py index bd0dc89b585f..697801b439bc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -60,6 +60,7 @@ class ModelConfig: max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. + flash_style: Enable flash style page attention. """ def __init__( @@ -79,6 +80,7 @@ def __init__( quantization: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, + flash_style: bool = False, ) -> None: self.model = model self.tokenizer = tokenizer @@ -93,6 +95,7 @@ def __init__( self.quantization = quantization self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture + self.flash_style = flash_style if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, @@ -295,12 +298,14 @@ def __init__( swap_space: int, cache_dtype: str, sliding_window: Optional[int] = None, + flash_style: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB self.cache_dtype = cache_dtype self.sliding_window = sliding_window + self.flash_style = flash_style self._verify_args() self._verify_cache_dtype() @@ -314,6 +319,15 @@ def _verify_args(self) -> None: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + if self.flash_style: + logger.info("Flash attention enabled.") + if self.block_size < 256: + # Flash style attention only supports block size >=256 for now. + # https://github.com/Dao-AILab/flash-attention/pull/824 will fix it. + raise ValueError( + "Flash style attention only supports block size >= 256. Got" + f"{self.block_size }") + def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a4efd171b871..e064f357e069 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -45,6 +45,7 @@ class EngineArgs: lora_dtype = 'auto' max_cpu_loras: Optional[int] = None device: str = 'cuda' + flash_style: bool = False def __post_init__(self): if self.tokenizer is None: @@ -271,6 +272,9 @@ def add_cli_args( choices=["cuda"], help=('Device type for vLLM execution. ' 'Currently, only CUDA-compatible devices are supported.')) + parser.add_argument('--flash-style', + action='store_true', + help='use flash attention.') return parser @classmethod @@ -291,11 +295,12 @@ def create_engine_configs( self.trust_remote_code, self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, - self.enforce_eager, self.max_context_len_to_capture) + self.enforce_eager, self.max_context_len_to_capture, self.flash_style) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, - model_config.get_sliding_window()) + model_config.get_sliding_window(), + self.flash_style) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index f0a88ac8e27f..1beccd380355 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -13,6 +13,10 @@ class InputMetadata: context_lens: the length of attention context for each sequence. block_tables: The block tables. (Seq id -> list of physical block) kv_cache_dtype: Data type to store kv cache. + num_prompt_tokens: The number of tokens in the prompts. This might + include padding. + num_generation_tokens: The number of tokens in the generation sequences. + This might include padding. """ def __init__( @@ -27,6 +31,9 @@ def __init__( block_tables: Optional[torch.Tensor], use_cuda_graph: bool, kv_cache_dtype: str, + # SANG-TODO + # num_prompt_tokens: int, + # num_generation_tokens: int, ) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens @@ -43,6 +50,45 @@ def __init__( # FIXME(woosuk): This is a hack. self.attn_bias = None + # SANG-TODO + # # Prompt related metadata + # # This value might include padding if CudaGraph is enabled. + # self.num_prompts = len(prompt_lens) + # # This value is the source of truth. + # self.num_prompts_tensor = torch.cuda.IntTensor([self.num_prompts]) + # # This value might include padding if CudaGraph is enabled. + # self.num_prompt_tokens = num_prompt_tokens + # self.prompt_lens_tensor = torch.cuda.IntTensor(self.prompt_lens) + # self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 + + # # Cumulative prompt lengths for each prompt in the input + # # tensor. + # self.cum_prompt_query_lens = torch.zeros( + # self.num_prompts + 1, + # device=self.prompt_lens_tensor.device, + # dtype=torch.int32) + # # Cumulative context lengths. + # self.cum_prompt_context_lens = torch.zeros( + # self.num_prompts + 1, + # device=self.prompt_lens_tensor.device, + # dtype=torch.int32) + + # torch.cumsum(self.prompt_lens_tensor, + # dim=0, + # dtype=self.cum_prompt_query_lens.dtype, + # out=self.cum_prompt_query_lens[1:]) + + # # TODO: this will be different once we support chunked prefills. + # self.cum_prompt_context_lens = self.cum_prompt_query_lens + # self.max_context_len = max(self.max_context_len, self.max_prompt_len) + + # # Generation related metadata + # # This value might include padding if CudaGraph is enabled. + # self.num_generation_tokens = num_generation_tokens + # # This is the source of truth for the number of generation tokens. + # self.num_generation_tokens_tensor = torch.cuda.IntTensor( + # [num_generation_tokens]) + def __repr__(self) -> str: return ("InputMetadata(" f"is_prompt={self.is_prompt}, " diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 31ead4841139..3c0a7b54ab2b 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -14,6 +14,8 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( context_attention_fwd) from vllm.utils import is_hip +from vllm.model_executor.layers.attention import ( + flash_attn_with_kvcache_paged, ) # TODO(sang): Support varlen API. try: @@ -47,6 +49,7 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + flash_style: bool = False, ) -> None: super().__init__() self.num_heads = num_heads @@ -57,6 +60,7 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) + self.flash_style = flash_style assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -133,14 +137,23 @@ def forward( # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: - cache_ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - input_metadata.slot_mapping.flatten(), - input_metadata.kv_cache_dtype, - ) + if self.flash_style: + cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping.flatten() + ) + else: + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping.flatten(), + input_metadata.kv_cache_dtype, + ) if input_metadata.is_prompt: # normal attention @@ -215,32 +228,56 @@ def forward( else: # prefix-enabled attention output = torch.empty_like(query) - context_attention_fwd( + if self.flash_attention: + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + output = flash_attn_with_kvcache_paged( + query + key_cache: torch.Tensor, + value_cache: torch.Tensor, + self.scale + input_metadata.block_tables, + input_metadata.context_lens, + self.alibi_slopes, + ) + else: + context_attention_fwd( + query, + key, + value, + output, + key_cache, + value_cache, + input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata.start_loc, + input_metadata.prompt_lens, + input_metadata.context_lens, + input_metadata.max_seq_len, + getattr(self, "alibi_slopes", None), + ) + + else: + # Decoding run. + if self.flash_style: + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + output = flash_attn_with_kvcache_paged( query, - key, - value, - output, key_cache, value_cache, - input_metadata.block_tables, # [BS, max_block_per_request] - input_metadata.start_loc, - input_metadata.prompt_lens, + self.scale, + input_metadata.block_tables, input_metadata.context_lens, - input_metadata.max_seq_len, - getattr(self, "alibi_slopes", None), + self.alibi_slopes + ) + else: + output = _paged_attention( + query, + key_cache, + value_cache, + input_metadata, + self.num_kv_heads, + self.scale, + self.alibi_slopes, ) - - else: - # Decoding run. - output = _paged_attention( - query, - key_cache, - value_cache, - input_metadata, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - ) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) @@ -305,6 +342,7 @@ def _paged_attention( # For context len > 8192, use V2 kernel to avoid shared memory shortage. use_v1 = input_metadata.max_context_len <= 8192 and ( max_num_partitions == 1 or num_seqs * num_heads > 512) + if use_v1: # Run PagedAttention V1. ops.paged_attention_v1( @@ -356,14 +394,12 @@ def _paged_attention( def flash_attn_with_kvcache_paged( - output: Optional[torch.Tensor], query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, scale: float, block_tables: torch.Tensor, context_lens: torch.Tensor, - block_size: int, alibi_slopes: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Similar to vLLM's page attention, calculates a single token's attention @@ -403,7 +439,4 @@ def flash_attn_with_kvcache_paged( alibi_slopes=alibi_slopes, num_splits=0, ) - if output is not None: - # in case that output is padded, only copy the valid part. - output[:num_tokens].copy_(out.view(num_tokens, num_heads, head_size)) return out.view(num_tokens, num_heads, head_size) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b7f6b8f3ec37..9fd4425bca02 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -93,6 +93,7 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, bias: bool = False, sliding_window: Optional[int] = None, + flash_style: bool = False, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -143,7 +144,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + sliding_window=sliding_window, + flash_style=flash_style) def forward( self, @@ -151,6 +153,7 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, + flash_style: bool = False, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -167,6 +170,7 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + flash_style: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -186,6 +190,7 @@ def __init__( linear_method=linear_method, bias=getattr(config, "bias", False), sliding_window=sliding_window, + flash_style=flash_style, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -234,6 +239,7 @@ def __init__( config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, lora_config: Optional[LoRAConfig] = None, + flash_style: bool = False, ) -> None: super().__init__() self.config = config @@ -248,7 +254,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method) + LlamaDecoderLayer(config, linear_method, flash_style) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -308,11 +314,12 @@ def __init__( config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, lora_config: Optional[LoRAConfig] = None, + flash_style: bool = False, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = LlamaModel(config, linear_method, lora_config=lora_config) + self.model = LlamaModel(config, linear_method, lora_config=lora_config, flash_style=flash_style) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index bbe33989fc2a..80d7b68b884b 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -57,19 +57,33 @@ def __init__( def get_key_block_shape(self) -> Tuple[int, int, int, int]: element_size = torch.tensor([], dtype=self.dtype).element_size() x = 16 // element_size - return ( - self.num_heads, - self.head_size // x, - self.block_size, - x, - ) + if self.cache_config.flash_style: + return ( + self.block_size, + self.num_heads, + self.head_size, + ) + else: + return ( + self.num_heads, + self.head_size // x, + self.block_size, + x, + ) def get_value_block_shape(self) -> Tuple[int, int, int]: - return ( - self.num_heads, - self.head_size, - self.block_size, - ) + if self.cache_config.flash_style: + return ( + self.block_size, + self.num_heads, + self.head_size, + ) + else: + return ( + self.num_heads, + self.head_size, + self.block_size, + ) def allocate_gpu_cache(self) -> List[KVCache]: gpu_cache: List[KVCache] = [] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b99a409e02d1..467ad28226b4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -475,6 +475,7 @@ def prepare_input_tensors( # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt + # SANG-TODO set num prompt tokens and generations? # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_metadata, prompt_lens, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9df518d155ec..870c7b7edd40 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -21,6 +21,8 @@ from vllm.lora.request import LoRARequest from vllm.utils import is_hip +MAX_INT_32 = 2**31 - 1 + class Worker: """A worker class that executes (a partition of) the model on a GPU. @@ -143,6 +145,15 @@ def profile_num_available_blocks( self.model_runner.remove_all_loras() gc.collect() torch.cuda.empty_cache() + + # Flash attention only allows max(int32) slots. + # TODO: Remove this once we support int64 in flash-attention. + if self.model_config.flash_style: + num_gpu_blocks = min(num_gpu_blocks, + MAX_INT_32 // cache_block_size) + num_cpu_blocks = min(num_cpu_blocks, + MAX_INT_32 // cache_block_size) + return num_gpu_blocks, num_cpu_blocks def init_cache_engine(self, cache_config: CacheConfig) -> None: From 963db44683aabeb8ac78774b8618b874403cc7cb Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 00:46:13 -0800 Subject: [PATCH 05/88] flash attn enabled. --- benchmarks/benchmark_latency.py | 17 ++++--- csrc/cache.h | 3 +- csrc/cache_kernels.cu | 7 ++- tests/chunked_prefill/test_correctness.py | 14 +++--- tests/conftest.py | 3 +- tests/kernels/test_cache.py | 22 ++------ vllm/engine/arg_utils.py | 3 +- vllm/model_executor/input_metadata.py | 28 ++++------- vllm/model_executor/layers/attention.py | 61 +++++++++-------------- vllm/model_executor/model_loader.py | 2 +- vllm/model_executor/models/__init__.py | 22 +++++++- vllm/model_executor/models/llama.py | 13 ++--- vllm/worker/model_runner.py | 8 +-- 13 files changed, 93 insertions(+), 110 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index d8083a826911..b915f913e7d1 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -65,23 +65,24 @@ def run_to_completion(profile_dir: Optional[str] = None): else: start_time = time.perf_counter() if args.use_sample: - batch = ( - SAMPLE_PROMPTS * - (args.batch_size // len(SAMPLE_PROMPTS) + 1))[:args.batch_size] + batch = (SAMPLE_PROMPTS * + (args.batch_size // len(SAMPLE_PROMPTS) + + 1))[:args.batch_size] outputs = llm.generate(prompts=batch, - sampling_params=sampling_params, - use_tqdm=False) + sampling_params=sampling_params, + use_tqdm=False) else: outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids, - sampling_params=sampling_params, - use_tqdm=False) + sampling_params=sampling_params, + use_tqdm=False) end_time = time.perf_counter() if args.verbose: for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print( - f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + f"Prompt: {prompt!r}, Generated text: {generated_text!r}" + ) latency = end_time - start_time return latency diff --git a/csrc/cache.h b/csrc/cache.h index ce3f96d59e83..1bca7e4e39a9 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -28,8 +28,7 @@ void reshape_and_cache_flash( torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - torch::Tensor& num_tokens); + torch::Tensor& slot_mapping); // Just for unittest void convert_fp8_e5m2( diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 7c44deafea6a..06da455117b0 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -174,6 +174,7 @@ __global__ void reshape_and_cache_kernel( const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size; for (int i = threadIdx.x; i < n; i += blockDim.x) { const int64_t src_key_idx = token_idx * key_stride + i; @@ -285,6 +286,10 @@ __global__ void reshape_and_cache_flash_kernel( const int block_size) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; @@ -318,7 +323,7 @@ void reshape_and_cache_flash( torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping) // [num_tokens] + torch::Tensor& slot_mapping) { int num_tokens_padded = key.size(0); int num_heads = key.size(1); diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index abd0e1f7199a..1cabbbb9e3e1 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -11,7 +11,6 @@ "JackFram/llama-68m", ] -# SANG-TODO Read it from example.txt TEST_PROMPTS = [ # pylint: disable=line-too-long "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", @@ -30,11 +29,13 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("block_size", [256]) def test_models( vllm_runner, model: str, dtype: str, max_tokens: int, + block_size: int, ) -> None: """ verify the flash attention has the same output as page attention """ @@ -52,12 +53,10 @@ def test_models( gc.collect() torch.cuda.empty_cache() - flash_attn_model = vllm_runner( - model, - dtype=dtype, - enable_cuda_graph=False, - flash_style=True, - ) + flash_attn_model = vllm_runner(model, + dtype=dtype, + flash_style=True, + block_size=block_size) flash_attn_output_by_batchs = [] for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] @@ -75,7 +74,6 @@ def test_models( fa_output_ids, fa_output_str = flash_attn_outputs[i] vllm_output_ids, vllm_output_str = expected_outputs[ i % len(expected_outputs)] - print() assert fa_output_ids == vllm_output_ids, ( f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" f"Test{i}:\nflash ouput: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" diff --git a/tests/conftest.py b/tests/conftest.py index b579bd961bb9..4ccac964b999 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -165,6 +165,7 @@ def __init__( dtype: str = "half", disable_log_stats: bool = True, tensor_parallel_size: int = 1, + block_size: int = 32, flash_style: bool = False, **kwargs, ) -> None: @@ -177,7 +178,7 @@ def __init__( disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, flash_style=flash_style, - block_size=32, + block_size=block_size, **kwargs, ) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 4b7bb67b6ee6..23adc10df7cf 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -26,7 +26,6 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] -PADDINGS = [8, 16, 0] @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @@ -235,7 +234,6 @@ def test_swap_blocks( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("padding", PADDINGS) @torch.inference_mode() def test_reshape_and_cache_flash( kv_cache_factory, @@ -246,7 +244,6 @@ def test_reshape_and_cache_flash( num_blocks: int, dtype: torch.dtype, seed: int, - padding: int, ) -> None: random.seed(seed) torch.random.manual_seed(seed) @@ -274,27 +271,16 @@ def test_reshape_and_cache_flash( dtype, seed, flash_style=True) - assert len(key_caches) == 1 and len(value_caches) == 0 + assert len(key_caches) == 1 and len(value_caches) == 1 key_cache, value_cache = key_caches[0], value_caches[0] # Clone the KV caches. cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() - num_tokens = torch.tensor([num_tokens], dtype=torch.long, device="cuda") - - def pad_key_value(key: torch.Tensor, value: torch.Tensor, - pad_size: int) -> Tuple[torch.Tensor, torch.Tensor]: - if pad_size == 0: - return key, value - return F.pad(key, (0, 0, 0, 0, 0, pad_size)),\ - F.pad(value, (0, 0, 0, 0, 0, pad_size)) - - # kv shapes: (num_blocks, block_size, num_heads, head_size) - # pad tokens. - padded_key, padded_value = pad_key_value(key, value, padding) + # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache_flash(padded_key, padded_value, key_cache, - value_cache, slot_mapping) + cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping) # Run the reference implementation. block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e064f357e069..00c9f688e2a2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -295,7 +295,8 @@ def create_engine_configs( self.trust_remote_code, self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, - self.enforce_eager, self.max_context_len_to_capture, self.flash_style) + self.enforce_eager, self.max_context_len_to_capture, + self.flash_style) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 1beccd380355..97c23cdda069 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -9,6 +9,7 @@ class InputMetadata: Args: prompt_lens: Lengths of prompts. slot_mapping: The address to write the new KV to of each token. + index: token_id, value: address within kv_cache. max_context_len: The maximum context length. context_lens: the length of attention context for each sequence. block_tables: The block tables. (Seq id -> list of physical block) @@ -19,22 +20,13 @@ class InputMetadata: This might include padding. """ - def __init__( - self, - is_prompt: bool, - slot_mapping: torch.Tensor, - prompt_lens: Optional[torch.Tensor], - max_seq_len: Optional[int], - start_loc: Optional[torch.Tensor], - max_context_len: Optional[int], - context_lens: Optional[torch.Tensor], - block_tables: Optional[torch.Tensor], - use_cuda_graph: bool, - kv_cache_dtype: str, - # SANG-TODO - # num_prompt_tokens: int, - # num_generation_tokens: int, - ) -> None: + def __init__(self, is_prompt: bool, slot_mapping: torch.Tensor, + prompt_lens: Optional[torch.Tensor], + max_seq_len: Optional[int], start_loc: Optional[torch.Tensor], + max_context_len: Optional[int], + context_lens: Optional[torch.Tensor], + block_tables: Optional[torch.Tensor], use_cuda_graph: bool, + kv_cache_dtype: str, flash_style: bool) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens self.max_seq_len = max_seq_len @@ -45,6 +37,7 @@ def __init__( self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph self.kv_cache_dtype = kv_cache_dtype + self.flash_style = flash_style # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. @@ -97,4 +90,5 @@ def __repr__(self) -> str: f"context_lens={self.context_lens}, " f"block_tables={self.block_tables}, " f"use_cuda_graph={self.use_cuda_graph}, " - f"kv_cache_dtype={self.kv_cache_dtype})") + f"kv_cache_dtype={self.kv_cache_dtype}), " + f"flash_style={self.flash_style}") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 3c0a7b54ab2b..51c47368e831 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -14,10 +14,7 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( context_attention_fwd) from vllm.utils import is_hip -from vllm.model_executor.layers.attention import ( - flash_attn_with_kvcache_paged, ) -# TODO(sang): Support varlen API. try: from flash_attn import flash_attn_with_kvcache except ImportError: @@ -49,7 +46,6 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, - flash_style: bool = False, ) -> None: super().__init__() self.num_heads = num_heads @@ -60,7 +56,6 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) - self.flash_style = flash_style assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -137,14 +132,10 @@ def forward( # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: - if self.flash_style: + if input_metadata.flash_style: cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - input_metadata.slot_mapping.flatten() - ) + key, value, key_cache, value_cache, + input_metadata.slot_mapping.flatten()) else: cache_ops.reshape_and_cache( key, @@ -226,20 +217,20 @@ def forward( ) output = out.view_as(query) else: - # prefix-enabled attention - output = torch.empty_like(query) - if self.flash_attention: - query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + if input_metadata.flash_style: output = flash_attn_with_kvcache_paged( - query - key_cache: torch.Tensor, - value_cache: torch.Tensor, - self.scale + query.view(batch_size, seq_len, self.num_heads, + self.head_size), + key_cache, + value_cache, + self.scale, input_metadata.block_tables, input_metadata.context_lens, self.alibi_slopes, ) else: + # prefix-enabled attention + output = torch.empty_like(query) context_attention_fwd( query, key, @@ -247,7 +238,8 @@ def forward( output, key_cache, value_cache, - input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata. + block_tables, # [BS, max_block_per_request] input_metadata.start_loc, input_metadata.prompt_lens, input_metadata.context_lens, @@ -257,17 +249,12 @@ def forward( else: # Decoding run. - if self.flash_style: - query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + if input_metadata.flash_style: output = flash_attn_with_kvcache_paged( - query, - key_cache, - value_cache, - self.scale, - input_metadata.block_tables, - input_metadata.context_lens, - self.alibi_slopes - ) + query.view(batch_size, seq_len, self.num_heads, + self.head_size), key_cache, value_cache, + self.scale, input_metadata.block_tables, + input_metadata.context_lens, self.alibi_slopes) else: output = _paged_attention( query, @@ -407,19 +394,17 @@ def flash_attn_with_kvcache_paged( style key-value caches. Arguments: - output: [num_padded_tokens, num_heads, head_size], output tensor - to write. if None an new output tensor will be created. See https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py for other arguments. Returns: - output: [num_padded_tokens, num_heads, head_size] + output: [num_tokens, num_heads, head_size] """ block_size = value_cache.shape[1] assert block_size % 256 == 0, "only support block_size divisible by 256." - num_tokens, num_heads, head_size = query.shape + _, _, num_heads, head_size = query.shape out = flash_attn_with_kvcache( - query.view(num_tokens, 1, num_heads, head_size), + query, key_cache, value_cache, # Inplace update is slow. We don't use it. @@ -439,4 +424,6 @@ def flash_attn_with_kvcache_paged( alibi_slopes=alibi_slopes, num_splits=0, ) - return out.view(num_tokens, num_heads, head_size) + + # num_tokens == batch_size * seqlen + return out.view(-1, num_heads, head_size) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index ebe092b5d62b..40a0a9cf394b 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -29,7 +29,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: architectures = ["QuantMixtralForCausalLM"] for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) + model_cls = ModelRegistry.load_model_cls(arch, model_config) if model_cls is not None: return model_cls raise ValueError( diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 66d28207d664..5ed1cb5dbc5c 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -5,6 +5,7 @@ from vllm.logger import init_logger from vllm.utils import is_hip +from vllm.config import ModelConfig logger = init_logger(__name__) @@ -61,11 +62,18 @@ "Sliding window attention is not yet supported in ROCm's flash attention", } +_MODEL_CLASSES_SUPPORT_FLASH_ATTN = { + arch_and_class[1] + for _, arch_and_class in _MODELS.items() + if arch_and_class[1] == "LlamaForCausalLM" +} + class ModelRegistry: @staticmethod - def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + def load_model_cls(model_arch: str, + model_config: ModelConfig) -> Optional[Type[nn.Module]]: if model_arch not in _MODELS: return None if is_hip(): @@ -81,7 +89,17 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: module_name, model_cls_name = _MODELS[model_arch] module = importlib.import_module( f"vllm.model_executor.models.{module_name}") - return getattr(module, model_cls_name, None) + model_cls = getattr(module, model_cls_name, None) + + if (model_config.flash_style + and model_cls_name not in _MODEL_CLASSES_SUPPORT_FLASH_ATTN): + raise ValueError( + f"{model_config.model} doesn't support " + "flash attention in vLLM, but " + "flash_style=True is given. Choose one of models, " + f"{_MODEL_CLASSES_SUPPORT_FLASH_ATTN} to use " + "flash_style=True.") + return model_cls @staticmethod def get_supported_archs() -> List[str]: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 9fd4425bca02..b7f6b8f3ec37 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -93,7 +93,6 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, bias: bool = False, sliding_window: Optional[int] = None, - flash_style: bool = False, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -144,8 +143,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window, - flash_style=flash_style) + sliding_window=sliding_window) def forward( self, @@ -153,7 +151,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - flash_style: bool = False, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -170,7 +167,6 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, - flash_style: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -190,7 +186,6 @@ def __init__( linear_method=linear_method, bias=getattr(config, "bias", False), sliding_window=sliding_window, - flash_style=flash_style, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -239,7 +234,6 @@ def __init__( config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, lora_config: Optional[LoRAConfig] = None, - flash_style: bool = False, ) -> None: super().__init__() self.config = config @@ -254,7 +248,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method, flash_style) + LlamaDecoderLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -314,12 +308,11 @@ def __init__( config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, lora_config: Optional[LoRAConfig] = None, - flash_style: bool = False, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = LlamaModel(config, linear_method, lora_config=lora_config, flash_style=flash_style) + self.model = LlamaModel(config, linear_method, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 467ad28226b4..84470d0fa114 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -249,7 +249,7 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, - ) + flash_style=self.model_config.flash_style) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -377,7 +377,7 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, - ) + flash_style=self.model_config.flash_style) return (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -539,7 +539,7 @@ def prepare_input_tensors( block_tables=metadata_dict["block_tables"], use_cuda_graph=metadata_dict["use_cuda_graph"], kv_cache_dtype=metadata_dict["kv_cache_dtype"], - ) + flash_style=self.model_config.flash_style) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -723,7 +723,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: block_tables=block_tables[:batch_size], use_cuda_graph=True, kv_cache_dtype=self.kv_cache_dtype, - ) + flash_style=self.model_config.flash_style) if self.lora_config: lora_mapping = LoRAMapping( From 2c1bb6c9e6887c0e8efc1cada38c638bc4004529 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 00:51:14 -0800 Subject: [PATCH 06/88] support every model --- tests/chunked_prefill/test_correctness.py | 2 +- vllm/model_executor/models/__init__.py | 15 --------------- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index b79156d2d48a..d63ac9f5429e 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -8,7 +8,7 @@ from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel MODELS = [ - "JackFram/llama-68m", + # "JackFram/llama-68m", "facebook/opt-125m", ] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index b624395c27b5..ea150ef9c6f5 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -62,13 +62,6 @@ "Sliding window attention is not yet supported in ROCm's flash attention", } -# TODO(sang): We may not need this. I think it should work with everything. -_MODEL_CLASSES_SUPPORT_FLASH_ATTN = { - arch_and_class[1] - for _, arch_and_class in _MODELS.items() - if arch_and_class[1] == "LlamaForCausalLM" -} - # Models not supported by Neuron. _NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"} @@ -102,14 +95,6 @@ def load_model_cls(model_arch: str, f"vllm.model_executor.models.{module_name}") model_cls = getattr(module, model_cls_name, None) - if (model_config.flash_style - and model_cls_name not in _MODEL_CLASSES_SUPPORT_FLASH_ATTN): - raise ValueError( - f"{model_config.model} doesn't support " - "flash attention in vLLM, but " - "flash_style=True is given. Choose one of models, " - f"{_MODEL_CLASSES_SUPPORT_FLASH_ATTN} to use " - "flash_style=True.") return model_cls @staticmethod From 2bb5e624a4ed843613b7439ef418eb21dac2fa50 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 01:28:13 -0800 Subject: [PATCH 07/88] Fixed broken tests. --- tests/chunked_prefill/test_correctness.py | 2 +- tests/worker/test_model_runner.py | 2 + vllm/worker/model_runner.py | 53 ++++++++++++----------- 3 files changed, 30 insertions(+), 27 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index d63ac9f5429e..6659b83bfcb6 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -77,5 +77,5 @@ def test_models( i % len(expected_outputs)] assert fa_output_ids == vllm_output_ids, ( f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" - f"Test{i}:\nflash ouput: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" + f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" ) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index f44895a728c7..94c9b45157ec 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -3,6 +3,8 @@ from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + DeviceConfig) def test_prepare_prompt(): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fa337a68f87d..e95b147cb25d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -55,6 +55,9 @@ def __init__( # FIXME(woosuk): This is a hack to make the tests work. Refactor this. self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None) + self.flash_style = (self.model_config.flash_style + if model_config is not None else False) + self.device_config = (device_config if device_config is not None else DeviceConfig()) self.device = self.device_config.device @@ -245,18 +248,17 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - input_metadata = InputMetadata( - is_prompt=True, - slot_mapping=slot_mapping, - prompt_lens=prompt_lens_tensor, - max_seq_len=max_prompt_len, - start_loc=start_loc_tensor, - max_context_len=None, - context_lens=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.model_config.flash_style) + input_metadata = InputMetadata(is_prompt=True, + slot_mapping=slot_mapping, + prompt_lens=prompt_lens_tensor, + max_seq_len=max_prompt_len, + start_loc=start_loc_tensor, + max_context_len=None, + context_lens=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + kv_cache_dtype=self.kv_cache_dtype, + flash_style=self.flash_style) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -373,18 +375,17 @@ def _prepare_decode( _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping ] - input_metadata = InputMetadata( - is_prompt=False, - slot_mapping=slot_mapping, - prompt_lens=None, - max_seq_len=None, - start_loc=None, - max_context_len=max_context_len, - context_lens=context_lens, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.model_config.flash_style) + input_metadata = InputMetadata(is_prompt=False, + slot_mapping=slot_mapping, + prompt_lens=None, + max_seq_len=None, + start_loc=None, + max_context_len=max_context_len, + context_lens=context_lens, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + kv_cache_dtype=self.kv_cache_dtype, + flash_style=self.flash_style) return (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -547,7 +548,7 @@ def prepare_input_tensors( block_tables=metadata_dict["block_tables"], use_cuda_graph=metadata_dict["use_cuda_graph"], kv_cache_dtype=metadata_dict["kv_cache_dtype"], - flash_style=self.model_config.flash_style) + flash_style=self.flash_style) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -731,7 +732,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: block_tables=block_tables[:batch_size], use_cuda_graph=True, kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.model_config.flash_style) + flash_style=self.flash_style) if self.lora_config: lora_mapping = LoRAMapping( From 4d6a05f405ea553946753e544ab9087cd4fb34f0 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 01:39:18 -0800 Subject: [PATCH 08/88] [2/n] scheduler changes --- benchmarks/benchmark_latency.py | 5 + tests/chunked_prefill/test_correctness.py | 24 +- tests/conftest.py | 6 + tests/core/test_scheduler.py | 446 ++++++++++++++++++++++ 4 files changed, 477 insertions(+), 4 deletions(-) create mode 100644 tests/core/test_scheduler.py diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index b915f913e7d1..51c792d3756a 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -33,6 +33,9 @@ def main(args: argparse.Namespace): enforce_eager=args.enforce_eager, kv_cache_dtype=args.kv_cache_dtype, device=args.device, + flash_style=args.flash_style, + max_chunked_prefill_len=args.max_chunked_prefill_len, + max_num_prompt_seqs=args.max_num_prompt_seqs, ) sampling_params = SamplingParams( @@ -180,5 +183,7 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument('--verbose', action='store_true', help='print generated text') + parser.add_argument('--max-chunked-prefill-len', type=int, default=-1) + parser.add_argument('--max-num-prompt-seqs', type=int, default=1000) args = parser.parse_args() main(args) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index 6659b83bfcb6..0bf0ddbf8a9f 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -31,15 +31,26 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("block_size", [256]) +@pytest.mark.parametrize("max_chunked_prefill_len", [-1, 16, 64]) +@pytest.mark.parametrize("max_num_prompt_seqs", [1, 2, 100]) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) def test_models( vllm_runner, model: str, dtype: str, max_tokens: int, block_size: int, + max_chunked_prefill_len: int, + max_num_prompt_seqs: int, + tensor_parallel_size: int, ) -> None: """ verify the flash attention has the same output as page attention """ + if torch.cuda.device_count() < tensor_parallel_size: + pytest.skip( + f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}" + ) + print("loading page attention models..") pg_model = vllm_runner(model, dtype=dtype) expected_outputs = [] @@ -54,10 +65,15 @@ def test_models( gc.collect() torch.cuda.empty_cache() - flash_attn_model = vllm_runner(model, - dtype=dtype, - flash_style=True, - block_size=block_size) + flash_attn_model = vllm_runner( + model, + dtype=dtype, + flash_style=True, + block_size=block_size, + max_chunked_prefill_len=max_chunked_prefill_len, + max_num_prompt_seqs=max_num_prompt_seqs, + tensor_parallel_size=tensor_parallel_size) + flash_attn_output_by_batchs = [] for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] diff --git a/tests/conftest.py b/tests/conftest.py index 4ccac964b999..ad499265f7ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -167,6 +167,9 @@ def __init__( tensor_parallel_size: int = 1, block_size: int = 32, flash_style: bool = False, + max_chunked_prefill_len: int = -1, + max_num_prompt_seqs: int = 1000, + max_num_batched_tokens: int = 4096, **kwargs, ) -> None: self.model = LLM( @@ -179,6 +182,9 @@ def __init__( tensor_parallel_size=tensor_parallel_size, flash_style=flash_style, block_size=block_size, + max_chunked_prefill_len=max_chunked_prefill_len, + max_num_prompt_seqs=max_num_prompt_seqs, + max_num_batched_tokens=max_num_batched_tokens, **kwargs, ) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py new file mode 100644 index 000000000000..7b8ebafec4f4 --- /dev/null +++ b/tests/core/test_scheduler.py @@ -0,0 +1,446 @@ +# This library may only be used in the Anyscale Platform. +# Notwithstanding the terms of any license or notice within this container, +# you may not modify, copy or remove this file. +# Your right to use this library is subject to the +# Anyscale Terms of Service (anyscale.com/terms) +# or other written agreement between you and Anyscale. + +# Copyright (2023 and onwards) Anyscale, Inc. +# This Software includes software developed at Anyscale (anyscale.com/) +# and its use is subject to the included LICENSE file. + +import imp +from typing import List +import pytest +import time + +from vllm.config import CacheConfig, SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.sequence import SampleLogprob, SequenceGroup +from tests.utils import round_up_to_next_block + +from .utils import create_dummy_prompt + + +def test_scheduler_add_seq_group(): + block_size = 4 + scheduler_config = SchedulerConfig(100, 64, 1) + cache_config = CacheConfig(block_size, 1.0, 1) + cache_config.num_cpu_blocks = 4 + cache_config.num_gpu_blocks = 4 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq group to scheduler. + num_seq_group = 4 + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), block_size) + scheduler.add_seq_group(seq_group) + assert scheduler.get_num_unfinished_seq_groups() == i + 1 + + +def test_scheduler_abort_seq_group(): + block_size = 4 + scheduler_config = SchedulerConfig(100, 64, 1) + cache_config = CacheConfig(block_size, 1.0, 1) + cache_config.num_cpu_blocks = 4 + cache_config.num_gpu_blocks = 4 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add multiple seq groups to scheduler. + num_seq_group = 4 + request_ids = set() + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), block_size) + scheduler.add_seq_group(seq_group) + request_ids.add(str(i)) + + # Abort all added seq groups. + assert scheduler.get_num_unfinished_seq_groups() == num_seq_group + scheduler.abort_seq_group(request_ids) + assert scheduler.get_num_unfinished_seq_groups() == 0 + + +def test_scheduler_schedule_simple(): + block_size = 4 + num_seq_group = 4 + max_model_len = 16 + scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) + cache_config = CacheConfig(block_size, 1.0, 1) + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + running: List[SequenceGroup] = [] + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + num_processed_token_ids=block_size - + 1) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Schedule seq groups prompts. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs( + )[0].get_len() + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == num_seq_group + + # Schedule seq groups generation. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + assert out.num_batched_tokens == num_seq_group + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == num_seq_group + + +def test_scheduler_schedule_preempt_abort(): + block_size = 4 + max_model_len = 16 + scheduler_config = SchedulerConfig(64, 2, max_model_len) + cache_config = CacheConfig(block_size, 1.0, 1) + cache_config.num_cpu_blocks = 2 + cache_config.num_gpu_blocks = 2 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + seq_a, seq_group_a = create_dummy_prompt("1", block_size) + seq_b, seq_group_b = create_dummy_prompt("2", block_size) + scheduler.add_seq_group(seq_group_a) + scheduler.add_seq_group(seq_group_b) + + # Schedule seq groups prompts. + seq_group_meta, out = scheduler.schedule() + assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] + assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 2 + assert scheduler.get_num_unfinished_seq_groups() == 2 + + # Append "generated" tokens, allowing the sequence to mark prompt tokens as + # processed. + token_id = 0 + seq_a.append_token_id(token_id, {token_id: SampleLogprob(0.0)}) + seq_b.append_token_id(token_id, {token_id: SampleLogprob(0.0)}) + + # Schedule seq groups generation and preempt seq group b. + seq_group_meta, out = scheduler.schedule() + assert out.scheduled_seq_groups == [seq_group_a] + assert out.num_batched_tokens == 1 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 1 + assert scheduler.get_num_unfinished_seq_groups() == 2 + + # Abort seq group a. Re-schedule seq group b prompt with recomputation. + scheduler.abort_seq_group("1") + seq_group_meta, out = scheduler.schedule() + assert out.scheduled_seq_groups == [seq_group_b] + assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len() + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 1 + assert scheduler.get_num_unfinished_seq_groups() == 1 + + +@pytest.mark.parametrize("block_size", [1, 8, 16, 32]) +@pytest.mark.parametrize("prompt_len", [1, 128]) +@pytest.mark.parametrize("num_unprocessed_tokens", [2, 17, 128]) +@pytest.mark.parametrize("num_seq_group", [1, 4, 16]) +def test_can_schedule_seqs_with_multiple_unprocessed_tokens( + block_size: int, prompt_len: int, num_unprocessed_tokens: int, + num_seq_group: int): + """Verify scheduler can schedule sequences with more than one unprocessed + tokens. This occurs when the worker emits more than one token. + """ + max_model_len = 2048 + scheduler_config = SchedulerConfig(max_num_batched_tokens=max_model_len, + max_num_seqs=num_seq_group, + max_model_len=max_model_len) + cache_config = CacheConfig(block_size=block_size, + gpu_memory_utilization=1.0, + swap_space=0) + cache_config.num_cpu_blocks = 0 + cache_config.num_gpu_blocks = 8192 // block_size + scheduler = Scheduler(scheduler_config, cache_config, None) + + prompt_lens = [prompt_len for _ in range(num_seq_group)] + + token_ids_to_append = [ + list(range(num_unprocessed_tokens)) for _ in range(num_seq_group) + ] + + # Add seq groups to scheduler. + running: List[SequenceGroup] = [] + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(request_id=str(i), + prompt_length=prompt_lens[i], + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Schedule seq groups prompts. + _, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + + # Add tokens to sequences + for seq_group in out.scheduled_seq_groups: + for i, seq in enumerate(seq_group.get_seqs()): + + seq.append_token_ids(token_ids_to_append[i], + logprobs=[{ + token_id: SampleLogprob(logprob=0.0) + for token_id in token_ids_to_append[i] + }]) + + # Schedule seq groups generation. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + + for seq_group_metadata in seq_group_meta: + # Only one seq per group in this test. + seq_id = next(iter(seq_group_metadata.seq_data.keys())) + + block_table = seq_group_metadata.block_tables[seq_id] + blocks_required = (seq_group_metadata.seq_data[seq_id].get_len() - 1 + + block_size) // block_size + assert len(block_table) == blocks_required + + +@pytest.mark.parametrize("block_size", [1, 8, 16, 32]) +@pytest.mark.parametrize("prompt_len", [1, 128]) +@pytest.mark.parametrize("num_unprocessed_tokens", [1, 9]) +@pytest.mark.parametrize("num_preallocated_slots_per_step", [1, 9]) +@pytest.mark.parametrize("num_seq_group", [1, 4]) +def test_can_schedule_multiple_steps(block_size: int, prompt_len: int, + num_preallocated_slots_per_step: int, + num_unprocessed_tokens: int, + num_seq_group: int): + """Verify correct scheduling when the model runs more than one step per + scheduler iteration. + """ + max_model_len = 2048 + scheduler_config = SchedulerConfig( + max_num_batched_tokens=max_model_len, + max_num_seqs=num_seq_group, + max_model_len=max_model_len, + num_preallocated_slots_per_step=num_preallocated_slots_per_step) + cache_config = CacheConfig(block_size=block_size, + gpu_memory_utilization=1.0, + swap_space=0) + cache_config.num_cpu_blocks = 0 + cache_config.num_gpu_blocks = 8192 // block_size + scheduler = Scheduler(scheduler_config, cache_config, None) + + prompt_lens = [prompt_len for _ in range(num_seq_group)] + + token_ids_to_append = [ + list(range(num_unprocessed_tokens)) for _ in range(num_seq_group) + ] + + # Add seq groups to scheduler. + running: List[SequenceGroup] = [] + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(request_id=str(i), + prompt_length=prompt_lens[i], + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Schedule seq groups prompts. + _, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + + # Add tokens to sequences + for seq_group in out.scheduled_seq_groups: + for i, seq in enumerate(seq_group.get_seqs()): + seq.append_token_ids(token_ids_to_append[i], + logprobs=[{ + token_id: SampleLogprob(logprob=0.0) + for token_id in token_ids_to_append[i] + }]) + + # Schedule seq groups generation. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + + for seq_group_metadata in seq_group_meta: + # Only one seq per group in this test. + seq_id = next(iter(seq_group_metadata.seq_data.keys())) + + # The last slot is not required because it is for the last generated + # token, and will be stored in the next iteration. + slots_required = (seq_group_metadata.seq_data[seq_id].get_len() + + num_preallocated_slots_per_step) + blocks_required = round_up_to_next_block(slots_required, block_size) + + block_table = seq_group_metadata.block_tables[seq_id] + assert len(block_table) == blocks_required + + +def test_scheduler_schedule_chunked_prefill(): + block_size = 4 + num_seq_group = 2 + max_model_len = 16 + max_chunked_prefill_len = 2 + max_num_prompt_seqs = 1 + scheduler_config = SchedulerConfig( + 64, + num_seq_group, + max_model_len, + flash_style=True, + max_chunked_prefill_len=max_chunked_prefill_len, + max_num_prompt_seqs=max_num_prompt_seqs) + cache_config = CacheConfig(block_size, 1.0, 1) + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + seq_groups: List[SequenceGroup] = [] + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + num_processed_token_ids=block_size - + 1) + scheduler.add_seq_group(seq_group) + seq_groups.append(seq_group) + + # Schedule chunk prefill. Only the first seq_group should be scheduled. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(seq_groups[:1]) + seq_groups[0].get_num_unprefilled() == 2 + seq_groups[1].get_num_unprefilled() == 4 + assert out.num_batched_tokens == 2 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 1 + assert seq_group_meta[0].request_id == "0" + assert seq_group_meta[0].is_chunked_prefill + assert seq_group_meta[0].is_prompt + + # Schedule chunk prefill. Still Only the first seq_group should be scheduled. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(seq_groups[:1]) + seq_groups[0].get_num_unprefilled() == 0 + seq_groups[1].get_num_unprefilled() == 4 + assert out.num_batched_tokens == 2 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 1 + assert seq_group_meta[0].request_id == "0" + assert not seq_group_meta[0].is_chunked_prefill + assert seq_group_meta[0].is_prompt + + # Schedule chunk prefill. This time the second seq_group should be selected + # for chunk prefill, and the first seq_group should be select for decoding. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(seq_groups) + seq_groups[0].get_num_unprefilled() == 0 + seq_groups[1].get_num_unprefilled() == 2 + assert out.num_batched_tokens == 3 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 2 + assert seq_group_meta[0].request_id == "1" + assert seq_group_meta[0].is_chunked_prefill + assert seq_group_meta[0].is_prompt + assert seq_group_meta[1].request_id == "0" + assert not seq_group_meta[1].is_chunked_prefill + assert not seq_group_meta[1].is_prompt + + +def test_scheduler_max_seqs(): + block_size = 4 + num_seq_group = 4 + max_seq_group = 2 + max_model_len = 16 + scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len) + cache_config = CacheConfig(block_size, 1.0, 1) + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + + all_seq_groups: List[SequenceGroup] = [] + # Add seq groups to scheduler. + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + num_processed_token_ids=block_size - + 1) + all_seq_groups.append(seq_group) + + # Append 1 seq group + scheduler.add_seq_group(all_seq_groups[0]) + + # Schedule seq groups prompts. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) + + # Schedule seq groups generation. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) + + # Append 2 more seq group + running: List[SequenceGroup] = [] + scheduler.add_seq_group(all_seq_groups[1]) + scheduler.add_seq_group(all_seq_groups[2]) + + # Schedule seq groups prompts. + # Only 1 seq group should be scheduled since max_seq_group is 2 + # and one is prompting. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]]) + + +@pytest.mark.parametrize("min_decodes_per_prefill", + [1, 2, 4, 8, 16, 32, 64, 128]) +def test_scheduler_delayed_prefill_scheduling(min_decodes_per_prefill): + block_size = 4 + num_seq_group = 64 + max_model_len = 16 + num_pattern_repeat = 4 + scheduler_config = SchedulerConfig( + 64, + num_seq_group, + max_model_len, + min_decodes_per_prefill=min_decodes_per_prefill) + cache_config = CacheConfig(block_size, 1.0, 1) + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + num_processed_token_ids=block_size - + 1) + scheduler.add_seq_group(seq_group) + + scheduled_is_prefilling = [] + # Schedule seq groups prompts. + for i in range((min_decodes_per_prefill + 1) * num_pattern_repeat): + seq_group_meta, out = scheduler.schedule() + # num_prompt_groups only come from prefilling_outputs + # use this to determine if a prefilling step is scheduled + scheduled_is_prefilling.append(out.num_prompt_groups > 0) + + # abort most of the seq groups to free slots for the next iteration + # note we leave one in the scheduler to avoid always schedule prefilling + for j in range(len(seq_group_meta) - 1): + scheduler.abort_seq_group(seq_group_meta[j].request_id) + + # the scheduled sequence with delayed prefilling should look like + # [True, False * min_decodes_per_prefill, + # True, False * min_decodes_per_prefill, ...] + expected_pattern = [True] + [False] * min_decodes_per_prefill + assert scheduled_is_prefilling == expected_pattern * num_pattern_repeat,\ + f"Delayed refill scheduling is not correct, expected pattern \ + {expected_pattern} (to be repeated {num_pattern_repeat} times), \ + got {scheduled_is_prefilling}" From 0831f841fc7353eb4901a8920b17dc4807e9848e Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 01:52:00 -0800 Subject: [PATCH 09/88] [2/n] ip --- tests/core/__init__.py | 0 tests/core/test_scheduler.py | 391 ----------------------------------- tests/core/utils.py | 27 +++ 3 files changed, 27 insertions(+), 391 deletions(-) create mode 100644 tests/core/__init__.py create mode 100644 tests/core/utils.py diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 7b8ebafec4f4..c1fd9ba16dd9 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -1,403 +1,12 @@ -# This library may only be used in the Anyscale Platform. -# Notwithstanding the terms of any license or notice within this container, -# you may not modify, copy or remove this file. -# Your right to use this library is subject to the -# Anyscale Terms of Service (anyscale.com/terms) -# or other written agreement between you and Anyscale. - -# Copyright (2023 and onwards) Anyscale, Inc. -# This Software includes software developed at Anyscale (anyscale.com/) -# and its use is subject to the included LICENSE file. - -import imp from typing import List import pytest -import time from vllm.config import CacheConfig, SchedulerConfig from vllm.core.scheduler import Scheduler -from vllm.sequence import SampleLogprob, SequenceGroup -from tests.utils import round_up_to_next_block from .utils import create_dummy_prompt -def test_scheduler_add_seq_group(): - block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1) - cache_config = CacheConfig(block_size, 1.0, 1) - cache_config.num_cpu_blocks = 4 - cache_config.num_gpu_blocks = 4 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq group to scheduler. - num_seq_group = 4 - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), block_size) - scheduler.add_seq_group(seq_group) - assert scheduler.get_num_unfinished_seq_groups() == i + 1 - - -def test_scheduler_abort_seq_group(): - block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1) - cache_config = CacheConfig(block_size, 1.0, 1) - cache_config.num_cpu_blocks = 4 - cache_config.num_gpu_blocks = 4 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add multiple seq groups to scheduler. - num_seq_group = 4 - request_ids = set() - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), block_size) - scheduler.add_seq_group(seq_group) - request_ids.add(str(i)) - - # Abort all added seq groups. - assert scheduler.get_num_unfinished_seq_groups() == num_seq_group - scheduler.abort_seq_group(request_ids) - assert scheduler.get_num_unfinished_seq_groups() == 0 - - -def test_scheduler_schedule_simple(): - block_size = 4 - num_seq_group = 4 - max_model_len = 16 - scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) - cache_config = CacheConfig(block_size, 1.0, 1) - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq groups to scheduler. - running: List[SequenceGroup] = [] - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - num_processed_token_ids=block_size - - 1) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Schedule seq groups prompts. - seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(running) - assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs( - )[0].get_len() - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == num_seq_group - - # Schedule seq groups generation. - seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(running) - assert out.num_batched_tokens == num_seq_group - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == num_seq_group - - -def test_scheduler_schedule_preempt_abort(): - block_size = 4 - max_model_len = 16 - scheduler_config = SchedulerConfig(64, 2, max_model_len) - cache_config = CacheConfig(block_size, 1.0, 1) - cache_config.num_cpu_blocks = 2 - cache_config.num_gpu_blocks = 2 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq groups to scheduler. - seq_a, seq_group_a = create_dummy_prompt("1", block_size) - seq_b, seq_group_b = create_dummy_prompt("2", block_size) - scheduler.add_seq_group(seq_group_a) - scheduler.add_seq_group(seq_group_b) - - # Schedule seq groups prompts. - seq_group_meta, out = scheduler.schedule() - assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] - assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2 - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 2 - assert scheduler.get_num_unfinished_seq_groups() == 2 - - # Append "generated" tokens, allowing the sequence to mark prompt tokens as - # processed. - token_id = 0 - seq_a.append_token_id(token_id, {token_id: SampleLogprob(0.0)}) - seq_b.append_token_id(token_id, {token_id: SampleLogprob(0.0)}) - - # Schedule seq groups generation and preempt seq group b. - seq_group_meta, out = scheduler.schedule() - assert out.scheduled_seq_groups == [seq_group_a] - assert out.num_batched_tokens == 1 - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 1 - assert scheduler.get_num_unfinished_seq_groups() == 2 - - # Abort seq group a. Re-schedule seq group b prompt with recomputation. - scheduler.abort_seq_group("1") - seq_group_meta, out = scheduler.schedule() - assert out.scheduled_seq_groups == [seq_group_b] - assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len() - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 1 - assert scheduler.get_num_unfinished_seq_groups() == 1 - - -@pytest.mark.parametrize("block_size", [1, 8, 16, 32]) -@pytest.mark.parametrize("prompt_len", [1, 128]) -@pytest.mark.parametrize("num_unprocessed_tokens", [2, 17, 128]) -@pytest.mark.parametrize("num_seq_group", [1, 4, 16]) -def test_can_schedule_seqs_with_multiple_unprocessed_tokens( - block_size: int, prompt_len: int, num_unprocessed_tokens: int, - num_seq_group: int): - """Verify scheduler can schedule sequences with more than one unprocessed - tokens. This occurs when the worker emits more than one token. - """ - max_model_len = 2048 - scheduler_config = SchedulerConfig(max_num_batched_tokens=max_model_len, - max_num_seqs=num_seq_group, - max_model_len=max_model_len) - cache_config = CacheConfig(block_size=block_size, - gpu_memory_utilization=1.0, - swap_space=0) - cache_config.num_cpu_blocks = 0 - cache_config.num_gpu_blocks = 8192 // block_size - scheduler = Scheduler(scheduler_config, cache_config, None) - - prompt_lens = [prompt_len for _ in range(num_seq_group)] - - token_ids_to_append = [ - list(range(num_unprocessed_tokens)) for _ in range(num_seq_group) - ] - - # Add seq groups to scheduler. - running: List[SequenceGroup] = [] - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(request_id=str(i), - prompt_length=prompt_lens[i], - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Schedule seq groups prompts. - _, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(running) - - # Add tokens to sequences - for seq_group in out.scheduled_seq_groups: - for i, seq in enumerate(seq_group.get_seqs()): - - seq.append_token_ids(token_ids_to_append[i], - logprobs=[{ - token_id: SampleLogprob(logprob=0.0) - for token_id in token_ids_to_append[i] - }]) - - # Schedule seq groups generation. - seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(running) - - for seq_group_metadata in seq_group_meta: - # Only one seq per group in this test. - seq_id = next(iter(seq_group_metadata.seq_data.keys())) - - block_table = seq_group_metadata.block_tables[seq_id] - blocks_required = (seq_group_metadata.seq_data[seq_id].get_len() - 1 + - block_size) // block_size - assert len(block_table) == blocks_required - - -@pytest.mark.parametrize("block_size", [1, 8, 16, 32]) -@pytest.mark.parametrize("prompt_len", [1, 128]) -@pytest.mark.parametrize("num_unprocessed_tokens", [1, 9]) -@pytest.mark.parametrize("num_preallocated_slots_per_step", [1, 9]) -@pytest.mark.parametrize("num_seq_group", [1, 4]) -def test_can_schedule_multiple_steps(block_size: int, prompt_len: int, - num_preallocated_slots_per_step: int, - num_unprocessed_tokens: int, - num_seq_group: int): - """Verify correct scheduling when the model runs more than one step per - scheduler iteration. - """ - max_model_len = 2048 - scheduler_config = SchedulerConfig( - max_num_batched_tokens=max_model_len, - max_num_seqs=num_seq_group, - max_model_len=max_model_len, - num_preallocated_slots_per_step=num_preallocated_slots_per_step) - cache_config = CacheConfig(block_size=block_size, - gpu_memory_utilization=1.0, - swap_space=0) - cache_config.num_cpu_blocks = 0 - cache_config.num_gpu_blocks = 8192 // block_size - scheduler = Scheduler(scheduler_config, cache_config, None) - - prompt_lens = [prompt_len for _ in range(num_seq_group)] - - token_ids_to_append = [ - list(range(num_unprocessed_tokens)) for _ in range(num_seq_group) - ] - - # Add seq groups to scheduler. - running: List[SequenceGroup] = [] - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(request_id=str(i), - prompt_length=prompt_lens[i], - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Schedule seq groups prompts. - _, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(running) - - # Add tokens to sequences - for seq_group in out.scheduled_seq_groups: - for i, seq in enumerate(seq_group.get_seqs()): - seq.append_token_ids(token_ids_to_append[i], - logprobs=[{ - token_id: SampleLogprob(logprob=0.0) - for token_id in token_ids_to_append[i] - }]) - - # Schedule seq groups generation. - seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(running) - - for seq_group_metadata in seq_group_meta: - # Only one seq per group in this test. - seq_id = next(iter(seq_group_metadata.seq_data.keys())) - - # The last slot is not required because it is for the last generated - # token, and will be stored in the next iteration. - slots_required = (seq_group_metadata.seq_data[seq_id].get_len() + - num_preallocated_slots_per_step) - blocks_required = round_up_to_next_block(slots_required, block_size) - - block_table = seq_group_metadata.block_tables[seq_id] - assert len(block_table) == blocks_required - - -def test_scheduler_schedule_chunked_prefill(): - block_size = 4 - num_seq_group = 2 - max_model_len = 16 - max_chunked_prefill_len = 2 - max_num_prompt_seqs = 1 - scheduler_config = SchedulerConfig( - 64, - num_seq_group, - max_model_len, - flash_style=True, - max_chunked_prefill_len=max_chunked_prefill_len, - max_num_prompt_seqs=max_num_prompt_seqs) - cache_config = CacheConfig(block_size, 1.0, 1) - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq groups to scheduler. - seq_groups: List[SequenceGroup] = [] - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - num_processed_token_ids=block_size - - 1) - scheduler.add_seq_group(seq_group) - seq_groups.append(seq_group) - - # Schedule chunk prefill. Only the first seq_group should be scheduled. - seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(seq_groups[:1]) - seq_groups[0].get_num_unprefilled() == 2 - seq_groups[1].get_num_unprefilled() == 4 - assert out.num_batched_tokens == 2 - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 1 - assert seq_group_meta[0].request_id == "0" - assert seq_group_meta[0].is_chunked_prefill - assert seq_group_meta[0].is_prompt - - # Schedule chunk prefill. Still Only the first seq_group should be scheduled. - seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(seq_groups[:1]) - seq_groups[0].get_num_unprefilled() == 0 - seq_groups[1].get_num_unprefilled() == 4 - assert out.num_batched_tokens == 2 - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 1 - assert seq_group_meta[0].request_id == "0" - assert not seq_group_meta[0].is_chunked_prefill - assert seq_group_meta[0].is_prompt - - # Schedule chunk prefill. This time the second seq_group should be selected - # for chunk prefill, and the first seq_group should be select for decoding. - seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(seq_groups) - seq_groups[0].get_num_unprefilled() == 0 - seq_groups[1].get_num_unprefilled() == 2 - assert out.num_batched_tokens == 3 - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 2 - assert seq_group_meta[0].request_id == "1" - assert seq_group_meta[0].is_chunked_prefill - assert seq_group_meta[0].is_prompt - assert seq_group_meta[1].request_id == "0" - assert not seq_group_meta[1].is_chunked_prefill - assert not seq_group_meta[1].is_prompt - - -def test_scheduler_max_seqs(): - block_size = 4 - num_seq_group = 4 - max_seq_group = 2 - max_model_len = 16 - scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len) - cache_config = CacheConfig(block_size, 1.0, 1) - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - - all_seq_groups: List[SequenceGroup] = [] - # Add seq groups to scheduler. - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - num_processed_token_ids=block_size - - 1) - all_seq_groups.append(seq_group) - - # Append 1 seq group - scheduler.add_seq_group(all_seq_groups[0]) - - # Schedule seq groups prompts. - seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) - - # Schedule seq groups generation. - seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) - - # Append 2 more seq group - running: List[SequenceGroup] = [] - scheduler.add_seq_group(all_seq_groups[1]) - scheduler.add_seq_group(all_seq_groups[2]) - - # Schedule seq groups prompts. - # Only 1 seq group should be scheduled since max_seq_group is 2 - # and one is prompting. - seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]]) - - @pytest.mark.parametrize("min_decodes_per_prefill", [1, 2, 4, 8, 16, 32, 64, 128]) def test_scheduler_delayed_prefill_scheduling(min_decodes_per_prefill): diff --git a/tests/core/utils.py b/tests/core/utils.py new file mode 100644 index 000000000000..4c8d29c2246b --- /dev/null +++ b/tests/core/utils.py @@ -0,0 +1,27 @@ +import time +from typing import Tuple + +from vllm import SamplingParams +from vllm.sequence import Sequence, SequenceGroup + + +def create_dummy_prompt( + request_id: str, + prompt_length: int, + block_size: int = None, + num_processed_token_ids: int = 0) -> Tuple[Sequence, SequenceGroup]: + if not block_size: + block_size = prompt_length + + # Create dummy prompt sequence with tokens 0...block_size-1 + # and prompt "0 ... block_size". + prompt_tokens = list(range(prompt_length)) + prompt_str = " ".join([str(t) for t in prompt_tokens]) + prompt = Sequence(int(request_id), + prompt_str, + prompt_tokens, + block_size) + seq_group = SequenceGroup(request_id, [prompt], SamplingParams(), + time.time(), time.perf_counter()) + + return prompt, seq_group From f31371f0a31bfc93a1c0abee8e7b8271b0e93d9e Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 05:50:07 -0800 Subject: [PATCH 10/88] [2/n]ip --- tests/core/test_scheduler.py | 79 ++++++++++++++++++--------- tests/kernels/test_flash_attention.py | 10 ++-- 2 files changed, 56 insertions(+), 33 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index c1fd9ba16dd9..fa81006e980e 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -3,53 +3,78 @@ from vllm.config import CacheConfig, SchedulerConfig from vllm.core.scheduler import Scheduler +from vllm.sequence import SequenceGroup from .utils import create_dummy_prompt -@pytest.mark.parametrize("min_decodes_per_prefill", - [1, 2, 4, 8, 16, 32, 64, 128]) -def test_scheduler_delayed_prefill_scheduling(min_decodes_per_prefill): +def test_scheduler_schedule_chunked_prefill(): block_size = 4 - num_seq_group = 64 + num_seq_group = 2 max_model_len = 16 - num_pattern_repeat = 4 + max_chunked_prefill_len = 2 + max_num_prompt_seqs = 1 scheduler_config = SchedulerConfig( 64, num_seq_group, max_model_len, - min_decodes_per_prefill=min_decodes_per_prefill) + flash_style=True, + max_chunked_prefill_len=max_chunked_prefill_len, + max_num_prompt_seqs=max_num_prompt_seqs) cache_config = CacheConfig(block_size, 1.0, 1) cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 scheduler = Scheduler(scheduler_config, cache_config, None) # Add seq groups to scheduler. + seq_groups: List[SequenceGroup] = [] for i in range(num_seq_group): _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size, num_processed_token_ids=block_size - 1) scheduler.add_seq_group(seq_group) + seq_groups.append(seq_group) - scheduled_is_prefilling = [] - # Schedule seq groups prompts. - for i in range((min_decodes_per_prefill + 1) * num_pattern_repeat): - seq_group_meta, out = scheduler.schedule() - # num_prompt_groups only come from prefilling_outputs - # use this to determine if a prefilling step is scheduled - scheduled_is_prefilling.append(out.num_prompt_groups > 0) - - # abort most of the seq groups to free slots for the next iteration - # note we leave one in the scheduler to avoid always schedule prefilling - for j in range(len(seq_group_meta) - 1): - scheduler.abort_seq_group(seq_group_meta[j].request_id) - - # the scheduled sequence with delayed prefilling should look like - # [True, False * min_decodes_per_prefill, - # True, False * min_decodes_per_prefill, ...] - expected_pattern = [True] + [False] * min_decodes_per_prefill - assert scheduled_is_prefilling == expected_pattern * num_pattern_repeat,\ - f"Delayed refill scheduling is not correct, expected pattern \ - {expected_pattern} (to be repeated {num_pattern_repeat} times), \ - got {scheduled_is_prefilling}" + # Schedule chunk prefill. Only the first seq_group should be scheduled. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(seq_groups[:1]) + seq_groups[0].get_num_unprefilled() == 2 + seq_groups[1].get_num_unprefilled() == 4 + assert out.num_batched_tokens == 2 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 1 + assert seq_group_meta[0].request_id == "0" + assert seq_group_meta[0].is_chunked_prefill + assert seq_group_meta[0].is_prompt + + # Schedule chunk prefill. Still Only the first seq_group should be scheduled. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(seq_groups[:1]) + seq_groups[0].get_num_unprefilled() == 0 + seq_groups[1].get_num_unprefilled() == 4 + assert out.num_batched_tokens == 2 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 1 + assert seq_group_meta[0].request_id == "0" + assert not seq_group_meta[0].is_chunked_prefill + assert seq_group_meta[0].is_prompt + + # Schedule chunk prefill. This time the second seq_group should be selected + # for chunk prefill, and the first seq_group should be select for decoding. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(seq_groups) + seq_groups[0].get_num_unprefilled() == 0 + seq_groups[1].get_num_unprefilled() == 2 + assert out.num_batched_tokens == 3 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 2 + assert seq_group_meta[0].request_id == "1" + assert seq_group_meta[0].is_chunked_prefill + assert seq_group_meta[0].is_prompt + assert seq_group_meta[1].request_id == "0" + assert not seq_group_meta[1].is_chunked_prefill + assert not seq_group_meta[1].is_prompt diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index a9c43e31edc0..1618576e03bc 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -1,5 +1,5 @@ import random -from typing import Optional, Tuple +from typing import Optional, Tuple, List import pytest import torch @@ -19,7 +19,7 @@ DTYPES = [torch.half, torch.bfloat16] NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing -NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing +NUM_HEADS = [(1, 40), (40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS_SMALL = NUM_HEADS # head size should be bigger than or equal to block size. HEAD_SIZES = [256] @@ -171,7 +171,7 @@ def test_flash_paged_attention( device="cuda") query.uniform_(-scale, scale) - assert num_query_heads % num_kv_heads == 0 + # assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads alibi_slopes = None if use_alibi: @@ -195,6 +195,7 @@ def test_flash_paged_attention( ] block_tables.append(block_table) block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + breakpoint() # Create the KV caches. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, @@ -239,7 +240,4 @@ def test_flash_paged_attention( flash_style=True, ) - # NOTE(woosuk): Due to the kernel-level differences in the two - # implementations, there is a small numerical difference in the two - # outputs. Thus, we use a relaxed tolerance for the test. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) From 78bb887bc625c473dc3a5ce5369617a743048c3e Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 08:49:47 -0800 Subject: [PATCH 11/88] ip --- benchmarks/benchmark_latency.py | 2 + .../kernels/benchmark_paged_attention.py | 5 +- tests/chunked_prefill/test_correctness.py | 8 +-- tests/kernels/test_attention.py | 2 + tests/kernels/test_flash_attention.py | 52 +++++++++---------- vllm/model_executor/layers/attention.py | 2 + 6 files changed, 38 insertions(+), 33 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index b915f913e7d1..7edad46a6172 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -33,6 +33,8 @@ def main(args: argparse.Namespace): enforce_eager=args.enforce_eager, kv_cache_dtype=args.kv_cache_dtype, device=args.device, + block_size=args.block_size, + flash_style=args.flash_style, ) sampling_params = SamplingParams( diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index d921dea1220e..14ca569a3d66 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -7,6 +7,7 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random from vllm._C import ops +from vllm.model_executor.layers.attention import flash_attn_with_kvcache_paged NUM_BLOCKS = 1024 PARTITION_SIZE = 512 @@ -131,6 +132,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: alibi_slopes, kv_cache_dtype, ) + elif version == "flash": + else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() @@ -158,7 +161,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: description="Benchmark the paged attention kernel.") parser.add_argument("--version", type=str, - choices=["v1", "v2"], + choices=["v1", "v2", "flash"], default="v2") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--context-len", type=int, default=4096) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index 6659b83bfcb6..e7ed28b86e04 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -8,7 +8,7 @@ from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel MODELS = [ - # "JackFram/llama-68m", + "JackFram/llama-68m", "facebook/opt-125m", ] @@ -58,10 +58,10 @@ def test_models( dtype=dtype, flash_style=True, block_size=block_size) - flash_attn_output_by_batchs = [] + flash_attn_output_by_batches = [] for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] - flash_attn_output_by_batchs.append( + flash_attn_output_by_batches.append( flash_attn_model.generate_greedy(prompts, max_tokens)) del flash_attn_model @@ -70,7 +70,7 @@ def test_models( gc.collect() torch.cuda.empty_cache() - for flash_attn_outputs in flash_attn_output_by_batchs: + for flash_attn_outputs in flash_attn_output_by_batches: for i in range(len(flash_attn_outputs)): fa_output_ids, fa_output_str = flash_attn_outputs[i] vllm_output_ids, vllm_output_str = expected_outputs[ diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index fb571de63d4e..e2f59dff1ca2 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -280,6 +280,7 @@ def ref_multi_query_kv_attention( num_seqs = len(cu_seq_lens) - 1 ref_outputs = [] for i in range(num_seqs): + breakpoint() start_idx = cu_seq_lens[i] end_idx = cu_seq_lens[i + 1] seq_len = end_idx - start_idx @@ -299,6 +300,7 @@ def ref_multi_query_kv_attention( ) ref_outputs.append(ref_output) ref_output = torch.cat(ref_outputs, dim=0) + breakpoint() return ref_output diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index a9c43e31edc0..7a98a0db9ca5 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -1,5 +1,5 @@ import random -from typing import Optional, Tuple +from typing import Optional, Tuple, List import pytest import torch @@ -19,7 +19,7 @@ DTYPES = [torch.half, torch.bfloat16] NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing -NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing +NUM_HEADS = [(1, 40), (40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS_SMALL = NUM_HEADS # head size should be bigger than or equal to block size. HEAD_SIZES = [256] @@ -29,8 +29,7 @@ BLOCK_SIZES = [256] USE_ALIBI = [False, True] SEEDS = [0] -# PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] -PAD_CONFIGS = [(0, 0)] +PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] def pad_attention_inputs( @@ -84,10 +83,8 @@ def ref_single_query_cached_kv_attention( context_lens: torch.Tensor, scale: float, alibi_slopes: Optional[torch.Tensor], - flash_style: bool = False, ) -> None: num_query_heads = query.shape[1] - num_kv_heads = value_cache.shape[-2] head_size = value_cache.shape[-1] block_size = value_cache.shape[-3] num_seqs = query.shape[0] @@ -105,18 +102,11 @@ def ref_single_query_cached_kv_attention( block_number = int(block_table[j // block_size]) block_offset = j % block_size - if flash_style: - k = key_cache[block_number, block_offset, :, :] - else: - k = key_cache[block_number, :, :, block_offset, :] - k = k.reshape(num_kv_heads, head_size) + k = key_cache[block_number, block_offset, :, :] keys.append(k) v = value_cache[block_number, :, :, block_offset] - if flash_style: - v = value_cache[block_number, block_offset, :, :] - else: - v = value_cache[block_number, :, :, block_offset] + v = value_cache[block_number, block_offset, :, :] values.append(v) keys = torch.stack(keys, dim=0) values = torch.stack(values, dim=0) @@ -138,12 +128,20 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) -@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("use_alibi", [False, True]) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +# @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +# @pytest.mark.parametrize("num_heads", NUM_HEADS) +# @pytest.mark.parametrize("head_size", HEAD_SIZES) +# @pytest.mark.parametrize("use_alibi", [False, True]) +# @pytest.mark.parametrize("block_size", BLOCK_SIZES) +# @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +# @pytest.mark.parametrize("seed", SEEDS) +# @pytest.mark.parametrize("pad_config", PAD_CONFIGS) +@pytest.mark.parametrize("num_seqs", [3]) +@pytest.mark.parametrize("num_heads", [(40,40)]) +@pytest.mark.parametrize("head_size", [256]) +@pytest.mark.parametrize("use_alibi", [True]) +@pytest.mark.parametrize("block_size", [256]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("pad_config", PAD_CONFIGS) @torch.inference_mode() @@ -164,14 +162,15 @@ def test_flash_paged_attention( scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, + query = torch.empty(num_seqs, # batch_size + 1, # seqlen num_query_heads, head_size, dtype=dtype, device="cuda") query.uniform_(-scale, scale) - assert num_query_heads % num_kv_heads == 0 + # assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads alibi_slopes = None if use_alibi: @@ -195,6 +194,7 @@ def test_flash_paged_attention( ] block_tables.append(block_table) block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + breakpoint() # Create the KV caches. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, @@ -210,7 +210,7 @@ def test_flash_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) - padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ + padded_query, padded_block_table, padded_context_lens, _ = \ pad_attention_inputs(pad_config, block_size, query, block_tables, context_lens, max_context_len) @@ -236,10 +236,6 @@ def test_flash_paged_attention( context_lens, scale, alibi_slopes, - flash_style=True, ) - # NOTE(woosuk): Due to the kernel-level differences in the two - # implementations, there is a small numerical difference in the two - # outputs. Thus, we use a relaxed tolerance for the test. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 51c47368e831..64b1beaa6205 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -229,6 +229,7 @@ def forward( self.alibi_slopes, ) else: + breakpoint() # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( @@ -360,6 +361,7 @@ def _paged_attention( device=output.device, ) max_logits = torch.empty_like(exp_sums) + breakpoint() ops.paged_attention_v2( output, exp_sums, From 42dd36264e67d636c91efe9ce537b651e223ce42 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 20:56:36 -0800 Subject: [PATCH 12/88] [2/n] ip --- vllm/model_executor/layers/attention.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 64b1beaa6205..302720ba8fef 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -229,7 +229,6 @@ def forward( self.alibi_slopes, ) else: - breakpoint() # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( @@ -270,7 +269,6 @@ def forward( # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) - def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, @@ -361,7 +359,6 @@ def _paged_attention( device=output.device, ) max_logits = torch.empty_like(exp_sums) - breakpoint() ops.paged_attention_v2( output, exp_sums, @@ -395,6 +392,15 @@ def flash_attn_with_kvcache_paged( based on key/value caches. The main difference is this uses flash attention style key-value caches. + CHEN: + - input queries are flattened. + - First portion is N decoding tokens. + - Second portion is a single (or multiple) prefill token. + - Still needs to separate out (it doens't improve performance). + - oss flash kernel for prefill doesn't support varlen. + - + + Arguments: See https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py for other arguments. From 74ac900cbf56dedea8d93d2b948b81e1aeaa5085 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 21:32:59 -0800 Subject: [PATCH 13/88] seems to work. --- .../kernels/benchmark_paged_attention.py | 34 +++++++++++++------ tests/chunked_prefill/test_correctness.py | 2 -- tests/kernels/test_attention.py | 2 -- tests/kernels/test_cache.py | 1 - tests/kernels/test_flash_attention.py | 15 ++++---- tests/worker/test_model_runner.py | 2 -- vllm/model_executor/layers/attention.py | 5 ++- 7 files changed, 33 insertions(+), 28 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 14ca569a3d66..7b3c6405e259 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -65,14 +65,17 @@ def main( block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) # Create the KV cache. - key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, - block_size, - 1, - num_kv_heads, - head_size, - kv_cache_dtype, - dtype, - device=device) + flash_style = version == "flash" + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + flash_style=flash_style) key_cache, value_cache = key_caches[0], value_caches[0] # Prepare for the paged attention kernel. @@ -133,7 +136,15 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: kv_cache_dtype, ) elif version == "flash": - + flash_attn_with_kvcache_paged( + query.view(num_seqs, 1, num_query_heads, head_size), + key_cache, + value_cache, + scale, + block_tables, + context_lens, + alibi_slopes=alibi_slopes, + ) else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() @@ -171,7 +182,10 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: type=int, choices=[64, 80, 96, 112, 128, 256], default=128) - parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--block-size", + type=int, + choices=[16, 32, 256], + default=16) parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--dtype", type=str, diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index e7ed28b86e04..72fb9f73766c 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -1,7 +1,5 @@ import gc -from typing import List - import pytest import torch diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index e2f59dff1ca2..fb571de63d4e 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -280,7 +280,6 @@ def ref_multi_query_kv_attention( num_seqs = len(cu_seq_lens) - 1 ref_outputs = [] for i in range(num_seqs): - breakpoint() start_idx = cu_seq_lens[i] end_idx = cu_seq_lens[i + 1] seq_len = end_idx - start_idx @@ -300,7 +299,6 @@ def ref_multi_query_kv_attention( ) ref_outputs.append(ref_output) ref_output = torch.cat(ref_outputs, dim=0) - breakpoint() return ref_output diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 23adc10df7cf..d2de4105b7f3 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -7,7 +7,6 @@ from vllm._C import cache_ops from vllm.utils import is_hip -import torch.nn.functional as F COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index 7a98a0db9ca5..74c9d3008f17 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -1,5 +1,5 @@ import random -from typing import Optional, Tuple, List +from typing import Optional, Tuple import pytest import torch @@ -125,7 +125,8 @@ def ref_single_query_cached_kv_attention( out = ref_masked_attention(q, keys, values, scale, alibi_bias) out = out.view(num_query_heads, head_size) - output[i].copy_(out, non_blocking=True) + # output[i].copy_(out, non_blocking=True) + output[i].copy_(out) # @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @@ -137,13 +138,13 @@ def ref_single_query_cached_kv_attention( # @pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("pad_config", PAD_CONFIGS) @pytest.mark.parametrize("num_seqs", [3]) -@pytest.mark.parametrize("num_heads", [(40,40)]) +@pytest.mark.parametrize("num_heads", [(40, 40)]) @pytest.mark.parametrize("head_size", [256]) @pytest.mark.parametrize("use_alibi", [True]) @pytest.mark.parametrize("block_size", [256]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("pad_config", PAD_CONFIGS) +@pytest.mark.parametrize("pad_config", [(0, 0)]) @torch.inference_mode() def test_flash_paged_attention( kv_cache_factory, @@ -162,8 +163,7 @@ def test_flash_paged_attention( scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, # batch_size - 1, # seqlen + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype, @@ -194,7 +194,6 @@ def test_flash_paged_attention( ] block_tables.append(block_table) block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - breakpoint() # Create the KV caches. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, @@ -215,7 +214,7 @@ def test_flash_paged_attention( block_tables, context_lens, max_context_len) output = flash_attn_with_kvcache_paged( - padded_query, + padded_query.view(num_seqs, 1, num_query_heads, head_size), key_cache, value_cache, scale, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 94c9b45157ec..f44895a728c7 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -3,8 +3,6 @@ from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner -from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, - DeviceConfig) def test_prepare_prompt(): diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 64b1beaa6205..e04732d67d13 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -229,7 +229,6 @@ def forward( self.alibi_slopes, ) else: - breakpoint() # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( @@ -361,7 +360,6 @@ def _paged_attention( device=output.device, ) max_logits = torch.empty_like(exp_sums) - breakpoint() ops.paged_attention_v2( output, exp_sums, @@ -403,7 +401,8 @@ def flash_attn_with_kvcache_paged( output: [num_tokens, num_heads, head_size] """ block_size = value_cache.shape[1] - assert block_size % 256 == 0, "only support block_size divisible by 256." + assert block_size % 256 == 0, ("only support block_size divisible by 256. " + f"Current block size: {block_size}") _, _, num_heads, head_size = query.shape out = flash_attn_with_kvcache( query, From 61418852541d49e30ed1d8e5bc780ec36d91640c Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 29 Feb 2024 23:16:45 -0800 Subject: [PATCH 14/88] [2/n] ip --- tests/chunked_prefill/test_correctness.py | 41 ++++--- tests/test_sequence.py | 142 ++++++++++++++++++++++ vllm/config.py | 23 +++- vllm/engine/arg_utils.py | 25 +++- vllm/engine/llm_engine.py | 1 - vllm/model_executor/layers/attention.py | 23 +++- vllm/worker/worker.py | 2 + 7 files changed, 229 insertions(+), 28 deletions(-) create mode 100644 tests/test_sequence.py diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index c0cd38d26abe..b3075abf8b58 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("block_size", [256]) @pytest.mark.parametrize("max_chunked_prefill_len", [-1, 16, 64]) @pytest.mark.parametrize("max_num_prompt_seqs", [1, 2, 100]) -@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("tensor_parallel_size", [1, 2]) def test_models( vllm_runner, model: str, @@ -49,23 +49,24 @@ def test_models( f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}" ) - print("loading page attention models..") - pg_model = vllm_runner(model, dtype=dtype) - expected_outputs = [] + # print("loading page attention models..") + # pg_model = vllm_runner(model, dtype=dtype) + # expected_outputs = [] - print("generating tokens...") - expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) - print("generating tokens finished") + # print("generating tokens...") + # expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) + # print("generating tokens finished") - del pg_model + # del pg_model - destroy_model_parallel() - gc.collect() - torch.cuda.empty_cache() + # destroy_model_parallel() + # gc.collect() + # torch.cuda.empty_cache() flash_attn_model = vllm_runner( model, dtype=dtype, + block_size=block_size, flash_style=True, max_chunked_prefill_len=max_chunked_prefill_len, max_num_prompt_seqs=max_num_prompt_seqs, @@ -82,12 +83,12 @@ def test_models( gc.collect() torch.cuda.empty_cache() - for flash_attn_outputs in flash_attn_output_by_batches: - for i in range(len(flash_attn_outputs)): - fa_output_ids, fa_output_str = flash_attn_outputs[i] - vllm_output_ids, vllm_output_str = expected_outputs[ - i % len(expected_outputs)] - assert fa_output_ids == vllm_output_ids, ( - f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" - f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" - ) + # for flash_attn_outputs in flash_attn_output_by_batches: + # for i in range(len(flash_attn_outputs)): + # fa_output_ids, fa_output_str = flash_attn_outputs[i] + # vllm_output_ids, vllm_output_str = expected_outputs[ + # i % len(expected_outputs)] + # assert fa_output_ids == vllm_output_ids, ( + # f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" + # f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" + # ) diff --git a/tests/test_sequence.py b/tests/test_sequence.py new file mode 100644 index 000000000000..da6ef99367a9 --- /dev/null +++ b/tests/test_sequence.py @@ -0,0 +1,142 @@ +import pytest +from vllm.sequence import SequenceData, Sequence + +@pytest.fixture(name="sequence") +def create_sequence(seq_len: int, block_size: int) -> Sequence: + return Sequence( + seq_id=0, + prompt="", + prompt_token_ids=list(range(seq_len)), + block_size=block_size, + ) + +# @pytest.mark.parametrize("block_size", [1, 2, 4, 8]) +# @pytest.mark.parametrize("num_empty_slots", list(range(8))) +# @pytest.mark.parametrize("seq_len", [0, 1, 100]) +# def test_ensure_num_empty_slots(block_size: int, seq_len: int, +# num_empty_slots: int, sequence: Sequence): +# """Verify ensure_num_empty_slots correctly ensures empty slots. +# """ +# sequence.ensure_num_empty_slots(num_empty_slots) +# num_total_slots = block_size * len(sequence.logical_token_blocks) +# measured_num_empty_slots = sum(block.get_num_empty_slots() +# for block in sequence.logical_token_blocks) +# num_full_slots = num_total_slots - measured_num_empty_slots +# assert measured_num_empty_slots >= num_empty_slots +# assert num_full_slots == seq_len +# @pytest.fixture(name="sequence_with_extra_blocks") +# def add_blocks_to_sequence(sequence: Sequence, +# num_extra_blocks: int) -> Sequence: +# for _ in range(num_extra_blocks): +# sequence._append_logical_block() # pylint: disable=protected-access +# return sequence +# @pytest.mark.parametrize("num_tokens_to_append", [1, 10]) +# @pytest.mark.parametrize("seq_len", [0, 1, 100]) +# @pytest.mark.parametrize("block_size", [1, 2, 4, 8]) +# @pytest.mark.parametrize("num_extra_blocks", [0, 1, 100]) +# def test_append_tokens_correct_placement_in_blocks( +# num_tokens_to_append: int, sequence_with_extra_blocks: Sequence, +# block_size: int, seq_len: int): +# """Verify new tokens are appended at the end of the sequence, instead of the +# last block. This enables preallocated empty slots, which requires empty +# blocks after the sequence. +# """ +# token_ids = list(range(num_tokens_to_append)) +# logprobs = [{token_id: 0.0} for token_id in token_ids] +# seq_len_before_append = seq_len +# seq_len_after_append = seq_len_before_append + num_tokens_to_append +# sequence_with_extra_blocks.append_token_ids(token_ids, logprobs) +# # Assert number of full slots equal to total sequence length. +# assert sum(block_size - block.get_num_empty_slots() +# for block in sequence_with_extra_blocks.logical_token_blocks +# ) == seq_len_after_append +# # Assert each appended token is immediately after the original sequence. +# for i, token_id in enumerate(token_ids): +# index = seq_len_before_append + i +# block_token_ids = sequence_with_extra_blocks.logical_token_blocks[ +# index // block_size].get_token_ids() +# assert block_token_ids[index % block_size] == token_id +# @pytest.mark.parametrize("generation_or_prefill", ["generation", "prefill"]) +# @pytest.mark.parametrize("num_output_tokens", [0, 1, 10]) +# @pytest.mark.parametrize("num_prompt_tokens", [5, 50]) +# def test_get_unprocessed_tokens(generation_or_prefill: str, +# num_output_tokens: int, +# num_prompt_tokens: int): +# """Verify sequence data correctly tracks the number of processed tokens. +# """ +# is_generation = generation_or_prefill == "generation" +# if is_generation: +# generated_token_id = 1337 +# prompt_token_ids = list(range(num_prompt_tokens)) +# output_token_ids = list(range(num_output_tokens)) +# data = SequenceData( +# prompt_token_ids=prompt_token_ids[:], +# output_token_ids=output_token_ids[:], +# ) +# if is_generation: +# data.append_token_ids([generated_token_id], logprobs=[0.0]) +# unprocessed_token_ids = data.get_unprocessed_token_ids() +# unprocessed_token_positions = data.get_unprocessed_token_positions() +# if is_generation: +# assert unprocessed_token_ids == [generated_token_id] +# assert unprocessed_token_positions == [ +# num_prompt_tokens + num_output_tokens +# ] +# else: +# assert unprocessed_token_ids == prompt_token_ids + output_token_ids +# assert unprocessed_token_positions == list( +# range(num_prompt_tokens + num_output_tokens)) +# # Reset processed tokens. Everything should behave like a prompt run now. +# data.reset_processed_tokens() +# unprocessed_token_ids = data.get_unprocessed_token_ids() +# unprocessed_token_positions = data.get_unprocessed_token_positions() +# if is_generation: +# assert unprocessed_token_ids == (prompt_token_ids + output_token_ids + +# [generated_token_id]) +# assert unprocessed_token_positions == list( +# range(num_prompt_tokens + num_output_tokens + 1)) +# if not is_generation: +# assert unprocessed_token_ids == prompt_token_ids + output_token_ids +# assert unprocessed_token_positions == list( +# range(num_prompt_tokens + num_output_tokens)) + + +def test_sequence_data_prefill(): + seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4], output_token_ids=[]) + assert seq_data.get_prefill_range() == (0, 0) + assert seq_data.get_num_unprefilled() == 4 + + # advance by 2 + assert seq_data.advance_prefill_range(2) == 2 + assert seq_data.get_num_unprefilled() == 2 + assert seq_data.get_prefill_range() == (0, 2) + + # advance range by 3 even though there are only 2 unprefilled tokens + assert seq_data.advance_prefill_range(3) == 2 + assert seq_data.get_num_unprefilled() == 0 + assert seq_data.get_prefill_range() == (2, 4) + + # following advances should not change anything + assert seq_data.advance_prefill_range(2) == 0 + assert seq_data.get_num_unprefilled() == 0 + assert seq_data.get_prefill_range() == (4, 4) + + # append tokens and reset, simulating recompute + seq_data.append_token_ids([1], logprobs=[0.0]) + seq_data.reset_processed_tokens() + + # after reset, the prefill range should be reset to 0 + # but the num_unprefilled should include. + # output tokens + assert seq_data.get_prefill_range() == (0, 0) + assert seq_data.get_num_unprefilled() == 5 + + # advance by 2 + assert seq_data.advance_prefill_range(2) == 2 + assert seq_data.get_num_unprefilled() == 3 + assert seq_data.get_prefill_range() == (0, 2) + + # advance range by 3 even though there are only 2 unprefilled tokens + assert seq_data.advance_prefill_range(3) == 3 + assert seq_data.get_num_unprefilled() == 0 + assert seq_data.get_prefill_range() == (2, 5) \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index 2a0df4353add..c8e4c728ab60 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -453,6 +453,14 @@ class SchedulerConfig: max_model_len: Maximum length of a sequence (including prompt and generated text). max_paddings: Maximum number of paddings to be added to a batch. + max_chunked_prefill_len: The maximum length of tokens for prefill + requests. Longer requests will be chunked into multiple chunks. + -1 means no chunking (disabled). This features is only supported + for flash style attention. + max_num_prompt_seqs: The maximum number of prompt sequences that can be + processed in a single iteration. + flash_style: Whether to use flash style attention. Only support + LLaMA models. """ def __init__( @@ -461,6 +469,9 @@ def __init__( max_num_seqs: int, max_model_len: int, max_paddings: int, + max_chunked_prefill_len: int = -1, + max_num_prompt_seqs: int = 1024, + flash_style: bool = False, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -471,10 +482,15 @@ def __init__( self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.max_paddings = max_paddings + self.chunked_prefill_enabled = max_chunked_prefill_len != -1 + self.max_chunked_prefill_len = max_chunked_prefill_len + self.max_num_prompt_seqs = max_num_prompt_seqs + self.flash_style = flash_style self._verify_args() def _verify_args(self) -> None: - if self.max_num_batched_tokens < self.max_model_len: + if self.max_num_batched_tokens < self.max_model_len and \ + not self.chunked_prefill_enabled: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " f"smaller than max_model_len ({self.max_model_len}). " @@ -487,6 +503,11 @@ def _verify_args(self) -> None: f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " f"({self.max_num_seqs}).") + if self.chunked_prefill_enabled and not self.flash_style: + # SANG-TODO It is probably fixable. + raise ValueError( + "chunked prefill is only supported for flash style") + class DeviceConfig: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 04a48d151b3b..673119c91230 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -46,6 +46,8 @@ class EngineArgs: max_cpu_loras: Optional[int] = None flash_style: bool = False device: str = 'auto' + max_chunked_prefill_len: int = -1 + max_num_prompt_seqs: int = 256 def __post_init__(self): if self.tokenizer is None: @@ -273,6 +275,17 @@ def add_cli_args( parser.add_argument('--flash-style', action='store_true', help='use flash attention.') + parser.add_argument( + '--max-chunked-prefill-len', + type=int, + default=-1, + help='max number of prefill tokens allowed in chunked prefill' + ', -1 means no limit') + parser.add_argument( + '--max-num-prompt-seqs', + type=int, + default=1024, + help='max number of prompt sequences allowed in prefill') return parser @classmethod @@ -305,10 +318,14 @@ def create_engine_configs( self.worker_use_ray, self.max_parallel_loading_workers, self.disable_custom_all_reduce) - scheduler_config = SchedulerConfig(self.max_num_batched_tokens, - self.max_num_seqs, - model_config.max_model_len, - self.max_paddings) + scheduler_config = SchedulerConfig( + self.max_num_batched_tokens, + self.max_num_seqs, + model_config.max_model_len, + self.max_paddings, + max_chunked_prefill_len=self.max_chunked_prefill_len, + max_num_prompt_seqs=self.max_num_prompt_seqs, + flash_style=self.flash_style,) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6f5af71426d7..be2a1ca13d04 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -831,7 +831,6 @@ def step(self) -> List[RequestOutput]: >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - if not scheduler_outputs.is_empty(): # Execute the model. all_outputs = self._run_workers( diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index ffa84d7c03a7..2ae880104f40 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -36,6 +36,22 @@ class PagedAttention(nn.Module): 2. Perform (multi-head/multi-query/grouped-query) attention using either xformers or the PagedAttention custom op. 3. Return the output tensor. + + Chunked Prefill support: + + If chunked prefill is enabled, the input will include both prompt tokens + and generation tokens. The layout is as follows: + |<---------------------- num_valid_tokens -------------------------->| + |<--------- num_prompt_tokens ----->|<--- num_generation_tokens----->| + |<-prompt_0->|<-prompt_1->|...||<-gen_0->|<-gen_1->|......|| + + Notice that both num_prompt_tokens and num_generation_tokens + include padding. + + The actual prompt length and offeset are stored in cum_prompt_context_lens. + The actual num generation tokens are stored in num_generation_tokens_tensor. + To support chunked prefill, where the prompt and context might have different + length, we stored the context's length in cum_prompt_context_lens. """ def __init__( @@ -98,6 +114,9 @@ def ref_masked_attention( out = torch.einsum("hqk,khd->qhd", attn_weights, value) return out + # def _update_cache(self): + + def forward( self, query: torch.Tensor, @@ -114,9 +133,9 @@ def forward( key: shape = [batch_size, seq_len, num_kv_heads * head_size] value: shape = [batch_size, seq_len, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, - block_size, x] + block_size, x]. None if it is a profiling run. value_cache: shape = [num_blocks, num_kv_heads, head_size, - block_size] + block_size]. None if it is a profiling run. input_metadata: metadata for the inputs. Returns: shape = [batch_size, seq_len, num_heads * head_size] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 870c7b7edd40..9efe890150a0 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -117,6 +117,7 @@ def profile_num_available_blocks( gpu_memory_utilization: The fraction of the total GPU memory to use. cpu_swap_space: The size of the CPU swap space in bytes. """ + breakpoint() # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -231,6 +232,7 @@ def execute_model( if num_seq_groups == 0: return {} + breakpoint() output = self.model_runner.execute_model(seq_group_metadata_list, self.gpu_cache) return output From 71bdada54fa2294e1d6483d88e4990a943c6817d Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 1 Mar 2024 06:38:28 -0800 Subject: [PATCH 15/88] . --- tests/chunked_prefill/test_correctness.py | 4 +-- tests/models/test_models.py | 26 +++++++++--------- vllm/engine/llm_engine.py | 2 ++ vllm/entrypoints/llm.py | 1 + vllm/model_executor/layers/attention.py | 32 ++++++++++++++++++++++- vllm/worker/model_runner.py | 14 +++++++++- vllm/worker/worker.py | 3 +++ 7 files changed, 65 insertions(+), 17 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index 72fb9f73766c..b9b6353e04e6 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -7,7 +7,7 @@ MODELS = [ "JackFram/llama-68m", - "facebook/opt-125m", + # "facebook/opt-125m", ] TEST_PROMPTS = [ @@ -57,7 +57,7 @@ def test_models( flash_style=True, block_size=block_size) flash_attn_output_by_batches = [] - for i in range(10): + for i in range(2): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] flash_attn_output_by_batches.append( flash_attn_model.generate_greedy(prompts, max_tokens)) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index e44452e9893c..02afd3ac9868 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -6,19 +6,19 @@ MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", - "mistralai/Mistral-7B-v0.1", - "Deci/DeciLM-7b", - "tiiuae/falcon-7b", - "gpt2", - "bigcode/tiny_starcoder_py", - "EleutherAI/gpt-j-6b", - "EleutherAI/pythia-70m", - "bigscience/bloom-560m", - "mosaicml/mpt-7b", - "microsoft/phi-2", - "stabilityai/stablelm-3b-4e1t", - "allenai/OLMo-1B", + # "meta-llama/Llama-2-7b-hf", + # "mistralai/Mistral-7B-v0.1", + # "Deci/DeciLM-7b", + # "tiiuae/falcon-7b", + # "gpt2", + # "bigcode/tiny_starcoder_py", + # "EleutherAI/gpt-j-6b", + # "EleutherAI/pythia-70m", + # "bigscience/bloom-560m", + # "mosaicml/mpt-7b", + # "microsoft/phi-2", + # "stabilityai/stablelm-3b-4e1t", + # "allenai/OLMo-1B", ] diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6f5af71426d7..9f303e4b8209 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -831,6 +831,8 @@ def step(self) -> List[RequestOutput]: >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + print("SANG-TODO step seq_group_metadata_list length: ", + len(seq_group_metadata_list)) if not scheduler_outputs.is_empty(): # Execute the model. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fc82018d18eb..d29fabb3f47a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -152,6 +152,7 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ + print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index e04732d67d13..b9b60cd8c9ad 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -122,6 +122,7 @@ def forward( shape = [batch_size, seq_len, num_heads * head_size] """ batch_size, seq_len, hidden_size = query.shape + print("SANG-TODO query size: ", query.size()) # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -133,10 +134,12 @@ def forward( # profiling run. if key_cache is not None and value_cache is not None: if input_metadata.flash_style: + print("SANG-TODO reshape cache flash.") cache_ops.reshape_and_cache_flash( key, value, key_cache, value_cache, input_metadata.slot_mapping.flatten()) else: + print("SANG-TODO reshape cache.") cache_ops.reshape_and_cache( key, value, @@ -150,6 +153,33 @@ def forward( # normal attention if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): + print("SANG-TODO flash attn is used.") + print( + "SANG-TODO query size: ", + query.view(batch_size, seq_len, self.num_heads, + self.head_size).size()) + # if key_cache is not None and value_cache is not None: + # output2 = flash_attn_with_kvcache_paged( + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size), + # key_cache, + # value_cache, + # self.scale, + # input_metadata.block_tables, + # input_metadata.context_lens + seq_len, + # self.alibi_slopes, + # ) + # from flash_attn import flash_attn_func + # breakpoint() + # output3 = flash_attn_func( + # q=query.view(batch_size, seq_len, self.num_heads, + # self.head_size), + # k=key.view(batch_size, seq_len, self.num_kv_heads, self.head_size), + # v=value.view(batch_size, seq_len, self.num_kv_heads, self.head_size), + # softmax_scale=self.scale, + # causal=True, + # alibi_slopes=self.alibi_slopes, + # ) if self.num_kv_heads != self.num_heads: # As of Nov 2023, xformers only supports MHA. For MQA/GQA, # project the key and value tensors to the desired number of @@ -217,6 +247,7 @@ def forward( ) output = out.view_as(query) else: + # prefix-enabled attention if input_metadata.flash_style: output = flash_attn_with_kvcache_paged( query.view(batch_size, seq_len, self.num_heads, @@ -229,7 +260,6 @@ def forward( self.alibi_slopes, ) else: - # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( query, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e95b147cb25d..855daeddd1e2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -138,6 +138,8 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] + print("SANG-TODO # of requests (seq_group_metadata_list): ", + len(seq_group_metadata_list)) for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -150,14 +152,22 @@ def _prepare_prompt( prompt_lens.append(prompt_len) prefix_len = 0 prefix = seq_group_metadata.prefix + print("SANG-TODO prefix, ", prefix) if prefix is not None and prefix.computed: prefix_len = prefix.get_length() prompt_tokens = prompt_tokens[prefix_len:] prefix_block_tables.append(prefix.get_block_numbers()) + context_len = prefix_len else: prefix_block_tables.append([]) + # if seq_group_metadata.block_tables is None: + # prefix_block_tables.append([]) + # else: + # prefix_block_tables.append( + # seq_group_metadata.block_tables[seq_id]) + context_len = prefix_len # actual prompt lens - context_lens.append(prefix_len) + context_lens.append(context_len) subquery_lens.append(prompt_len - prefix_len) input_tokens.append(prompt_tokens) @@ -487,10 +497,12 @@ def prepare_input_tensors( # SANG-TODO set num prompt tokens and generations? # Prepare input tensors. if is_prompt: + print("SANG-TODO execute model prompt.") (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_prompt(seq_group_metadata_list) else: + print("SANG-TODO execute model decode.") (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_decode(seq_group_metadata_list) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 870c7b7edd40..a2c5f8e0b1e0 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -117,6 +117,7 @@ def profile_num_available_blocks( gpu_memory_utilization: The fraction of the total GPU memory to use. cpu_swap_space: The size of the CPU swap space in bytes. """ + print("SANG-TODO profile_num_available_blocks") # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -153,6 +154,7 @@ def profile_num_available_blocks( MAX_INT_32 // cache_block_size) num_cpu_blocks = min(num_cpu_blocks, MAX_INT_32 // cache_block_size) + print("SANG-TODO profile_num_available_blocks done") return num_gpu_blocks, num_cpu_blocks @@ -205,6 +207,7 @@ def execute_model( blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, ) -> Optional[SamplerOutput]: + print("SANG-TODO execute model.") if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups = len(seq_group_metadata_list) From d4c3b5ddb1a66eae256aee8257b4a4bdeb1dad54 Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 1 Mar 2024 06:46:29 -0800 Subject: [PATCH 16/88] ip? --- vllm/worker/model_runner.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 855daeddd1e2..406e110550fa 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -159,13 +159,12 @@ def _prepare_prompt( prefix_block_tables.append(prefix.get_block_numbers()) context_len = prefix_len else: - prefix_block_tables.append([]) - # if seq_group_metadata.block_tables is None: - # prefix_block_tables.append([]) - # else: - # prefix_block_tables.append( - # seq_group_metadata.block_tables[seq_id]) - context_len = prefix_len + if seq_group_metadata.block_tables is None: + prefix_block_tables.append([]) + else: + prefix_block_tables.append( + seq_group_metadata.block_tables[seq_id]) + context_len = prompt_len # actual prompt lens context_lens.append(context_len) subquery_lens.append(prompt_len - prefix_len) From baef7c66e7c22698d3b86f0a8029cac073eb1bee Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 1 Mar 2024 07:10:20 -0800 Subject: [PATCH 17/88] block tables updated correctly --- tests/chunked_prefill/test_correctness.py | 2 +- vllm/model_executor/input_metadata.py | 4 +++- vllm/model_executor/layers/attention.py | 6 ++++-- vllm/worker/model_runner.py | 15 +++++++++++---- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index b9b6353e04e6..561ef24713dd 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -57,7 +57,7 @@ def test_models( flash_style=True, block_size=block_size) flash_attn_output_by_batches = [] - for i in range(2): + for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] flash_attn_output_by_batches.append( flash_attn_model.generate_greedy(prompts, max_tokens)) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 97c23cdda069..00acf71a10b5 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -26,7 +26,8 @@ def __init__(self, is_prompt: bool, slot_mapping: torch.Tensor, max_context_len: Optional[int], context_lens: Optional[torch.Tensor], block_tables: Optional[torch.Tensor], use_cuda_graph: bool, - kv_cache_dtype: str, flash_style: bool) -> None: + kv_cache_dtype: str, flash_style: bool, + prefix_enabled: bool) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens self.max_seq_len = max_seq_len @@ -38,6 +39,7 @@ def __init__(self, is_prompt: bool, slot_mapping: torch.Tensor, self.use_cuda_graph = use_cuda_graph self.kv_cache_dtype = kv_cache_dtype self.flash_style = flash_style + self.prefix_enabled = prefix_enabled # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index b9b60cd8c9ad..051472a9d09b 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -152,7 +152,8 @@ def forward( if input_metadata.is_prompt: # normal attention if (key_cache is None or value_cache is None - or input_metadata.block_tables.numel() == 0): + # or input_metadata.block_tables.numel() == 0): + or not input_metadata.prefix_enabled): print("SANG-TODO flash attn is used.") print( "SANG-TODO query size: ", @@ -256,7 +257,8 @@ def forward( value_cache, self.scale, input_metadata.block_tables, - input_metadata.context_lens, + # input_metadata.context_lens, + # seq_len, self.alibi_slopes, ) else: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 406e110550fa..db4165141b2a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -159,6 +159,7 @@ def _prepare_prompt( prefix_block_tables.append(prefix.get_block_numbers()) context_len = prefix_len else: + prefix_block_tables.append([]) if seq_group_metadata.block_tables is None: prefix_block_tables.append([]) else: @@ -267,7 +268,9 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.flash_style) + flash_style=self.flash_style, + prefix_enabled=prefix is not None + and prefix.computed) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -394,7 +397,8 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.flash_style) + flash_style=self.flash_style, + prefix_enabled=False) return (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -540,6 +544,7 @@ def prepare_input_tensors( sampling_metadata.selected_token_indices, "lora_requests": lora_requests, "lora_mapping": lora_mapping, + "prefix_enabled": input_metadata.prefix_enabled } broadcast_tensor_dict(metadata_dict, src=0) else: @@ -559,7 +564,8 @@ def prepare_input_tensors( block_tables=metadata_dict["block_tables"], use_cuda_graph=metadata_dict["use_cuda_graph"], kv_cache_dtype=metadata_dict["kv_cache_dtype"], - flash_style=self.flash_style) + flash_style=self.flash_style, + prefix_enabled=metadata_dict["prefix_enabled"]) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -743,7 +749,8 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: block_tables=block_tables[:batch_size], use_cuda_graph=True, kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.flash_style) + flash_style=self.flash_style, + prefix_enabled=False) if self.lora_config: lora_mapping = LoRAMapping( From a12ec689d410dc15970cda7ad922dfeb528b1957 Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 1 Mar 2024 07:12:15 -0800 Subject: [PATCH 18/88] hopefully tests pass --- vllm/engine/llm_engine.py | 4 +-- vllm/entrypoints/llm.py | 2 +- vllm/model_executor/layers/attention.py | 38 ++++++------------------- vllm/worker/model_runner.py | 10 +++---- vllm/worker/worker.py | 6 ++-- 5 files changed, 19 insertions(+), 41 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9f303e4b8209..a7784a2ca443 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -831,8 +831,8 @@ def step(self) -> List[RequestOutput]: >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - print("SANG-TODO step seq_group_metadata_list length: ", - len(seq_group_metadata_list)) + # print("SANG-TODO step seq_group_metadata_list length: ", + # len(seq_group_metadata_list)) if not scheduler_outputs.is_empty(): # Execute the model. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d29fabb3f47a..2af5a47f66e8 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -152,7 +152,7 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - print("SANG-TODO generate: ", prompts, prompt_token_ids) + # print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 051472a9d09b..b99efdcc160a 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -122,7 +122,7 @@ def forward( shape = [batch_size, seq_len, num_heads * head_size] """ batch_size, seq_len, hidden_size = query.shape - print("SANG-TODO query size: ", query.size()) + # print("SANG-TODO query size: ", query.size()) # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -134,12 +134,12 @@ def forward( # profiling run. if key_cache is not None and value_cache is not None: if input_metadata.flash_style: - print("SANG-TODO reshape cache flash.") + # print("SANG-TODO reshape cache flash.") cache_ops.reshape_and_cache_flash( key, value, key_cache, value_cache, input_metadata.slot_mapping.flatten()) else: - print("SANG-TODO reshape cache.") + # print("SANG-TODO reshape cache.") cache_ops.reshape_and_cache( key, value, @@ -154,33 +154,11 @@ def forward( if (key_cache is None or value_cache is None # or input_metadata.block_tables.numel() == 0): or not input_metadata.prefix_enabled): - print("SANG-TODO flash attn is used.") - print( - "SANG-TODO query size: ", - query.view(batch_size, seq_len, self.num_heads, - self.head_size).size()) - # if key_cache is not None and value_cache is not None: - # output2 = flash_attn_with_kvcache_paged( - # query.view(batch_size, seq_len, self.num_heads, - # self.head_size), - # key_cache, - # value_cache, - # self.scale, - # input_metadata.block_tables, - # input_metadata.context_lens + seq_len, - # self.alibi_slopes, - # ) - # from flash_attn import flash_attn_func - # breakpoint() - # output3 = flash_attn_func( - # q=query.view(batch_size, seq_len, self.num_heads, - # self.head_size), - # k=key.view(batch_size, seq_len, self.num_kv_heads, self.head_size), - # v=value.view(batch_size, seq_len, self.num_kv_heads, self.head_size), - # softmax_scale=self.scale, - # causal=True, - # alibi_slopes=self.alibi_slopes, - # ) + # print("SANG-TODO flash attn is used.") + # print( + # "SANG-TODO query size: ", + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size).size()) if self.num_kv_heads != self.num_heads: # As of Nov 2023, xformers only supports MHA. For MQA/GQA, # project the key and value tensors to the desired number of diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index db4165141b2a..4251a5a6b61f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -138,8 +138,8 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] - print("SANG-TODO # of requests (seq_group_metadata_list): ", - len(seq_group_metadata_list)) + # print("SANG-TODO # of requests (seq_group_metadata_list): ", + # len(seq_group_metadata_list)) for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -152,7 +152,7 @@ def _prepare_prompt( prompt_lens.append(prompt_len) prefix_len = 0 prefix = seq_group_metadata.prefix - print("SANG-TODO prefix, ", prefix) + # print("SANG-TODO prefix, ", prefix) if prefix is not None and prefix.computed: prefix_len = prefix.get_length() prompt_tokens = prompt_tokens[prefix_len:] @@ -500,12 +500,12 @@ def prepare_input_tensors( # SANG-TODO set num prompt tokens and generations? # Prepare input tensors. if is_prompt: - print("SANG-TODO execute model prompt.") + # print("SANG-TODO execute model prompt.") (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_prompt(seq_group_metadata_list) else: - print("SANG-TODO execute model decode.") + # print("SANG-TODO execute model decode.") (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_decode(seq_group_metadata_list) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index a2c5f8e0b1e0..d8a1c67bb123 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -117,7 +117,7 @@ def profile_num_available_blocks( gpu_memory_utilization: The fraction of the total GPU memory to use. cpu_swap_space: The size of the CPU swap space in bytes. """ - print("SANG-TODO profile_num_available_blocks") + # print("SANG-TODO profile_num_available_blocks") # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -154,7 +154,7 @@ def profile_num_available_blocks( MAX_INT_32 // cache_block_size) num_cpu_blocks = min(num_cpu_blocks, MAX_INT_32 // cache_block_size) - print("SANG-TODO profile_num_available_blocks done") + # print("SANG-TODO profile_num_available_blocks done") return num_gpu_blocks, num_cpu_blocks @@ -207,7 +207,7 @@ def execute_model( blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, ) -> Optional[SamplerOutput]: - print("SANG-TODO execute model.") + # print("SANG-TODO execute model.") if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups = len(seq_group_metadata_list) From e40bc459e4190321db2f0082fad9a216e8cc6216 Mon Sep 17 00:00:00 2001 From: sang Date: Sat, 2 Mar 2024 23:22:44 -0800 Subject: [PATCH 19/88] [2/n] update sequence data --- tests/core/utils.py | 5 +- tests/samplers/test_sampler.py | 4 + tests/test_sequence.py | 113 +----------------------- tests/worker/spec_decode/utils.py | 1 + tests/worker/test_model_runner.py | 4 + vllm/config.py | 1 - vllm/core/scheduler.py | 1 + vllm/engine/arg_utils.py | 3 +- vllm/model_executor/layers/attention.py | 2 +- vllm/sequence.py | 62 ++++++++++++- vllm/worker/model_runner.py | 1 + vllm/worker/worker.py | 1 - 12 files changed, 79 insertions(+), 119 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index 4c8d29c2246b..aa2e921697eb 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -17,10 +17,7 @@ def create_dummy_prompt( # and prompt "0 ... block_size". prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), - prompt_str, - prompt_tokens, - block_size) + prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) seq_group = SequenceGroup(request_id, [prompt], SamplingParams(), time.time(), time.perf_counter()) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 31e865f42ff3..dd098c3df3a8 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -61,6 +61,7 @@ def _do_sample( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, + is_chunked_prefill=False, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, @@ -237,6 +238,7 @@ def test_sampler_mixed(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, + is_chunked_prefill=False, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, @@ -313,6 +315,7 @@ def pick_ith(token_ids, logits): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, + is_chunked_prefill=False, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), @@ -366,6 +369,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, + is_chunked_prefill=False, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=SamplingParams( temperature=1, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index da6ef99367a9..5af4f9efb601 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,6 +1,7 @@ import pytest from vllm.sequence import SequenceData, Sequence + @pytest.fixture(name="sequence") def create_sequence(seq_len: int, block_size: int) -> Sequence: return Sequence( @@ -10,99 +11,10 @@ def create_sequence(seq_len: int, block_size: int) -> Sequence: block_size=block_size, ) -# @pytest.mark.parametrize("block_size", [1, 2, 4, 8]) -# @pytest.mark.parametrize("num_empty_slots", list(range(8))) -# @pytest.mark.parametrize("seq_len", [0, 1, 100]) -# def test_ensure_num_empty_slots(block_size: int, seq_len: int, -# num_empty_slots: int, sequence: Sequence): -# """Verify ensure_num_empty_slots correctly ensures empty slots. -# """ -# sequence.ensure_num_empty_slots(num_empty_slots) -# num_total_slots = block_size * len(sequence.logical_token_blocks) -# measured_num_empty_slots = sum(block.get_num_empty_slots() -# for block in sequence.logical_token_blocks) -# num_full_slots = num_total_slots - measured_num_empty_slots -# assert measured_num_empty_slots >= num_empty_slots -# assert num_full_slots == seq_len -# @pytest.fixture(name="sequence_with_extra_blocks") -# def add_blocks_to_sequence(sequence: Sequence, -# num_extra_blocks: int) -> Sequence: -# for _ in range(num_extra_blocks): -# sequence._append_logical_block() # pylint: disable=protected-access -# return sequence -# @pytest.mark.parametrize("num_tokens_to_append", [1, 10]) -# @pytest.mark.parametrize("seq_len", [0, 1, 100]) -# @pytest.mark.parametrize("block_size", [1, 2, 4, 8]) -# @pytest.mark.parametrize("num_extra_blocks", [0, 1, 100]) -# def test_append_tokens_correct_placement_in_blocks( -# num_tokens_to_append: int, sequence_with_extra_blocks: Sequence, -# block_size: int, seq_len: int): -# """Verify new tokens are appended at the end of the sequence, instead of the -# last block. This enables preallocated empty slots, which requires empty -# blocks after the sequence. -# """ -# token_ids = list(range(num_tokens_to_append)) -# logprobs = [{token_id: 0.0} for token_id in token_ids] -# seq_len_before_append = seq_len -# seq_len_after_append = seq_len_before_append + num_tokens_to_append -# sequence_with_extra_blocks.append_token_ids(token_ids, logprobs) -# # Assert number of full slots equal to total sequence length. -# assert sum(block_size - block.get_num_empty_slots() -# for block in sequence_with_extra_blocks.logical_token_blocks -# ) == seq_len_after_append -# # Assert each appended token is immediately after the original sequence. -# for i, token_id in enumerate(token_ids): -# index = seq_len_before_append + i -# block_token_ids = sequence_with_extra_blocks.logical_token_blocks[ -# index // block_size].get_token_ids() -# assert block_token_ids[index % block_size] == token_id -# @pytest.mark.parametrize("generation_or_prefill", ["generation", "prefill"]) -# @pytest.mark.parametrize("num_output_tokens", [0, 1, 10]) -# @pytest.mark.parametrize("num_prompt_tokens", [5, 50]) -# def test_get_unprocessed_tokens(generation_or_prefill: str, -# num_output_tokens: int, -# num_prompt_tokens: int): -# """Verify sequence data correctly tracks the number of processed tokens. -# """ -# is_generation = generation_or_prefill == "generation" -# if is_generation: -# generated_token_id = 1337 -# prompt_token_ids = list(range(num_prompt_tokens)) -# output_token_ids = list(range(num_output_tokens)) -# data = SequenceData( -# prompt_token_ids=prompt_token_ids[:], -# output_token_ids=output_token_ids[:], -# ) -# if is_generation: -# data.append_token_ids([generated_token_id], logprobs=[0.0]) -# unprocessed_token_ids = data.get_unprocessed_token_ids() -# unprocessed_token_positions = data.get_unprocessed_token_positions() -# if is_generation: -# assert unprocessed_token_ids == [generated_token_id] -# assert unprocessed_token_positions == [ -# num_prompt_tokens + num_output_tokens -# ] -# else: -# assert unprocessed_token_ids == prompt_token_ids + output_token_ids -# assert unprocessed_token_positions == list( -# range(num_prompt_tokens + num_output_tokens)) -# # Reset processed tokens. Everything should behave like a prompt run now. -# data.reset_processed_tokens() -# unprocessed_token_ids = data.get_unprocessed_token_ids() -# unprocessed_token_positions = data.get_unprocessed_token_positions() -# if is_generation: -# assert unprocessed_token_ids == (prompt_token_ids + output_token_ids + -# [generated_token_id]) -# assert unprocessed_token_positions == list( -# range(num_prompt_tokens + num_output_tokens + 1)) -# if not is_generation: -# assert unprocessed_token_ids == prompt_token_ids + output_token_ids -# assert unprocessed_token_positions == list( -# range(num_prompt_tokens + num_output_tokens)) - +# TODO(sang): Upstream more tests. def test_sequence_data_prefill(): - seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4], output_token_ids=[]) + seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4]) assert seq_data.get_prefill_range() == (0, 0) assert seq_data.get_num_unprefilled() == 4 @@ -122,21 +34,4 @@ def test_sequence_data_prefill(): assert seq_data.get_prefill_range() == (4, 4) # append tokens and reset, simulating recompute - seq_data.append_token_ids([1], logprobs=[0.0]) - seq_data.reset_processed_tokens() - - # after reset, the prefill range should be reset to 0 - # but the num_unprefilled should include. - # output tokens - assert seq_data.get_prefill_range() == (0, 0) - assert seq_data.get_num_unprefilled() == 5 - - # advance by 2 - assert seq_data.advance_prefill_range(2) == 2 - assert seq_data.get_num_unprefilled() == 3 - assert seq_data.get_prefill_range() == (0, 2) - - # advance range by 3 even though there are only 2 unprefilled tokens - assert seq_data.advance_prefill_range(3) == 3 - assert seq_data.get_num_unprefilled() == 0 - assert seq_data.get_prefill_range() == (2, 5) \ No newline at end of file + seq_data.append_token_id(1, logprob=0.0) diff --git a/tests/worker/spec_decode/utils.py b/tests/worker/spec_decode/utils.py index 8d74509fea48..8bfb49b9e8a0 100644 --- a/tests/worker/spec_decode/utils.py +++ b/tests/worker/spec_decode/utils.py @@ -153,6 +153,7 @@ def create_seq_group_metadata_from_prompts( SequenceGroupMetadata( request_id=str(i), is_prompt=len(cont_token_ids) == 0, + is_chunked_prefill=False, seq_data={ i: SequenceData(prompt_token_ids=prompt_token_ids[:] + diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index f44895a728c7..0e11d74a3a1a 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -21,6 +21,7 @@ def test_prepare_prompt(): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, + is_chunked_prefill=False, seq_data={0: SequenceData(seq_data)}, sampling_params=SamplingParams(temperature=0), block_tables={0: [1]}, @@ -48,3 +49,6 @@ def test_prepare_prompt(): device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) + + +# TODO(sang) Test chunked prefill prompt. diff --git a/vllm/config.py b/vllm/config.py index c8e4c728ab60..f3d79197d0af 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -509,7 +509,6 @@ def _verify_args(self) -> None: "chunked prefill is only supported for flash style") - class DeviceConfig: def __init__(self, device: str = "auto") -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5e7cc3091d77..c816529234e5 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -379,6 +379,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data[seq_id] = seq.data block_tables[seq_id] = self.block_manager.get_block_table(seq) + # SANG-TODO Update chunked prefill related info. seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=scheduler_outputs.prompt_run, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 673119c91230..924b40d7d0d6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -325,7 +325,8 @@ def create_engine_configs( self.max_paddings, max_chunked_prefill_len=self.max_chunked_prefill_len, max_num_prompt_seqs=self.max_num_prompt_seqs, - flash_style=self.flash_style,) + flash_style=self.flash_style, + ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 5122c0d1ef48..679f483b59c1 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -116,7 +116,6 @@ def ref_masked_attention( # def _update_cache(self): - def forward( self, query: torch.Tensor, @@ -298,6 +297,7 @@ def forward( # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) + def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, diff --git a/vllm/sequence.py b/vllm/sequence.py index 040e9756e15c..d4bd78068245 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -2,7 +2,7 @@ import copy import enum from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union from vllm.block import LogicalTokenBlock from vllm.prefix import Prefix @@ -79,6 +79,8 @@ class SequenceData: prompt_token_ids: The token IDs of the prompt. output_token_ids: The token IDs of the output. cumulative_logprob: The cumulative log probability of the output. + _prefill_start: The start index of the prefill. + _prefill_end: The end index of the prefill. """ def __init__( @@ -87,7 +89,9 @@ def __init__( ) -> None: self.prompt_token_ids = prompt_token_ids self.output_token_ids: List[int] = [] - self.cumulative_logprob = 0.0 + self.cumulative_logprob: float = 0.0 + self._prefill_start: int = 0 + self._prefill_end: int = 0 def append_token_id(self, token_id: int, logprob: float) -> None: self.output_token_ids.append(token_id) @@ -105,6 +109,36 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.prompt_token_ids + self.output_token_ids + def advance_prefill_range(self, size: int) -> int: + """Advance the prefill range by the specified amount + + Args: + size: The amount to advance the prefill range. + Returns: + The actual number of advanced tokens. + """ + self._prefill_start = self._prefill_end + # The increased range could be larger than the seq length. + # Clamp it to the seq length. + # Note that we use prompt_len + output_len instead of + # prompt_len here. This is because during recompute + # we need to prefill for both prompt and output. + self._prefill_end = min(self._prefill_end + size, self.get_len()) + return self._prefill_end - self._prefill_start + + def get_prefill_range(self) -> Tuple[int, int]: + """Returns the prefill range.""" + return self._prefill_start, self._prefill_end + + def get_num_unprefilled(self) -> int: + """Return the number of prefil tokens that are not completed. + + Note that we use prompt_len + output_len instead of + prompt_len here. This is because during recompute + we need to prefill for both prompt and output. + """ + return self.get_len() - self._prefill_end + def get_last_token_id(self) -> int: if not self.output_token_ids: return self.prompt_token_ids[-1] @@ -363,6 +397,24 @@ def get_unfinished_seqs(self) -> List[Sequence]: def get_finished_seqs(self) -> List[Sequence]: return [seq for seq in self.seqs_dict.values() if seq.is_finished()] + def advance_prefill_range(self, size: int) -> int: + """Advance the prefill range by the specified amount. + + Args: + size: The amount to advance the prefill range. + Returns: + The actual number of advanced tokens. + """ + # All sequences in the group should have the same prompt. + return [ + seq.data.advance_prefill_range(size) + for seq in self.seqs_dict.values() + ][0] + + def get_num_unprefilled(self) -> int: + # All sequences in the group should have the same prompt. + return list(self.seqs_dict.values())[0].data.get_num_unprefilled() + def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: return len(self.get_seqs(status)) @@ -402,6 +454,10 @@ class SequenceGroupMetadata: Args: request_id: The ID of the request. is_prompt: Whether the request is at prompt stage. + is_chunked_prefill: Whether the request is at chunked prefill stage. + If a prefill request is chunked, the first ~ n-1th chunks are + chunked prefill requests. + Note that chunked_prefill is also a prompt stage. seq_data: The sequence data. (Seq id -> sequence data) sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block @@ -415,6 +471,7 @@ def __init__( self, request_id: str, is_prompt: bool, + is_chunked_prefill: bool, seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], @@ -424,6 +481,7 @@ def __init__( ) -> None: self.request_id = request_id self.is_prompt = is_prompt + self.is_chunked_prefill = is_chunked_prefill self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4251a5a6b61f..95ad335859f2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -652,6 +652,7 @@ def profile_run(self) -> None: seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, + is_chunked_prefill=False, seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b41f5c4249ed..d8a1c67bb123 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -234,7 +234,6 @@ def execute_model( if num_seq_groups == 0: return {} - breakpoint() output = self.model_runner.execute_model(seq_group_metadata_list, self.gpu_cache) return output From d85670f41fceae6ced169925ee66a51a9502f74e Mon Sep 17 00:00:00 2001 From: sang Date: Sun, 3 Mar 2024 00:49:24 -0800 Subject: [PATCH 20/88] [2/n] add prefill range apis --- tests/chunked_prefill/test_correctness.py | 46 ++++++++++++----------- vllm/core/scheduler.py | 2 + vllm/engine/llm_engine.py | 4 +- vllm/model_executor/input_metadata.py | 17 +++++++-- vllm/worker/model_runner.py | 32 ++++++++++++++-- 5 files changed, 70 insertions(+), 31 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index b7180617019f..b40f43ee6892 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -29,9 +29,11 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("block_size", [256]) -@pytest.mark.parametrize("max_chunked_prefill_len", [-1, 16, 64]) -@pytest.mark.parametrize("max_num_prompt_seqs", [1, 2, 100]) -@pytest.mark.parametrize("tensor_parallel_size", [1, 2]) +@pytest.mark.parametrize("max_chunked_prefill_len", [16]) +# SANG-TODO +# @pytest.mark.parametrize("max_num_prompt_seqs", [1, 2, 100]) +@pytest.mark.parametrize("max_num_prompt_seqs", [1]) +@pytest.mark.parametrize("tensor_parallel_size", [2]) def test_models( vllm_runner, model: str, @@ -49,19 +51,19 @@ def test_models( f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}" ) - # print("loading page attention models..") - # pg_model = vllm_runner(model, dtype=dtype) - # expected_outputs = [] + print("loading page attention models..") + pg_model = vllm_runner(model, dtype=dtype) + expected_outputs = [] - # print("generating tokens...") - # expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) - # print("generating tokens finished") + print("generating tokens...") + expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) + print("generating tokens finished") - # del pg_model + del pg_model - # destroy_model_parallel() - # gc.collect() - # torch.cuda.empty_cache() + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() flash_attn_model = vllm_runner( model, @@ -83,12 +85,12 @@ def test_models( gc.collect() torch.cuda.empty_cache() - # for flash_attn_outputs in flash_attn_output_by_batches: - # for i in range(len(flash_attn_outputs)): - # fa_output_ids, fa_output_str = flash_attn_outputs[i] - # vllm_output_ids, vllm_output_str = expected_outputs[ - # i % len(expected_outputs)] - # assert fa_output_ids == vllm_output_ids, ( - # f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" - # f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" - # ) + for flash_attn_outputs in flash_attn_output_by_batches: + for i in range(len(flash_attn_outputs)): + fa_output_ids, fa_output_str = flash_attn_outputs[i] + vllm_output_ids, vllm_output_str = expected_outputs[ + i % len(expected_outputs)] + assert fa_output_ids == vllm_output_ids, ( + f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" + f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" + ) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c816529234e5..4565990d58e5 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -249,6 +249,7 @@ def _schedule(self) -> SchedulerOutputs: curr_loras.add(lora_int_id) self.waiting.popleft() self._allocate(seq_group) + seq_group.advance_prefill_range(num_prompt_tokens) self.running.append(seq_group) num_curr_seqs += num_new_seqs scheduled.append(seq_group) @@ -383,6 +384,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=scheduler_outputs.prompt_run, + is_chunked_prefill=False, seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 24057ed7809f..1c5b3b5338f3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -831,8 +831,8 @@ def step(self) -> List[RequestOutput]: >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - print("SANG-TODO step seq_group_metadata_list length: ", - len(seq_group_metadata_list)) + # print("SANG-TODO step seq_group_metadata_list length: ", + # len(seq_group_metadata_list)) if not scheduler_outputs.is_empty(): # Execute the model. all_outputs = self._run_workers( diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 00acf71a10b5..2f298689adf8 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -8,6 +8,8 @@ class InputMetadata: Args: prompt_lens: Lengths of prompts. + num_chunked_prefill: Number of chunked prefill requests across + sequences. slot_mapping: The address to write the new KV to of each token. index: token_id, value: address within kv_cache. max_context_len: The maximum context length. @@ -21,7 +23,7 @@ class InputMetadata: """ def __init__(self, is_prompt: bool, slot_mapping: torch.Tensor, - prompt_lens: Optional[torch.Tensor], + prompt_lens: Optional[torch.Tensor], num_chunked_prefill: int, max_seq_len: Optional[int], start_loc: Optional[torch.Tensor], max_context_len: Optional[int], context_lens: Optional[torch.Tensor], @@ -30,6 +32,7 @@ def __init__(self, is_prompt: bool, slot_mapping: torch.Tensor, prefix_enabled: bool) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens + self.num_chunked_prefill = num_chunked_prefill self.max_seq_len = max_seq_len self.start_loc = start_loc self.max_context_len = max_context_len @@ -72,17 +75,23 @@ def __init__(self, is_prompt: bool, slot_mapping: torch.Tensor, # dim=0, # dtype=self.cum_prompt_query_lens.dtype, # out=self.cum_prompt_query_lens[1:]) + # torch.cumsum(self.context_lens[:self.num_prompts], + # dim=0, + # dtype=self.cum_prompt_context_lens.dtype, + # out=self.cum_prompt_context_lens[1:]) # # TODO: this will be different once we support chunked prefills. # self.cum_prompt_context_lens = self.cum_prompt_query_lens - # self.max_context_len = max(self.max_context_len, self.max_prompt_len) + # self.max_context_len = max_context_len # # Generation related metadata # # This value might include padding if CudaGraph is enabled. # self.num_generation_tokens = num_generation_tokens # # This is the source of truth for the number of generation tokens. - # self.num_generation_tokens_tensor = torch.cuda.IntTensor( - # [num_generation_tokens]) + # self.num_generation_tokens_tensor = torch.tensor( + # [self.num_generation_tokens], + # dtype=torch.int32 if self.flash_style else torch.long, + # device='cuda') def __repr__(self) -> str: return ("InputMetadata(" diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 95ad335859f2..8cee718dbe2e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -138,6 +138,7 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] + num_chunked_prefill = 0 # print("SANG-TODO # of requests (seq_group_metadata_list): ", # len(seq_group_metadata_list)) for seq_group_metadata in seq_group_metadata_list: @@ -146,8 +147,17 @@ def _prepare_prompt( assert len(seq_ids) == 1 seq_id = seq_ids[0] + if seq_group_metadata.is_chunked_prefill: + num_chunked_prefill += 1 + # TODO(sang): Support it. + if prefix is not None: + raise RuntimeError( + "chunked prefill cannot be used with prefix caching now." + ) + seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() + prefill_start, prefill_end = seq_data.get_prefill_range() + prompt_tokens = seq_data.get_token_ids()[prefill_start:prefill_end] prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) prefix_len = 0 @@ -173,8 +183,11 @@ def _prepare_prompt( input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. + # NOTE(sang): prefill_end is always # of prompts if chunked + # prefill is not enabled. Prefix caching is not working with + # chunked prefill now. input_positions.append( - list(range(prefix_len, prefix_len + len(prompt_tokens)))) + list(range(prefix_len, prefix_len + prefill_end))) lora_id = seq_group_metadata.lora_int_id @@ -207,7 +220,14 @@ def _prepare_prompt( "Prefix caching is currently not supported with " "sliding window attention") start_idx = max(0, prompt_len - self.sliding_window) - for i in range(prefix_len, prompt_len): + + # If chunked prefill is enabled, prefix_len is always 0. + # TODO(sang) This is hack. We should clean it up when + # supporting prefix cache + chunked prefill. + if prefix_len == 0: + prefix_len = prefill_start + + for i in range(prefix_len, prefill_end): if i < start_idx: slot_mapping[-1].append(_PAD_SLOT_ID) continue @@ -261,6 +281,7 @@ def _prepare_prompt( input_metadata = InputMetadata(is_prompt=True, slot_mapping=slot_mapping, prompt_lens=prompt_lens_tensor, + num_chunked_prefill=num_chunked_prefill, max_seq_len=max_prompt_len, start_loc=start_loc_tensor, max_context_len=None, @@ -390,6 +411,7 @@ def _prepare_decode( input_metadata = InputMetadata(is_prompt=False, slot_mapping=slot_mapping, prompt_lens=None, + num_chunked_prefill=0, max_seq_len=None, start_loc=None, max_context_len=max_context_len, @@ -533,6 +555,7 @@ def prepare_input_tensors( "is_prompt": input_metadata.is_prompt, "slot_mapping": input_metadata.slot_mapping, "prompt_lens": input_metadata.prompt_lens, + "num_chunked_prefill": input_metadata.num_chunked_prefill, "max_seq_len": input_metadata.max_seq_len, "start_loc": input_metadata.start_loc, "max_context_len": input_metadata.max_context_len, @@ -557,6 +580,7 @@ def prepare_input_tensors( is_prompt=metadata_dict["is_prompt"], slot_mapping=metadata_dict["slot_mapping"], prompt_lens=metadata_dict["prompt_lens"], + num_chunked_prefill=metadata_dict["num_chunked_prefill"], max_seq_len=metadata_dict["max_seq_len"], start_loc=metadata_dict["start_loc"], max_context_len=metadata_dict["max_context_len"], @@ -649,6 +673,7 @@ def profile_run(self) -> None: seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) seq_data = SequenceData([0] * seq_len) + seq_data.advance_prefill_range(seq_len) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, @@ -743,6 +768,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: is_prompt=False, slot_mapping=slot_mapping[:batch_size], prompt_lens=None, + num_chunked_prefill=0, max_seq_len=None, start_loc=None, max_context_len=self.max_context_len_to_capture, From 08c8541ca04622c2e00811ef4f86caa993342631 Mon Sep 17 00:00:00 2001 From: sang Date: Sun, 3 Mar 2024 08:37:25 -0800 Subject: [PATCH 21/88] . --- tests/chunked_prefill/test_correctness.py | 6 ++--- vllm/core/block_manager.py | 2 ++ vllm/entrypoints/llm.py | 2 +- vllm/model_executor/layers/attention.py | 28 ++++++++++++++++------- vllm/worker/worker.py | 4 ++-- 5 files changed, 28 insertions(+), 14 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index 561ef24713dd..17ce58149fab 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -53,9 +53,9 @@ def test_models( torch.cuda.empty_cache() flash_attn_model = vllm_runner(model, - dtype=dtype, - flash_style=True, - block_size=block_size) + dtype=dtype,) + # flash_style=True, + # block_size=block_size) flash_attn_output_by_batches = [] for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 08d519ab767a..b5ae90e217ba 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -200,6 +200,8 @@ def allocate(self, seq_group: SequenceGroup) -> None: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) + print("allcate a new block for seq_group", seq_group) + print("block", block) block_table.append(block) # Assign the block table for each sequence. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 93adc98adda5..e1932bb761ef 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -146,7 +146,7 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - # print("SANG-TODO generate: ", prompts, prompt_token_ids) + print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 17a0bedf9e0c..f80a47565e95 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -153,6 +153,7 @@ def forward( # normal attention if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): + # or not input_metadata.prefix_enabled): # print("SANG-TODO flash attn is used.") # print( # "SANG-TODO query size: ", @@ -212,7 +213,6 @@ def forward( query = query.unflatten(0, (batch_size, seq_len)) key = key.unflatten(0, (batch_size, seq_len)) value = value.unflatten(0, (batch_size, seq_len)) - out = xops.memory_efficient_attention_forward( query, key, @@ -224,6 +224,19 @@ def forward( (is_hip()) else None, ) output = out.view_as(query) + # if key_cache is not None and value_cache is not None: + # breakpoint() + # output2 = flash_attn_with_kvcache_paged( + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size), + # key_cache, + # value_cache, + # self.scale, + # input_metadata.block_tables, + # input_metadata.context_lens, + # self.alibi_slopes, + # ) + # breakpoint() else: if input_metadata.flash_style: output = flash_attn_with_kvcache_paged( @@ -237,6 +250,8 @@ def forward( self.alibi_slopes, ) else: + print("SANG-TODO context attention") + breakpoint() # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( @@ -257,6 +272,8 @@ def forward( else: # Decoding run. + # breakpoint() + print("SANG-TODO decoding") if input_metadata.flash_style: output = flash_attn_with_kvcache_paged( query.view(batch_size, seq_len, self.num_heads, @@ -419,17 +436,12 @@ def flash_attn_with_kvcache_paged( # Inplace update is slow. We don't use it. # We assume kvcache is already updated before # calling this API. - None, - None, - rotary_cos=None, - rotary_sin=None, + None, # key + None, # value cache_seqlens=context_lens, - cache_batch_idx=None, block_table=block_tables, softmax_scale=scale, causal=True, - window_size=(-1, -1), - rotary_interleaved=False, alibi_slopes=alibi_slopes, num_splits=0, ) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d8a1c67bb123..02975bcd5326 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -117,7 +117,7 @@ def profile_num_available_blocks( gpu_memory_utilization: The fraction of the total GPU memory to use. cpu_swap_space: The size of the CPU swap space in bytes. """ - # print("SANG-TODO profile_num_available_blocks") + print("SANG-TODO profile_num_available_blocks") # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -154,7 +154,7 @@ def profile_num_available_blocks( MAX_INT_32 // cache_block_size) num_cpu_blocks = min(num_cpu_blocks, MAX_INT_32 // cache_block_size) - # print("SANG-TODO profile_num_available_blocks done") + print("SANG-TODO profile_num_available_blocks done") return num_gpu_blocks, num_cpu_blocks From 3bac9af4f4166b0c059a355c9ece5bfe56e9613d Mon Sep 17 00:00:00 2001 From: sang Date: Sun, 3 Mar 2024 12:02:03 -0800 Subject: [PATCH 22/88] ip --- tests/chunked_prefill/test_correctness.py | 9 +- tests/kernels/test_flash_attention.py | 252 +++++++++++++++++- vllm/entrypoints/llm.py | 2 +- vllm/model_executor/layers/attention.py | 150 ++++++----- .../layers/triton_kernel/prefix_prefill.py | 2 +- vllm/worker/worker.py | 4 +- 6 files changed, 344 insertions(+), 75 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index 17ce58149fab..d6cb0a4dc0ac 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -53,10 +53,12 @@ def test_models( torch.cuda.empty_cache() flash_attn_model = vllm_runner(model, - dtype=dtype,) + dtype=dtype) # flash_style=True, # block_size=block_size) flash_attn_output_by_batches = [] + # flash_attn_output_by_batches.extend(flash_attn_model.generate_greedy(TEST_PROMPTS, max_tokens)) + flash_attn_output_by_batches = [] for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] flash_attn_output_by_batches.append( @@ -64,6 +66,11 @@ def test_models( del flash_attn_model + # for e, f in zip(expected_outputs, flash_attn_output_by_batches): + # # print("expected: ", e[1]) + # # print("flash: ", f[1]) + # assert e[1] == f[1] + destroy_model_parallel() gc.collect() torch.cuda.empty_cache() diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index 74c9d3008f17..2938dff7cac8 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -1,5 +1,5 @@ import random -from typing import Optional, Tuple +from typing import Optional, Tuple, List import pytest import torch @@ -238,3 +238,253 @@ def test_flash_paged_attention( ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +def ref_multi_query_kv_attention( + cu_seq_lens: List[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + ref_outputs = [] + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask. + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype, device="cuda") + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +def ref_multi_query_kv_attention_padded( + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + cu_seq_lens: List[int], + context_lens: List[int], + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + block_size = value_cache.shape[-3] + ref_outputs = [] + + for i in range(num_seqs): + q_start_idx = cu_seq_lens[i] + q_end_idx = cu_seq_lens[i + 1] + seq_len = q_end_idx - q_start_idx + + context_len = context_lens[i] + + block_table = block_tables[i] + keys = [] + values = [] + + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, block_offset, :, :] + keys.append(k) + + v = value_cache[block_number, block_offset, :, :] + values.append(v) + + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + q = query[q_start_idx:q_end_idx, :, :] + k = keys[:context_len, :, :] + v = values[:context_len, :, :] + + assert seq_len <= context_len + + # pad q if seq_len is less than context_len + # this is for correct calculation of attention. + if seq_len < context_len: + indices = [i % seq_len for i in range(context_len - seq_len)] + q_left_pad = q[indices, :, :] + q = torch.cat([q_left_pad, q], dim=0) + + # Create attention mask. + attn_mask = torch.triu(torch.ones(context_len, + context_len, + dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype, device="cuda") + ref_output = ref_masked_attention( + q, + k, + v, + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output[-seq_len:, :, :]) + breakpoint + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +def is_a100(): + return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 + + +if not is_a100(): + NUM_HEADS_SMALL = [(16, 16), (16, 8)] + MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) + + +@pytest.mark.parametrize("num_seqs", [17]) +@pytest.mark.parametrize("num_heads", [(40, 40)]) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("version", ["flash"]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("block_size", [256]) +@torch.inference_mode() +def test_multi_query_kv_attention( + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + version: str, + seed: int, + block_size: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. + # As the xformers library is already tested with its own tests, we can use + # a smaller MAX_SEQ_LEN here. + max_len = min(MAX_SEQ_LEN, 4096) + + seq_lens = [random.randint(1, max_len // 2) for i in range(num_seqs)] + max_seq_len = max(seq_lens) + context_lens = seq_lens + max_context_len = max(context_lens) + + num_tokens = sum(seq_lens) + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + + cu_context_lens = [0] + for context_len in context_lens: + cu_context_lens.append(cu_context_lens[-1] + context_len) + + print(f"cu_seq_lens={cu_seq_lens}, cu_context_lens={cu_context_lens}") + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + num_queries_per_kv = num_query_heads // num_kv_heads + + value_cache = torch.empty(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + key_cache = torch.empty(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + query = torch.empty(max_seq_len * num_seqs, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + value_cache.uniform_(-scale, scale) + key_cache.uniform_(-scale, scale) + query.uniform_(-scale, scale) + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + output = torch.empty_like(query) + + if version == "flash": + # flash_multi_query_cached_kv_attention_varlen( + # output, + # query, + # key_cache, + # value_cache, + # scale, + # block_tables, + # torch.cuda.IntTensor(cu_seq_lens), + # torch.cuda.IntTensor(cu_context_lens), + # block_size, + # max_seq_len, + # max_context_len, + # None, + # ) + from flash_attn import flash_attn_func + breakpoint() + # output = flash_attn_func( + # query.unsqueeze(0), + # k=key, + # v=value, + # softmax_scale=scale, + # causal=True, + # alibi_slopes=alibi_slopes, + # ) + output = flash_attn_with_kvcache_paged( + query.view(num_seqs, max_seq_len, num_query_heads, head_size), + key_cache, + value_cache, + scale, + block_tables, + torch.tensor(context_lens, dtype=torch.int, device="cuda"), + None, + ) + else: + assert False, f"{version=} is not supported" + + ref_output = ref_multi_query_kv_attention_padded( + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + cu_seq_lens, + context_lens, + scale, + dtype, + ) + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e1932bb761ef..93adc98adda5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -146,7 +146,7 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - print("SANG-TODO generate: ", prompts, prompt_token_ids) + # print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index f80a47565e95..db5e5f44ad42 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -154,76 +154,90 @@ def forward( if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # or not input_metadata.prefix_enabled): - # print("SANG-TODO flash attn is used.") - # print( - # "SANG-TODO query size: ", - # query.view(batch_size, seq_len, self.num_heads, - # self.head_size).size()) - if self.num_kv_heads != self.num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # TODO(woosuk): Use MQA/GQA kernels for higher performance. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - - # Set attention bias if not provided. This typically happens at - # the very attention layer of every iteration. - # FIXME(woosuk): This is a hack. - if input_metadata.attn_bias is None: + # if key_cache is not None and value_cache is not None and input_metadata.flash_style: + if False: + print("SANG-TODO flash attention.") + output = flash_attn_with_kvcache_paged( + query.view(batch_size, seq_len, self.num_heads, + self.head_size), + key_cache, + value_cache, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + self.alibi_slopes, + ) + else: + # print("SANG-TODO flash attn is used.") + # print( + # "SANG-TODO query size: ", + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size).size()) + if self.num_kv_heads != self.num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # TODO(woosuk): Use MQA/GQA kernels for higher performance. + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], + self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + if input_metadata.attn_bias is None: + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * batch_size) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = attn_bias + else: + input_metadata.attn_bias = _make_alibi_bias( + self.alibi_slopes, self.num_kv_heads, batch_size, + seq_len, query.dtype) + + if self.use_ref_attention: + output = self.ref_masked_attention( + query, + key, + value, + ) + # Using view got RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use reshape instead + return output.reshape(batch_size, seq_len, hidden_size) + + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. if self.alibi_slopes is None: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size) - if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention( - self.sliding_window) - input_metadata.attn_bias = attn_bias + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) else: - input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, batch_size, - seq_len, query.dtype) - - if self.use_ref_attention: - output = self.ref_masked_attention( + query = query.unflatten(0, (batch_size, seq_len)) + key = key.unflatten(0, (batch_size, seq_len)) + value = value.unflatten(0, (batch_size, seq_len)) + out = xops.memory_efficient_attention_forward( query, key, value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, ) - # Using view got RuntimeError: view size is not compatible with input tensor's size and stride - # (at least one dimension spans across two contiguous subspaces). Use reshape instead - return output.reshape(batch_size, seq_len, hidden_size) - - # TODO(woosuk): Too many view operations. Let's try to reduce - # them in the future for code readability. - if self.alibi_slopes is None: - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - else: - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) - out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ) - output = out.view_as(query) + output = out.view_as(query) # if key_cache is not None and value_cache is not None: # breakpoint() # output2 = flash_attn_with_kvcache_paged( @@ -239,6 +253,7 @@ def forward( # breakpoint() else: if input_metadata.flash_style: + assert False output = flash_attn_with_kvcache_paged( query.view(batch_size, seq_len, self.num_heads, self.head_size), @@ -250,8 +265,7 @@ def forward( self.alibi_slopes, ) else: - print("SANG-TODO context attention") - breakpoint() + # print("SANG-TODO context attention") # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( @@ -261,8 +275,7 @@ def forward( output, key_cache, value_cache, - input_metadata. - block_tables, # [BS, max_block_per_request] + input_metadata.block_tables, # [BS, max_block_per_request] input_metadata.start_loc, input_metadata.prompt_lens, input_metadata.context_lens, @@ -273,7 +286,6 @@ def forward( else: # Decoding run. # breakpoint() - print("SANG-TODO decoding") if input_metadata.flash_style: output = flash_attn_with_kvcache_paged( query.view(batch_size, seq_len, self.num_heads, diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py index 70f09224f1cf..c6054de2b718 100644 --- a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -696,7 +696,7 @@ def context_attention_fwd(q, ) return - _fwd_kernel[grid]( + _fwd_kernel_flash_attn_v2[grid]( q, k, v, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 02975bcd5326..d8a1c67bb123 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -117,7 +117,7 @@ def profile_num_available_blocks( gpu_memory_utilization: The fraction of the total GPU memory to use. cpu_swap_space: The size of the CPU swap space in bytes. """ - print("SANG-TODO profile_num_available_blocks") + # print("SANG-TODO profile_num_available_blocks") # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -154,7 +154,7 @@ def profile_num_available_blocks( MAX_INT_32 // cache_block_size) num_cpu_blocks = min(num_cpu_blocks, MAX_INT_32 // cache_block_size) - print("SANG-TODO profile_num_available_blocks done") + # print("SANG-TODO profile_num_available_blocks done") return num_gpu_blocks, num_cpu_blocks From 2487bda216be434087c4e7a8a4288461c895867a Mon Sep 17 00:00:00 2001 From: sang Date: Sun, 3 Mar 2024 12:05:26 -0800 Subject: [PATCH 23/88] ip --- tests/chunked_prefill/test_correctness.py | 35 +++++++++++------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index c33f7ff9f46c..a88169286043 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -73,29 +73,28 @@ def test_models( max_chunked_prefill_len=max_chunked_prefill_len, max_num_prompt_seqs=max_num_prompt_seqs, tensor_parallel_size=tensor_parallel_size) - flash_attn_output_by_batches = [] - for i in range(10): - prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] - flash_attn_output_by_batches.append( - flash_attn_model.generate_greedy(prompts, max_tokens)) + expected_outputs.extend(flash_attn_model.generate_greedy(TEST_PROMPTS, max_tokens)) + # flash_attn_output_by_batches = [] + # for i in range(10): + # prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] + # flash_attn_output_by_batches.append( + # flash_attn_model.generate_greedy(prompts, max_tokens)) del flash_attn_model - # for e, f in zip(expected_outputs, flash_attn_output_by_batches): - # # print("expected: ", e[1]) - # # print("flash: ", f[1]) - # assert e[1] == f[1] + for e, f in zip(expected_outputs, flash_attn_output_by_batches): + assert e[1] == f[1] destroy_model_parallel() gc.collect() torch.cuda.empty_cache() - for flash_attn_outputs in flash_attn_output_by_batches: - for i in range(len(flash_attn_outputs)): - fa_output_ids, fa_output_str = flash_attn_outputs[i] - vllm_output_ids, vllm_output_str = expected_outputs[ - i % len(expected_outputs)] - assert fa_output_ids == vllm_output_ids, ( - f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" - f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" - ) + # for flash_attn_outputs in flash_attn_output_by_batches: + # for i in range(len(flash_attn_outputs)): + # fa_output_ids, fa_output_str = flash_attn_outputs[i] + # vllm_output_ids, vllm_output_str = expected_outputs[ + # i % len(expected_outputs)] + # assert fa_output_ids == vllm_output_ids, ( + # f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" + # f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" + # ) From 81151e8e978844eade9c69643a928331b9dc25a2 Mon Sep 17 00:00:00 2001 From: sang Date: Sun, 3 Mar 2024 12:24:01 -0800 Subject: [PATCH 24/88] ip --- tests/chunked_prefill/test_correctness.py | 53 ++++++++++++++--------- vllm/config.py | 8 ++-- vllm/worker/model_runner.py | 19 ++++---- 3 files changed, 47 insertions(+), 33 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index a88169286043..3159f5c3a9d4 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -33,8 +33,9 @@ # SANG-TODO # @pytest.mark.parametrize("max_num_prompt_seqs", [1, 2, 100]) @pytest.mark.parametrize("max_num_prompt_seqs", [1]) -@pytest.mark.parametrize("tensor_parallel_size", [2]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) def test_models( + hf_runner, vllm_runner, model: str, dtype: str, @@ -51,6 +52,10 @@ def test_models( f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}" ) + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(TEST_PROMPTS, max_tokens) + del hf_model + print("loading page attention models..") pg_model = vllm_runner(model, dtype=dtype) expected_outputs = [] @@ -61,33 +66,41 @@ def test_models( del pg_model + for i in range(len(TEST_PROMPTS)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = expected_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + destroy_model_parallel() gc.collect() torch.cuda.empty_cache() - flash_attn_model = vllm_runner( - model, - dtype=dtype, - # block_size=block_size, - # flash_style=True, - max_chunked_prefill_len=max_chunked_prefill_len, - max_num_prompt_seqs=max_num_prompt_seqs, - tensor_parallel_size=tensor_parallel_size) - expected_outputs.extend(flash_attn_model.generate_greedy(TEST_PROMPTS, max_tokens)) # flash_attn_output_by_batches = [] - # for i in range(10): - # prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] - # flash_attn_output_by_batches.append( - # flash_attn_model.generate_greedy(prompts, max_tokens)) + # flash_attn_model = vllm_runner( + # model, + # dtype=dtype, + # # block_size=block_size, + # # flash_style=True, + # max_chunked_prefill_len=max_chunked_prefill_len, + # max_num_prompt_seqs=max_num_prompt_seqs, + # tensor_parallel_size=tensor_parallel_size) + # flash_attn_output_by_batches.extend(flash_attn_model.generate_greedy(TEST_PROMPTS, max_tokens)) + # # for i in range(10): + # # prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] + # # flash_attn_output_by_batches.append( + # # flash_attn_model.generate_greedy(prompts, max_tokens)) - del flash_attn_model + # del flash_attn_model - for e, f in zip(expected_outputs, flash_attn_output_by_batches): - assert e[1] == f[1] + # for e, f in zip(expected_outputs, flash_attn_output_by_batches): + # assert e[1] == f[1] - destroy_model_parallel() - gc.collect() - torch.cuda.empty_cache() + # destroy_model_parallel() + # gc.collect() + # torch.cuda.empty_cache() # for flash_attn_outputs in flash_attn_output_by_batches: # for i in range(len(flash_attn_outputs)): diff --git a/vllm/config.py b/vllm/config.py index 09053cbe0005..ce88ffc85d23 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -512,10 +512,10 @@ def _verify_args(self) -> None: f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " f"({self.max_num_seqs}).") - if self.chunked_prefill_enabled and not self.flash_style: - # SANG-TODO It is probably fixable. - raise ValueError( - "chunked prefill is only supported for flash style") + # if self.chunked_prefill_enabled and not self.flash_style: + # # SANG-TODO It is probably fixable. + # raise ValueError( + # "chunked prefill is only supported for flash style") class DeviceConfig: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b0148c928930..24e6ca82bbd8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -152,7 +152,7 @@ def _prepare_prompt( if seq_group_metadata.is_chunked_prefill: num_chunked_prefill += 1 # TODO(sang): Support it. - if prefix is not None: + if computed_block_nums is not None: raise RuntimeError( "chunked prefill cannot be used with prefix caching now." ) @@ -176,11 +176,12 @@ def _prepare_prompt( context_len = computed_len prefix_enabled = True else: - if seq_group_metadata.block_tables is None: - prefix_block_tables.append([]) - else: - prefix_block_tables.append( - seq_group_metadata.block_tables[seq_id]) + prefix_block_tables.append([]) + # if seq_group_metadata.block_tables is None: + # prefix_block_tables.append([]) + # else: + # prefix_block_tables.append( + # seq_group_metadata.block_tables[seq_id]) context_len = prompt_len context_lens.append(context_len) @@ -227,11 +228,11 @@ def _prepare_prompt( "sliding window attention") start_idx = max(0, prompt_len - self.sliding_window) - # If chunked prefill is enabled, prefix_len is always 0. + # If chunked prefill is enabled, computed_len is always 0. # TODO(sang) This is hack. We should clean it up when # supporting prefix cache + chunked prefill. - if prefix_len == 0: - prefix_len = prefill_start + if computed_len == 0: + computed_len = prefill_start for i in range(computed_len, prefill_end): if i < start_idx: From 31aa92052d43ce31aa93faa545db01a639c206c4 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 4 Mar 2024 00:41:48 -0800 Subject: [PATCH 25/88] ip --- tests/chunked_prefill/test_correctness.py | 49 +- tests/kernels/test_flash_attention.py | 779 ++++++++++++++++++++-- tests/models/test_models.py | 28 +- vllm/config.py | 4 +- vllm/core/block_manager.py | 4 +- vllm/engine/llm_engine.py | 4 +- vllm/entrypoints/llm.py | 2 +- vllm/model_executor/layers/attention.py | 447 +++++++++---- vllm/worker/model_runner.py | 1 + 9 files changed, 1085 insertions(+), 233 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index d6cb0a4dc0ac..efd1d5610f2c 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -28,49 +28,51 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("block_size", [256]) +@pytest.mark.parametrize("block_size", [32]) +@pytest.mark.parametrize("max_num_prompt_seqs", [1]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) def test_models( vllm_runner, model: str, dtype: str, max_tokens: int, block_size: int, + max_num_prompt_seqs: int, + tensor_parallel_size: int, ) -> None: """ verify the flash attention has the same output as page attention """ - print("loading page attention models..") - pg_model = vllm_runner(model, dtype=dtype) - expected_outputs = [] + if torch.cuda.device_count() < tensor_parallel_size: + pytest.skip( + f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}" + ) + # print("loading page attention models..") + # pg_model = vllm_runner(model, dtype=dtype) + # expected_outputs = [] - print("generating tokens...") - expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) - print("generating tokens finished") + # print("generating tokens...") + # expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) + # print("generating tokens finished") - del pg_model + # del pg_model - destroy_model_parallel() - gc.collect() - torch.cuda.empty_cache() + # destroy_model_parallel() + # gc.collect() + # torch.cuda.empty_cache() - flash_attn_model = vllm_runner(model, - dtype=dtype) - # flash_style=True, - # block_size=block_size) - flash_attn_output_by_batches = [] - # flash_attn_output_by_batches.extend(flash_attn_model.generate_greedy(TEST_PROMPTS, max_tokens)) flash_attn_output_by_batches = [] + flash_attn_model = vllm_runner( + model, + dtype=dtype, + block_size=block_size, + flash_style=True, + tensor_parallel_size=tensor_parallel_size) for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] flash_attn_output_by_batches.append( flash_attn_model.generate_greedy(prompts, max_tokens)) del flash_attn_model - - # for e, f in zip(expected_outputs, flash_attn_output_by_batches): - # # print("expected: ", e[1]) - # # print("flash: ", f[1]) - # assert e[1] == f[1] - destroy_model_parallel() gc.collect() torch.cuda.empty_cache() @@ -80,6 +82,7 @@ def test_models( fa_output_ids, fa_output_str = flash_attn_outputs[i] vllm_output_ids, vllm_output_str = expected_outputs[ i % len(expected_outputs)] + print(vllm_output_str) assert fa_output_ids == vllm_output_ids, ( f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index 2938dff7cac8..52f9ad3ccc9f 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -1,12 +1,505 @@ +# import random +# from typing import Optional, Tuple, List + +# import pytest +# import torch +# import torch.nn.functional as F + +# from vllm.model_executor.layers.attention import ( +# flash_attn_with_kvcache_paged, ) +# from vllm.utils import get_max_shared_memory_bytes + +# FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# # This will change depending on the compute capability. +# # - 512 as a buffer +# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# NUM_BLOCKS = 128 # Arbitrary values for testing +# PARTITION_SIZE = 512 + +# DTYPES = [torch.half, torch.bfloat16] +# NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing +# NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing +# NUM_HEADS = [(1, 40), (40, 40), (64, 8)] # Arbitrary values for testing +# NUM_HEADS_SMALL = NUM_HEADS +# # head size should be bigger than or equal to block size. +# HEAD_SIZES = [256] +# # TODO(sang): https://github.com/Dao-AILab/flash-attention/pull/824 +# # should fix the block size. But right now, the block size should be +# # divisible by 256. +# BLOCK_SIZES = [256] +# USE_ALIBI = [False, True] +# SEEDS = [0] +# PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] + + +# def pad_attention_inputs( +# pad_config: Tuple[int, int], +# block_size: int, +# query: torch.Tensor, +# block_tables: torch.Tensor, +# context_lens: torch.Tensor, +# max_context_len: int, +# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: +# """Pad the attention inputs to the specified batch size and context length. +# """ +# pad_batch_size, pad_max_context_len = pad_config +# if pad_batch_size == 0: +# return query, block_tables, context_lens, max_context_len +# target_batch_size = ( +# (query.shape[0] - 1) % pad_batch_size + 1) * pad_batch_size +# target_block_size = pad_max_context_len // block_size + 1 +# padded_query = F.pad(query, +# (0, 0, 0, 0, 0, target_batch_size - query.shape[0])) +# padded_block_table = F.pad(block_tables, +# (0, target_block_size - block_tables.shape[1], +# 0, target_batch_size - block_tables.shape[0])) +# padded_context_lens = F.pad(context_lens, +# (0, target_batch_size - context_lens.shape[0])) +# return padded_query, padded_block_table, padded_context_lens, pad_max_context_len + + +# def ref_masked_attention( +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# scale: float, +# attn_mask: Optional[torch.Tensor] = None, +# ) -> torch.Tensor: +# attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() +# if attn_mask is not None: +# attn_weights = attn_weights + attn_mask.float() +# attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) +# out = torch.einsum("hqk,khd->qhd", attn_weights, value) +# return out + + +# def ref_single_query_cached_kv_attention( +# output: torch.Tensor, +# query: torch.Tensor, +# num_queries_per_kv: int, +# key_cache: torch.Tensor, +# value_cache: torch.Tensor, +# block_tables: torch.Tensor, +# context_lens: torch.Tensor, +# scale: float, +# alibi_slopes: Optional[torch.Tensor], +# ) -> None: +# num_query_heads = query.shape[1] +# head_size = value_cache.shape[-1] +# block_size = value_cache.shape[-3] +# num_seqs = query.shape[0] + +# block_tables = block_tables.cpu().tolist() +# context_lens = context_lens.cpu().tolist() +# for i in range(num_seqs): +# q = query[i].unsqueeze(0) +# block_table = block_tables[i] +# context_len = int(context_lens[i]) + +# keys = [] +# values = [] +# for j in range(context_len): +# block_number = int(block_table[j // block_size]) +# block_offset = j % block_size + +# k = key_cache[block_number, block_offset, :, :] +# keys.append(k) + +# v = value_cache[block_number, :, :, block_offset] +# v = value_cache[block_number, block_offset, :, :] +# values.append(v) +# keys = torch.stack(keys, dim=0) +# values = torch.stack(values, dim=0) +# if num_queries_per_kv > 1: +# # Handle MQA and GQA +# keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) +# values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + +# alibi_bias = None +# if alibi_slopes is not None: +# # Create the ALiBi bias used in the paged attention kernel. +# position_ids = torch.arange(context_len, device="cuda").int() +# alibi_bias = (position_ids - context_len + 1).float() +# alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( +# 1, 1, -1) + +# out = ref_masked_attention(q, keys, values, scale, alibi_bias) +# out = out.view(num_query_heads, head_size) +# # output[i].copy_(out, non_blocking=True) +# output[i].copy_(out) + + +# # @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +# # @pytest.mark.parametrize("num_heads", NUM_HEADS) +# # @pytest.mark.parametrize("head_size", HEAD_SIZES) +# # @pytest.mark.parametrize("use_alibi", [False, True]) +# # @pytest.mark.parametrize("block_size", BLOCK_SIZES) +# # @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +# # @pytest.mark.parametrize("seed", SEEDS) +# # @pytest.mark.parametrize("pad_config", PAD_CONFIGS) +# @pytest.mark.parametrize("num_seqs", [3]) +# @pytest.mark.parametrize("num_heads", [(40, 40)]) +# @pytest.mark.parametrize("head_size", [256]) +# @pytest.mark.parametrize("use_alibi", [True]) +# @pytest.mark.parametrize("block_size", [256]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("seed", SEEDS) +# @pytest.mark.parametrize("pad_config", [(0, 0)]) +# @torch.inference_mode() +# def test_flash_paged_attention( +# kv_cache_factory, +# num_seqs: int, +# num_heads: Tuple[int, int], +# head_size: int, +# use_alibi: bool, +# block_size: int, +# dtype: torch.dtype, +# seed: int, +# pad_config: Tuple[int, int], +# ) -> None: +# random.seed(seed) +# torch.random.manual_seed(seed) +# torch.cuda.manual_seed(seed) + +# scale = float(1.0 / (head_size**0.5)) +# num_query_heads, num_kv_heads = num_heads +# query = torch.empty(num_seqs, +# num_query_heads, +# head_size, +# dtype=dtype, +# device="cuda") +# query.uniform_(-scale, scale) + +# # assert num_query_heads % num_kv_heads == 0 +# num_queries_per_kv = num_query_heads // num_kv_heads +# alibi_slopes = None +# if use_alibi: +# alibi_slopes = torch.randn(num_query_heads, +# dtype=torch.float, +# device="cuda") + +# max_seq_len = MAX_SEQ_LEN if not pad_config[0] else (pad_config[1] - 1000) +# context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] +# context_lens[-1] = max_seq_len +# max_context_len = max(context_lens) +# context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + +# # Create the block tables. +# max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size +# block_tables = [] +# for _ in range(num_seqs): +# block_table = [ +# random.randint(0, NUM_BLOCKS - 1) +# for _ in range(max_num_blocks_per_seq) +# ] +# block_tables.append(block_table) +# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + +# # Create the KV caches. +# key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, +# block_size, +# 1, +# num_kv_heads, +# head_size, +# dtype, +# seed, +# flash_style=True) +# key_cache, value_cache = key_caches[0], value_caches[0] + +# # Call the paged attention kernel. +# output = torch.empty_like(query) + +# padded_query, padded_block_table, padded_context_lens, _ = \ +# pad_attention_inputs(pad_config, block_size, query, +# block_tables, context_lens, max_context_len) + +# output = flash_attn_with_kvcache_paged( +# padded_query.view(num_seqs, 1, num_query_heads, head_size), +# key_cache, +# value_cache, +# scale, +# padded_block_table, +# padded_context_lens, +# alibi_slopes, +# ) + +# # Run the reference implementation. +# ref_output = torch.empty_like(query) +# ref_single_query_cached_kv_attention( +# ref_output, +# query, +# num_queries_per_kv, +# key_cache, +# value_cache, +# block_tables, +# context_lens, +# scale, +# alibi_slopes, +# ) + +# assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +# def ref_multi_query_kv_attention( +# cu_seq_lens: List[int], +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# scale: float, +# dtype: torch.dtype, +# ) -> torch.Tensor: +# num_seqs = len(cu_seq_lens) - 1 +# ref_outputs = [] +# for i in range(num_seqs): +# start_idx = cu_seq_lens[i] +# end_idx = cu_seq_lens[i + 1] +# seq_len = end_idx - start_idx + +# # Create attention mask. +# attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), +# diagonal=1) +# attn_mask = attn_mask * torch.finfo(dtype).min +# attn_mask = attn_mask.to(dtype=dtype, device="cuda") + +# ref_output = ref_masked_attention( +# query[start_idx:end_idx], +# key[start_idx:end_idx], +# value[start_idx:end_idx], +# scale, +# attn_mask=attn_mask, +# ) +# ref_outputs.append(ref_output) +# ref_output = torch.cat(ref_outputs, dim=0) +# return ref_output + + +# def ref_multi_query_kv_attention_padded( +# query: torch.Tensor, +# num_queries_per_kv: int, +# key_cache: torch.Tensor, +# value_cache: torch.Tensor, +# block_tables: torch.Tensor, +# cu_seq_lens: List[int], +# context_lens: List[int], +# scale: float, +# dtype: torch.dtype, +# ) -> torch.Tensor: +# num_seqs = len(cu_seq_lens) - 1 +# block_size = value_cache.shape[-3] +# ref_outputs = [] + +# for i in range(num_seqs): +# q_start_idx = cu_seq_lens[i] +# q_end_idx = cu_seq_lens[i + 1] +# seq_len = q_end_idx - q_start_idx + +# context_len = context_lens[i] + +# block_table = block_tables[i] +# keys = [] +# values = [] + +# for j in range(context_len): +# block_number = int(block_table[j // block_size]) +# block_offset = j % block_size + +# k = key_cache[block_number, block_offset, :, :] +# keys.append(k) + +# v = value_cache[block_number, block_offset, :, :] +# values.append(v) + +# keys = torch.stack(keys, dim=0) +# values = torch.stack(values, dim=0) + +# if num_queries_per_kv > 1: +# # Handle MQA and GQA +# keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) +# values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + +# q = query[q_start_idx:q_end_idx, :, :] +# k = keys[:context_len, :, :] +# v = values[:context_len, :, :] + +# assert seq_len <= context_len + +# # pad q if seq_len is less than context_len +# # this is for correct calculation of attention. +# if seq_len < context_len: +# indices = [i % seq_len for i in range(context_len - seq_len)] +# q_left_pad = q[indices, :, :] +# q = torch.cat([q_left_pad, q], dim=0) + +# # Create attention mask. +# attn_mask = torch.triu(torch.ones(context_len, +# context_len, +# dtype=dtype), +# diagonal=1) +# attn_mask = attn_mask * torch.finfo(dtype).min +# attn_mask = attn_mask.to(dtype=dtype, device="cuda") +# ref_output = ref_masked_attention( +# q, +# k, +# v, +# scale, +# attn_mask=attn_mask, +# ) +# ref_outputs.append(ref_output[-seq_len:, :, :]) +# breakpoint +# ref_output = torch.cat(ref_outputs, dim=0) +# return ref_output + + +# def is_a100(): +# return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 + + +# if not is_a100(): +# NUM_HEADS_SMALL = [(16, 16), (16, 8)] +# MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) + + +# @pytest.mark.parametrize("num_seqs", [17]) +# @pytest.mark.parametrize("num_heads", [(40, 40)]) +# @pytest.mark.parametrize("head_size", HEAD_SIZES) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("version", ["flash"]) +# @pytest.mark.parametrize("seed", SEEDS) +# @pytest.mark.parametrize("block_size", [256]) +# @torch.inference_mode() +# def test_multi_query_kv_attention( +# num_seqs: int, +# num_heads: Tuple[int, int], +# head_size: int, +# dtype: torch.dtype, +# version: str, +# seed: int, +# block_size: int, +# ) -> None: +# random.seed(seed) +# torch.random.manual_seed(seed) +# torch.cuda.manual_seed(seed) + +# # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. +# # As the xformers library is already tested with its own tests, we can use +# # a smaller MAX_SEQ_LEN here. +# max_len = min(MAX_SEQ_LEN, 4096) + +# seq_lens = [random.randint(1, max_len // 2) for i in range(num_seqs)] +# max_seq_len = max(seq_lens) +# context_lens = seq_lens +# max_context_len = max(context_lens) + +# num_tokens = sum(seq_lens) +# cu_seq_lens = [0] +# for seq_len in seq_lens: +# cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + +# cu_context_lens = [0] +# for context_len in context_lens: +# cu_context_lens.append(cu_context_lens[-1] + context_len) + +# print(f"cu_seq_lens={cu_seq_lens}, cu_context_lens={cu_context_lens}") + +# scale = float(1.0 / (head_size**0.5)) +# num_query_heads, num_kv_heads = num_heads +# num_queries_per_kv = num_query_heads // num_kv_heads + +# value_cache = torch.empty(NUM_BLOCKS, +# block_size, +# num_kv_heads, +# head_size, +# dtype=dtype, +# device="cuda") +# key_cache = torch.empty(NUM_BLOCKS, +# block_size, +# num_kv_heads, +# head_size, +# dtype=dtype, +# device="cuda") +# query = torch.empty(max_seq_len * num_seqs, +# num_query_heads, +# head_size, +# dtype=dtype, +# device="cuda") +# value_cache.uniform_(-scale, scale) +# key_cache.uniform_(-scale, scale) +# query.uniform_(-scale, scale) + +# # Create the block tables. +# max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size +# block_tables = [] +# for _ in range(num_seqs): +# block_table = [ +# random.randint(0, NUM_BLOCKS - 1) +# for _ in range(max_num_blocks_per_seq) +# ] +# block_tables.append(block_table) +# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + +# output = torch.empty_like(query) + +# if version == "flash": +# # flash_multi_query_cached_kv_attention_varlen( +# # output, +# # query, +# # key_cache, +# # value_cache, +# # scale, +# # block_tables, +# # torch.cuda.IntTensor(cu_seq_lens), +# # torch.cuda.IntTensor(cu_context_lens), +# # block_size, +# # max_seq_len, +# # max_context_len, +# # None, +# # ) +# from flash_attn import flash_attn_func +# breakpoint() +# # output = flash_attn_func( +# # query.unsqueeze(0), +# # k=key, +# # v=value, +# # softmax_scale=scale, +# # causal=True, +# # alibi_slopes=alibi_slopes, +# # ) +# output = flash_attn_with_kvcache_paged( +# query.view(num_seqs, max_seq_len, num_query_heads, head_size), +# key_cache, +# value_cache, +# scale, +# block_tables, +# torch.tensor(context_lens, dtype=torch.int, device="cuda"), +# None, +# ) +# else: +# assert False, f"{version=} is not supported" + +# ref_output = ref_multi_query_kv_attention_padded( +# query, +# num_queries_per_kv, +# key_cache, +# value_cache, +# block_tables, +# cu_seq_lens, +# context_lens, +# scale, +# dtype, +# ) +# assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + import random -from typing import Optional, Tuple, List +from typing import List, Optional, Tuple import pytest import torch import torch.nn.functional as F from vllm.model_executor.layers.attention import ( - flash_attn_with_kvcache_paged, ) + flash_single_query_cached_kv_attention, + flash_multi_query_cached_kv_attention_varlen, +) from vllm.utils import get_max_shared_memory_bytes FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 @@ -19,14 +512,10 @@ DTYPES = [torch.half, torch.bfloat16] NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing -NUM_HEADS = [(1, 40), (40, 40), (64, 8)] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS_SMALL = NUM_HEADS -# head size should be bigger than or equal to block size. -HEAD_SIZES = [256] -# TODO(sang): https://github.com/Dao-AILab/flash-attention/pull/824 -# should fix the block size. But right now, the block size should be -# divisible by 256. -BLOCK_SIZES = [256] +HEAD_SIZES = [32, 64, 128, 256] +BLOCK_SIZES = [32, 64, 256] USE_ALIBI = [False, True] SEEDS = [0] PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] @@ -83,8 +572,10 @@ def ref_single_query_cached_kv_attention( context_lens: torch.Tensor, scale: float, alibi_slopes: Optional[torch.Tensor], + flash_style: bool = False, ) -> None: num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[-2] head_size = value_cache.shape[-1] block_size = value_cache.shape[-3] num_seqs = query.shape[0] @@ -102,11 +593,17 @@ def ref_single_query_cached_kv_attention( block_number = int(block_table[j // block_size]) block_offset = j % block_size - k = key_cache[block_number, block_offset, :, :] + if flash_style: + k = key_cache[block_number, block_offset, :, :] + else: + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) keys.append(k) - v = value_cache[block_number, :, :, block_offset] - v = value_cache[block_number, block_offset, :, :] + if flash_style: + v = value_cache[block_number, block_offset, :, :] + else: + v = value_cache[block_number, :, :, block_offset] values.append(v) keys = torch.stack(keys, dim=0) values = torch.stack(values, dim=0) @@ -125,24 +622,15 @@ def ref_single_query_cached_kv_attention( out = ref_masked_attention(q, keys, values, scale, alibi_bias) out = out.view(num_query_heads, head_size) - # output[i].copy_(out, non_blocking=True) - output[i].copy_(out) + output[i].copy_(out, non_blocking=True) -# @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -# @pytest.mark.parametrize("num_heads", NUM_HEADS) -# @pytest.mark.parametrize("head_size", HEAD_SIZES) -# @pytest.mark.parametrize("use_alibi", [False, True]) -# @pytest.mark.parametrize("block_size", BLOCK_SIZES) -# @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) -# @pytest.mark.parametrize("seed", SEEDS) -# @pytest.mark.parametrize("pad_config", PAD_CONFIGS) -@pytest.mark.parametrize("num_seqs", [3]) -@pytest.mark.parametrize("num_heads", [(40, 40)]) -@pytest.mark.parametrize("head_size", [256]) -@pytest.mark.parametrize("use_alibi", [True]) -@pytest.mark.parametrize("block_size", [256]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", [False]) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("pad_config", [(0, 0)]) @torch.inference_mode() @@ -170,8 +658,11 @@ def test_flash_paged_attention( device="cuda") query.uniform_(-scale, scale) - # assert num_query_heads % num_kv_heads == 0 + assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) alibi_slopes = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, @@ -207,14 +698,16 @@ def test_flash_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. + num_valid_tokens = torch.cuda.IntTensor([num_seqs]) output = torch.empty_like(query) - padded_query, padded_block_table, padded_context_lens, _ = \ + padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ pad_attention_inputs(pad_config, block_size, query, block_tables, context_lens, max_context_len) - output = flash_attn_with_kvcache_paged( - padded_query.view(num_seqs, 1, num_query_heads, head_size), + flash_single_query_cached_kv_attention( + output, + padded_query, key_cache, value_cache, scale, @@ -235,8 +728,12 @@ def test_flash_paged_attention( context_lens, scale, alibi_slopes, + flash_style=True, ) + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) @@ -337,6 +834,7 @@ def ref_multi_query_kv_attention_padded( diagonal=1) attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask.to(dtype=dtype, device="cuda") + ref_output = ref_masked_attention( q, k, @@ -345,7 +843,6 @@ def ref_multi_query_kv_attention_padded( attn_mask=attn_mask, ) ref_outputs.append(ref_output[-seq_len:, :, :]) - breakpoint ref_output = torch.cat(ref_outputs, dim=0) return ref_output @@ -359,13 +856,14 @@ def is_a100(): MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) -@pytest.mark.parametrize("num_seqs", [17]) -@pytest.mark.parametrize("num_heads", [(40, 40)]) +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS_SMALL) @pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("version", ["flash"]) +@pytest.mark.parametrize("chunked_prefill", [False, True]) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("block_size", [256]) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, @@ -374,6 +872,7 @@ def test_multi_query_kv_attention( dtype: torch.dtype, version: str, seed: int, + chunked_prefill: bool, block_size: int, ) -> None: random.seed(seed) @@ -387,8 +886,18 @@ def test_multi_query_kv_attention( seq_lens = [random.randint(1, max_len // 2) for i in range(num_seqs)] max_seq_len = max(seq_lens) - context_lens = seq_lens + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device="cuda") + + if chunked_prefill: + # context length will be different from seq_len if chunked_prefill is + # true. + context_lens = random.sample(range(max_seq_len, max_len), num_seqs) + else: + context_lens = seq_lens max_context_len = max(context_lens) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device="cuda") num_tokens = sum(seq_lens) cu_seq_lens = [0] @@ -417,7 +926,7 @@ def test_multi_query_kv_attention( head_size, dtype=dtype, device="cuda") - query = torch.empty(max_seq_len * num_seqs, + query = torch.empty(num_tokens, num_query_heads, head_size, dtype=dtype, @@ -440,37 +949,18 @@ def test_multi_query_kv_attention( output = torch.empty_like(query) if version == "flash": - # flash_multi_query_cached_kv_attention_varlen( - # output, - # query, - # key_cache, - # value_cache, - # scale, - # block_tables, - # torch.cuda.IntTensor(cu_seq_lens), - # torch.cuda.IntTensor(cu_context_lens), - # block_size, - # max_seq_len, - # max_context_len, - # None, - # ) - from flash_attn import flash_attn_func - breakpoint() - # output = flash_attn_func( - # query.unsqueeze(0), - # k=key, - # v=value, - # softmax_scale=scale, - # causal=True, - # alibi_slopes=alibi_slopes, - # ) - output = flash_attn_with_kvcache_paged( - query.view(num_seqs, max_seq_len, num_query_heads, head_size), + flash_multi_query_cached_kv_attention_varlen( + output, + query, key_cache, value_cache, scale, block_tables, - torch.tensor(context_lens, dtype=torch.int, device="cuda"), + torch.cuda.IntTensor(cu_seq_lens), + torch.cuda.IntTensor(cu_context_lens), + block_size, + max_seq_len, + max_context_len, None, ) else: @@ -488,3 +978,160 @@ def test_multi_query_kv_attention( dtype, ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +@pytest.mark.parametrize("num_heads", [40]) +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("block_size", [32]) +@pytest.mark.parametrize("num_blocks", [128]) +@pytest.mark.parametrize("dtype", [torch.half]) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_e2e( + kv_cache_factory, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, +) -> None: + from vllm._C import cache_ops + batch_size = 2 + seqlen = 29 + num_tokens = batch_size * seqlen + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + block_tables = [] + for _ in range(batch_size): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + # Create a random slot mapping. + slot_mapping = [] + for i in range(0, 29): + block_number = int(block_tables[i // block_size]) + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + # for _ in range(23, 29): + # slot_mapping.append(-1) + for i in range(0, 29): + block_number = int(block_tables[1]) + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + # slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device='cuda') + query, key, value = qkv.unbind(dim=1) + # query[23:29] = 0 + # key[23:29] = 0 + # value[23:29] = 0 + # Create the KV caches. + + key_caches, value_caches = kv_cache_factory(num_blocks, + block_size, + 1, + num_heads, + head_size, + dtype, + seed, + flash_style=True) + assert len(key_caches) == 1 and len(value_caches) == 1 + key_cache, value_cache = key_caches[0], value_caches[0] + # Call the reshape_and_cache kernel. + cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping) + + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Run the reference implementation. + block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies[i] + block_offset = block_offsets[i] + print("block_idx", block_idx) + print("block_offset", block_offset) + if block_idx != -1: + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) + + for i in range(58): + print(i) + block_idx = block_indicies[i] + block_offset = block_offsets[i] + torch.allclose(key[i], key_cache[block_idx][block_offset]) + + from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + from xformers import ops as xops + seqlen = query.shape[1] + attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seqlen] * batch_size) + scale = float(1.0 / (head_size**0.5)) + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=None, + p=0.0, + scale=scale, + ) + + num_tokens, num_heads, head_size = query.shape + from flash_attn import flash_attn_func + output1 = flash_single_query_cached_kv_attention( + None, + query.view(batch_size, num_tokens // batch_size, num_heads, head_size), + key_cache, + value_cache, + scale, + block_tables, + torch.tensor([23, 29], dtype=torch.int, device='cuda'), + alibi_slopes=None, + ) + output2 = flash_attn_func( + # query.view(batch_size, num_tokens // batch_size, num_heads, head_size), + # key.view(batch_size, num_tokens // batch_size, num_heads, head_size), + # value.view(batch_size, num_tokens // batch_size, num_heads, head_size), + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + # key_cache, + # value_cache, + softmax_scale=scale, + # block_tables, + # torch.tensor([23, 29], dtype=torch.int, device='cuda'), + # alibi_slopes=None, + causal=True, + ) + output3 = flash_attn_func( + query.view(batch_size, num_tokens // batch_size, num_heads, head_size), + key.view(batch_size, num_tokens // batch_size, num_heads, head_size), + value.view(batch_size, num_tokens // batch_size, num_heads, head_size), + # key_cache, + # value_cache, + softmax_scale=scale, + # block_tables, + # torch.tensor([23, 29], dtype=torch.int, device='cuda'), + # alibi_slopes=None, + causal=True, + ) + + breakpoint() diff --git a/tests/models/test_models.py b/tests/models/test_models.py index fb567e837d28..b8ee03759284 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -6,20 +6,20 @@ MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", - "mistralai/Mistral-7B-v0.1", - "Deci/DeciLM-7b", - "tiiuae/falcon-7b", - "gpt2", - "bigcode/tiny_starcoder_py", - "EleutherAI/gpt-j-6b", - "EleutherAI/pythia-70m", - "bigscience/bloom-560m", - "mosaicml/mpt-7b", - "microsoft/phi-2", - "stabilityai/stablelm-3b-4e1t", - "allenai/OLMo-1B", - "bigcode/starcoder2-3b", + # "meta-llama/Llama-2-7b-hf", + # "mistralai/Mistral-7B-v0.1", + # "Deci/DeciLM-7b", + # "tiiuae/falcon-7b", + # "gpt2", + # "bigcode/tiny_starcoder_py", + # "EleutherAI/gpt-j-6b", + # "EleutherAI/pythia-70m", + # "bigscience/bloom-560m", + # "mosaicml/mpt-7b", + # "microsoft/phi-2", + # "stabilityai/stablelm-3b-4e1t", + # "allenai/OLMo-1B", + # "bigcode/starcoder2-3b", ] diff --git a/vllm/config.py b/vllm/config.py index 859a4ee734ad..bbafd4c06de3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -335,11 +335,11 @@ def _verify_args(self) -> None: if self.flash_style: logger.info("Flash attention enabled.") - if self.block_size < 256: + if self.block_size > 32: # Flash style attention only supports block size >=256 for now. # https://github.com/Dao-AILab/flash-attention/pull/824 will fix it. raise ValueError( - "Flash style attention only supports block size >= 256. Got" + "Flash style attention only supports block size <= 32. Got" f"{self.block_size }") def _verify_cache_dtype(self) -> None: diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index b5ae90e217ba..465704f9a1c3 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -200,8 +200,8 @@ def allocate(self, seq_group: SequenceGroup) -> None: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) - print("allcate a new block for seq_group", seq_group) - print("block", block) + # print("allcate a new block for seq_group", seq_group) + # print("block", block) block_table.append(block) # Assign the block table for each sequence. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3f1b5eca5d05..f59a5605dcd2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -823,8 +823,8 @@ def step(self) -> List[RequestOutput]: >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - # print("SANG-TODO step seq_group_metadata_list length: ", - # len(seq_group_metadata_list)) + print("SANG-TODO step seq_group_metadata_list length: ", + len(seq_group_metadata_list)) if not scheduler_outputs.is_empty(): # Execute the model. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 93adc98adda5..e1932bb761ef 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -146,7 +146,7 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - # print("SANG-TODO generate: ", prompts, prompt_token_ids) + print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index db5e5f44ad42..1a5fd49cf025 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -14,11 +14,18 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( context_attention_fwd) from vllm.utils import is_hip +from flash_attn import flash_attn_func +# try: +# from flash_attn import flash_attn_with_kvcache +# except ImportError: +# flash_attn_with_kvcache = None try: - from flash_attn import flash_attn_with_kvcache -except ImportError: - flash_attn_with_kvcache = None + from flash_attn import (flash_attn_with_page_attention, + flash_attn_varlen_with_page_attention) +except Exception as e: + flash_attn_with_page_attention = e + flash_attn_varlen_with_page_attention = e _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -152,120 +159,145 @@ def forward( if input_metadata.is_prompt: # normal attention if (key_cache is None or value_cache is None - or input_metadata.block_tables.numel() == 0): - # or not input_metadata.prefix_enabled): - # if key_cache is not None and value_cache is not None and input_metadata.flash_style: - if False: - print("SANG-TODO flash attention.") - output = flash_attn_with_kvcache_paged( - query.view(batch_size, seq_len, self.num_heads, - self.head_size), + # or input_metadata.block_tables.numel() == 0): + or not input_metadata.prefix_enabled): + # print("SANG-TODO flash attn is used.") + # print( + # "SANG-TODO query size: ", + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size).size()) + if self.num_kv_heads != self.num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # TODO(woosuk): Use MQA/GQA kernels for higher performance. + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], + self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + if input_metadata.attn_bias is None: + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * batch_size) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = attn_bias + else: + input_metadata.attn_bias = _make_alibi_bias( + self.alibi_slopes, self.num_kv_heads, batch_size, + seq_len, query.dtype) + + if self.use_ref_attention: + output = self.ref_masked_attention( + query, + key, + value, + ) + # Using view got RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use reshape instead + return output.reshape(batch_size, seq_len, hidden_size) + + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + if self.alibi_slopes is None: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + else: + query = query.unflatten(0, (batch_size, seq_len)) + key = key.unflatten(0, (batch_size, seq_len)) + value = value.unflatten(0, (batch_size, seq_len)) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + output = out.view_as(query) + if key_cache is not None and value_cache is not None: + output2 = flash_single_query_cached_kv_attention( + None, + query.view(batch_size, seq_len, self.num_heads, self.head_size), key_cache, value_cache, self.scale, input_metadata.block_tables, input_metadata.context_lens, - self.alibi_slopes, + alibi_slopes=self.alibi_slopes, ) - else: - # print("SANG-TODO flash attn is used.") - # print( - # "SANG-TODO query size: ", - # query.view(batch_size, seq_len, self.num_heads, - # self.head_size).size()) - if self.num_kv_heads != self.num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # TODO(woosuk): Use MQA/GQA kernels for higher performance. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - - # Set attention bias if not provided. This typically happens at - # the very attention layer of every iteration. - # FIXME(woosuk): This is a hack. - if input_metadata.attn_bias is None: - if self.alibi_slopes is None: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size) - if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention( - self.sliding_window) - input_metadata.attn_bias = attn_bias - else: - input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, batch_size, - seq_len, query.dtype) - - if self.use_ref_attention: - output = self.ref_masked_attention( - query, - key, - value, - ) - # Using view got RuntimeError: view size is not compatible with input tensor's size and stride - # (at least one dimension spans across two contiguous subspaces). Use reshape instead - return output.reshape(batch_size, seq_len, hidden_size) - - # TODO(woosuk): Too many view operations. Let's try to reduce - # them in the future for code readability. - if self.alibi_slopes is None: - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - else: - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) - out = xops.memory_efficient_attention_forward( + output3 = flash_attn_func( + query.view(batch_size, seq_len, self.num_heads, self.head_size), + key.view(batch_size, seq_len, self.num_heads, self.head_size), + value.view(batch_size, seq_len, self.num_heads, self.head_size), + softmax_scale=self.scale, + causal=True, + ).view_as(query) + output4 = flash_attn_func( query, key, value, - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ) - output = out.view_as(query) - # if key_cache is not None and value_cache is not None: - # breakpoint() - # output2 = flash_attn_with_kvcache_paged( - # query.view(batch_size, seq_len, self.num_heads, - # self.head_size), - # key_cache, - # value_cache, - # self.scale, - # input_metadata.block_tables, - # input_metadata.context_lens, - # self.alibi_slopes, - # ) - # breakpoint() + softmax_scale=self.scale, + causal=True, + ).view_as(query) + breakpoint() + # if key_cache is not None and value_cache is not None: + # breakpoint() + # output2 = flash_attn_with_kvcache_paged( + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size), + # key_cache, + # value_cache, + # self.scale, + # input_metadata.block_tables, + # input_metadata.context_lens, + # self.alibi_slopes, + # ) + # breakpoint() else: if input_metadata.flash_style: - assert False - output = flash_attn_with_kvcache_paged( - query.view(batch_size, seq_len, self.num_heads, - self.head_size), + print("SANG-TODO flash prefill") + output = flash_single_query_cached_kv_attention( + None, + query.view(batch_size, seq_len, self.num_heads, self.head_size), key_cache, value_cache, self.scale, input_metadata.block_tables, input_metadata.context_lens, - self.alibi_slopes, + alibi_slopes=self.alibi_slopes, ) + # output = flash_multi_query_cached_kv_attention_varlen( + # None, + # query, + # key_cache, + # value_cache, + # self.scale, + # self.block_tables, + # input_metadata.start_loc, + # input_metadata.start_loc, + # max_query_len, + # max_context_len, + # alibi_slopes=self.alibi_slopes, + # ) else: - # print("SANG-TODO context attention") + print("SANG-TODO context attention") # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( @@ -287,11 +319,21 @@ def forward( # Decoding run. # breakpoint() if input_metadata.flash_style: - output = flash_attn_with_kvcache_paged( - query.view(batch_size, seq_len, self.num_heads, - self.head_size), key_cache, value_cache, - self.scale, input_metadata.block_tables, - input_metadata.context_lens, self.alibi_slopes) + # output = flash_attn_with_kvcache_paged( + # query.view(batch_size, seq_len, self.num_heads, + # self.head_size), key_cache, value_cache, + # self.scale, input_metadata.block_tables, + # input_metadata.context_lens, self.alibi_slopes) + output = flash_single_query_cached_kv_attention( + None, + query.view(batch_size, 1, self.num_heads, self.head_size), + key_cache, + value_cache, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + alibi_slopes=self.alibi_slopes, + ) else: output = _paged_attention( query, @@ -417,46 +459,205 @@ def _paged_attention( return output -def flash_attn_with_kvcache_paged( +# def flash_attn_with_kvcache_paged( +# query: torch.Tensor, +# key_cache: torch.Tensor, +# value_cache: torch.Tensor, +# scale: float, +# block_tables: torch.Tensor, +# context_lens: torch.Tensor, +# alibi_slopes: Optional[torch.Tensor] = None, +# ) -> torch.Tensor: +# """Similar to vLLM's page attention, calculates a single token's attention +# based on key/value caches. The main difference is this uses flash attention +# style key-value caches. + +# Arguments: +# See https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py +# for other arguments. + +# Returns: +# output: [num_tokens, num_heads, head_size] +# """ +# block_size = value_cache.shape[1] +# assert block_size % 256 == 0, ("only support block_size divisible by 256. " +# f"Current block size: {block_size}") +# _, _, num_heads, head_size = query.shape +# out = flash_attn_with_kvcache( +# query, +# key_cache, +# value_cache, +# # Inplace update is slow. We don't use it. +# # We assume kvcache is already updated before +# # calling this API. +# None, # key +# None, # value +# cache_seqlens=context_lens, +# block_table=block_tables, +# softmax_scale=scale, +# causal=True, +# alibi_slopes=alibi_slopes, +# num_splits=0, +# ) + +# # num_tokens == batch_size * seqlen +# return out.view(-1, num_heads, head_size) + + +def flash_single_query_cached_kv_attention( + output: Optional[torch.Tensor], query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, scale: float, block_tables: torch.Tensor, context_lens: torch.Tensor, - alibi_slopes: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor], + actual_batch_size: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Similar to vLLM's page attention, calculates a single token's attention + """Similar to vLLM's page attention, caclulates a single token's attention based on key/value caches. The main difference is this uses flash attention - style key-value caches. + sytle key-value caches. Arguments: - See https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py - for other arguments. + output: [num_padded_tokens, num_heads, head_size], output tensor + to write. if None an new output tensor will be created. + query: [batch_size, num_padded_tokens, num_heads, head_size], query tensor. + key_cache: [num_blocks, block_size, num_heads, head_size], key cache. + value_cache: [num_blocks, block_size, num_heads, head_size], value + cache. + scale: attention scale. + block_tables: [num_padded_tokens, max_context_len / block_size], + block tables. + context_lens: [num_padded_tokens], context lengths. + block_size: block size. + alibi_slopes: unused. + actual_batch_size: [1] actual batch size. Returns: - output: [num_tokens, num_heads, head_size] + output: [num_padded_tokens, num_heads, head_size] """ block_size = value_cache.shape[1] - assert block_size % 256 == 0, ("only support block_size divisible by 256. " - f"Current block size: {block_size}") - _, _, num_heads, head_size = query.shape - out = flash_attn_with_kvcache( + assert block_size >= 32, "only support block_size >= 32 for flash attention" + # TODO: support alibi_slopes + assert alibi_slopes is None, "doesn't support alibi_slopes" + batch_size, seqlen, num_heads, head_size = query.shape + num_tokens = batch_size * seqlen + out = flash_attn_with_page_attention( + query, + key_cache, + value_cache, + block_tables, + None, # key + None, # value + None, # cos + None, # sin + context_lens, + None, # cache_batch_idx + softmax_scale=scale, + causal=True, + window_size=(-1, -1), + rotary_interleaved=False, + num_splits=0, + actual_batch_size=actual_batch_size, + ) + return out + # if output is not None: + # # in case that output is padded, only copy the valid part. + # output[:num_tokens].copy_(out.view(num_tokens, num_heads, head_size)) + # return out.view(num_tokens, num_heads, head_size) + + +def flash_multi_query_cached_kv_attention_varlen( + output: Optional[torch.Tensor], + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + scale: float, + block_tables: torch.Tensor, + cum_seqlens_q: torch.Tensor, + cum_context_len: torch.Tensor, + max_query_len: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + actual_batch_size: Optional[torch.Tensor] = None, +): + """Efficient multi-query paged attention based on flash attention. + It calculates attentions of list of sequences packed in a single batch, + indexed by cum_seqlens_q where the seq_i's index is + [cum_seqlens_q[i], cum_seqlensq[i+1]]. + Similarlly, the length of context is stored in cum_seqlens_k with similar + fashions. + It also supports calculating attention incrementally, where context length + is longer than sequence length. + + Arguments: + output: [num_padded_tokens, num_heads, head_size], output tensor to + write to. if None an new output tensor will be created. + + prefill -> always 1 batch + query, head_size, head_dim + + varlen -> provides cumulative lengths of queries + + decoding -> always 1 batch + 1 * number_of_batch, head_size, head_dim + + query: [num_padded_tokens, num_heads, head_size], query tensor. + key_cache: [num_blocks, block_size, num_heads, head_size], key cache. + value_cache: [num_blocks, block_size, num_heads, head_size], + value cache. + scale: attention scale. + block_tables: [num_padded_tokens, max_context_len / block_size], + block tables. + - these two are the same if no chunked prefill (for prefill) + - If you do chunked prefill, it may be different + - when? see attention mask setting + - Each iteration, it can have smaller # of queries than (because it is chunked) + tokens we should attend to (context len). + - cum_seqlens_q: actual query length (actual prompt length). + - cum_context_len: context len that needs to be attended (subquery). + cum_seqlens_q: [padded_batch_size + 1], cumulative sequence lengths + of query. + cum_context_len: [padded_batch_size + 1], cumulative lengths + of context. + block_size: block size. + max_query_len: max query length. + max_context_len: max context length. + alibi_slopes: unused. + actual_batch_size: [1] actual batch size. + + Returns: + output: [num_padded_tokens, num_heads, head_size] + """ + block_size = value_cache.shape[1] + assert block_size >= 32, "only support block_size >= 32 for flash attention" + # TODO: support alibi_slopes + assert alibi_slopes is None, "doesn't support alibi_slopes" + + num_tokens, _, _ = query.shape + out = flash_attn_varlen_with_page_attention( query, key_cache, value_cache, - # Inplace update is slow. We don't use it. - # We assume kvcache is already updated before - # calling this API. + block_tables, + cum_seqlens_q, + cum_context_len, + max_query_len, + max_context_len, None, # key None, # value - cache_seqlens=context_lens, - block_table=block_tables, + None, # cos_cache + None, # sin_cache + None, # cache_batch_idx softmax_scale=scale, causal=True, - alibi_slopes=alibi_slopes, + window_size=(-1, -1), + rotary_interleaved=False, num_splits=0, + actual_batch_size=actual_batch_size, ) - # num_tokens == batch_size * seqlen - return out.view(-1, num_heads, head_size) + if output is not None: + output[:num_tokens].copy_(out) + return out \ No newline at end of file diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 84fc0a0e5528..5cddeed35538 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -224,6 +224,7 @@ def _prepare_prompt( slot_mapping[-1].append(slot) max_prompt_len = max(subquery_lens) + breakpoint() input_tokens = _make_tensor_with_pad(input_tokens, max_prompt_len, pad=0, From ef679d74a2296f941571c0c30dfaa0fd338783de Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 4 Mar 2024 05:09:04 -0800 Subject: [PATCH 26/88] . --- tests/kernels/test_flash_attention.py | 297 ++++++++++++------------ vllm/model_executor/layers/attention.py | 54 ++--- vllm/worker/model_runner.py | 1 - 3 files changed, 173 insertions(+), 179 deletions(-) diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index 52f9ad3ccc9f..eb999504690a 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -980,158 +980,157 @@ def test_multi_query_kv_attention( assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) -@pytest.mark.parametrize("num_heads", [40]) -@pytest.mark.parametrize("head_size", [64]) -@pytest.mark.parametrize("block_size", [32]) -@pytest.mark.parametrize("num_blocks", [128]) -@pytest.mark.parametrize("dtype", [torch.half]) -@pytest.mark.parametrize("seed", SEEDS) -@torch.inference_mode() -def test_e2e( - kv_cache_factory, - num_heads: int, - head_size: int, - block_size: int, - num_blocks: int, - dtype: torch.dtype, - seed: int, -) -> None: - from vllm._C import cache_ops - batch_size = 2 - seqlen = 29 - num_tokens = batch_size * seqlen - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) +# @pytest.mark.parametrize("num_heads", [40]) +# @pytest.mark.parametrize("head_size", [64]) +# @pytest.mark.parametrize("block_size", [32]) +# @pytest.mark.parametrize("num_blocks", [128]) +# @pytest.mark.parametrize("dtype", [torch.half]) +# @pytest.mark.parametrize("seed", SEEDS) +# @torch.inference_mode() +# def test_e2e( +# kv_cache_factory, +# num_heads: int, +# head_size: int, +# block_size: int, +# num_blocks: int, +# dtype: torch.dtype, +# seed: int, +# ) -> None: +# from vllm._C import cache_ops +# batch_size = 2 +# seqlen = 29 +# num_tokens = batch_size * seqlen +# random.seed(seed) +# torch.random.manual_seed(seed) +# torch.cuda.manual_seed(seed) - block_tables = [] - for _ in range(batch_size): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - # Create a random slot mapping. - slot_mapping = [] - for i in range(0, 29): - block_number = int(block_tables[i // block_size]) - block_offset = i % block_size - slot = block_number * block_size + block_offset - slot_mapping.append(slot) - # for _ in range(23, 29): - # slot_mapping.append(-1) - for i in range(0, 29): - block_number = int(block_tables[1]) - block_offset = i % block_size - slot = block_number * block_size + block_offset - slot_mapping.append(slot) - # slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') - - qkv = torch.randn(num_tokens, - 3, - num_heads, - head_size, - dtype=dtype, - device='cuda') - query, key, value = qkv.unbind(dim=1) - # query[23:29] = 0 - # key[23:29] = 0 - # value[23:29] = 0 - # Create the KV caches. +# block_tables = [] +# for _ in range(batch_size): +# block_table = [ +# random.randint(0, NUM_BLOCKS - 1) +# ] +# block_tables.append(block_table) +# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") +# # Create a random slot mapping. +# slot_mapping = [] +# for i in range(0, 29): +# block_number = int(block_tables[i // block_size]) +# block_offset = i % block_size +# slot = block_number * block_size + block_offset +# slot_mapping.append(slot) +# # for _ in range(23, 29): +# # slot_mapping.append(-1) +# for i in range(0, 29): +# block_number = int(block_tables[1]) +# block_offset = i % block_size +# slot = block_number * block_size + block_offset +# slot_mapping.append(slot) +# # slot_mapping = random.sample(range(num_slots), num_tokens) +# slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + +# qkv = torch.randn(num_tokens, +# 3, +# num_heads, +# head_size, +# dtype=dtype, +# device='cuda') +# query, key, value = qkv.unbind(dim=1) +# # query[23:29] = 0 +# # key[23:29] = 0 +# # value[23:29] = 0 +# # Create the KV caches. - key_caches, value_caches = kv_cache_factory(num_blocks, - block_size, - 1, - num_heads, - head_size, - dtype, - seed, - flash_style=True) - assert len(key_caches) == 1 and len(value_caches) == 1 - key_cache, value_cache = key_caches[0], value_caches[0] - # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping) +# key_caches, value_caches = kv_cache_factory(num_blocks, +# block_size, +# 1, +# num_heads, +# head_size, +# dtype, +# seed, +# flash_style=True) +# assert len(key_caches) == 1 and len(value_caches) == 1 +# key_cache, value_cache = key_caches[0], value_caches[0] +# # Call the reshape_and_cache kernel. +# cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, +# slot_mapping) - cloned_key_cache = key_cache.clone() - cloned_value_cache = value_cache.clone() +# cloned_key_cache = key_cache.clone() +# cloned_value_cache = value_cache.clone() - # Run the reference implementation. - block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') - block_indicies = block_indicies.cpu().tolist() - block_offsets = slot_mapping % block_size - block_offsets = block_offsets.cpu().tolist() - for i in range(num_tokens): - block_idx = block_indicies[i] - block_offset = block_offsets[i] - print("block_idx", block_idx) - print("block_offset", block_offset) - if block_idx != -1: - cloned_key_cache[block_idx, block_offset, :, :] = key[i] - cloned_value_cache[block_idx, block_offset, :, :] = value[i] - assert torch.allclose(key_cache, cloned_key_cache) - assert torch.allclose(value_cache, cloned_value_cache) - - for i in range(58): - print(i) - block_idx = block_indicies[i] - block_offset = block_offsets[i] - torch.allclose(key[i], key_cache[block_idx][block_offset]) - - from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask - from xformers import ops as xops - seqlen = query.shape[1] - attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seqlen] * batch_size) - scale = float(1.0 / (head_size**0.5)) - output = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=None, - p=0.0, - scale=scale, - ) +# # Run the reference implementation. +# block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') +# block_indicies = block_indicies.cpu().tolist() +# block_offsets = slot_mapping % block_size +# block_offsets = block_offsets.cpu().tolist() +# for i in range(num_tokens): +# block_idx = block_indicies[i] +# block_offset = block_offsets[i] +# print("block_idx", block_idx) +# print("block_offset", block_offset) +# if block_idx != -1: +# cloned_key_cache[block_idx, block_offset, :, :] = key[i] +# cloned_value_cache[block_idx, block_offset, :, :] = value[i] +# assert torch.allclose(key_cache, cloned_key_cache) +# assert torch.allclose(value_cache, cloned_value_cache) + +# for i in range(58): +# print(i) +# block_idx = block_indicies[i] +# block_offset = block_offsets[i] +# torch.allclose(key[i], key_cache[block_idx][block_offset]) + +# from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask +# from xformers import ops as xops +# seqlen = query.shape[1] +# attn_bias = BlockDiagonalCausalMask.from_seqlens( +# [seqlen] * batch_size) +# scale = float(1.0 / (head_size**0.5)) +# output = xops.memory_efficient_attention_forward( +# query.unsqueeze(0), +# key.unsqueeze(0), +# value.unsqueeze(0), +# attn_bias=None, +# p=0.0, +# scale=scale, +# ) - num_tokens, num_heads, head_size = query.shape - from flash_attn import flash_attn_func - output1 = flash_single_query_cached_kv_attention( - None, - query.view(batch_size, num_tokens // batch_size, num_heads, head_size), - key_cache, - value_cache, - scale, - block_tables, - torch.tensor([23, 29], dtype=torch.int, device='cuda'), - alibi_slopes=None, - ) - output2 = flash_attn_func( - # query.view(batch_size, num_tokens // batch_size, num_heads, head_size), - # key.view(batch_size, num_tokens // batch_size, num_heads, head_size), - # value.view(batch_size, num_tokens // batch_size, num_heads, head_size), - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - # key_cache, - # value_cache, - softmax_scale=scale, - # block_tables, - # torch.tensor([23, 29], dtype=torch.int, device='cuda'), - # alibi_slopes=None, - causal=True, - ) - output3 = flash_attn_func( - query.view(batch_size, num_tokens // batch_size, num_heads, head_size), - key.view(batch_size, num_tokens // batch_size, num_heads, head_size), - value.view(batch_size, num_tokens // batch_size, num_heads, head_size), - # key_cache, - # value_cache, - softmax_scale=scale, - # block_tables, - # torch.tensor([23, 29], dtype=torch.int, device='cuda'), - # alibi_slopes=None, - causal=True, - ) +# num_tokens, num_heads, head_size = query.shape +# from flash_attn import flash_attn_func +# output1 = flash_single_query_cached_kv_attention( +# None, +# query.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# key_cache, +# value_cache, +# scale, +# block_tables, +# torch.tensor([23, 29], dtype=torch.int, device='cuda'), +# alibi_slopes=None, +# ) +# output2 = flash_attn_func( +# # query.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# # key.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# # value.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# query.unsqueeze(0), +# key.unsqueeze(0), +# value.unsqueeze(0), +# # key_cache, +# # value_cache, +# softmax_scale=scale, +# # block_tables, +# # torch.tensor([23, 29], dtype=torch.int, device='cuda'), +# # alibi_slopes=None, +# causal=True, +# ) +# output3 = flash_attn_func( +# query.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# key.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# value.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# # key_cache, +# # value_cache, +# softmax_scale=scale, +# # block_tables, +# # torch.tensor([23, 29], dtype=torch.int, device='cuda'), +# # alibi_slopes=None, +# causal=True, +# ) - breakpoint() diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 1a5fd49cf025..74c574e2903f 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -231,34 +231,32 @@ def forward( (is_hip()) else None, ) output = out.view_as(query) - if key_cache is not None and value_cache is not None: - output2 = flash_single_query_cached_kv_attention( - None, - query.view(batch_size, seq_len, self.num_heads, self.head_size), - key_cache, - value_cache, - self.scale, - input_metadata.block_tables, - input_metadata.context_lens, - alibi_slopes=self.alibi_slopes, - ) - output3 = flash_attn_func( - query.view(batch_size, seq_len, self.num_heads, self.head_size), - key.view(batch_size, seq_len, self.num_heads, self.head_size), - value.view(batch_size, seq_len, self.num_heads, self.head_size), - softmax_scale=self.scale, - causal=True, - ).view_as(query) - output4 = flash_attn_func( - query, - key, - value, - softmax_scale=self.scale, - causal=True, - ).view_as(query) - breakpoint() + # if key_cache is not None and value_cache is not None: + # output2 = flash_single_query_cached_kv_attention( + # None, + # query.view(batch_size, seq_len, self.num_heads, self.head_size), + # key_cache, + # value_cache, + # self.scale, + # input_metadata.block_tables, + # input_metadata.context_lens, + # alibi_slopes=self.alibi_slopes, + # ) + # output3 = flash_attn_func( + # query.view(batch_size, seq_len, self.num_heads, self.head_size), + # key.view(batch_size, seq_len, self.num_heads, self.head_size), + # value.view(batch_size, seq_len, self.num_heads, self.head_size), + # softmax_scale=self.scale, + # causal=True, + # ).view_as(query) + # output4 = flash_attn_func( + # query, + # key, + # value, + # softmax_scale=self.scale, + # causal=True, + # ).view_as(query) # if key_cache is not None and value_cache is not None: - # breakpoint() # output2 = flash_attn_with_kvcache_paged( # query.view(batch_size, seq_len, self.num_heads, # self.head_size), @@ -269,7 +267,6 @@ def forward( # input_metadata.context_lens, # self.alibi_slopes, # ) - # breakpoint() else: if input_metadata.flash_style: print("SANG-TODO flash prefill") @@ -317,7 +314,6 @@ def forward( else: # Decoding run. - # breakpoint() if input_metadata.flash_style: # output = flash_attn_with_kvcache_paged( # query.view(batch_size, seq_len, self.num_heads, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5cddeed35538..84fc0a0e5528 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -224,7 +224,6 @@ def _prepare_prompt( slot_mapping[-1].append(slot) max_prompt_len = max(subquery_lens) - breakpoint() input_tokens = _make_tensor_with_pad(input_tokens, max_prompt_len, pad=0, From 71bda972692059492c56223e5d713ea8fc9f0759 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 4 Mar 2024 05:18:34 -0800 Subject: [PATCH 27/88] . --- tests/chunked_prefill/test_correctness.py | 33 +- tests/kernels/test_flash_attention.py | 503 +--------------------- vllm/core/block_manager.py | 2 - vllm/engine/llm_engine.py | 4 +- vllm/entrypoints/llm.py | 2 +- vllm/model_executor/layers/attention.py | 47 +- 6 files changed, 45 insertions(+), 546 deletions(-) diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py index efd1d5610f2c..6b38bb97b41d 100644 --- a/tests/chunked_prefill/test_correctness.py +++ b/tests/chunked_prefill/test_correctness.py @@ -29,7 +29,6 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("block_size", [32]) -@pytest.mark.parametrize("max_num_prompt_seqs", [1]) @pytest.mark.parametrize("tensor_parallel_size", [1]) def test_models( vllm_runner, @@ -37,7 +36,6 @@ def test_models( dtype: str, max_tokens: int, block_size: int, - max_num_prompt_seqs: int, tensor_parallel_size: int, ) -> None: """ verify the flash attention has the same output @@ -46,27 +44,26 @@ def test_models( pytest.skip( f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}" ) - # print("loading page attention models..") - # pg_model = vllm_runner(model, dtype=dtype) - # expected_outputs = [] + print("loading page attention models..") + pg_model = vllm_runner(model, dtype=dtype) + expected_outputs = [] - # print("generating tokens...") - # expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) - # print("generating tokens finished") + print("generating tokens...") + expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) + print("generating tokens finished") - # del pg_model + del pg_model - # destroy_model_parallel() - # gc.collect() - # torch.cuda.empty_cache() + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() flash_attn_output_by_batches = [] - flash_attn_model = vllm_runner( - model, - dtype=dtype, - block_size=block_size, - flash_style=True, - tensor_parallel_size=tensor_parallel_size) + flash_attn_model = vllm_runner(model, + dtype=dtype, + block_size=block_size, + flash_style=True, + tensor_parallel_size=tensor_parallel_size) for i in range(10): prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] flash_attn_output_by_batches.append( diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index eb999504690a..b5b43be50a64 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -1,494 +1,3 @@ -# import random -# from typing import Optional, Tuple, List - -# import pytest -# import torch -# import torch.nn.functional as F - -# from vllm.model_executor.layers.attention import ( -# flash_attn_with_kvcache_paged, ) -# from vllm.utils import get_max_shared_memory_bytes - -# FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 -# # This will change depending on the compute capability. -# # - 512 as a buffer -# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -# NUM_BLOCKS = 128 # Arbitrary values for testing -# PARTITION_SIZE = 512 - -# DTYPES = [torch.half, torch.bfloat16] -# NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing -# NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing -# NUM_HEADS = [(1, 40), (40, 40), (64, 8)] # Arbitrary values for testing -# NUM_HEADS_SMALL = NUM_HEADS -# # head size should be bigger than or equal to block size. -# HEAD_SIZES = [256] -# # TODO(sang): https://github.com/Dao-AILab/flash-attention/pull/824 -# # should fix the block size. But right now, the block size should be -# # divisible by 256. -# BLOCK_SIZES = [256] -# USE_ALIBI = [False, True] -# SEEDS = [0] -# PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] - - -# def pad_attention_inputs( -# pad_config: Tuple[int, int], -# block_size: int, -# query: torch.Tensor, -# block_tables: torch.Tensor, -# context_lens: torch.Tensor, -# max_context_len: int, -# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: -# """Pad the attention inputs to the specified batch size and context length. -# """ -# pad_batch_size, pad_max_context_len = pad_config -# if pad_batch_size == 0: -# return query, block_tables, context_lens, max_context_len -# target_batch_size = ( -# (query.shape[0] - 1) % pad_batch_size + 1) * pad_batch_size -# target_block_size = pad_max_context_len // block_size + 1 -# padded_query = F.pad(query, -# (0, 0, 0, 0, 0, target_batch_size - query.shape[0])) -# padded_block_table = F.pad(block_tables, -# (0, target_block_size - block_tables.shape[1], -# 0, target_batch_size - block_tables.shape[0])) -# padded_context_lens = F.pad(context_lens, -# (0, target_batch_size - context_lens.shape[0])) -# return padded_query, padded_block_table, padded_context_lens, pad_max_context_len - - -# def ref_masked_attention( -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# scale: float, -# attn_mask: Optional[torch.Tensor] = None, -# ) -> torch.Tensor: -# attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() -# if attn_mask is not None: -# attn_weights = attn_weights + attn_mask.float() -# attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) -# out = torch.einsum("hqk,khd->qhd", attn_weights, value) -# return out - - -# def ref_single_query_cached_kv_attention( -# output: torch.Tensor, -# query: torch.Tensor, -# num_queries_per_kv: int, -# key_cache: torch.Tensor, -# value_cache: torch.Tensor, -# block_tables: torch.Tensor, -# context_lens: torch.Tensor, -# scale: float, -# alibi_slopes: Optional[torch.Tensor], -# ) -> None: -# num_query_heads = query.shape[1] -# head_size = value_cache.shape[-1] -# block_size = value_cache.shape[-3] -# num_seqs = query.shape[0] - -# block_tables = block_tables.cpu().tolist() -# context_lens = context_lens.cpu().tolist() -# for i in range(num_seqs): -# q = query[i].unsqueeze(0) -# block_table = block_tables[i] -# context_len = int(context_lens[i]) - -# keys = [] -# values = [] -# for j in range(context_len): -# block_number = int(block_table[j // block_size]) -# block_offset = j % block_size - -# k = key_cache[block_number, block_offset, :, :] -# keys.append(k) - -# v = value_cache[block_number, :, :, block_offset] -# v = value_cache[block_number, block_offset, :, :] -# values.append(v) -# keys = torch.stack(keys, dim=0) -# values = torch.stack(values, dim=0) -# if num_queries_per_kv > 1: -# # Handle MQA and GQA -# keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) -# values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - -# alibi_bias = None -# if alibi_slopes is not None: -# # Create the ALiBi bias used in the paged attention kernel. -# position_ids = torch.arange(context_len, device="cuda").int() -# alibi_bias = (position_ids - context_len + 1).float() -# alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( -# 1, 1, -1) - -# out = ref_masked_attention(q, keys, values, scale, alibi_bias) -# out = out.view(num_query_heads, head_size) -# # output[i].copy_(out, non_blocking=True) -# output[i].copy_(out) - - -# # @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -# # @pytest.mark.parametrize("num_heads", NUM_HEADS) -# # @pytest.mark.parametrize("head_size", HEAD_SIZES) -# # @pytest.mark.parametrize("use_alibi", [False, True]) -# # @pytest.mark.parametrize("block_size", BLOCK_SIZES) -# # @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) -# # @pytest.mark.parametrize("seed", SEEDS) -# # @pytest.mark.parametrize("pad_config", PAD_CONFIGS) -# @pytest.mark.parametrize("num_seqs", [3]) -# @pytest.mark.parametrize("num_heads", [(40, 40)]) -# @pytest.mark.parametrize("head_size", [256]) -# @pytest.mark.parametrize("use_alibi", [True]) -# @pytest.mark.parametrize("block_size", [256]) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) -# @pytest.mark.parametrize("seed", SEEDS) -# @pytest.mark.parametrize("pad_config", [(0, 0)]) -# @torch.inference_mode() -# def test_flash_paged_attention( -# kv_cache_factory, -# num_seqs: int, -# num_heads: Tuple[int, int], -# head_size: int, -# use_alibi: bool, -# block_size: int, -# dtype: torch.dtype, -# seed: int, -# pad_config: Tuple[int, int], -# ) -> None: -# random.seed(seed) -# torch.random.manual_seed(seed) -# torch.cuda.manual_seed(seed) - -# scale = float(1.0 / (head_size**0.5)) -# num_query_heads, num_kv_heads = num_heads -# query = torch.empty(num_seqs, -# num_query_heads, -# head_size, -# dtype=dtype, -# device="cuda") -# query.uniform_(-scale, scale) - -# # assert num_query_heads % num_kv_heads == 0 -# num_queries_per_kv = num_query_heads // num_kv_heads -# alibi_slopes = None -# if use_alibi: -# alibi_slopes = torch.randn(num_query_heads, -# dtype=torch.float, -# device="cuda") - -# max_seq_len = MAX_SEQ_LEN if not pad_config[0] else (pad_config[1] - 1000) -# context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] -# context_lens[-1] = max_seq_len -# max_context_len = max(context_lens) -# context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") - -# # Create the block tables. -# max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size -# block_tables = [] -# for _ in range(num_seqs): -# block_table = [ -# random.randint(0, NUM_BLOCKS - 1) -# for _ in range(max_num_blocks_per_seq) -# ] -# block_tables.append(block_table) -# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - -# # Create the KV caches. -# key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, -# block_size, -# 1, -# num_kv_heads, -# head_size, -# dtype, -# seed, -# flash_style=True) -# key_cache, value_cache = key_caches[0], value_caches[0] - -# # Call the paged attention kernel. -# output = torch.empty_like(query) - -# padded_query, padded_block_table, padded_context_lens, _ = \ -# pad_attention_inputs(pad_config, block_size, query, -# block_tables, context_lens, max_context_len) - -# output = flash_attn_with_kvcache_paged( -# padded_query.view(num_seqs, 1, num_query_heads, head_size), -# key_cache, -# value_cache, -# scale, -# padded_block_table, -# padded_context_lens, -# alibi_slopes, -# ) - -# # Run the reference implementation. -# ref_output = torch.empty_like(query) -# ref_single_query_cached_kv_attention( -# ref_output, -# query, -# num_queries_per_kv, -# key_cache, -# value_cache, -# block_tables, -# context_lens, -# scale, -# alibi_slopes, -# ) - -# assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) - - -# def ref_multi_query_kv_attention( -# cu_seq_lens: List[int], -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# scale: float, -# dtype: torch.dtype, -# ) -> torch.Tensor: -# num_seqs = len(cu_seq_lens) - 1 -# ref_outputs = [] -# for i in range(num_seqs): -# start_idx = cu_seq_lens[i] -# end_idx = cu_seq_lens[i + 1] -# seq_len = end_idx - start_idx - -# # Create attention mask. -# attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), -# diagonal=1) -# attn_mask = attn_mask * torch.finfo(dtype).min -# attn_mask = attn_mask.to(dtype=dtype, device="cuda") - -# ref_output = ref_masked_attention( -# query[start_idx:end_idx], -# key[start_idx:end_idx], -# value[start_idx:end_idx], -# scale, -# attn_mask=attn_mask, -# ) -# ref_outputs.append(ref_output) -# ref_output = torch.cat(ref_outputs, dim=0) -# return ref_output - - -# def ref_multi_query_kv_attention_padded( -# query: torch.Tensor, -# num_queries_per_kv: int, -# key_cache: torch.Tensor, -# value_cache: torch.Tensor, -# block_tables: torch.Tensor, -# cu_seq_lens: List[int], -# context_lens: List[int], -# scale: float, -# dtype: torch.dtype, -# ) -> torch.Tensor: -# num_seqs = len(cu_seq_lens) - 1 -# block_size = value_cache.shape[-3] -# ref_outputs = [] - -# for i in range(num_seqs): -# q_start_idx = cu_seq_lens[i] -# q_end_idx = cu_seq_lens[i + 1] -# seq_len = q_end_idx - q_start_idx - -# context_len = context_lens[i] - -# block_table = block_tables[i] -# keys = [] -# values = [] - -# for j in range(context_len): -# block_number = int(block_table[j // block_size]) -# block_offset = j % block_size - -# k = key_cache[block_number, block_offset, :, :] -# keys.append(k) - -# v = value_cache[block_number, block_offset, :, :] -# values.append(v) - -# keys = torch.stack(keys, dim=0) -# values = torch.stack(values, dim=0) - -# if num_queries_per_kv > 1: -# # Handle MQA and GQA -# keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) -# values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - -# q = query[q_start_idx:q_end_idx, :, :] -# k = keys[:context_len, :, :] -# v = values[:context_len, :, :] - -# assert seq_len <= context_len - -# # pad q if seq_len is less than context_len -# # this is for correct calculation of attention. -# if seq_len < context_len: -# indices = [i % seq_len for i in range(context_len - seq_len)] -# q_left_pad = q[indices, :, :] -# q = torch.cat([q_left_pad, q], dim=0) - -# # Create attention mask. -# attn_mask = torch.triu(torch.ones(context_len, -# context_len, -# dtype=dtype), -# diagonal=1) -# attn_mask = attn_mask * torch.finfo(dtype).min -# attn_mask = attn_mask.to(dtype=dtype, device="cuda") -# ref_output = ref_masked_attention( -# q, -# k, -# v, -# scale, -# attn_mask=attn_mask, -# ) -# ref_outputs.append(ref_output[-seq_len:, :, :]) -# breakpoint -# ref_output = torch.cat(ref_outputs, dim=0) -# return ref_output - - -# def is_a100(): -# return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 - - -# if not is_a100(): -# NUM_HEADS_SMALL = [(16, 16), (16, 8)] -# MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) - - -# @pytest.mark.parametrize("num_seqs", [17]) -# @pytest.mark.parametrize("num_heads", [(40, 40)]) -# @pytest.mark.parametrize("head_size", HEAD_SIZES) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) -# @pytest.mark.parametrize("version", ["flash"]) -# @pytest.mark.parametrize("seed", SEEDS) -# @pytest.mark.parametrize("block_size", [256]) -# @torch.inference_mode() -# def test_multi_query_kv_attention( -# num_seqs: int, -# num_heads: Tuple[int, int], -# head_size: int, -# dtype: torch.dtype, -# version: str, -# seed: int, -# block_size: int, -# ) -> None: -# random.seed(seed) -# torch.random.manual_seed(seed) -# torch.cuda.manual_seed(seed) - -# # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. -# # As the xformers library is already tested with its own tests, we can use -# # a smaller MAX_SEQ_LEN here. -# max_len = min(MAX_SEQ_LEN, 4096) - -# seq_lens = [random.randint(1, max_len // 2) for i in range(num_seqs)] -# max_seq_len = max(seq_lens) -# context_lens = seq_lens -# max_context_len = max(context_lens) - -# num_tokens = sum(seq_lens) -# cu_seq_lens = [0] -# for seq_len in seq_lens: -# cu_seq_lens.append(cu_seq_lens[-1] + seq_len) - -# cu_context_lens = [0] -# for context_len in context_lens: -# cu_context_lens.append(cu_context_lens[-1] + context_len) - -# print(f"cu_seq_lens={cu_seq_lens}, cu_context_lens={cu_context_lens}") - -# scale = float(1.0 / (head_size**0.5)) -# num_query_heads, num_kv_heads = num_heads -# num_queries_per_kv = num_query_heads // num_kv_heads - -# value_cache = torch.empty(NUM_BLOCKS, -# block_size, -# num_kv_heads, -# head_size, -# dtype=dtype, -# device="cuda") -# key_cache = torch.empty(NUM_BLOCKS, -# block_size, -# num_kv_heads, -# head_size, -# dtype=dtype, -# device="cuda") -# query = torch.empty(max_seq_len * num_seqs, -# num_query_heads, -# head_size, -# dtype=dtype, -# device="cuda") -# value_cache.uniform_(-scale, scale) -# key_cache.uniform_(-scale, scale) -# query.uniform_(-scale, scale) - -# # Create the block tables. -# max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size -# block_tables = [] -# for _ in range(num_seqs): -# block_table = [ -# random.randint(0, NUM_BLOCKS - 1) -# for _ in range(max_num_blocks_per_seq) -# ] -# block_tables.append(block_table) -# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - -# output = torch.empty_like(query) - -# if version == "flash": -# # flash_multi_query_cached_kv_attention_varlen( -# # output, -# # query, -# # key_cache, -# # value_cache, -# # scale, -# # block_tables, -# # torch.cuda.IntTensor(cu_seq_lens), -# # torch.cuda.IntTensor(cu_context_lens), -# # block_size, -# # max_seq_len, -# # max_context_len, -# # None, -# # ) -# from flash_attn import flash_attn_func -# breakpoint() -# # output = flash_attn_func( -# # query.unsqueeze(0), -# # k=key, -# # v=value, -# # softmax_scale=scale, -# # causal=True, -# # alibi_slopes=alibi_slopes, -# # ) -# output = flash_attn_with_kvcache_paged( -# query.view(num_seqs, max_seq_len, num_query_heads, head_size), -# key_cache, -# value_cache, -# scale, -# block_tables, -# torch.tensor(context_lens, dtype=torch.int, device="cuda"), -# None, -# ) -# else: -# assert False, f"{version=} is not supported" - -# ref_output = ref_multi_query_kv_attention_padded( -# query, -# num_queries_per_kv, -# key_cache, -# value_cache, -# block_tables, -# cu_seq_lens, -# context_lens, -# scale, -# dtype, -# ) -# assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) - import random from typing import List, Optional, Tuple @@ -660,9 +169,6 @@ def test_flash_paged_attention( assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), - num_queries_per_kv) alibi_slopes = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, @@ -698,7 +204,6 @@ def test_flash_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. - num_valid_tokens = torch.cuda.IntTensor([num_seqs]) output = torch.empty_like(query) padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ @@ -707,7 +212,7 @@ def test_flash_paged_attention( flash_single_query_cached_kv_attention( output, - padded_query, + padded_query.view(num_seqs, 1, num_query_heads, head_size), key_cache, value_cache, scale, @@ -731,9 +236,6 @@ def test_flash_paged_attention( flash_style=True, ) - # NOTE(woosuk): Due to the kernel-level differences in the two - # implementations, there is a small numerical difference in the two - # outputs. Thus, we use a relaxed tolerance for the test. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) @@ -1053,7 +555,7 @@ def test_multi_query_kv_attention( # # Call the reshape_and_cache kernel. # cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, # slot_mapping) - + # cloned_key_cache = key_cache.clone() # cloned_value_cache = value_cache.clone() @@ -1133,4 +635,3 @@ def test_multi_query_kv_attention( # # alibi_slopes=None, # causal=True, # ) - diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index a7c38fc4c0f2..daf83827a7e5 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -200,8 +200,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) - # print("allcate a new block for seq_group", seq_group) - # print("block", block) block_table.append(block) # Assign the block table for each sequence. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5283654d1a34..50bedbe89cd1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -836,8 +836,8 @@ def step(self) -> List[RequestOutput]: >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - print("SANG-TODO step seq_group_metadata_list length: ", - len(seq_group_metadata_list)) + # print("SANG-TODO step seq_group_metadata_list length: ", + # len(seq_group_metadata_list)) if not scheduler_outputs.is_empty(): # Execute the model. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e1932bb761ef..93adc98adda5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -146,7 +146,7 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - print("SANG-TODO generate: ", prompts, prompt_token_ids) + # print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 74c574e2903f..03941f9f6c9a 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -172,17 +172,17 @@ def forward( # heads. # TODO(woosuk): Use MQA/GQA kernels for higher performance. query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - query.shape[-1]) + self.num_queries_per_kv, + query.shape[-1]) key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + key.shape[-1]) value = value[:, :, - None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) + None, :].expand(value.shape[0], + self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. @@ -269,10 +269,10 @@ def forward( # ) else: if input_metadata.flash_style: - print("SANG-TODO flash prefill") output = flash_single_query_cached_kv_attention( None, - query.view(batch_size, seq_len, self.num_heads, self.head_size), + query.view(batch_size, seq_len, self.num_heads, + self.head_size), key_cache, value_cache, self.scale, @@ -294,7 +294,7 @@ def forward( # alibi_slopes=self.alibi_slopes, # ) else: - print("SANG-TODO context attention") + # print("SANG-TODO context attention") # prefix-enabled attention output = torch.empty_like(query) context_attention_fwd( @@ -304,7 +304,8 @@ def forward( output, key_cache, value_cache, - input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata. + block_tables, # [BS, max_block_per_request] input_metadata.start_loc, input_metadata.prompt_lens, input_metadata.context_lens, @@ -455,6 +456,7 @@ def _paged_attention( return output +# OSS version. # def flash_attn_with_kvcache_paged( # query: torch.Tensor, # key_cache: torch.Tensor, @@ -501,7 +503,7 @@ def _paged_attention( def flash_single_query_cached_kv_attention( - output: Optional[torch.Tensor], + output: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -511,9 +513,9 @@ def flash_single_query_cached_kv_attention( alibi_slopes: Optional[torch.Tensor], actual_batch_size: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Similar to vLLM's page attention, caclulates a single token's attention + """Similar to vLLM's page attention, calculates a single token's attention based on key/value caches. The main difference is this uses flash attention - sytle key-value caches. + style key-value caches. Arguments: output: [num_padded_tokens, num_heads, head_size], output tensor @@ -538,6 +540,8 @@ def flash_single_query_cached_kv_attention( # TODO: support alibi_slopes assert alibi_slopes is None, "doesn't support alibi_slopes" batch_size, seqlen, num_heads, head_size = query.shape + assert seqlen == 1, ( + "Single query attention can be only used for decoding phase.") num_tokens = batch_size * seqlen out = flash_attn_with_page_attention( query, @@ -557,11 +561,10 @@ def flash_single_query_cached_kv_attention( num_splits=0, actual_batch_size=actual_batch_size, ) - return out - # if output is not None: - # # in case that output is padded, only copy the valid part. - # output[:num_tokens].copy_(out.view(num_tokens, num_heads, head_size)) - # return out.view(num_tokens, num_heads, head_size) + if output is not None: + # in case that output is padded, only copy the valid part. + output[:num_tokens].copy_(out.view(num_tokens, num_heads, head_size)) + return out.view(num_tokens, num_heads, head_size) def flash_multi_query_cached_kv_attention_varlen( @@ -656,4 +659,4 @@ def flash_multi_query_cached_kv_attention_varlen( if output is not None: output[:num_tokens].copy_(out) - return out \ No newline at end of file + return out From 4e00e7f00f2278045ad8970a70d50c4ac3b4241a Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 4 Mar 2024 05:18:44 -0800 Subject: [PATCH 28/88] done? --- tests/kernels/test_flash_attention.py | 6 +----- vllm/model_executor/layers/attention.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py index b5b43be50a64..66a668cb7dd5 100644 --- a/tests/kernels/test_flash_attention.py +++ b/tests/kernels/test_flash_attention.py @@ -388,7 +388,6 @@ def test_multi_query_kv_attention( seq_lens = [random.randint(1, max_len // 2) for i in range(num_seqs)] max_seq_len = max(seq_lens) - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device="cuda") if chunked_prefill: # context length will be different from seq_len if chunked_prefill is @@ -397,9 +396,6 @@ def test_multi_query_kv_attention( else: context_lens = seq_lens max_context_len = max(context_lens) - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device="cuda") num_tokens = sum(seq_lens) cu_seq_lens = [0] @@ -466,7 +462,7 @@ def test_multi_query_kv_attention( None, ) else: - assert False, f"{version=} is not supported" + raise AssertionError(f"{version=} is not supported") ref_output = ref_multi_query_kv_attention_padded( query, diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 03941f9f6c9a..be4343fde453 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( context_attention_fwd) from vllm.utils import is_hip -from flash_attn import flash_attn_func # try: # from flash_attn import flash_attn_with_kvcache From c0384a42dc2ed266ff269f946b6482221a5b746a Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 6 Mar 2024 05:37:42 -0800 Subject: [PATCH 29/88] Refactor 2d query to 1d query --- tests/core/test_scheduler.py | 18 +-- tests/models/test_models.py | 32 ++-- tests/worker/test_model_runner.py | 146 ++++++++++++++++- vllm/config.py | 3 - vllm/core/scheduler.py | 13 +- vllm/engine/arg_utils.py | 8 +- vllm/model_executor/input_metadata.py | 19 ++- vllm/model_executor/layers/activation.py | 4 +- vllm/model_executor/layers/attention.py | 152 ++++++++++-------- vllm/model_executor/layers/sampler.py | 1 - vllm/worker/model_runner.py | 189 +++++++++++++++-------- 11 files changed, 403 insertions(+), 182 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index ebfeb8ba0481..397101fa8610 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,7 +10,7 @@ def test_scheduler_add_seq_group(): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1, 256) + scheduler_config = SchedulerConfig(100, 64, 1) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -26,7 +26,7 @@ def test_scheduler_add_seq_group(): def test_scheduler_abort_seq_group(): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1, 256) + scheduler_config = SchedulerConfig(100, 64, 1) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -50,7 +50,7 @@ def test_scheduler_schedule_simple(): block_size = 4 num_seq_group = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len, 256) + scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -64,10 +64,10 @@ def test_scheduler_schedule_simple(): running.append(seq_group) # Schedule seq groups prompts. + num_tokens = block_size * num_seq_group seq_group_meta, out = scheduler.schedule() assert set(out.scheduled_seq_groups) == set(running) - assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs( - )[0].get_len() + assert out.num_batched_tokens == num_tokens assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == num_seq_group @@ -84,7 +84,7 @@ def test_scheduler_schedule_simple(): def test_scheduler_schedule_preempt_abort(): block_size = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, 2, max_model_len, 256) + scheduler_config = SchedulerConfig(64, 2, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 @@ -99,7 +99,7 @@ def test_scheduler_schedule_preempt_abort(): # Schedule seq groups prompts. seq_group_meta, out = scheduler.schedule() assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] - assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2 + assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == 2 @@ -124,7 +124,7 @@ def test_scheduler_schedule_preempt_abort(): scheduler.abort_seq_group("1") seq_group_meta, out = scheduler.schedule() assert out.scheduled_seq_groups == [seq_group_b] - assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len() + assert out.num_batched_tokens == 5 # 4 prompt + 1 generation. assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == 1 @@ -136,7 +136,7 @@ def test_scheduler_max_seqs(): num_seq_group = 4 max_seq_group = 2 max_model_len = 16 - scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len, 256) + scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 diff --git a/tests/models/test_models.py b/tests/models/test_models.py index fb567e837d28..c268b6fd4868 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -6,26 +6,27 @@ MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", - "mistralai/Mistral-7B-v0.1", - "Deci/DeciLM-7b", - "tiiuae/falcon-7b", - "gpt2", - "bigcode/tiny_starcoder_py", - "EleutherAI/gpt-j-6b", - "EleutherAI/pythia-70m", - "bigscience/bloom-560m", - "mosaicml/mpt-7b", - "microsoft/phi-2", - "stabilityai/stablelm-3b-4e1t", - "allenai/OLMo-1B", - "bigcode/starcoder2-3b", + # "meta-llama/Llama-2-7b-hf", + # "mistralai/Mistral-7B-v0.1", + # "Deci/DeciLM-7b", + # "tiiuae/falcon-7b", + # "gpt2", + # "bigcode/tiny_starcoder_py", + # "EleutherAI/gpt-j-6b", + # "EleutherAI/pythia-70m", + # "bigscience/bloom-560m", + # "mosaicml/mpt-7b", + # "microsoft/phi-2", + # "stabilityai/stablelm-3b-4e1t", + # "allenai/OLMo-1B", + # "bigcode/starcoder2-3b", ] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("enforce_eager", [False, True]) def test_models( hf_runner, vllm_runner, @@ -33,12 +34,13 @@ def test_models( model: str, dtype: str, max_tokens: int, + enforce_eager: bool, ) -> None: hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype) + vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index f44895a728c7..7fed21bc2aa9 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -2,7 +2,14 @@ import torch from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import ModelRunner +from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT +from vllm.config import ModelConfig + + +# Make sure the result is aligned. +def round_up_to_next_multiple_of_batch_size(n): + batch_size = _BATCH_SIZE_ALIGNMENT + return ((n + 7) // batch_size) * batch_size def test_prepare_prompt(): @@ -28,21 +35,148 @@ def test_prepare_prompt(): expected_selected_token_indices = [] selected_token_start_idx = 0 - max_seq_len = max(prompt_lens) for prompt_len in prompt_lens: expected_selected_token_indices.append(selected_token_start_idx + prompt_len - 1) - selected_token_start_idx += max_seq_len - input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = ( + selected_token_start_idx += prompt_len + input_tokens, input_positions, input_metadata, return_prompt_lens, _, _, _, _ = ( model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens + + # Verify input metadata is correct for prompts. + device = model_runner.device + assert input_metadata.is_prompt is True + assert torch.allclose(input_metadata.prompt_lens, + torch.tensor(prompt_lens, device=device)) + assert input_metadata.num_prompt_tokens == sum(prompt_lens) + assert input_metadata.num_generation_tokens == 0 + assert input_metadata.max_seq_len == max(prompt_lens) + # build start_loc + start_idx = 0 + start_loc = [start_idx] + # start_loc is padded. + for prompt_len in prompt_lens: + start_idx += prompt_len + start_loc.append(start_idx) + assert torch.allclose( + input_metadata.start_loc, + torch.tensor(start_loc, dtype=torch.long, device=device)) + assert input_metadata.max_context_len == max(prompt_lens) + assert torch.allclose( + input_metadata.context_lens, + torch.tensor(prompt_lens, dtype=torch.int, device=device)) + # assert input_metadata.slot_mapping == max(prompt_lens) + # block_tables + # Cuda graph should not be used for prerill. + assert input_metadata.use_cuda_graph is False + assert input_metadata.kv_cache_dtype == "auto" + assert input_metadata.num_valid_tokens == round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)) + + assert input_tokens.shape == (round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)), ) + assert input_positions.shape == (round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)), ) + torch.testing.assert_close(input_tokens, input_positions) + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - assert input_tokens.shape == (batch_size, max_seq_len) - assert input_positions.shape == (batch_size, max_seq_len) + assert input_tokens.shape == (round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)), ) + assert input_positions.shape == (round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)), ) + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) + torch.testing.assert_close(input_tokens, input_positions) + + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) + + +def test_prepare_decode_cuda_graph(): + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=False, + ) + model_runner = ModelRunner(model_config, None, None, None, None) + model_runner.set_block_size(16) + + # Make sure the result is aligned. + def round_up_to_next_multiple_of_batch_size(n): + batch_size = _BATCH_SIZE_ALIGNMENT + return ((n + 7) // batch_size) * batch_size + + batch_size = random.randint(1, 256) + prompt_lens = [] + seq_group_metadata_list = [] + for i in range(batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_lens.append(prompt_len) + seq_data = list(range(prompt_len)) + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: SequenceData(seq_data)}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + )) + + input_tokens, input_positions, input_metadata, _, _, _ = ( + model_runner._prepare_decode(seq_group_metadata_list)) + + # Verify input metadata is correct for prompts. + device = model_runner.device + assert input_metadata.is_prompt is False + assert input_metadata.prompt_lens is None + assert input_metadata.num_prompt_tokens == 0 + assert input_metadata.num_generation_tokens == ( + round_up_to_next_multiple_of_batch_size(len(seq_group_metadata_list))) + assert input_metadata.max_seq_len is None + assert input_metadata.start_loc is None + assert input_metadata.max_context_len == max(prompt_lens) + assert torch.allclose( + input_metadata.context_lens[:len(prompt_lens)], + torch.tensor(prompt_lens, dtype=torch.int, device=device)) + # assert input_metadata.slot_mapping == max(prompt_lens) + # block_tables + # Cuda graph should not be used for prerill. + assert input_metadata.use_cuda_graph is True + assert input_metadata.kv_cache_dtype == "auto" + assert input_metadata.num_valid_tokens == ( + round_up_to_next_multiple_of_batch_size(len(seq_group_metadata_list))) + + assert input_tokens.shape == (round_up_to_next_multiple_of_batch_size( + len(seq_group_metadata_list)), ) + assert input_positions.shape == (round_up_to_next_multiple_of_batch_size( + len(seq_group_metadata_list)), ) torch.testing.assert_close(input_tokens, input_positions) + # Verify Sampling + expected_selected_token_indices = [] + selected_token_start_idx = 0 + for prompt_len in prompt_lens: + expected_selected_token_indices.append(selected_token_start_idx) + selected_token_start_idx += 1 + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, diff --git a/vllm/config.py b/vllm/config.py index ef9a920f29c2..9867ca9a2c1e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -456,7 +456,6 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). - max_paddings: Maximum number of paddings to be added to a batch. """ def __init__( @@ -464,7 +463,6 @@ def __init__( max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, - max_paddings: int, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -474,7 +472,6 @@ def __init__( self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len - self.max_paddings = max_paddings self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c96c6d62ef19..3ec5216a0f18 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -173,12 +173,12 @@ def _schedule(self) -> SchedulerOutputs: curr_loras = set( seq_group.lora_int_id for seq_group in self.running) if self.lora_enabled else None - seq_lens: List[int] = [] # Optimization: We do not sort the waiting queue since the preempted # sequence groups are added to the front and the new sequence groups # are added to the back. leftover_waiting_sequences = deque() + num_batched_tokens = 0 while self.waiting: seq_group = self.waiting[0] waiting_seqs = seq_group.get_seqs( @@ -223,8 +223,7 @@ def _schedule(self) -> SchedulerOutputs: continue # If the number of batched tokens exceeds the limit, stop. - new_seq_lens = seq_lens + [num_prompt_tokens] - num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) + num_batched_tokens += num_prompt_tokens if (num_batched_tokens > self.scheduler_config.max_num_batched_tokens): break @@ -236,11 +235,6 @@ def _schedule(self) -> SchedulerOutputs: self.scheduler_config.max_num_seqs): break - num_paddings = num_batched_tokens - sum(new_seq_lens) - if num_paddings > self.scheduler_config.max_paddings: - break - seq_lens = new_seq_lens - if lora_int_id > 0: curr_loras.add(lora_int_id) self.waiting.popleft() @@ -255,8 +249,7 @@ def _schedule(self) -> SchedulerOutputs: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, prompt_run=True, - num_batched_tokens=len(seq_lens) * - max(seq_lens) if seq_lens else 0, + num_batched_tokens=num_batched_tokens, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c3dccdd5bb50..fb4462c0adff 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -30,7 +30,6 @@ class EngineArgs: gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 - max_paddings: int = 256 max_logprobs: int = 5 # OpenAI default value disable_log_stats: bool = False revision: Optional[str] = None @@ -209,10 +208,6 @@ def add_cli_args( type=int, default=EngineArgs.max_num_seqs, help='maximum number of sequences per iteration') - parser.add_argument('--max-paddings', - type=int, - default=EngineArgs.max_paddings, - help='maximum number of paddings in a batch') parser.add_argument( '--max-logprobs', type=int, @@ -322,8 +317,7 @@ def create_engine_configs( self.ray_workers_use_nsight) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, - model_config.max_model_len, - self.max_paddings) + model_config.max_model_len) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index f0a88ac8e27f..e7d45a6cbf3f 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -8,9 +8,16 @@ class InputMetadata: Args: prompt_lens: Lengths of prompts. - slot_mapping: The address to write the new KV to of each token. + slot_mapping: The index of each token mapped into a physical block + in block tables. E.g., if block_size is 32, 35 means it is in + the block number 1, 3rd index. + num_prompt_tokens: The number of tokens in the prompts. This might + include padding. + num_generation_tokens: The number of tokens in the generation sequences. + This might include padding. max_context_len: The maximum context length. context_lens: the length of attention context for each sequence. + I.e., the number of tokens that have attended so far. block_tables: The block tables. (Seq id -> list of physical block) kv_cache_dtype: Data type to store kv cache. """ @@ -20,6 +27,8 @@ def __init__( is_prompt: bool, slot_mapping: torch.Tensor, prompt_lens: Optional[torch.Tensor], + num_prompt_tokens: int, + num_generation_tokens: int, max_seq_len: Optional[int], start_loc: Optional[torch.Tensor], max_context_len: Optional[int], @@ -30,6 +39,8 @@ def __init__( ) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens + self.num_prompt_tokens = num_prompt_tokens + self.num_generation_tokens = num_generation_tokens self.max_seq_len = max_seq_len self.start_loc = start_loc self.max_context_len = max_context_len @@ -42,13 +53,17 @@ def __init__( # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. self.attn_bias = None + self.num_valid_tokens = slot_mapping.shape[0] def __repr__(self) -> str: return ("InputMetadata(" f"is_prompt={self.is_prompt}, " f"max_context_len={self.max_context_len}, " + f"num_generation_tokens={self.num_generation_tokens}, " + f"num_prompt_tokens={self.num_prompt_tokens}, " f"slot_mapping={self.slot_mapping}, " f"context_lens={self.context_lens}, " f"block_tables={self.block_tables}, " f"use_cuda_graph={self.use_cuda_graph}, " - f"kv_cache_dtype={self.kv_cache_dtype})") + f"kv_cache_dtype={self.kv_cache_dtype}) " + f"num_valid_tokens={self.num_valid_tokens}") diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 5a3a7b2dbaee..6d829395883a 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -20,8 +20,8 @@ class SiluAndMul(nn.Module): The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. Shapes: - x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) - return: (batch_size, seq_len, d) or (num_tokens, d) + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) """ def _forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 2a82325b8021..bfcdecc8876b 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -23,14 +23,28 @@ class PagedAttention(nn.Module): """MHA/MQA/GQA layer with PagedAttention. - This class takes query, key, and value tensors as input. The input tensors + This class takes flattened 1D query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + + If the input tensors contain prompt tokens, the layout is as follows: + |<---------------------- num_valid_tokens ---------------------->| + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->| + Otherwise, the layout is as follows: + |<------------------ num_valid_tokens ------------------->| + |<------- num_generation_tokens (M) ------->| + |<--generation_0-->|...|<--generation_M-1-->|<--padding-->| + The prompts might have different lengths, while the generation tokens always + have length 1. The paddings are appended to make the input length a multiple + of 8, which is desirable for Tensor Cores. + The class does the following: 1. Reshape and store the input key and value tensors in the KV cache. 2. Perform (multi-head/multi-query/grouped-query) attention using either xformers or the PagedAttention custom op. - 3. Return the output tensor. + 3. Output a flattened 1D tensor. """ def __init__( @@ -105,38 +119,50 @@ def forward( """PagedAttention forward pass. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] + query: shape = [num_tokens num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] input_metadata: metadata for the inputs. Returns: - shape = [batch_size, seq_len, num_heads * head_size] + shape = [num_tokens, num_heads * head_size] """ - batch_size, seq_len, hidden_size = query.shape + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) + output = torch.empty_like(query) # Reshape the keys and values and store them in the cache. # If key_cache and value_cache are not provided, the new key and value # vectors will not be cached. This happens during the initial memory # profiling run. - if key_cache is not None and value_cache is not None: + num_valid_tokens = input_metadata.num_valid_tokens + if (num_valid_tokens > 0 and key_cache is not None + and value_cache is not None): + key_to_cache = key[:num_valid_tokens] + value_to_cache = value[:num_valid_tokens] cache_ops.reshape_and_cache( - key, - value, + key_to_cache, + value_to_cache, key_cache, value_cache, input_metadata.slot_mapping.flatten(), input_metadata.kv_cache_dtype, ) - if input_metadata.is_prompt: + num_prompt_tokens = input_metadata.num_prompt_tokens + num_generation_tokens = input_metadata.num_generation_tokens + + if num_prompt_tokens > 0: + assert num_generation_tokens == 0 + query = query[:num_prompt_tokens] + key = key[:num_prompt_tokens] + value = value[:num_prompt_tokens] # normal attention if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): @@ -164,25 +190,25 @@ def forward( if input_metadata.attn_bias is None: if self.alibi_slopes is None: attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size) + input_metadata.prompt_lens.tolist()) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) input_metadata.attn_bias = attn_bias else: input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, batch_size, - seq_len, query.dtype) + self.alibi_slopes, self.num_kv_heads, query.dtype, + input_metadata) if self.use_ref_attention: - output = self.ref_masked_attention( + output[:num_prompt_tokens] = self.ref_masked_attention( query, key, value, ) # Using view got RuntimeError: view size is not compatible with input tensor's size and stride # (at least one dimension spans across two contiguous subspaces). Use reshape instead - return output.reshape(batch_size, seq_len, hidden_size) + return output.reshape(num_tokens, hidden_size) # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. @@ -191,21 +217,22 @@ def forward( key = key.unsqueeze(0) value = value.unsqueeze(0) else: - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) + query = query.unflatten(0, (num_tokens)) + key = key.unflatten(0, (num_tokens)) + value = value.unflatten(0, (num_tokens)) - out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ) - output = out.view_as(query) + output[: + num_prompt_tokens] = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha. + MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ).view_as(query) else: # prefix-enabled attention output = torch.empty_like(query) @@ -224,10 +251,12 @@ def forward( getattr(self, "alibi_slopes", None), ) - else: + if num_generation_tokens > 0: + assert num_prompt_tokens == 0 # Decoding run. output = _paged_attention( - query, + output[num_prompt_tokens:num_valid_tokens], + query[num_prompt_tokens:num_valid_tokens], key_cache, value_cache, input_metadata, @@ -237,45 +266,44 @@ def forward( ) # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + return output.view(-1, self.num_heads * self.head_size) def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, - batch_size: int, - seq_len: int, dtype: torch.dtype, + input_metadata: InputMetadata, ) -> LowerTriangularMaskWithTensorBias: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - # When using custom attention bias, xformers requires the bias to - # be sliced from a tensor whose length is a multiple of 8. - padded_len = (seq_len + 7) // 8 * 8 - num_heads = alibi_slopes.shape[0] - bias = torch.empty( - batch_size, - num_heads, - seq_len, - padded_len, - device=alibi_slopes.device, - dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) - if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) - attn_bias = LowerTriangularMaskWithTensorBias(bias) - return attn_bias + for prompt_len in input_metadata.prompt_lens: + bias = torch.arange(prompt_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + padded_len = (prompt_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + prompt_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :prompt_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + attn_bias = LowerTriangularMaskWithTensorBias(bias) + return attn_bias def _paged_attention( - query: torch.Tensor, + output: torch.Tensor, # [num_tokens, num_heads, head_size] + query: torch.Tensor, # [num_tokens, num_heads, head_size] key_cache: torch.Tensor, value_cache: torch.Tensor, input_metadata: InputMetadata, @@ -283,8 +311,6 @@ def _paged_attention( scale: float, alibi_slopes: Optional[torch.Tensor], ) -> torch.Tensor: - output = torch.empty_like(query) - block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = ( diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 320cb443524c..bb5aa281568f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -125,7 +125,6 @@ def _prune_hidden_states( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) return hidden_states.index_select(0, sampling_metadata.selected_token_indices) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index aff8ebc90362..6c207a1abbe1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -28,9 +28,14 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] _PAD_SLOT_ID = -1 LORA_WARMUP_RANK = 8 -# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. +_BATCH_SIZE_ALIGNMENT = 8 +# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] +# Note that cuda graph is only used for decoding because it speeds up +# the performance when num_tokens < 200. Batch here means a single token. +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) +] class ModelRunner: @@ -124,9 +129,9 @@ def _prepare_prompt( ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], List[int], List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() @@ -155,16 +160,18 @@ def _prepare_prompt( computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) + context_len = computed_len else: prefix_block_tables.append([]) + context_len = prompt_len # actual prompt lens - context_lens.append(computed_len) + context_lens.append(context_len) subquery_lens.append(prompt_len - computed_len) - input_tokens.append(prompt_tokens) + input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.append( + input_positions.extend( list(range(computed_len, computed_len + len(prompt_tokens)))) lora_id = seq_group_metadata.lora_int_id @@ -181,11 +188,10 @@ def _prepare_prompt( if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. - slot_mapping.append([_PAD_SLOT_ID] * prompt_len) + slot_mapping.extend([_PAD_SLOT_ID] * prompt_len) continue # Compute the slot mapping. - slot_mapping.append([]) block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, prompt_len - sliding_window). @@ -200,30 +206,26 @@ def _prepare_prompt( start_idx = max(0, prompt_len - self.sliding_window) for i in range(computed_len, prompt_len): if i < start_idx: - slot_mapping[-1].append(_PAD_SLOT_ID) + slot_mapping.append(_PAD_SLOT_ID) continue block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping[-1].append(slot) + slot_mapping.append(slot) max_prompt_len = max(subquery_lens) - input_tokens = _make_tensor_with_pad(input_tokens, - max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - input_positions = _make_tensor_with_pad(input_positions, - max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) + num_prompt_tokens = len(input_tokens) + input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, + pad=0, + dtype=torch.long, + device=self.device) + input_positions = _make_tensor_with_pad_for_alignment( + input_positions, pad=0, dtype=torch.long, device=self.device) + slot_mapping = _make_tensor_with_pad_for_alignment(slot_mapping, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) lora_index_mapping = [ _pad_to_max(mapping, max_prompt_len, pad=0) for mapping in lora_index_mapping @@ -240,22 +242,30 @@ def _prepare_prompt( dtype=torch.int, device=self.device, ) - start_loc_tensor = torch.arange(0, - len(prompt_lens) * max_prompt_len, - max_prompt_len, - dtype=torch.long, - device=self.device) prompt_lens_tensor = torch.tensor(prompt_lens, dtype=torch.long, device=self.device) + # Cumulative index of each prompt. [prompt_lens + 1] + # [0, 0+1th, 0+1th+2nd, ...] + start_loc_tensor = torch.zeros(prompt_lens_tensor.shape[0] + 1, + dtype=torch.long, + device=self.device) + + torch.cumsum(prompt_lens_tensor, + dim=0, + dtype=start_loc_tensor.dtype, + out=start_loc_tensor[1:]) + input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping, prompt_lens=prompt_lens_tensor, + num_prompt_tokens=num_prompt_tokens, + num_generation_tokens=0, max_seq_len=max_prompt_len, start_loc=start_loc_tensor, - max_context_len=None, + max_context_len=max(context_lens), context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -271,9 +281,9 @@ def _prepare_decode( ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] context_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] @@ -292,11 +302,11 @@ def _prepare_decode( for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) + input_tokens.append(generation_token) seq_len = seq_data.get_len() position = seq_len - 1 - input_positions.append([position]) + input_positions.append(position) context_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) @@ -306,7 +316,7 @@ def _prepare_decode( block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping.append([slot]) + slot_mapping.append(slot) lora_index_mapping.append([lora_id]) lora_prompt_mapping.append(lora_id) @@ -316,6 +326,9 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + # vLLM uses cuda graph only for decoding requests. + # See `capture_model` API for more details. + # For decoding requests, batch_size == input_tokens. batch_size = len(input_tokens) max_context_len = max(context_lens) use_captured_graph = ( @@ -328,32 +341,32 @@ def _prepare_decode( graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size for _ in range(graph_batch_size - batch_size): - input_tokens.append([]) - input_positions.append([]) - slot_mapping.append([]) + input_tokens.append(0) + input_positions.append(0) + slot_mapping.append(_PAD_SLOT_ID) context_lens.append(1) block_tables.append([]) batch_size = graph_batch_size - input_tokens = _make_tensor_with_pad(input_tokens, - max_len=1, - pad=0, - dtype=torch.long, - device=self.device) - input_positions = _make_tensor_with_pad(input_positions, - max_len=1, - pad=0, - dtype=torch.long, - device=self.device) - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_len=1, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) + input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, + pad=0, + dtype=torch.long, + device=self.device) + input_positions = _make_tensor_with_pad_for_alignment( + input_positions, pad=0, dtype=torch.long, device=self.device) + slot_mapping = _make_tensor_with_pad_for_alignment(slot_mapping, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) + # The first dimension corresponds to number of tokens. + assert context_lens.shape[0] == input_tokens.shape[0] + assert context_lens.shape[0] == input_positions.shape[0] + assert context_lens.shape[0] == slot_mapping.shape[0] + if use_captured_graph: # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -381,6 +394,8 @@ def _prepare_decode( is_prompt=False, slot_mapping=slot_mapping, prompt_lens=None, + num_prompt_tokens=0, + num_generation_tokens=len(input_tokens), max_seq_len=None, start_loc=None, max_context_len=max_context_len, @@ -406,7 +421,6 @@ def _prepare_sample( categorized_sample_indices_start_idx = 0 pin_memory = not self.in_wsl and not self.device_config.is_neuron - max_subquery_len = max(subquery_lens) if subquery_lens else 1 for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params @@ -431,7 +445,7 @@ def _prepare_sample( selected_token_start_idx + subquery_len - 1)) selected_token_indices.append(selected_token_start_idx + subquery_len - 1) - selected_token_start_idx += max_subquery_len + selected_token_start_idx += subquery_len if sampling_params.seed is not None: seq_group_metadata.state.generator = torch.Generator( @@ -520,6 +534,8 @@ def prepare_input_tensors( "is_prompt": input_metadata.is_prompt, "slot_mapping": input_metadata.slot_mapping, "prompt_lens": input_metadata.prompt_lens, + "num_prompt_tokens": input_metadata.num_prompt_tokens, + "num_generation_tokens": input_metadata.num_generation_tokens, "max_seq_len": input_metadata.max_seq_len, "start_loc": input_metadata.start_loc, "max_context_len": input_metadata.max_context_len, @@ -543,6 +559,8 @@ def prepare_input_tensors( is_prompt=metadata_dict["is_prompt"], slot_mapping=metadata_dict["slot_mapping"], prompt_lens=metadata_dict["prompt_lens"], + num_prompt_tokens=metadata_dict["num_prompt_tokens"], + num_generation_tokens=metadata_dict["num_generation_tokens"], max_seq_len=metadata_dict["max_seq_len"], start_loc=metadata_dict["start_loc"], max_context_len=metadata_dict["max_context_len"], @@ -579,6 +597,9 @@ def execute_model( # Execute the model. if input_metadata.use_cuda_graph: + # NOTE: We use cuda graph only when there are only + # decoding requests, which means the number of batch + # size is equivalent to number of input tokens. graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: @@ -680,6 +701,18 @@ def list_loras(self) -> Set[int]: @torch.inference_mode() def capture_model(self, kv_caches: List[KVCache]) -> None: + """Cuda graph capture a model. + + Note that cuda graph performance gain is negligible if number + of batched tokens are less than 200. And since Cuda graph + requires fixed sized tensors, supporting large/variable batch + size requires high GPU memory overhead. Thus, vLLM only captures + decoding requests. Mixed batch (chunked prefill + decoding) or + prefill requests are not captured. + + this API assumes the shape is N batches of tokens flattened to 1D + tensor, where is token's seqlen is 1. + """ # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never # deleted before the CUDA graphs. self.cupy_nccl_backend = cupy_utils.get_nccl_backend() @@ -698,10 +731,9 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() - input_positions = torch.zeros(max_batch_size, 1, - dtype=torch.long).cuda() - slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() + input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() + input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() + slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() @@ -727,6 +759,8 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: is_prompt=False, slot_mapping=slot_mapping[:batch_size], prompt_lens=None, + num_prompt_tokens=0, + num_generation_tokens=0, max_seq_len=None, start_loc=None, max_context_len=self.max_context_len_to_capture, @@ -862,11 +896,32 @@ def _maybe_cupy_nccl(): yield +def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]: + return x + [pad] * ((-len(x)) % multiple_of) + + def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: assert len(x) <= max_len return x + [pad] * (max_len - len(x)) +def _make_tensor_with_pad_for_alignment( + x: List[int], + pad: int, + dtype: torch.dtype, + device: Optional[Union[str, torch.device]], +) -> torch.Tensor: + """Create a tensor of a given list x with padding. + It adds paddings to align with graph batch size. See + _get_graph_batch_size for more details. + # NOTE: This API is only for decoding requests. + """ + batch_size = len(x) + batch_size = _get_graph_batch_size(batch_size) + padded_x = _pad_to_alignment(x, batch_size, pad) + return torch.tensor(padded_x, dtype=dtype, device=device) + + def _make_tensor_with_pad( x: List[List[int]], max_len: int, @@ -879,12 +934,18 @@ def _make_tensor_with_pad( def _get_graph_batch_size(batch_size: int) -> int: + """Returns the padded batch size given actual batch size. + + Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, + 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... + """ if batch_size <= 2: return batch_size elif batch_size <= 4: return 4 else: - return (batch_size + 7) // 8 * 8 + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) def _async_h2d( From 6032edf76fd7c31c3fd5617c13bd7c0b70d227fc Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 6 Mar 2024 08:01:21 -0800 Subject: [PATCH 30/88] ., --- tests/models/test_models.py | 2 +- tests/prompts/example.txt | 9 +-------- vllm/engine/llm_engine.py | 2 +- vllm/model_executor/layers/attention.py | 2 ++ vllm/worker/model_runner.py | 12 +++++++----- vllm/worker/worker.py | 4 ++-- 6 files changed, 14 insertions(+), 17 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index c268b6fd4868..f3f416e63838 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -26,7 +26,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("enforce_eager", [False]) def test_models( hf_runner, vllm_runner, diff --git a/tests/prompts/example.txt b/tests/prompts/example.txt index e1b97bc6eee7..6e8c45b673ee 100644 --- a/tests/prompts/example.txt +++ b/tests/prompts/example.txt @@ -1,8 +1 @@ -vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. -Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020. -Compare and contrast artificial intelligence with human intelligence in terms of processing information. -Describe the basic components of a neural network and how it can be trained. -Write a short story about a robot that dreams for the first time. -Analyze the impact of the COVID-19 pandemic on global economic structures and future business models. -Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies. -Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.' +vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. \ No newline at end of file diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8484014c9a13..4c25602c6e96 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -798,7 +798,7 @@ def _process_model_outputs( # Log stats. if self.log_stats: self.stat_logger.log(self._get_stats(scheduler_outputs)) - + # breakpoint() return request_outputs def step(self) -> List[RequestOutput]: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index bfcdecc8876b..db67577fe881 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -157,6 +157,7 @@ def forward( num_prompt_tokens = input_metadata.num_prompt_tokens num_generation_tokens = input_metadata.num_generation_tokens + print(num_generation_tokens) if num_prompt_tokens > 0: assert num_generation_tokens == 0 @@ -252,6 +253,7 @@ def forward( ) if num_generation_tokens > 0: + breakpoint() assert num_prompt_tokens == 0 # Decoding run. output = _paged_attention( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6c207a1abbe1..fa27a62ceee8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -265,7 +265,7 @@ def _prepare_prompt( num_generation_tokens=0, max_seq_len=max_prompt_len, start_loc=start_loc_tensor, - max_context_len=max(context_lens), + max_context_len=None, context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -344,10 +344,11 @@ def _prepare_decode( input_tokens.append(0) input_positions.append(0) slot_mapping.append(_PAD_SLOT_ID) - context_lens.append(1) + context_lens.append(0) block_tables.append([]) batch_size = graph_batch_size + # Q: should we not pad when cuda graph is disabled? input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, pad=0, dtype=torch.long, @@ -610,6 +611,7 @@ def execute_model( kv_caches=kv_caches, input_metadata=input_metadata, ) + breakpoint() # Sample the next token. output = self.model.sample( @@ -717,7 +719,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # deleted before the CUDA graphs. self.cupy_nccl_backend = cupy_utils.get_nccl_backend() - assert not self.model_config.enforce_eager + # assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " "unexpected consequences if the model is not static. To " "run the model in eager mode, set 'enforce_eager=True' or " @@ -735,7 +737,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) - context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + context_lens = torch.zeros(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() graph_batch_size = _get_graph_batch_size( @@ -876,7 +878,7 @@ def forward( non_blocking=True) self.input_buffers["block_tables"].copy_(input_metadata.block_tables, non_blocking=True) - + breakpoint() # Run the graph. self.graph.replay() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 157e8c45836b..90ab9f421b7f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -153,8 +153,8 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: self.model_runner.set_block_size(self.cache_engine.block_size) def warm_up_model(self) -> None: - if not self.model_config.enforce_eager: - self.model_runner.capture_model(self.gpu_cache) + # if not self.model_config.enforce_eager: + self.model_runner.capture_model(self.gpu_cache) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) From c1ab0b0bedf0e25f3d35c998f5216eb33b4275d1 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 6 Mar 2024 08:06:53 -0800 Subject: [PATCH 31/88] done --- tests/models/test_models.py | 2 +- vllm/engine/llm_engine.py | 1 - vllm/model_executor/layers/attention.py | 53 +++++++++---------------- vllm/worker/model_runner.py | 2 - 4 files changed, 20 insertions(+), 38 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index f3f416e63838..c268b6fd4868 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -26,7 +26,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("enforce_eager", [False, True]) def test_models( hf_runner, vllm_runner, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4c25602c6e96..ea10f3e18897 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -798,7 +798,6 @@ def _process_model_outputs( # Log stats. if self.log_stats: self.stat_logger.log(self._get_stats(scheduler_outputs)) - # breakpoint() return request_outputs def step(self) -> List[RequestOutput]: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index db67577fe881..8c4619a11f1d 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -141,29 +141,17 @@ def forward( # If key_cache and value_cache are not provided, the new key and value # vectors will not be cached. This happens during the initial memory # profiling run. - num_valid_tokens = input_metadata.num_valid_tokens - if (num_valid_tokens > 0 and key_cache is not None - and value_cache is not None): - key_to_cache = key[:num_valid_tokens] - value_to_cache = value[:num_valid_tokens] + if (key_cache is not None and value_cache is not None): cache_ops.reshape_and_cache( - key_to_cache, - value_to_cache, + key, + value, key_cache, value_cache, input_metadata.slot_mapping.flatten(), input_metadata.kv_cache_dtype, ) - num_prompt_tokens = input_metadata.num_prompt_tokens - num_generation_tokens = input_metadata.num_generation_tokens - print(num_generation_tokens) - - if num_prompt_tokens > 0: - assert num_generation_tokens == 0 - query = query[:num_prompt_tokens] - key = key[:num_prompt_tokens] - value = value[:num_prompt_tokens] + if input_metadata.is_prompt: # normal attention if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): @@ -202,7 +190,7 @@ def forward( input_metadata) if self.use_ref_attention: - output[:num_prompt_tokens] = self.ref_masked_attention( + output = self.ref_masked_attention( query, key, value, @@ -222,18 +210,17 @@ def forward( key = key.unflatten(0, (num_tokens)) value = value.unflatten(0, (num_tokens)) - output[: - num_prompt_tokens] = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - op=xops.fmha. - MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ).view_as(query) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + output = out.view_as(query) else: # prefix-enabled attention output = torch.empty_like(query) @@ -252,13 +239,11 @@ def forward( getattr(self, "alibi_slopes", None), ) - if num_generation_tokens > 0: - breakpoint() - assert num_prompt_tokens == 0 + else: # Decoding run. output = _paged_attention( - output[num_prompt_tokens:num_valid_tokens], - query[num_prompt_tokens:num_valid_tokens], + output, + query, key_cache, value_cache, input_metadata, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fa27a62ceee8..2ad4ee534dd5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -611,7 +611,6 @@ def execute_model( kv_caches=kv_caches, input_metadata=input_metadata, ) - breakpoint() # Sample the next token. output = self.model.sample( @@ -878,7 +877,6 @@ def forward( non_blocking=True) self.input_buffers["block_tables"].copy_(input_metadata.block_tables, non_blocking=True) - breakpoint() # Run the graph. self.graph.replay() From f48dc72f5058f661cee56a1f34ccae314f9606d1 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 7 Mar 2024 01:01:39 -0800 Subject: [PATCH 32/88] Addressed code review. --- tests/kernels/test_prefix_prefill.py | 1 - tests/lora/test_worker.py | 2 +- tests/models/test_models.py | 28 ++++++++++++++-------------- tests/prompts/example.txt | 9 ++++++++- tests/worker/test_model_runner.py | 13 +++++++++++-- vllm/worker/model_runner.py | 15 ++++++++------- vllm/worker/worker.py | 4 ++-- 7 files changed, 44 insertions(+), 28 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index c068b38a6691..76dd9bdf3a19 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -113,7 +113,6 @@ def test_contexted_kv_attention( # to V_cache[num_blocks, num_kv_heads, head_size, block_size] v_cache = v_cache.view(-1, block_size, num_kv_heads, head_size).permute(0, 2, 3, 1).contiguous() - # Warm up the Triton kernel by calling it once before actually measuring generation time context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, b_start_loc, b_seq_len, b_ctx_len, max_input_len) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 31a7c716afbf..e4538de35169 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -25,7 +25,7 @@ def test_worker_apply_lora(sql_lora_files): revision=None, ), parallel_config=ParallelConfig(1, 1, False), - scheduler_config=SchedulerConfig(32, 32, 32, 256), + scheduler_config=SchedulerConfig(32, 32, 32), device_config=DeviceConfig("cuda"), local_rank=0, rank=0, diff --git a/tests/models/test_models.py b/tests/models/test_models.py index c268b6fd4868..7bb93a519bf0 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -6,20 +6,20 @@ MODELS = [ "facebook/opt-125m", - # "meta-llama/Llama-2-7b-hf", - # "mistralai/Mistral-7B-v0.1", - # "Deci/DeciLM-7b", - # "tiiuae/falcon-7b", - # "gpt2", - # "bigcode/tiny_starcoder_py", - # "EleutherAI/gpt-j-6b", - # "EleutherAI/pythia-70m", - # "bigscience/bloom-560m", - # "mosaicml/mpt-7b", - # "microsoft/phi-2", - # "stabilityai/stablelm-3b-4e1t", - # "allenai/OLMo-1B", - # "bigcode/starcoder2-3b", + "meta-llama/Llama-2-7b-hf", + "mistralai/Mistral-7B-v0.1", + "Deci/DeciLM-7b", + "tiiuae/falcon-7b", + "gpt2", + "bigcode/tiny_starcoder_py", + "EleutherAI/gpt-j-6b", + "EleutherAI/pythia-70m", + "bigscience/bloom-560m", + "mosaicml/mpt-7b", + "microsoft/phi-2", + "stabilityai/stablelm-3b-4e1t", + "allenai/OLMo-1B", + "bigcode/starcoder2-3b", ] diff --git a/tests/prompts/example.txt b/tests/prompts/example.txt index 6e8c45b673ee..cef4d1d76873 100644 --- a/tests/prompts/example.txt +++ b/tests/prompts/example.txt @@ -1 +1,8 @@ -vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. \ No newline at end of file +vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. +Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020. +Compare and contrast artificial intelligence with human intelligence in terms of processing information. +Describe the basic components of a neural network and how it can be trained. +Write a short story about a robot that dreams for the first time. +Analyze the impact of the COVID-19 pandemic on global economic structures and future business models. +Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies. +Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.' \ No newline at end of file diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 7fed21bc2aa9..6ae0c5eddd04 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -61,10 +61,17 @@ def test_prepare_prompt(): assert torch.allclose( input_metadata.start_loc, torch.tensor(start_loc, dtype=torch.long, device=device)) - assert input_metadata.max_context_len == max(prompt_lens) + assert input_metadata.max_context_len is None + # TODO(sang): The current definition of context_lens is the + # number of k/v that are already cached (before this run). + # It is inconsistent with decoding. assert torch.allclose( input_metadata.context_lens, - torch.tensor(prompt_lens, dtype=torch.int, device=device)) + torch.zeros(input_metadata.context_lens.shape[0], + dtype=torch.int, + device=device)) + + # SANG-TODO # assert input_metadata.slot_mapping == max(prompt_lens) # block_tables # Cuda graph should not be used for prerill. @@ -154,6 +161,8 @@ def round_up_to_next_multiple_of_batch_size(n): assert torch.allclose( input_metadata.context_lens[:len(prompt_lens)], torch.tensor(prompt_lens, dtype=torch.int, device=device)) + + # SANG-TODO # assert input_metadata.slot_mapping == max(prompt_lens) # block_tables # Cuda graph should not be used for prerill. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2ad4ee534dd5..04713914497f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -109,8 +109,7 @@ def load_model(self) -> None: ), "Model does not have embedding_padding_modules" self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens + - self.scheduler_config.max_paddings, vocab_size, + self.scheduler_config.max_num_batched_tokens, vocab_size, self.lora_config, self.device, self.model.embedding_modules, self.model.embedding_padding_modules) self.model = self.lora_manager.create_lora_manager(self.model) @@ -163,7 +162,7 @@ def _prepare_prompt( context_len = computed_len else: prefix_block_tables.append([]) - context_len = prompt_len + context_len = 0 # actual prompt lens context_lens.append(context_len) subquery_lens.append(prompt_len - computed_len) @@ -363,10 +362,12 @@ def _prepare_decode( dtype=torch.int, device=self.device) - # The first dimension corresponds to number of tokens. - assert context_lens.shape[0] == input_tokens.shape[0] - assert context_lens.shape[0] == input_positions.shape[0] - assert context_lens.shape[0] == slot_mapping.shape[0] + if use_captured_graph: + # When using cuda-graph all these tensors should be + # padded. + assert context_lens.shape[0] == input_tokens.shape[0] + assert context_lens.shape[0] == input_positions.shape[0] + assert context_lens.shape[0] == slot_mapping.shape[0] if use_captured_graph: # The shape of graph_block_tables is diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 90ab9f421b7f..157e8c45836b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -153,8 +153,8 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: self.model_runner.set_block_size(self.cache_engine.block_size) def warm_up_model(self) -> None: - # if not self.model_config.enforce_eager: - self.model_runner.capture_model(self.gpu_cache) + if not self.model_config.enforce_eager: + self.model_runner.capture_model(self.gpu_cache) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) From 769b2b491939f9e461486fecd2ac97e80e15eb0c Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 7 Mar 2024 01:14:49 -0800 Subject: [PATCH 33/88] working --- vllm/model_executor/input_metadata.py | 9 +++++++++ vllm/model_executor/layers/attention.py | 9 +++++---- vllm/worker/model_runner.py | 14 ++++++-------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index e7d45a6cbf3f..be4f71445276 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -45,6 +45,13 @@ def __init__( self.start_loc = start_loc self.max_context_len = max_context_len self.slot_mapping = slot_mapping + # Index: The batched sequence's index. + # Value: The length of attention context. + # NOTE(sang): When it is prefill/decoding, + # the definition is different. For prefill, + # it means the the length of KV that are cached + # excluding the new KVs. In decoding, this + # includes a new KV. self.context_lens = context_lens self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph @@ -53,6 +60,8 @@ def __init__( # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. self.attn_bias = None + # Number of valid tokens. It includes paddings. + # See attention.py for precise definition. self.num_valid_tokens = slot_mapping.shape[0] def __repr__(self) -> str: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 8c4619a11f1d..a6fea54e9815 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -242,7 +242,6 @@ def forward( else: # Decoding run. output = _paged_attention( - output, query, key_cache, value_cache, @@ -289,15 +288,17 @@ def _make_alibi_bias( def _paged_attention( - output: torch.Tensor, # [num_tokens, num_heads, head_size] query: torch.Tensor, # [num_tokens, num_heads, head_size] - key_cache: torch.Tensor, - value_cache: torch.Tensor, + key_cache: torch. + Tensor, # [num_total_blocks, block_size, num_heads, head_size] + value_cache: torch. + Tensor, # [num_total_blocks, block_size, num_heads, head_size] input_metadata: InputMetadata, num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], ) -> torch.Tensor: + output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = ( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 04713914497f..99a5602ca967 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -31,8 +31,6 @@ _BATCH_SIZE_ALIGNMENT = 8 # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. -# Note that cuda graph is only used for decoding because it speeds up -# the performance when num_tokens < 200. Batch here means a single token. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) ] @@ -215,6 +213,9 @@ def _prepare_prompt( max_prompt_len = max(subquery_lens) num_prompt_tokens = len(input_tokens) + + # Pad tokens to better utilize tensor cores although + # cuda graph is not enabled. input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, pad=0, dtype=torch.long, @@ -347,7 +348,8 @@ def _prepare_decode( block_tables.append([]) batch_size = graph_batch_size - # Q: should we not pad when cuda graph is disabled? + # Pad tokens to better utilize tensor cores although + # cuda graph is not enabled. input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, pad=0, dtype=torch.long, @@ -599,9 +601,6 @@ def execute_model( # Execute the model. if input_metadata.use_cuda_graph: - # NOTE: We use cuda graph only when there are only - # decoding requests, which means the number of batch - # size is equivalent to number of input tokens. graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: @@ -719,7 +718,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # deleted before the CUDA graphs. self.cupy_nccl_backend = cupy_utils.get_nccl_backend() - # assert not self.model_config.enforce_eager + assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " "unexpected consequences if the model is not static. To " "run the model in eager mode, set 'enforce_eager=True' or " @@ -915,7 +914,6 @@ def _make_tensor_with_pad_for_alignment( """Create a tensor of a given list x with padding. It adds paddings to align with graph batch size. See _get_graph_batch_size for more details. - # NOTE: This API is only for decoding requests. """ batch_size = len(x) batch_size = _get_graph_batch_size(batch_size) From f7347b8c704a587040873496c5398d0683b9decd Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 7 Mar 2024 06:42:08 -0800 Subject: [PATCH 34/88] working --- vllm/model_executor/input_metadata.py | 5 +++ .../layers/attention/attention.py | 7 ++-- .../layers/attention/backends/flash_attn.py | 42 +++++++++---------- .../layers/attention/backends/xformers.py | 1 + 4 files changed, 29 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index be4f71445276..d992f6b7d207 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -6,6 +6,11 @@ class InputMetadata: """Metadata for input sequences. Used in PagedAttention. + NOTE: Any python object stored here is not updated when it is + cuda-graph replays. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + Args: prompt_lens: Lengths of prompts. slot_mapping: The index of each token mapped into a physical block diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 4d288dbb81e7..58f1bbd99aad 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -5,7 +5,7 @@ import torch.nn as nn from vllm.model_executor.input_metadata import InputMetadata -from vllm.utils import is_hip +# from vllm.utils import is_hip class Attention(nn.Module): @@ -46,8 +46,9 @@ def __init__( sliding_window: Optional[int] = None, ) -> None: super().__init__() - if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and - torch.get_default_dtype() in (torch.float16, torch.bfloat16)): + # if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and + # torch.get_default_dtype() in (torch.float16, torch.bfloat16)): + if False: # Ampere or later NVIDIA GPUs. # NOTE(woosuk): FlashAttention does not support FP32. from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index 512f4e49c7eb..20a3a072bb3b 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -53,18 +53,18 @@ def forward( """Forward pass with FlashAttention and PagedAttention. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] input_metadata: metadata for the inputs. Returns: - shape = [batch_size, seq_len, num_heads * head_size] + shape = [num_tokens, num_heads * head_size] """ - batch_size, seq_len, hidden_size = query.shape + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -74,27 +74,25 @@ def forward( # If key_cache and value_cache are not provided, the new key and value # vectors will not be cached. This happens during the initial memory # profiling run. - if key_cache is not None and value_cache is not None: - PagedAttentionImpl.reshape_and_cache(key, value, key_cache, - value_cache, input_metadata) + PagedAttentionImpl.reshape_and_cache(key, value, key_cache, + value_cache, input_metadata) if input_metadata.is_prompt: # Prompt run. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # normal attention - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) - output = flash_attn_func( - query, - key, - value, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) + output = torch.empty_like(query) + query = query.unflatten(0, (num_tokens, )) + key = key.unflatten(0, (num_tokens, )) + value = value.unflatten(0, (num_tokens, )) + output = flash_attn_func(query, + key, + value, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes) else: # prefix-enabled attention output = PagedAttentionImpl.forward_prefix( @@ -104,8 +102,6 @@ def forward( key_cache, value_cache, input_metadata, - self.num_heads, - self.num_kv_heads, self.alibi_slopes, ) else: @@ -121,4 +117,4 @@ def forward( ) # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + return output.view(num_tokens, hidden_size) diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/model_executor/layers/attention/backends/xformers.py index 5206959de333..269b73303fb0 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/model_executor/layers/attention/backends/xformers.py @@ -83,6 +83,7 @@ def forward( if input_metadata.is_prompt: # Prompt run. + # Unless there's a prefix, context lens is all 0 for prefill. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # normal attention From f91d73e40e4e60143eeddc904ef080d88aea4b07 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 7 Mar 2024 23:04:29 -0800 Subject: [PATCH 35/88] fix lora --- vllm/worker/model_runner.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d6be22ec89fb..060828346c3f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -182,7 +182,7 @@ def _prepare_prompt( if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping.append([lora_id] * (prompt_len - computed_len)) + lora_index_mapping += [lora_id] * (prompt_len - computed_len) lora_prompt_mapping.extend( [lora_id] * (prompt_len - computed_len @@ -232,10 +232,11 @@ def _prepare_prompt( pad=_PAD_SLOT_ID, dtype=torch.long, device=self.device) - lora_index_mapping = [ - _pad_to_max(mapping, max_prompt_len, pad=0) - for mapping in lora_index_mapping - ] + lora_index_mapping = _pad_to_alignment(lora_index_mapping, + _get_graph_batch_size( + len(lora_index_mapping)), + pad=0) + context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -323,7 +324,7 @@ def _prepare_decode( block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - lora_index_mapping.append([lora_id]) + lora_index_mapping.append(lora_id) lora_prompt_mapping.append(lora_id) if self.sliding_window is not None: @@ -396,9 +397,10 @@ def _prepare_decode( device=self.device, ) - lora_index_mapping = [ - _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping - ] + lora_index_mapping = _pad_to_alignment(lora_index_mapping, + _get_graph_batch_size( + len(lora_index_mapping)), + pad=0) input_metadata = InputMetadata( is_prompt=False, @@ -527,11 +529,8 @@ def prepare_input_tensors( subquery_lens) if self.lora_config: - flat_lora_index_mapping = [ - item for sublist in lora_index_mapping for item in sublist - ] lora_mapping = LoRAMapping( - flat_lora_index_mapping, + lora_index_mapping, lora_prompt_mapping, ) else: From f7d79dad9b61de77f0d1df5f20f86752adbb7f58 Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 8 Mar 2024 00:06:10 -0800 Subject: [PATCH 36/88] fixed --- .../spec_decode/test_multi_step_worker.py | 6 +++--- tests/worker/test_model_runner.py | 18 ++++++++++-------- vllm/worker/model_runner.py | 9 ++++++--- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/tests/worker/spec_decode/test_multi_step_worker.py b/tests/worker/spec_decode/test_multi_step_worker.py index ea5480290357..15a0546813eb 100644 --- a/tests/worker/spec_decode/test_multi_step_worker.py +++ b/tests/worker/spec_decode/test_multi_step_worker.py @@ -90,8 +90,8 @@ def test_same_output_for_single_step(): num_gpu_blocks, seed, ) - multi_step_worker.model_runner = worker.model_runner - multi_step_worker.cache_engine = worker.cache_engine + # multi_step_worker.model_runner = worker.model_runner + # multi_step_worker.cache_engine = worker.cache_engine num_steps = 1 @@ -258,4 +258,4 @@ def test_same_output_for_multi_step(): for multi_step_logprobs, single_step_logprobs in zip( multi_step_output_logprobs, single_step_output_logprobs): assert_logprobs_dict_allclose(multi_step_logprobs, - single_step_logprobs) + single_step_logprobs) \ No newline at end of file diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 6ae0c5eddd04..514eedd9a415 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -19,6 +19,7 @@ def test_prepare_prompt(): batch_size = random.randint(1, 256) prompt_lens = [] seq_group_metadata_list = [] + block_tables = {0: [1]} for i in range(batch_size): # make sure all tokens fit into one block prompt_len = i % (model_runner.block_size - 1) + 1 @@ -30,7 +31,7 @@ def test_prepare_prompt(): is_prompt=True, seq_data={0: SequenceData(seq_data)}, sampling_params=SamplingParams(temperature=0), - block_tables={0: [1]}, + block_tables=block_tables, )) expected_selected_token_indices = [] @@ -70,10 +71,7 @@ def test_prepare_prompt(): torch.zeros(input_metadata.context_lens.shape[0], dtype=torch.int, device=device)) - - # SANG-TODO - # assert input_metadata.slot_mapping == max(prompt_lens) - # block_tables + assert input_metadata.block_tables is None # Cuda graph should not be used for prerill. assert input_metadata.use_cuda_graph is False assert input_metadata.kv_cache_dtype == "auto" @@ -162,9 +160,13 @@ def round_up_to_next_multiple_of_batch_size(n): input_metadata.context_lens[:len(prompt_lens)], torch.tensor(prompt_lens, dtype=torch.int, device=device)) - # SANG-TODO - # assert input_metadata.slot_mapping == max(prompt_lens) - # block_tables + # block table's first index corresponds to each batch, meaning in + # decoding it is each token. + assert input_metadata.block_tables.shape[0] == len(input_tokens) + # Block table's second dim correspondsd to each token's block number. + # It is padded up to + assert input_metadata.block_tables.shape[1] == ( + model_runner.get_max_block_per_batch()) # Cuda graph should not be used for prerill. assert input_metadata.use_cuda_graph is True assert input_metadata.kv_cache_dtype == "auto" diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 060828346c3f..2edb21c4900f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -121,10 +121,13 @@ def load_model(self) -> None: def set_block_size(self, block_size: int) -> None: self.block_size = block_size - max_num_blocks = (self.max_context_len_to_capture + block_size - - 1) // block_size self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + dtype=np.int32) + + def get_max_block_per_batch(self): + block_size = self.block_size + return (self.max_context_len_to_capture + block_size - 1) // block_size def _prepare_prompt( self, From 406f1d40cde7069233c2cde0ac8cfffc50b2c81e Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 8 Mar 2024 00:38:30 -0800 Subject: [PATCH 37/88] fix --- tests/worker/test_model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 514eedd9a415..55a078230b46 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -71,7 +71,11 @@ def test_prepare_prompt(): torch.zeros(input_metadata.context_lens.shape[0], dtype=torch.int, device=device)) - assert input_metadata.block_tables is None + + expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.allclose(input_metadata.block_tables, expected) # Cuda graph should not be used for prerill. assert input_metadata.use_cuda_graph is False assert input_metadata.kv_cache_dtype == "auto" From c067a4cc4633ab053d6ae6b333c38250bd186700 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 11 Mar 2024 04:46:41 -0700 Subject: [PATCH 38/88] working. --- tests/models/test_models.py | 6 +++--- vllm/config.py | 1 + vllm/model_executor/input_metadata.py | 2 +- vllm/worker/model_runner.py | 4 +++- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index c268b6fd4868..b52c27d951b7 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -5,12 +5,12 @@ import pytest MODELS = [ - "facebook/opt-125m", + # "facebook/opt-125m", # "meta-llama/Llama-2-7b-hf", # "mistralai/Mistral-7B-v0.1", # "Deci/DeciLM-7b", # "tiiuae/falcon-7b", - # "gpt2", + "gpt2", # "bigcode/tiny_starcoder_py", # "EleutherAI/gpt-j-6b", # "EleutherAI/pythia-70m", @@ -26,7 +26,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("enforce_eager", [False]) def test_models( hf_runner, vllm_runner, diff --git a/vllm/config.py b/vllm/config.py index 27d89bba3775..c0de0387180b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -80,6 +80,7 @@ def __init__( quantization: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, + max_logprobs: int = 5, flash_style: bool = False, ) -> None: self.model = model diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index bc6a028f2c21..c8216e95e848 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -48,7 +48,7 @@ def __init__( block_tables: Optional[torch.Tensor], use_cuda_graph: bool, kv_cache_dtype: str, - self.flash_style: bool, + flash_style: bool, ) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2659b3120037..7303cfeb42f6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -197,7 +197,7 @@ def _prepare_prompt( # NOTE(sang): prefill_end is always # of prompts if chunked # prefill is not enabled. Prefix caching is not working with # chunked prefill now. - input_positions.append( + input_positions.extend( list(range(computed_len, computed_len + prefill_end))) lora_id = seq_group_metadata.lora_int_id @@ -309,6 +309,7 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, + flash_style=self.flash_style, ) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, @@ -448,6 +449,7 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, + flash_style=self.flash_style, ) return (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) From e1f244a7ce2a0391425dbc3d1d3f0b4ea5c03137 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 11 Mar 2024 05:08:53 -0700 Subject: [PATCH 39/88] clean up. --- benchmarks/benchmark_latency.py | 4 - .../kernels/benchmark_paged_attention.py | 31 +- csrc/cache.h | 7 - csrc/cache_kernels.cu | 86 --- csrc/pybind.cpp | 4 - requirements.txt | 1 - tests/conftest.py | 4 +- tests/core/utils.py | 2 +- tests/kernels/test_cache.py | 71 +- tests/kernels/test_flash_attention.py | 633 ------------------ vllm/config.py | 22 - vllm/engine/arg_utils.py | 11 +- vllm/model_executor/input_metadata.py | 48 -- .../layers/attention/ops/prefix_prefill.py | 4 +- vllm/model_executor/model_loader.py | 2 +- vllm/model_executor/models/__init__.py | 8 +- vllm/utils.py | 11 +- vllm/worker/cache_engine.py | 36 +- vllm/worker/model_runner.py | 17 +- vllm/worker/worker.py | 9 - 20 files changed, 40 insertions(+), 971 deletions(-) delete mode 100644 tests/kernels/test_flash_attention.py diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 9a07a583a6ad..e9e2a883db83 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -34,7 +34,6 @@ def main(args: argparse.Namespace): kv_cache_dtype=args.kv_cache_dtype, device=args.device, block_size=args.block_size, - flash_style=args.flash_style, max_chunked_prefill_len=args.max_chunked_prefill_len, max_num_prompt_seqs=args.max_num_prompt_seqs, ray_workers_use_nsight=args.ray_workers_use_nsight, @@ -172,9 +171,6 @@ def run_to_completion(profile_dir: Optional[str] = None): default="cuda", choices=["cuda"], help='device type for vLLM execution, supporting CUDA only currently.') - parser.add_argument('--flash-style', - action='store_true', - help='enable flash attention') parser.add_argument('--block-size', type=int, default=16, diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 7b3c6405e259..a39a0fb0b2bb 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -65,17 +65,14 @@ def main( block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) # Create the KV cache. - flash_style = version == "flash" - key_caches, value_caches = create_kv_caches_with_random( - NUM_BLOCKS, - block_size, - 1, - num_kv_heads, - head_size, - kv_cache_dtype, - dtype, - device=device, - flash_style=flash_style) + key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + device=device) key_cache, value_cache = key_caches[0], value_caches[0] # Prepare for the paged attention kernel. @@ -135,16 +132,6 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: alibi_slopes, kv_cache_dtype, ) - elif version == "flash": - flash_attn_with_kvcache_paged( - query.view(num_seqs, 1, num_query_heads, head_size), - key_cache, - value_cache, - scale, - block_tables, - context_lens, - alibi_slopes=alibi_slopes, - ) else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() @@ -172,7 +159,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: description="Benchmark the paged attention kernel.") parser.add_argument("--version", type=str, - choices=["v1", "v2", "flash"], + choices=["v1", "v2"], default="v2") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--context-len", type=int, default=4096) diff --git a/csrc/cache.h b/csrc/cache.h index 1bca7e4e39a9..765e231abd26 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -23,13 +23,6 @@ void reshape_and_cache( torch::Tensor& slot_mapping, const std::string& kv_cache_dtype); -void reshape_and_cache_flash( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping); - // Just for unittest void convert_fp8_e5m2( torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 06da455117b0..76709202b469 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -270,92 +270,6 @@ void reshape_and_cache( namespace vllm { -// flash-attention style cache funciton where key/value caches -// has the same shape of [num_blocks, block_size, num_heads, head_size] -template -__global__ void reshape_and_cache_flash_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, head_size] - scalar_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, head_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size) { - const int64_t token_idx = blockIdx.x; - const int64_t slot_idx = slot_mapping[token_idx]; - if (slot_idx < 0) { - // Padding token that should be ignored. - return; - } - const int64_t block_idx = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - - const int n = num_heads * head_size; - for (int i = threadIdx.x; i < n; i += blockDim.x) { - const int64_t src_key_idx = token_idx * key_stride + i; - const int64_t src_value_idx = token_idx * value_stride + i; - - const int head_idx = i / head_size; - const int head_offset = i % head_size; - - // Target idx for kv are the same. - const int64_t tgt_idx = block_idx * block_size * num_heads * head_size - + block_offset * num_heads * head_size - + head_idx * head_size - + head_offset; - - scalar_t tgt_key = key[src_key_idx]; - scalar_t tgt_value = value[src_value_idx]; - // TODO(sang): Support ENABLE_FP8_E5M2. - key_cache[tgt_idx] = tgt_key; - value_cache[tgt_idx] = tgt_value; - } -} - -} // namespace vllm - -// TODO(sang): Support kv_cache_dtype -void reshape_and_cache_flash( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping) -{ - int num_tokens_padded = key.size(0); - int num_heads = key.size(1); - int head_size = key.size(2); - int block_size = key_cache.size(1); - - int key_stride = key.stride(0); - int value_stride = value.stride(0); - - dim3 grid(num_tokens_padded); - dim3 block(std::min(num_heads * head_size, 512)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), - "reshape_and_cache_flash_kernel", - [&] { - vllm::reshape_and_cache_flash_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - slot_mapping.data_ptr(), - key_stride, - value_stride, - num_heads, - head_size, - block_size); - }); -} - -namespace vllm { - template __global__ void convert_fp8_e5m2_kernel( const Tin* __restrict__ src_cache, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a9ff2e8f5830..4b6ade756639 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -81,10 +81,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); - cache_ops.def( - "reshape_and_cache_flash", - &reshape_and_cache_flash, - "Reshape the flash-style key and value tensors and cache them."); cache_ops.def( "convert_fp8_e5m2", &convert_fp8_e5m2, diff --git a/requirements.txt b/requirements.txt index 7aff2658f9d9..05ec2e804e13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,3 @@ pynvml == 11.5.0 triton >= 2.1.0 outlines >= 0.0.27 cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. -flash-attn >= 2.5.0 # Required for chunked prefill. diff --git a/tests/conftest.py b/tests/conftest.py index 42c80b00d89a..463053812168 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -165,8 +165,7 @@ def __init__( dtype: str = "half", disable_log_stats: bool = True, tensor_parallel_size: int = 1, - block_size: int = 32, - flash_style: bool = False, + block_size: int = 16, max_chunked_prefill_len: int = -1, max_num_prompt_seqs: int = 1000, max_num_batched_tokens: int = 4096, @@ -180,7 +179,6 @@ def __init__( swap_space=0, disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, - flash_style=flash_style, block_size=block_size, max_chunked_prefill_len=max_chunked_prefill_len, max_num_prompt_seqs=max_num_prompt_seqs, diff --git a/tests/core/utils.py b/tests/core/utils.py index db79e24535b1..4074fc0b24d1 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -17,7 +17,7 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) - seq_group = SequenceGroup(request_id, [prompt], SamplingParams() + seq_group = SequenceGroup(request_id, [prompt], SamplingParams(), time.time(), None) return prompt, seq_group diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index d2de4105b7f3..9716094a3e81 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -224,73 +224,4 @@ def test_swap_blocks( dist_key_caches[0][dst].cpu()) assert torch.allclose(src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()) - - -@pytest.mark.parametrize("num_tokens", NUM_TOKENS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@torch.inference_mode() -def test_reshape_and_cache_flash( - kv_cache_factory, - num_tokens: int, - num_heads: int, - head_size: int, - block_size: int, - num_blocks: int, - dtype: torch.dtype, - seed: int, -) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - # Create a random slot mapping. - num_slots = block_size * num_blocks - slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') - - qkv = torch.randn(num_tokens, - 3, - num_heads, - head_size, - dtype=dtype, - device='cuda') - _, key, value = qkv.unbind(dim=1) - - # Create the KV caches. - key_caches, value_caches = kv_cache_factory(num_blocks, - block_size, - 1, - num_heads, - head_size, - dtype, - seed, - flash_style=True) - assert len(key_caches) == 1 and len(value_caches) == 1 - key_cache, value_cache = key_caches[0], value_caches[0] - - # Clone the KV caches. - cloned_key_cache = key_cache.clone() - cloned_value_cache = value_cache.clone() - - # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping) - - # Run the reference implementation. - block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') - block_indicies = block_indicies.cpu().tolist() - block_offsets = slot_mapping % block_size - block_offsets = block_offsets.cpu().tolist() - for i in range(num_tokens): - block_idx = block_indicies[i] - block_offset = block_offsets[i] - cloned_key_cache[block_idx, block_offset, :, :] = key[i] - cloned_value_cache[block_idx, block_offset, :, :] = value[i] - - assert torch.allclose(key_cache, cloned_key_cache) - assert torch.allclose(value_cache, cloned_value_cache) + \ No newline at end of file diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py deleted file mode 100644 index 66a668cb7dd5..000000000000 --- a/tests/kernels/test_flash_attention.py +++ /dev/null @@ -1,633 +0,0 @@ -import random -from typing import List, Optional, Tuple - -import pytest -import torch -import torch.nn.functional as F - -from vllm.model_executor.layers.attention import ( - flash_single_query_cached_kv_attention, - flash_multi_query_cached_kv_attention_varlen, -) -from vllm.utils import get_max_shared_memory_bytes - -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 -# This will change depending on the compute capability. -# - 512 as a buffer -MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -NUM_BLOCKS = 128 # Arbitrary values for testing -PARTITION_SIZE = 512 - -DTYPES = [torch.half, torch.bfloat16] -NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing -NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing -NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing -NUM_HEADS_SMALL = NUM_HEADS -HEAD_SIZES = [32, 64, 128, 256] -BLOCK_SIZES = [32, 64, 256] -USE_ALIBI = [False, True] -SEEDS = [0] -PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] - - -def pad_attention_inputs( - pad_config: Tuple[int, int], - block_size: int, - query: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - max_context_len: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: - """Pad the attention inputs to the specified batch size and context length. - """ - pad_batch_size, pad_max_context_len = pad_config - if pad_batch_size == 0: - return query, block_tables, context_lens, max_context_len - target_batch_size = ( - (query.shape[0] - 1) % pad_batch_size + 1) * pad_batch_size - target_block_size = pad_max_context_len // block_size + 1 - padded_query = F.pad(query, - (0, 0, 0, 0, 0, target_batch_size - query.shape[0])) - padded_block_table = F.pad(block_tables, - (0, target_block_size - block_tables.shape[1], - 0, target_batch_size - block_tables.shape[0])) - padded_context_lens = F.pad(context_lens, - (0, target_batch_size - context_lens.shape[0])) - return padded_query, padded_block_table, padded_context_lens, pad_max_context_len - - -def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - attn_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() - if attn_mask is not None: - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out - - -def ref_single_query_cached_kv_attention( - output: torch.Tensor, - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - scale: float, - alibi_slopes: Optional[torch.Tensor], - flash_style: bool = False, -) -> None: - num_query_heads = query.shape[1] - num_kv_heads = value_cache.shape[-2] - head_size = value_cache.shape[-1] - block_size = value_cache.shape[-3] - num_seqs = query.shape[0] - - block_tables = block_tables.cpu().tolist() - context_lens = context_lens.cpu().tolist() - for i in range(num_seqs): - q = query[i].unsqueeze(0) - block_table = block_tables[i] - context_len = int(context_lens[i]) - - keys = [] - values = [] - for j in range(context_len): - block_number = int(block_table[j // block_size]) - block_offset = j % block_size - - if flash_style: - k = key_cache[block_number, block_offset, :, :] - else: - k = key_cache[block_number, :, :, block_offset, :] - k = k.reshape(num_kv_heads, head_size) - keys.append(k) - - if flash_style: - v = value_cache[block_number, block_offset, :, :] - else: - v = value_cache[block_number, :, :, block_offset] - values.append(v) - keys = torch.stack(keys, dim=0) - values = torch.stack(values, dim=0) - if num_queries_per_kv > 1: - # Handle MQA and GQA - keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) - values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - - alibi_bias = None - if alibi_slopes is not None: - # Create the ALiBi bias used in the paged attention kernel. - position_ids = torch.arange(context_len, device="cuda").int() - alibi_bias = (position_ids - context_len + 1).float() - alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( - 1, 1, -1) - - out = ref_masked_attention(q, keys, values, scale, alibi_bias) - out = out.view(num_query_heads, head_size) - output[i].copy_(out, non_blocking=True) - - -@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("use_alibi", [False]) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("pad_config", [(0, 0)]) -@torch.inference_mode() -def test_flash_paged_attention( - kv_cache_factory, - num_seqs: int, - num_heads: Tuple[int, int], - head_size: int, - use_alibi: bool, - block_size: int, - dtype: torch.dtype, - seed: int, - pad_config: Tuple[int, int], -) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, - num_query_heads, - head_size, - dtype=dtype, - device="cuda") - query.uniform_(-scale, scale) - - assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - alibi_slopes = None - if use_alibi: - alibi_slopes = torch.randn(num_query_heads, - dtype=torch.float, - device="cuda") - - max_seq_len = MAX_SEQ_LEN if not pad_config[0] else (pad_config[1] - 1000) - context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] - context_lens[-1] = max_seq_len - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") - - # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - - # Create the KV caches. - key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, - block_size, - 1, - num_kv_heads, - head_size, - dtype, - seed, - flash_style=True) - key_cache, value_cache = key_caches[0], value_caches[0] - - # Call the paged attention kernel. - output = torch.empty_like(query) - - padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ - pad_attention_inputs(pad_config, block_size, query, - block_tables, context_lens, max_context_len) - - flash_single_query_cached_kv_attention( - output, - padded_query.view(num_seqs, 1, num_query_heads, head_size), - key_cache, - value_cache, - scale, - padded_block_table, - padded_context_lens, - alibi_slopes, - ) - - # Run the reference implementation. - ref_output = torch.empty_like(query) - ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - context_lens, - scale, - alibi_slopes, - flash_style=True, - ) - - assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) - - -def ref_multi_query_kv_attention( - cu_seq_lens: List[int], - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - dtype: torch.dtype, -) -> torch.Tensor: - num_seqs = len(cu_seq_lens) - 1 - ref_outputs = [] - for i in range(num_seqs): - start_idx = cu_seq_lens[i] - end_idx = cu_seq_lens[i + 1] - seq_len = end_idx - start_idx - - # Create attention mask. - attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), - diagonal=1) - attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype, device="cuda") - - ref_output = ref_masked_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - scale, - attn_mask=attn_mask, - ) - ref_outputs.append(ref_output) - ref_output = torch.cat(ref_outputs, dim=0) - return ref_output - - -def ref_multi_query_kv_attention_padded( - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - cu_seq_lens: List[int], - context_lens: List[int], - scale: float, - dtype: torch.dtype, -) -> torch.Tensor: - num_seqs = len(cu_seq_lens) - 1 - block_size = value_cache.shape[-3] - ref_outputs = [] - - for i in range(num_seqs): - q_start_idx = cu_seq_lens[i] - q_end_idx = cu_seq_lens[i + 1] - seq_len = q_end_idx - q_start_idx - - context_len = context_lens[i] - - block_table = block_tables[i] - keys = [] - values = [] - - for j in range(context_len): - block_number = int(block_table[j // block_size]) - block_offset = j % block_size - - k = key_cache[block_number, block_offset, :, :] - keys.append(k) - - v = value_cache[block_number, block_offset, :, :] - values.append(v) - - keys = torch.stack(keys, dim=0) - values = torch.stack(values, dim=0) - - if num_queries_per_kv > 1: - # Handle MQA and GQA - keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) - values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - - q = query[q_start_idx:q_end_idx, :, :] - k = keys[:context_len, :, :] - v = values[:context_len, :, :] - - assert seq_len <= context_len - - # pad q if seq_len is less than context_len - # this is for correct calculation of attention. - if seq_len < context_len: - indices = [i % seq_len for i in range(context_len - seq_len)] - q_left_pad = q[indices, :, :] - q = torch.cat([q_left_pad, q], dim=0) - - # Create attention mask. - attn_mask = torch.triu(torch.ones(context_len, - context_len, - dtype=dtype), - diagonal=1) - attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype, device="cuda") - - ref_output = ref_masked_attention( - q, - k, - v, - scale, - attn_mask=attn_mask, - ) - ref_outputs.append(ref_output[-seq_len:, :, :]) - ref_output = torch.cat(ref_outputs, dim=0) - return ref_output - - -def is_a100(): - return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 - - -if not is_a100(): - NUM_HEADS_SMALL = [(16, 16), (16, 8)] - MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) - - -@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS_SMALL) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("version", ["flash"]) -@pytest.mark.parametrize("chunked_prefill", [False, True]) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@torch.inference_mode() -def test_multi_query_kv_attention( - num_seqs: int, - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - version: str, - seed: int, - chunked_prefill: bool, - block_size: int, -) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. - # As the xformers library is already tested with its own tests, we can use - # a smaller MAX_SEQ_LEN here. - max_len = min(MAX_SEQ_LEN, 4096) - - seq_lens = [random.randint(1, max_len // 2) for i in range(num_seqs)] - max_seq_len = max(seq_lens) - - if chunked_prefill: - # context length will be different from seq_len if chunked_prefill is - # true. - context_lens = random.sample(range(max_seq_len, max_len), num_seqs) - else: - context_lens = seq_lens - max_context_len = max(context_lens) - - num_tokens = sum(seq_lens) - cu_seq_lens = [0] - for seq_len in seq_lens: - cu_seq_lens.append(cu_seq_lens[-1] + seq_len) - - cu_context_lens = [0] - for context_len in context_lens: - cu_context_lens.append(cu_context_lens[-1] + context_len) - - print(f"cu_seq_lens={cu_seq_lens}, cu_context_lens={cu_context_lens}") - - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - num_queries_per_kv = num_query_heads // num_kv_heads - - value_cache = torch.empty(NUM_BLOCKS, - block_size, - num_kv_heads, - head_size, - dtype=dtype, - device="cuda") - key_cache = torch.empty(NUM_BLOCKS, - block_size, - num_kv_heads, - head_size, - dtype=dtype, - device="cuda") - query = torch.empty(num_tokens, - num_query_heads, - head_size, - dtype=dtype, - device="cuda") - value_cache.uniform_(-scale, scale) - key_cache.uniform_(-scale, scale) - query.uniform_(-scale, scale) - - # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - - output = torch.empty_like(query) - - if version == "flash": - flash_multi_query_cached_kv_attention_varlen( - output, - query, - key_cache, - value_cache, - scale, - block_tables, - torch.cuda.IntTensor(cu_seq_lens), - torch.cuda.IntTensor(cu_context_lens), - block_size, - max_seq_len, - max_context_len, - None, - ) - else: - raise AssertionError(f"{version=} is not supported") - - ref_output = ref_multi_query_kv_attention_padded( - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - cu_seq_lens, - context_lens, - scale, - dtype, - ) - assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) - - -# @pytest.mark.parametrize("num_heads", [40]) -# @pytest.mark.parametrize("head_size", [64]) -# @pytest.mark.parametrize("block_size", [32]) -# @pytest.mark.parametrize("num_blocks", [128]) -# @pytest.mark.parametrize("dtype", [torch.half]) -# @pytest.mark.parametrize("seed", SEEDS) -# @torch.inference_mode() -# def test_e2e( -# kv_cache_factory, -# num_heads: int, -# head_size: int, -# block_size: int, -# num_blocks: int, -# dtype: torch.dtype, -# seed: int, -# ) -> None: -# from vllm._C import cache_ops -# batch_size = 2 -# seqlen = 29 -# num_tokens = batch_size * seqlen -# random.seed(seed) -# torch.random.manual_seed(seed) -# torch.cuda.manual_seed(seed) - -# block_tables = [] -# for _ in range(batch_size): -# block_table = [ -# random.randint(0, NUM_BLOCKS - 1) -# ] -# block_tables.append(block_table) -# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") -# # Create a random slot mapping. -# slot_mapping = [] -# for i in range(0, 29): -# block_number = int(block_tables[i // block_size]) -# block_offset = i % block_size -# slot = block_number * block_size + block_offset -# slot_mapping.append(slot) -# # for _ in range(23, 29): -# # slot_mapping.append(-1) -# for i in range(0, 29): -# block_number = int(block_tables[1]) -# block_offset = i % block_size -# slot = block_number * block_size + block_offset -# slot_mapping.append(slot) -# # slot_mapping = random.sample(range(num_slots), num_tokens) -# slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') - -# qkv = torch.randn(num_tokens, -# 3, -# num_heads, -# head_size, -# dtype=dtype, -# device='cuda') -# query, key, value = qkv.unbind(dim=1) -# # query[23:29] = 0 -# # key[23:29] = 0 -# # value[23:29] = 0 -# # Create the KV caches. - -# key_caches, value_caches = kv_cache_factory(num_blocks, -# block_size, -# 1, -# num_heads, -# head_size, -# dtype, -# seed, -# flash_style=True) -# assert len(key_caches) == 1 and len(value_caches) == 1 -# key_cache, value_cache = key_caches[0], value_caches[0] -# # Call the reshape_and_cache kernel. -# cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, -# slot_mapping) - -# cloned_key_cache = key_cache.clone() -# cloned_value_cache = value_cache.clone() - -# # Run the reference implementation. -# block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') -# block_indicies = block_indicies.cpu().tolist() -# block_offsets = slot_mapping % block_size -# block_offsets = block_offsets.cpu().tolist() -# for i in range(num_tokens): -# block_idx = block_indicies[i] -# block_offset = block_offsets[i] -# print("block_idx", block_idx) -# print("block_offset", block_offset) -# if block_idx != -1: -# cloned_key_cache[block_idx, block_offset, :, :] = key[i] -# cloned_value_cache[block_idx, block_offset, :, :] = value[i] -# assert torch.allclose(key_cache, cloned_key_cache) -# assert torch.allclose(value_cache, cloned_value_cache) - -# for i in range(58): -# print(i) -# block_idx = block_indicies[i] -# block_offset = block_offsets[i] -# torch.allclose(key[i], key_cache[block_idx][block_offset]) - -# from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask -# from xformers import ops as xops -# seqlen = query.shape[1] -# attn_bias = BlockDiagonalCausalMask.from_seqlens( -# [seqlen] * batch_size) -# scale = float(1.0 / (head_size**0.5)) -# output = xops.memory_efficient_attention_forward( -# query.unsqueeze(0), -# key.unsqueeze(0), -# value.unsqueeze(0), -# attn_bias=None, -# p=0.0, -# scale=scale, -# ) - -# num_tokens, num_heads, head_size = query.shape -# from flash_attn import flash_attn_func -# output1 = flash_single_query_cached_kv_attention( -# None, -# query.view(batch_size, num_tokens // batch_size, num_heads, head_size), -# key_cache, -# value_cache, -# scale, -# block_tables, -# torch.tensor([23, 29], dtype=torch.int, device='cuda'), -# alibi_slopes=None, -# ) -# output2 = flash_attn_func( -# # query.view(batch_size, num_tokens // batch_size, num_heads, head_size), -# # key.view(batch_size, num_tokens // batch_size, num_heads, head_size), -# # value.view(batch_size, num_tokens // batch_size, num_heads, head_size), -# query.unsqueeze(0), -# key.unsqueeze(0), -# value.unsqueeze(0), -# # key_cache, -# # value_cache, -# softmax_scale=scale, -# # block_tables, -# # torch.tensor([23, 29], dtype=torch.int, device='cuda'), -# # alibi_slopes=None, -# causal=True, -# ) -# output3 = flash_attn_func( -# query.view(batch_size, num_tokens // batch_size, num_heads, head_size), -# key.view(batch_size, num_tokens // batch_size, num_heads, head_size), -# value.view(batch_size, num_tokens // batch_size, num_heads, head_size), -# # key_cache, -# # value_cache, -# softmax_scale=scale, -# # block_tables, -# # torch.tensor([23, 29], dtype=torch.int, device='cuda'), -# # alibi_slopes=None, -# causal=True, -# ) diff --git a/vllm/config.py b/vllm/config.py index c0de0387180b..095e9603f564 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -60,7 +60,6 @@ class ModelConfig: max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. - flash_style: Enable flash style page attention. """ def __init__( @@ -81,7 +80,6 @@ def __init__( enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, max_logprobs: int = 5, - flash_style: bool = False, ) -> None: self.model = model self.tokenizer = tokenizer @@ -97,7 +95,6 @@ def __init__( self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture self.max_logprobs = max_logprobs - self.flash_style = flash_style if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, @@ -309,7 +306,6 @@ def __init__( cache_dtype: str, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, - flash_style: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -317,7 +313,6 @@ def __init__( self.cache_dtype = cache_dtype self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching - self.flash_style = flash_style self._verify_args() self._verify_cache_dtype() @@ -335,15 +330,6 @@ def _verify_args(self) -> None: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") - if self.flash_style: - logger.info("Flash attention enabled.") - if self.block_size > 32: - # Flash style attention only supports block size >=256 for now. - # https://github.com/Dao-AILab/flash-attention/pull/824 will fix it. - raise ValueError( - "Flash style attention only supports block size <= 32. Got" - f"{self.block_size }") - def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass @@ -476,8 +462,6 @@ class SchedulerConfig: for flash style attention. max_num_prompt_seqs: The maximum number of prompt sequences that can be processed in a single iteration. - flash_style: Whether to use flash style attention. Only support - LLaMA models. """ def __init__( @@ -487,7 +471,6 @@ def __init__( max_model_len: int, max_chunked_prefill_len: int = -1, max_num_prompt_seqs: int = 1024, - flash_style: bool = False, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -500,7 +483,6 @@ def __init__( self.chunked_prefill_enabled = max_chunked_prefill_len != -1 self.max_chunked_prefill_len = max_chunked_prefill_len self.max_num_prompt_seqs = max_num_prompt_seqs - self.flash_style = flash_style self._verify_args() def _verify_args(self) -> None: @@ -518,10 +500,6 @@ def _verify_args(self) -> None: f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " f"({self.max_num_seqs}).") - # if self.chunked_prefill_enabled and not self.flash_style: - # # SANG-TODO It is probably fixable. - # raise ValueError( - # "chunked prefill is only supported for flash style") class DeviceConfig: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a811814591e9..0abc4610b815 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -45,7 +45,6 @@ class EngineArgs: lora_extra_vocab_size: int = 256 lora_dtype = 'auto' max_cpu_loras: Optional[int] = None - flash_style: bool = False device: str = 'auto' ray_workers_use_nsight: bool = False max_chunked_prefill_len: int = -1 @@ -285,9 +284,6 @@ def add_cli_args( default=EngineArgs.device, choices=["auto", "cuda", "neuron"], help='Device type for vLLM execution.') - parser.add_argument('--flash-style', - action='store_true', - help='use flash attention.') parser.add_argument( '--max-chunked-prefill-len', type=int, @@ -320,14 +316,12 @@ def create_engine_configs( self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.enforce_eager, self.max_context_len_to_capture, - self.max_logprobs, - self.flash_style) + self.max_logprobs) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, model_config.get_sliding_window(), - self.enable_prefix_caching, - self.flash_style) + self.enable_prefix_caching) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, @@ -340,7 +334,6 @@ def create_engine_configs( model_config.max_model_len, max_chunked_prefill_len=self.max_chunked_prefill_len, max_num_prompt_seqs=self.max_num_prompt_seqs, - flash_style=self.flash_style, ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index c8216e95e848..d64ef88dd690 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -48,7 +48,6 @@ def __init__( block_tables: Optional[torch.Tensor], use_cuda_graph: bool, kv_cache_dtype: str, - flash_style: bool, ) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens @@ -70,7 +69,6 @@ def __init__( self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph self.kv_cache_dtype = kv_cache_dtype - self.flash_style = flash_style # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. @@ -79,51 +77,6 @@ def __init__( # See attention.py for precise definition. self.num_valid_tokens = slot_mapping.shape[0] - # SANG-TODO - # # Prompt related metadata - # # This value might include padding if CudaGraph is enabled. - # self.num_prompts = len(prompt_lens) - # # This value is the source of truth. - # self.num_prompts_tensor = torch.cuda.IntTensor([self.num_prompts]) - # # This value might include padding if CudaGraph is enabled. - # self.num_prompt_tokens = num_prompt_tokens - # self.prompt_lens_tensor = torch.cuda.IntTensor(self.prompt_lens) - # self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 - - # # Cumulative prompt lengths for each prompt in the input - # # tensor. - # self.cum_prompt_query_lens = torch.zeros( - # self.num_prompts + 1, - # device=self.prompt_lens_tensor.device, - # dtype=torch.int32) - # # Cumulative context lengths. - # self.cum_prompt_context_lens = torch.zeros( - # self.num_prompts + 1, - # device=self.prompt_lens_tensor.device, - # dtype=torch.int32) - - # torch.cumsum(self.prompt_lens_tensor, - # dim=0, - # dtype=self.cum_prompt_query_lens.dtype, - # out=self.cum_prompt_query_lens[1:]) - # torch.cumsum(self.context_lens[:self.num_prompts], - # dim=0, - # dtype=self.cum_prompt_context_lens.dtype, - # out=self.cum_prompt_context_lens[1:]) - - # # TODO: this will be different once we support chunked prefills. - # self.cum_prompt_context_lens = self.cum_prompt_query_lens - # self.max_context_len = max_context_len - - # # Generation related metadata - # # This value might include padding if CudaGraph is enabled. - # self.num_generation_tokens = num_generation_tokens - # # This is the source of truth for the number of generation tokens. - # self.num_generation_tokens_tensor = torch.tensor( - # [self.num_generation_tokens], - # dtype=torch.int32 if self.flash_style else torch.long, - # device='cuda') - def __repr__(self) -> str: return ("InputMetadata(" f"is_prompt={self.is_prompt}, " @@ -134,6 +87,5 @@ def __repr__(self) -> str: f"context_lens={self.context_lens}, " f"block_tables={self.block_tables}, " f"use_cuda_graph={self.use_cuda_graph}, " - f"flash_style={self.flash_style} " f"kv_cache_dtype={self.kv_cache_dtype}) " f"num_valid_tokens={self.num_valid_tokens}") diff --git a/vllm/model_executor/layers/attention/ops/prefix_prefill.py b/vllm/model_executor/layers/attention/ops/prefix_prefill.py index c6054de2b718..6b1da8d0da51 100644 --- a/vllm/model_executor/layers/attention/ops/prefix_prefill.py +++ b/vllm/model_executor/layers/attention/ops/prefix_prefill.py @@ -632,7 +632,7 @@ def context_attention_fwd(q, alibi_slopes=None): cap = torch.cuda.get_device_capability() - BLOCK = 128 if cap[0] >= 8 else 64 + BLOCK = 32 if cap[0] >= 8 else 32 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv @@ -696,7 +696,7 @@ def context_attention_fwd(q, ) return - _fwd_kernel_flash_attn_v2[grid]( + _fwd_kernel[grid]( q, k, v, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 5d4a460074c9..cb64d80c8147 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -29,7 +29,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: architectures = ["QuantMixtralForCausalLM"] for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch, model_config) + model_cls = ModelRegistry.load_model_cls(arch) if model_cls is not None: return model_cls raise ValueError( diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 1e9e7b76659d..75c2ae1e9f48 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -5,7 +5,6 @@ from vllm.logger import init_logger from vllm.utils import is_hip, is_neuron -from vllm.config import ModelConfig logger = init_logger(__name__) @@ -70,8 +69,7 @@ class ModelRegistry: @staticmethod - def load_model_cls(model_arch: str, - model_config: ModelConfig) -> Optional[Type[nn.Module]]: + def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: if model_arch not in _MODELS: return None if is_hip(): @@ -94,9 +92,7 @@ def load_model_cls(model_arch: str, module_name = _NEURON_SUPPORTED_MODELS[model_arch] module = importlib.import_module( f"vllm.model_executor.models.{module_name}") - model_cls = getattr(module, model_cls_name, None) - - return model_cls + return getattr(module, model_cls_name, None) @staticmethod def get_supported_archs() -> List[str]: diff --git a/vllm/utils.py b/vllm/utils.py index 2d99c93df187..d5d89a262f31 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -254,7 +254,6 @@ def create_kv_caches_with_random( model_dtype: Optional[Union[str, torch.dtype]] = None, seed: Optional[int] = 0, device: Optional[str] = "cuda", - flash_style: bool = False, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -281,10 +280,7 @@ def create_kv_caches_with_random( scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() - if flash_style: - key_cache_shape = (num_blocks, block_size, num_heads, head_size) - else: - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches = [] for _ in range(num_layers): @@ -300,10 +296,7 @@ def create_kv_caches_with_random( f"Does not support key cache of type {cache_dtype}") key_caches.append(key_cache) - if flash_style: - value_cache_shape = (num_blocks, block_size, num_heads, head_size) - else: - value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_caches = [] for _ in range(num_layers): diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 2ca359e6261d..880299783935 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -60,33 +60,19 @@ def __init__( def get_key_block_shape(self) -> Tuple[int, int, int, int]: element_size = torch.tensor([], dtype=self.dtype).element_size() x = 16 // element_size - if self.cache_config.flash_style: - return ( - self.block_size, - self.num_heads, - self.head_size, - ) - else: - return ( - self.num_heads, - self.head_size // x, - self.block_size, - x, - ) + return ( + self.num_heads, + self.head_size // x, + self.block_size, + x, + ) def get_value_block_shape(self) -> Tuple[int, int, int]: - if self.cache_config.flash_style: - return ( - self.block_size, - self.num_heads, - self.head_size, - ) - else: - return ( - self.num_heads, - self.head_size, - self.block_size, - ) + return ( + self.num_heads, + self.head_size, + self.block_size, + ) def allocate_gpu_cache(self) -> List[KVCache]: gpu_cache: List[KVCache] = [] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7303cfeb42f6..0a22c5b5c645 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -58,8 +58,6 @@ def __init__( # FIXME(woosuk): This is a hack to make the tests work. Refactor this. self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None) - self.flash_style = (self.model_config.flash_style - if model_config is not None else False) self.device_config = (device_config if device_config is not None else DeviceConfig()) @@ -185,7 +183,12 @@ def _prepare_prompt( prefix_block_tables.append(computed_block_nums) context_len = computed_len else: - prefix_block_tables.append([]) + # prefix_block_tables.append([]) + if seq_group_metadata.block_tables is None: + prefix_block_tables.append([]) + else: + block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) context_len = 0 # actual prompt lens context_lens.append(context_len) @@ -309,7 +312,6 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.flash_style, ) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, @@ -449,7 +451,6 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.flash_style, ) return (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -616,8 +617,7 @@ def prepare_input_tensors( context_lens=metadata_dict["context_lens"], block_tables=metadata_dict["block_tables"], use_cuda_graph=metadata_dict["use_cuda_graph"], - kv_cache_dtype=metadata_dict["kv_cache_dtype"], - flash_style=self.flash_style) + kv_cache_dtype=metadata_dict["kv_cache_dtype"]) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -816,8 +816,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size], use_cuda_graph=True, - kv_cache_dtype=self.kv_cache_dtype, - flash_style=self.flash_style) + kv_cache_dtype=self.kv_cache_dtype) if self.lora_config: lora_mapping = LoRAMapping( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c49ceb92b0e3..9d367bed4376 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -20,8 +20,6 @@ from vllm.worker.model_runner import ModelRunner from vllm.lora.request import LoRARequest -MAX_INT_32 = 2**31 - 1 - class Worker: """A worker class that executes (a partition of) the model on a GPU. @@ -146,13 +144,6 @@ def profile_num_available_blocks( gc.collect() torch.cuda.empty_cache() - # Flash attention only allows max(int32) slots. - # TODO: Remove this once we support int64 in flash-attention. - if self.model_config.flash_style: - num_gpu_blocks = min(num_gpu_blocks, - MAX_INT_32 // cache_block_size) - num_cpu_blocks = min(num_cpu_blocks, - MAX_INT_32 // cache_block_size) # print("SANG-TODO profile_num_available_blocks done") return num_gpu_blocks, num_cpu_blocks From d09eaf50a231773e2cd75b59ba4b25846bb5c951 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 11 Mar 2024 05:11:00 -0700 Subject: [PATCH 40/88] . --- tests/core/test_scheduler.py | 1 - tests/core/utils.py | 1 - tests/kernels/test_cache.py | 1 - vllm/utils.py | 3 +-- vllm/worker/model_runner.py | 7 +------ 5 files changed, 2 insertions(+), 11 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 19d8800a387c..1540f2fed9e8 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -238,4 +238,3 @@ def test_scheduler_max_seqs(): # and one is prompting. _, out = scheduler.schedule() assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]]) - diff --git a/tests/core/utils.py b/tests/core/utils.py index 4074fc0b24d1..6469789e8938 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -25,4 +25,3 @@ def create_dummy_prompt( def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size - diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 9716094a3e81..d8dc74bc7b00 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -224,4 +224,3 @@ def test_swap_blocks( dist_key_caches[0][dst].cpu()) assert torch.allclose(src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()) - \ No newline at end of file diff --git a/vllm/utils.py b/vllm/utils.py index d5d89a262f31..a7de36cca3a9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -280,8 +280,7 @@ def create_kv_caches_with_random( scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, - x) + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0a22c5b5c645..3c3a0fd73646 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -183,12 +183,7 @@ def _prepare_prompt( prefix_block_tables.append(computed_block_nums) context_len = computed_len else: - # prefix_block_tables.append([]) - if seq_group_metadata.block_tables is None: - prefix_block_tables.append([]) - else: - block_table = seq_group_metadata.block_tables[seq_id] - prefix_block_tables.append(block_table) + prefix_block_tables.append([]) context_len = 0 # actual prompt lens context_lens.append(context_len) From 93a7b90df5efddc20e6b1b4133bc8d9e4f8dce34 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 11 Mar 2024 18:58:23 -0700 Subject: [PATCH 41/88] . --- tests/models/test_models_2.py | 94 +++++++++++++++++++ .../layers/attention/backends/xformers.py | 1 + vllm/worker/model_runner.py | 7 +- 3 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 tests/models/test_models_2.py diff --git a/tests/models/test_models_2.py b/tests/models/test_models_2.py new file mode 100644 index 000000000000..442c0b9a0d11 --- /dev/null +++ b/tests/models/test_models_2.py @@ -0,0 +1,94 @@ +import gc + +import pytest +import torch + +from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel + +MODELS = [ + "JackFram/llama-68m", + # "facebook/opt-125m", +] + + +# SANG-TODO enforce_eager = True and chunked prefill currently doesn't work. +# TODO(sang): Add chunked prefill parameters. +# @pytest.mark.parametrize("model", MODELS) +# @pytest.mark.parametrize("dtype", ["half"]) +# @pytest.mark.parametrize("max_tokens", [128]) +# @pytest.mark.parametrize("max_chunked_prefill_len", [-1, 16, 64]) +# @pytest.mark.parametrize("max_num_prompt_seqs", [1, 2, 100]) +# @pytest.mark.parametrize("block_size", [32]) +# @pytest.mark.parametrize("tensor_parallel_size", [1, 2]) +# @pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("max_chunked_prefill_len", [500]) +@pytest.mark.parametrize("max_num_prompt_seqs", [256]) +@pytest.mark.parametrize("block_size", [32]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("enforce_eager", [False]) +def test_models( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + max_chunked_prefill_len: int, + max_num_prompt_seqs: int, + block_size: int, + tensor_parallel_size: int, + enforce_eager: bool, +) -> None: + """ verify the flash attention has the same output + as page attention """ + if torch.cuda.device_count() < tensor_parallel_size: + pytest.skip( + f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}" + ) + print("loading page attention models..") + pg_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) + expected_outputs = [] + + print("generating tokens...") + expected_outputs.extend( + pg_model.generate_greedy(example_prompts, max_tokens)) + print("generating tokens finished") + + del pg_model + + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + + flash_attn_output_by_batches = [] + flash_attn_model = vllm_runner( + model, + dtype=dtype, + block_size=block_size, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager) + for i in range(5, 6): + prompts = [example_prompts[j % len(example_prompts)] for j in range(i)] + breakpoint() + flash_attn_output_by_batches.append( + flash_attn_model.generate_greedy(prompts, max_tokens)) + + del flash_attn_model + + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + + for flash_attn_outputs in flash_attn_output_by_batches: + for i in range(len(flash_attn_outputs)): + fa_output_ids, fa_output_str = flash_attn_outputs[i] + vllm_output_ids, vllm_output_str = expected_outputs[ + i % len(expected_outputs)] + # print("expected, ",vllm_output_str, "\n") + # print("actual:, ", fa_output_str, "\n") + assert fa_output_ids == vllm_output_ids, ( + f"Test{i}:\nflash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" + f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" + ) diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/model_executor/layers/attention/backends/xformers.py index 269b73303fb0..d1e264323963 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/model_executor/layers/attention/backends/xformers.py @@ -162,6 +162,7 @@ def forward( else: # prefix-enabled attention + print("SANG-TODO prefix") output = PagedAttentionImpl.forward_prefix( query, key, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 501244e3dc2c..5e2086d5129c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -165,7 +165,12 @@ def _prepare_prompt( prefix_block_tables.append(computed_block_nums) context_len = computed_len else: - prefix_block_tables.append([]) + if seq_group_metadata.block_tables is None: + prefix_block_tables.append([]) + else: + block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) + # prefix_block_tables.append([]) context_len = 0 # actual prompt lens context_lens.append(context_len) From 647d8cc612b3cd10c9f829119f4ece5c97ed000f Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 12 Mar 2024 00:06:20 -0700 Subject: [PATCH 42/88] . --- tests/models/test_models_2.py | 77 +++++++++---------- .../layers/attention/ops/paged_attn.py | 1 + vllm/worker/model_runner.py | 7 +- 3 files changed, 41 insertions(+), 44 deletions(-) diff --git a/tests/models/test_models_2.py b/tests/models/test_models_2.py index 442c0b9a0d11..786534ec975d 100644 --- a/tests/models/test_models_2.py +++ b/tests/models/test_models_2.py @@ -1,9 +1,11 @@ import gc import pytest +from tests.conftest import example_prompts import torch from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel +from vllm import SamplingParams MODELS = [ "JackFram/llama-68m", @@ -24,22 +26,20 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("max_chunked_prefill_len", [500]) -@pytest.mark.parametrize("max_num_prompt_seqs", [256]) -@pytest.mark.parametrize("block_size", [32]) +@pytest.mark.parametrize("max_chunked_prefill_len", [-1]) @pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("num", [3]) def test_models( - vllm_runner, example_prompts, + vllm_runner, model: str, dtype: str, max_tokens: int, max_chunked_prefill_len: int, - max_num_prompt_seqs: int, - block_size: int, tensor_parallel_size: int, enforce_eager: bool, + num, ) -> None: """ verify the flash attention has the same output as page attention """ @@ -47,48 +47,41 @@ def test_models( pytest.skip( f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}" ) - print("loading page attention models..") - pg_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) - expected_outputs = [] - print("generating tokens...") - expected_outputs.extend( - pg_model.generate_greedy(example_prompts, max_tokens)) - print("generating tokens finished") + def cleanup(): + torch.backends.cuda.matmul.allow_tf32 = False + torch.set_default_dtype(torch.float32) + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() - del pg_model + def evaluate(init_llm): + llm = init_llm() + outputs = llm.generate_greedy(example_prompts[num], max_tokens=max_tokens) + token_ids_list = [] + output_str_list = [] - destroy_model_parallel() - gc.collect() - torch.cuda.empty_cache() + for i in range(len(outputs)): + token_ids = outputs[i][0] + output_str = outputs[i][1] + token_ids_list.append(token_ids) + output_str_list.append(output_str) + del llm + cleanup() + return token_ids_list, output_str_list - flash_attn_output_by_batches = [] - flash_attn_model = vllm_runner( + vllm_token_ids, vllm_str = evaluate(lambda: vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager)) + import os + os.environ["ENABLE"] = "1" + chunked_prefill_token_ids, chunked_str = evaluate(lambda: vllm_runner( model, dtype=dtype, - block_size=block_size, tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager) - for i in range(5, 6): - prompts = [example_prompts[j % len(example_prompts)] for j in range(i)] - breakpoint() - flash_attn_output_by_batches.append( - flash_attn_model.generate_greedy(prompts, max_tokens)) - - del flash_attn_model + enforce_eager=enforce_eager)) - destroy_model_parallel() - gc.collect() - torch.cuda.empty_cache() + for i in range(len(vllm_token_ids)): + print(f"TEST {i}") + print(f"{len(vllm_token_ids[i])=} {vllm_token_ids[i]=}\n{vllm_str[i]=}") + print(f"{len(chunked_prefill_token_ids[i])=} {chunked_prefill_token_ids[i]=}\n{chunked_str[i]=}\n") + assert vllm_token_ids[i] == chunked_prefill_token_ids[i] - for flash_attn_outputs in flash_attn_output_by_batches: - for i in range(len(flash_attn_outputs)): - fa_output_ids, fa_output_str = flash_attn_outputs[i] - vllm_output_ids, vllm_output_str = expected_outputs[ - i % len(expected_outputs)] - # print("expected, ",vllm_output_str, "\n") - # print("actual:, ", fa_output_str, "\n") - assert fa_output_ids == vllm_output_ids, ( - f"Test{i}:\nflash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" - f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" - ) diff --git a/vllm/model_executor/layers/attention/ops/paged_attn.py b/vllm/model_executor/layers/attention/ops/paged_attn.py index c5a9618c2395..f70c88d7128f 100644 --- a/vllm/model_executor/layers/attention/ops/paged_attn.py +++ b/vllm/model_executor/layers/attention/ops/paged_attn.py @@ -121,6 +121,7 @@ def forward_prefix( alibi_slopes: Optional[torch.Tensor], ) -> torch.Tensor: output = torch.empty_like(query) + print("SANG-TODO prefix attention!") context_attention_fwd( query, key, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5e2086d5129c..95ed95b53d1f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -154,7 +154,7 @@ def _prepare_prompt( prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) computed_len = 0 - + import os # NOTE: This only works for oooooooxxx style attention. computed_block_nums = seq_group_metadata.computed_block_nums if computed_block_nums is not None and len( @@ -164,7 +164,7 @@ def _prepare_prompt( prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) context_len = computed_len - else: + if os.getenv("ENABLE") is not None: if seq_group_metadata.block_tables is None: prefix_block_tables.append([]) else: @@ -172,6 +172,9 @@ def _prepare_prompt( prefix_block_tables.append(block_table) # prefix_block_tables.append([]) context_len = 0 + else: + prefix_block_tables.append([]) + context_len = 0 # actual prompt lens context_lens.append(context_len) subquery_lens.append(prompt_len - computed_len) From b2f4b3ea30672f8472dfd196be51b6b16493799a Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 12 Mar 2024 02:08:22 -0700 Subject: [PATCH 43/88] ip --- tests/models/test_models_2.py | 4 ++-- vllm/model_executor/layers/attention/ops/paged_attn.py | 8 ++++++++ vllm/worker/model_runner.py | 5 +++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/models/test_models_2.py b/tests/models/test_models_2.py index 786534ec975d..9683352b712a 100644 --- a/tests/models/test_models_2.py +++ b/tests/models/test_models_2.py @@ -28,7 +28,7 @@ @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_chunked_prefill_len", [-1]) @pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("num", [3]) def test_models( example_prompts, @@ -57,7 +57,7 @@ def cleanup(): def evaluate(init_llm): llm = init_llm() - outputs = llm.generate_greedy(example_prompts[num], max_tokens=max_tokens) + outputs = llm.generate_greedy(example_prompts, max_tokens=max_tokens) token_ids_list = [] output_str_list = [] diff --git a/vllm/model_executor/layers/attention/ops/paged_attn.py b/vllm/model_executor/layers/attention/ops/paged_attn.py index f70c88d7128f..32d18b0be6f6 100644 --- a/vllm/model_executor/layers/attention/ops/paged_attn.py +++ b/vllm/model_executor/layers/attention/ops/paged_attn.py @@ -122,6 +122,14 @@ def forward_prefix( ) -> torch.Tensor: output = torch.empty_like(query) print("SANG-TODO prefix attention!") + print(f"{input_metadata.block_tables=}") + print(f"{input_metadata.start_loc=}") + print(f"{input_metadata.prompt_lens=}") + print(f"{input_metadata.context_lens=}") + print(f"{input_metadata.max_seq_len=}") + print(f"{query.size()=}") + print(f"{key.size()=}") + print(f"{value.size()=}") context_attention_fwd( query, key, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 95ed95b53d1f..6b6549619b06 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -264,11 +264,12 @@ def _prepare_prompt( # Cumulative index of each prompt. [prompt_lens + 1] # [0, 0+1th, 0+1th+2nd, ...] - start_loc_tensor = torch.zeros(prompt_lens_tensor.shape[0] + 1, + subquery_lens_tensor = torch.tensor(subquery_lens, dtype=torch.long, device=self.device) + start_loc_tensor = torch.zeros(subquery_lens_tensor.shape[0], dtype=torch.long, device=self.device) - torch.cumsum(prompt_lens_tensor, + torch.cumsum(subquery_lens_tensor[:-1], dim=0, dtype=start_loc_tensor.dtype, out=start_loc_tensor[1:]) From cc8419fde7b6f000f54c63a9c04014807be490cd Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 12 Mar 2024 04:35:58 -0700 Subject: [PATCH 44/88] . --- tests/models/test_models.py | 33 ++++++++++--------- .../layers/attention/ops/prefix_prefill.py | 2 +- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 7bb93a519bf0..28cc5275fe17 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -5,28 +5,28 @@ import pytest MODELS = [ - "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", - "mistralai/Mistral-7B-v0.1", - "Deci/DeciLM-7b", - "tiiuae/falcon-7b", + # "facebook/opt-125m", + # "meta-llama/Llama-2-7b-hf", + # "mistralai/Mistral-7B-v0.1", + # "Deci/DeciLM-7b", + # "tiiuae/falcon-7b", "gpt2", - "bigcode/tiny_starcoder_py", - "EleutherAI/gpt-j-6b", - "EleutherAI/pythia-70m", - "bigscience/bloom-560m", - "mosaicml/mpt-7b", - "microsoft/phi-2", - "stabilityai/stablelm-3b-4e1t", - "allenai/OLMo-1B", - "bigcode/starcoder2-3b", + # "bigcode/tiny_starcoder_py", + # "EleutherAI/gpt-j-6b", + # "EleutherAI/pythia-70m", + # "bigscience/bloom-560m", + # "mosaicml/mpt-7b", + # "microsoft/phi-2", + # "stabilityai/stablelm-3b-4e1t", + # "allenai/OLMo-1B", + # "bigcode/starcoder2-3b", ] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("enforce_eager", [False]) def test_models( hf_runner, vllm_runner, @@ -40,6 +40,9 @@ def test_models( hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model + import os + os.environ["ENABLE"] = "1" + vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/vllm/model_executor/layers/attention/ops/prefix_prefill.py b/vllm/model_executor/layers/attention/ops/prefix_prefill.py index 70f09224f1cf..541c0aae44c7 100644 --- a/vllm/model_executor/layers/attention/ops/prefix_prefill.py +++ b/vllm/model_executor/layers/attention/ops/prefix_prefill.py @@ -632,7 +632,7 @@ def context_attention_fwd(q, alibi_slopes=None): cap = torch.cuda.get_device_capability() - BLOCK = 128 if cap[0] >= 8 else 64 + BLOCK = 32 if cap[0] >= 8 else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv From 3cb8093fb74ef63f2f7f3865f254a8bd0a2b6605 Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 15 Mar 2024 17:09:32 -0700 Subject: [PATCH 45/88] ip addressing comments. --- tests/basic_correctness/test_cuda_graph.py | 41 +++++++++ tests/models/test_models.py | 5 -- tests/models/test_models_2.py | 87 ------------------- tests/spec_decode/test_multi_step_worker.py | 5 +- tests/worker/test_model_runner.py | 5 +- vllm/model_executor/input_metadata.py | 24 +++-- .../layers/attention/attention.py | 2 +- .../layers/attention/backends/flash_attn.py | 21 ++--- .../layers/attention/backends/xformers.py | 4 - 9 files changed, 67 insertions(+), 127 deletions(-) create mode 100644 tests/basic_correctness/test_cuda_graph.py delete mode 100644 tests/models/test_models_2.py diff --git a/tests/basic_correctness/test_cuda_graph.py b/tests/basic_correctness/test_cuda_graph.py new file mode 100644 index 000000000000..69ecb153eb47 --- /dev/null +++ b/tests/basic_correctness/test_cuda_graph.py @@ -0,0 +1,41 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +Make sure both cuda graph & eager mode works. + +Run `pytest tests/models/test_models.py --forked`. +""" +import pytest + +MODELS = [ + "facebook/opt-125m", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + enforce_eager: bool, +) -> None: + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 28cc5275fe17..ce31288090e7 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -26,7 +26,6 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("enforce_eager", [False]) def test_models( hf_runner, vllm_runner, @@ -34,15 +33,11 @@ def test_models( model: str, dtype: str, max_tokens: int, - enforce_eager: bool, ) -> None: hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - import os - os.environ["ENABLE"] = "1" - vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/models/test_models_2.py b/tests/models/test_models_2.py deleted file mode 100644 index 9683352b712a..000000000000 --- a/tests/models/test_models_2.py +++ /dev/null @@ -1,87 +0,0 @@ -import gc - -import pytest -from tests.conftest import example_prompts -import torch - -from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel -from vllm import SamplingParams - -MODELS = [ - "JackFram/llama-68m", - # "facebook/opt-125m", -] - - -# SANG-TODO enforce_eager = True and chunked prefill currently doesn't work. -# TODO(sang): Add chunked prefill parameters. -# @pytest.mark.parametrize("model", MODELS) -# @pytest.mark.parametrize("dtype", ["half"]) -# @pytest.mark.parametrize("max_tokens", [128]) -# @pytest.mark.parametrize("max_chunked_prefill_len", [-1, 16, 64]) -# @pytest.mark.parametrize("max_num_prompt_seqs", [1, 2, 100]) -# @pytest.mark.parametrize("block_size", [32]) -# @pytest.mark.parametrize("tensor_parallel_size", [1, 2]) -# @pytest.mark.parametrize("enforce_eager", [True, False]) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("max_chunked_prefill_len", [-1]) -@pytest.mark.parametrize("tensor_parallel_size", [1]) -@pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("num", [3]) -def test_models( - example_prompts, - vllm_runner, - model: str, - dtype: str, - max_tokens: int, - max_chunked_prefill_len: int, - tensor_parallel_size: int, - enforce_eager: bool, - num, -) -> None: - """ verify the flash attention has the same output - as page attention """ - if torch.cuda.device_count() < tensor_parallel_size: - pytest.skip( - f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}" - ) - - def cleanup(): - torch.backends.cuda.matmul.allow_tf32 = False - torch.set_default_dtype(torch.float32) - destroy_model_parallel() - gc.collect() - torch.cuda.empty_cache() - - def evaluate(init_llm): - llm = init_llm() - outputs = llm.generate_greedy(example_prompts, max_tokens=max_tokens) - token_ids_list = [] - output_str_list = [] - - for i in range(len(outputs)): - token_ids = outputs[i][0] - output_str = outputs[i][1] - token_ids_list.append(token_ids) - output_str_list.append(output_str) - del llm - cleanup() - return token_ids_list, output_str_list - - vllm_token_ids, vllm_str = evaluate(lambda: vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager)) - import os - os.environ["ENABLE"] = "1" - chunked_prefill_token_ids, chunked_str = evaluate(lambda: vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager)) - - for i in range(len(vllm_token_ids)): - print(f"TEST {i}") - print(f"{len(vllm_token_ids[i])=} {vllm_token_ids[i]=}\n{vllm_str[i]=}") - print(f"{len(chunked_prefill_token_ids[i])=} {chunked_prefill_token_ids[i]=}\n{chunked_str[i]=}\n") - assert vllm_token_ids[i] == chunked_prefill_token_ids[i] - diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 9fb516faef5c..45b43ec59ee8 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -92,8 +92,8 @@ def test_same_output_for_single_step(): num_gpu_blocks, seed, ) - # multi_step_worker.model_runner = worker.model_runner - # multi_step_worker.cache_engine = worker.cache_engine + multi_step_worker.model_runner = worker.model_runner + multi_step_worker.cache_engine = worker.cache_engine num_steps = 1 @@ -418,4 +418,3 @@ def test_draft_proposals_mixed_k(): assert proposals.proposal_lens.tolist() == [ k for _ in range(expected_num_proposal_seqs - 1) ] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k] ->>>>>>> main:tests/spec_decode/test_multi_step_worker.py diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 55a078230b46..fe59edc8050b 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -7,9 +7,8 @@ # Make sure the result is aligned. -def round_up_to_next_multiple_of_batch_size(n): - batch_size = _BATCH_SIZE_ALIGNMENT - return ((n + 7) // batch_size) * batch_size +def round_up_to_next_multiple_of_batch_size(batch_size): + return ((batch_size + _BATCH_SIZE_ALIGNMENT -1) // _BATCH_SIZE_ALIGNMENT) * _BATCH_SIZE_ALIGNMENT def test_prepare_prompt(): diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index d992f6b7d207..255ff32e572f 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -13,10 +13,11 @@ class InputMetadata: Args: prompt_lens: Lengths of prompts. - slot_mapping: The index of each token mapped into a physical block - in block tables. E.g., if block_size is 32, 35 means it is in - the block number 1, 3rd index. - num_prompt_tokens: The number of tokens in the prompts. This might + slot_mapping: The indices of the token slots that input tokens will be stored into. + E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the three tokens + are stored in the 3rd slot in block 2, 2nd slot in block 0, and 1st slot in block 1, + respectively. + num_prompt_tokens: The total number of tokens in the prompts. This might include padding. num_generation_tokens: The number of tokens in the generation sequences. This might include padding. @@ -50,13 +51,8 @@ def __init__( self.start_loc = start_loc self.max_context_len = max_context_len self.slot_mapping = slot_mapping - # Index: The batched sequence's index. - # Value: The length of attention context. - # NOTE(sang): When it is prefill/decoding, - # the definition is different. For prefill, - # it means the the length of KV that are cached - # excluding the new KVs. In decoding, this - # includes a new KV. + # [batch_size]. Each index means each sequence, and the value means the length of tokens stored in the kv cache. + # NOTE(sang): When it is prefill/decoding, the definition is different. For prefill, it means the the length of KV that are cached excluding the new KVs. In decoding, this includes a new KV. self.context_lens = context_lens self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph @@ -78,6 +74,6 @@ def __repr__(self) -> str: f"slot_mapping={self.slot_mapping}, " f"context_lens={self.context_lens}, " f"block_tables={self.block_tables}, " - f"use_cuda_graph={self.use_cuda_graph}, " - f"kv_cache_dtype={self.kv_cache_dtype}) " - f"num_valid_tokens={self.num_valid_tokens}") + f"use_cuda_graph={self.use_cuda_graph} " + f"kv_cache_dtype={self.kv_cache_dtype} " + f"num_valid_tokens={self.num_valid_tokens})") diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 34a6dc2b1c0a..189200390658 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -19,7 +19,7 @@ class Attention(nn.Module): can either contain prompt tokens or generation tokens. - If the input tensors contain prompt tokens, the layout is as follows: + If the input tensors contain prompt tokens, the layout is as follows: |<---------------------- num_valid_tokens ---------------------->| |<--------------- num_prompt_tokens -------------->| |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->| diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index e203d2af17e6..ce22cbe90783 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -73,25 +73,26 @@ def forward( # If key_cache and value_cache are not provided, the new key and value # vectors will not be cached. This happens during the initial memory # profiling run. - PagedAttentionImpl.reshape_and_cache(key, value, key_cache, - value_cache, input_metadata) + if key_cache is not None and value_cache is not None: + PagedAttentionImpl.reshape_and_cache(key, value, key_cache, + value_cache, input_metadata) if input_metadata.is_prompt: # Prompt run. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # normal attention - output = torch.empty_like(query) query = query.unflatten(0, (num_tokens, )) key = key.unflatten(0, (num_tokens, )) value = value.unflatten(0, (num_tokens, )) - output = flash_attn_func(query, - key, - value, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes) + output = flash_attn_func( + query, + key, + value, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes) else: # prefix-enabled attention output = PagedAttentionImpl.forward_prefix( diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/model_executor/layers/attention/backends/xformers.py index d1e264323963..68d9c68eabe0 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/model_executor/layers/attention/backends/xformers.py @@ -143,10 +143,6 @@ def forward( query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) - else: - query = query.unflatten(0, (num_tokens)) - key = key.unflatten(0, (num_tokens)) - value = value.unflatten(0, (num_tokens)) out = xops.memory_efficient_attention_forward( query, From 5391129f2368ece0c7678b6454362211d0b0005b Mon Sep 17 00:00:00 2001 From: sang Date: Sun, 17 Mar 2024 20:46:29 -0700 Subject: [PATCH 46/88] Alibi slopes working now. --- tests/models/test_models.py | 6 +- tests/worker/test_model_runner.py | 3 +- vllm/model_executor/input_metadata.py | 17 ++- .../layers/attention/backends/flash_attn.py | 17 ++- .../layers/attention/backends/xformers.py | 132 ++++++++++++------ .../layers/attention/ops/paged_attn.py | 11 +- vllm/utils.py | 5 + vllm/worker/model_runner.py | 14 +- 8 files changed, 125 insertions(+), 80 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index ce31288090e7..cc7c6f5b1e41 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -10,11 +10,11 @@ # "mistralai/Mistral-7B-v0.1", # "Deci/DeciLM-7b", # "tiiuae/falcon-7b", - "gpt2", + # "gpt2", # "bigcode/tiny_starcoder_py", # "EleutherAI/gpt-j-6b", # "EleutherAI/pythia-70m", - # "bigscience/bloom-560m", + "bigscience/bloom-560m", # "mosaicml/mpt-7b", # "microsoft/phi-2", # "stabilityai/stablelm-3b-4e1t", @@ -38,7 +38,7 @@ def test_models( hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) + vllm_model = vllm_runner(model, dtype=dtype) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index fe59edc8050b..f26ad16fa41b 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -8,7 +8,8 @@ # Make sure the result is aligned. def round_up_to_next_multiple_of_batch_size(batch_size): - return ((batch_size + _BATCH_SIZE_ALIGNMENT -1) // _BATCH_SIZE_ALIGNMENT) * _BATCH_SIZE_ALIGNMENT + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT) * _BATCH_SIZE_ALIGNMENT def test_prepare_prompt(): diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 255ff32e572f..64210b50f9c2 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,4 +1,5 @@ -from typing import Optional +from typing import Optional, List +from xformers.ops.fmha.attn_bias import AttentionBias import torch @@ -12,7 +13,7 @@ class InputMetadata: updated from `CUDAGraphRunner.forward` API. Args: - prompt_lens: Lengths of prompts. + prompt_lens: Lengths of prompts per sequence. slot_mapping: The indices of the token slots that input tokens will be stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0, and 1st slot in block 1, @@ -32,7 +33,7 @@ def __init__( self, is_prompt: bool, slot_mapping: torch.Tensor, - prompt_lens: Optional[torch.Tensor], + prompt_lens: Optional[List], num_prompt_tokens: int, num_generation_tokens: int, max_seq_len: Optional[int], @@ -59,11 +60,19 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. # FIXME(woosuk): This is a hack. - self.attn_bias = None + self.attn_bias: Optional[List[AttentionBias]] = None # Number of valid tokens. It includes paddings. # See attention.py for precise definition. self.num_valid_tokens = slot_mapping.shape[0] + self.prompt_lens_tensor = None + if self.prompt_lens is not None: + self.prompt_lens_tensor = torch.tensor(self.prompt_lens, + dtype=torch.long, + device=slot_mapping.device) def __repr__(self) -> str: return ("InputMetadata(" diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index ce22cbe90783..0e3848233b96 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -75,7 +75,7 @@ def forward( # profiling run. if key_cache is not None and value_cache is not None: PagedAttentionImpl.reshape_and_cache(key, value, key_cache, - value_cache, input_metadata) + value_cache, input_metadata) if input_metadata.is_prompt: # Prompt run. @@ -85,14 +85,13 @@ def forward( query = query.unflatten(0, (num_tokens, )) key = key.unflatten(0, (num_tokens, )) value = value.unflatten(0, (num_tokens, )) - output = flash_attn_func( - query, - key, - value, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes) + output = flash_attn_func(query, + key, + value, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes) else: # prefix-enabled attention output = PagedAttentionImpl.forward_prefix( diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/model_executor/layers/attention/backends/xformers.py index 68d9c68eabe0..305501e7d4de 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/model_executor/layers/attention/backends/xformers.py @@ -10,7 +10,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention.ops.paged_attn import ( PagedAttentionImpl) -from vllm.utils import is_hip +from vllm.utils import is_hip, _get_aligned_size class XFormersBackend: @@ -55,7 +55,7 @@ def forward( """Forward pass with xFormers and PagedAttention. Args: - query: shape = [num_tokens num_heads * head_size] + query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, @@ -67,11 +67,9 @@ def forward( shape = [num_tokens, num_heads * head_size] """ num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) - output = torch.empty_like(query) # Reshape the keys and values and store them in the cache. # If key_cache and value_cache are not provided, the new key and value @@ -83,10 +81,10 @@ def forward( if input_metadata.is_prompt: # Prompt run. - # Unless there's a prefix, context lens is all 0 for prefill. + # key_cache and value_cache is None when it is a profiling run. + # block tables are empty if the prompt has never been computed. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): - # normal attention if self.num_kv_heads != self.num_heads: # As of Nov 2023, xformers only supports MHA. For MQA/GQA, # project the key and value tensors to the desired number of @@ -105,22 +103,6 @@ def forward( self.num_queries_per_kv, value.shape[-1]) - # Set attention bias if not provided. This typically happens at - # the very attention layer of every iteration. - # FIXME(woosuk): This is a hack. - if input_metadata.attn_bias is None: - if self.alibi_slopes is None: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - input_metadata.prompt_lens.tolist()) - if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention( - self.sliding_window) - input_metadata.attn_bias = attn_bias - else: - input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, query.dtype, - input_metadata) - if self.use_ref_attention: output = _ref_masked_attention( query, @@ -137,28 +119,10 @@ def forward( # Use reshape instead. return output.reshape(num_tokens, hidden_size) - # TODO(woosuk): Too many view operations. Let's try to reduce - # them in the future for code readability. - if self.alibi_slopes is None: - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - - out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if - (is_hip()) else None, - ) - output = out.view_as(query) - + output = self._multi_query_kv_attention( + query, key, value, input_metadata) else: # prefix-enabled attention - print("SANG-TODO prefix") output = PagedAttentionImpl.forward_prefix( query, key, @@ -183,6 +147,81 @@ def forward( # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) + def _multi_query_kv_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + """Attention for 1D query of multiple prompts. Multiple prompt + tokens are flattened in to `query` input. + + Args: + output: shape = [num_prompt_tokens, num_heads, head_size] + query: shape = [num_prompt_tokens, num_heads, head_size] + key: shape = [num_prompt_tokens, num_kv_heads, head_size] + value: shape = [num_prompt_tokens, num_kv_heads, head_size] + input_metadata: metadata for paged attention. + """ + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + if input_metadata.attn_bias is None: + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + input_metadata.prompt_lens) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = [attn_bias] + else: + input_metadata.attn_bias = _make_alibi_bias( + self.alibi_slopes, self.num_kv_heads, query.dtype, + input_metadata) + + op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if ( + is_hip()) else None + # No alibi slopes. + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + if self.alibi_slopes is None: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias[0], + p=0.0, + scale=self.scale, + op=op) + + return out.view_as(query) + + # Attention with alibi slopes. + # FIXME(woosuk): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + print("SANG-TODO alibi slopes.") + output = torch.empty_like(query) + start = 0 + for i, prompt_len in enumerate(input_metadata.prompt_lens): + end = start + prompt_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=input_metadata.attn_bias[i], + p=0.0, + scale=self.scale, + op=op) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out.squeeze(0)) + start += prompt_len + return output + def _make_alibi_bias( alibi_slopes: torch.Tensor, @@ -190,6 +229,7 @@ def _make_alibi_bias( dtype: torch.dtype, input_metadata: InputMetadata, ) -> LowerTriangularMaskWithTensorBias: + attn_biases = [] for prompt_len in input_metadata.prompt_lens: bias = torch.arange(prompt_len, dtype=dtype) # NOTE(zhuohan): HF uses @@ -197,9 +237,10 @@ def _make_alibi_bias( # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. + # Calculate a matrix where each element represents ith element - jth element. bias = bias[None, :] - bias[:, None] - padded_len = (prompt_len + 7) // 8 * 8 + padded_len = _get_aligned_size(prompt_len, 8) num_heads = alibi_slopes.shape[0] bias = torch.empty( 1, # batch size @@ -212,8 +253,9 @@ def _make_alibi_bias( bias.mul_(alibi_slopes[:, None, None]) if num_heads != num_kv_heads: bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) - attn_bias = LowerTriangularMaskWithTensorBias(bias) - return attn_bias + attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) + + return attn_biases def _check_use_ref_attention() -> bool: diff --git a/vllm/model_executor/layers/attention/ops/paged_attn.py b/vllm/model_executor/layers/attention/ops/paged_attn.py index 32d18b0be6f6..d5397bf1a78c 100644 --- a/vllm/model_executor/layers/attention/ops/paged_attn.py +++ b/vllm/model_executor/layers/attention/ops/paged_attn.py @@ -121,15 +121,6 @@ def forward_prefix( alibi_slopes: Optional[torch.Tensor], ) -> torch.Tensor: output = torch.empty_like(query) - print("SANG-TODO prefix attention!") - print(f"{input_metadata.block_tables=}") - print(f"{input_metadata.start_loc=}") - print(f"{input_metadata.prompt_lens=}") - print(f"{input_metadata.context_lens=}") - print(f"{input_metadata.max_seq_len=}") - print(f"{query.size()=}") - print(f"{key.size()=}") - print(f"{value.size()=}") context_attention_fwd( query, key, @@ -139,7 +130,7 @@ def forward_prefix( value_cache, input_metadata.block_tables, # [BS, max_block_per_request] input_metadata.start_loc, - input_metadata.prompt_lens, + input_metadata.prompt_lens_tensor, input_metadata.context_lens, input_metadata.max_seq_len, alibi_slopes, diff --git a/vllm/utils.py b/vllm/utils.py index fe6fd27962cd..2a128ba87664 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -312,6 +312,11 @@ def create_kv_caches_with_random( return key_caches, value_caches +def _get_aligned_size(batch_size: int, alignment: int) -> int: + """Returns the padded batch based on an alignment.""" + return ((batch_size + alignment - 1) // alignment * alignment) + + class measure_cuda_memory: def __init__(self, device=None): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6b6549619b06..adf1e91ffa49 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -21,7 +21,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest -from vllm.utils import in_wsl, measure_cuda_memory +from vllm.utils import in_wsl, measure_cuda_memory, _get_aligned_size logger = init_logger(__name__) @@ -258,13 +258,12 @@ def _prepare_prompt( dtype=torch.int, device=self.device, ) - prompt_lens_tensor = torch.tensor(prompt_lens, - dtype=torch.long, - device=self.device) # Cumulative index of each prompt. [prompt_lens + 1] # [0, 0+1th, 0+1th+2nd, ...] - subquery_lens_tensor = torch.tensor(subquery_lens, dtype=torch.long, device=self.device) + subquery_lens_tensor = torch.tensor(subquery_lens, + dtype=torch.long, + device=self.device) start_loc_tensor = torch.zeros(subquery_lens_tensor.shape[0], dtype=torch.long, device=self.device) @@ -277,7 +276,7 @@ def _prepare_prompt( input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping, - prompt_lens=prompt_lens_tensor, + prompt_lens=prompt_lens, num_prompt_tokens=num_prompt_tokens, num_generation_tokens=0, max_seq_len=max_prompt_len, @@ -961,8 +960,7 @@ def _get_graph_batch_size(batch_size: int) -> int: elif batch_size <= 4: return 4 else: - return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // - _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) + return _get_aligned_size(batch_size, _BATCH_SIZE_ALIGNMENT) def _async_h2d( From fe344f6af231a2546ea003f5fc48a1e539ed72d8 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 02:20:56 -0700 Subject: [PATCH 47/88] add new fieflds --- vllm/model_executor/input_metadata.py | 74 ++++++++++++------- .../layers/attention/attention.py | 14 ++-- .../layers/attention/backends/flash_attn.py | 4 +- .../layers/attention/backends/xformers.py | 1 - .../layers/attention/ops/paged_attn.py | 7 +- vllm/worker/model_runner.py | 64 ++++++++++------ 6 files changed, 100 insertions(+), 64 deletions(-) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index db1005732a81..1d1764332b98 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -10,44 +10,62 @@ class InputMetadata: """Metadata for input sequences. Used in PagedAttention. NOTE: Any python object stored here is not updated when it is - cuda-graph replays. If you have values that need to be changed + cuda-graph replayed. If you have values that need to be changed dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. - - Args: - prompt_lens: Lengths of prompts per sequence. - slot_mapping: The indices of the token slots that input tokens will be stored into. - E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the three tokens - are stored in the 3rd slot in block 2, 2nd slot in block 0, and 1st slot in block 1, - respectively. - num_prompt_tokens: The total number of tokens in the prompts. This might - include padding. - num_generation_tokens: The number of tokens in the generation sequences. - This might include padding. - max_context_len: The maximum context length. - context_lens: the length of attention context for each sequence. - I.e., the number of tokens that have attended so far. - block_tables: The block tables. (Seq id -> list of physical block) - kv_cache_dtype: Data type to store kv cache. """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. is_prompt: bool + # (num_tokens,). The indices of the token slots that input tokens will be stored into. + # E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the three tokens + # are stored in the 3rd slot in block 2, 2nd slot in block 0, and 1st slot in block 1, + # respectively. slot_mapping: torch.Tensor + # (batch_size,). The prompt length per sequence. None if it is a decoding. prompt_lens: Optional[List] + # prompt_lens stored as a tensor. + prompt_lens_tensor: Optional[torch.Tensor] + # The number of prompt tokens. Doesn't include padding. num_prompt_tokens: int + # The number of generation tokens. Doesn't include padding. num_generation_tokens: int - max_seq_len: Optional[int] - start_loc: Optional[torch.Tensor] + + """ + Definition of context_len, subquery_len, and seqlen. + |---------- N-1 iteration --------| + |---------------- N iteration ---------------------| + |- tokenA -|......................|-- newTokens ---| + |---------- context_len ----------| + |-------------------- seqlen ----------------------| + |- subquery_len -| + + """ + + # Maximum sequence length in the batch. + max_subquery_len: Optional[int] + # Maximum context length in the batch. max_context_len: Optional[int] - # [batch_size]. Each index means each sequence, and the value means the length of tokens stored in the kv cache. - # NOTE(sang): When it is prefill/decoding, the definition is different. For prefill, it means the the length of KV that are cached excluding the new KVs. In decoding, this includes a new KV. + # (batch_size + 1,). The cumulative subquery lengths of the sequences in the batch, used to index into q. + # E.g., if the subquery length is [4, 6], it is [0, 4, 10]. + subquery_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in the batch, used to index into k. + # E.g., if the sequence length is [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,). The length of context (tokens stored in KV cache) per + # sequence. It doesn't include the length of new tokens. context_lens: Optional[torch.Tensor] + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # The first dimension is padded if it is cuda-graph captured. block_tables: Optional[torch.Tensor] + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. use_cuda_graph: bool kv_cache_dtype: str - # Fields below are initialiezd in post init. - prompt_lens_tensor: Optional[torch.Tensor] = None - def __post_init__(self): # Set during the execution of the first attention op. # It is a list because it is needed to set per prompt @@ -55,7 +73,7 @@ def __post_init__(self): # from xformer API. # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[AttentionBias]] = None - if self.prompt_lens is not None: - self.prompt_lens_tensor = torch.tensor(self.prompt_lens, - dtype=torch.long, - device=self.slot_mapping.device) + + # Cuda graph is only used for decoding now. + if self.use_cuda_graph: + assert self.num_prompt_tokens == 0 diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 189200390658..640a47ad3fdc 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger from vllm.model_executor.input_metadata import InputMetadata -# from vllm.utils import is_hip +from vllm.utils import is_hip logger = init_logger(__name__) @@ -18,16 +18,18 @@ class Attention(nn.Module): This class takes query, key, and value tensors as input. The input tensors can either contain prompt tokens or generation tokens. - If the input tensors contain prompt tokens, the layout is as follows: |<---------------------- num_valid_tokens ---------------------->| |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->| Otherwise, the layout is as follows: - |<------------------ num_valid_tokens ------------------->| - |<------- num_generation_tokens (M) ------->| - |<--generation_0-->|...|<--generation_M-1-->|<--padding-->| + |<------------------ num_valid_tokens -------------------------->| + |<---------- num_generation_tokens (M) ----------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Both prompt and generation can contain padding for cuda-graph (currently + decoding only) or to be aligned with length 8 (so that it can utilize tensor cores). The prompts might have different lengths, while the generation tokens always have length 1. The paddings are appended to make the input length a multiple diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index 0e3848233b96..a610188a9de2 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -82,9 +82,7 @@ def forward( if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # normal attention - query = query.unflatten(0, (num_tokens, )) - key = key.unflatten(0, (num_tokens, )) - value = value.unflatten(0, (num_tokens, )) + assert False output = flash_attn_func(query, key, value, diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/model_executor/layers/attention/backends/xformers.py index 305501e7d4de..444717229291 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/model_executor/layers/attention/backends/xformers.py @@ -204,7 +204,6 @@ def _multi_query_kv_attention( # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. - print("SANG-TODO alibi slopes.") output = torch.empty_like(query) start = 0 for i, prompt_len in enumerate(input_metadata.prompt_lens): diff --git a/vllm/model_executor/layers/attention/ops/paged_attn.py b/vllm/model_executor/layers/attention/ops/paged_attn.py index d5397bf1a78c..a9623d95e39c 100644 --- a/vllm/model_executor/layers/attention/ops/paged_attn.py +++ b/vllm/model_executor/layers/attention/ops/paged_attn.py @@ -128,11 +128,12 @@ def forward_prefix( output, key_cache, value_cache, - input_metadata.block_tables, # [BS, max_block_per_request] - input_metadata.start_loc, + input_metadata.block_tables, + # subquery_start_loc is (batch_size + 1,) + input_metadata.subquery_start_loc[:-1], input_metadata.prompt_lens_tensor, input_metadata.context_lens, - input_metadata.max_seq_len, + input_metadata.max_subquery_len, alibi_slopes, ) return output diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2b7ebaaa15f9..6090e2d96a3b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -226,12 +226,10 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - max_prompt_len = max(subquery_lens) + max_subquery_len = max(subquery_lens) num_prompt_tokens = len(input_tokens) - assert max_prompt_len > 0 + assert max_subquery_len > 0 - # Pad tokens to better utilize tensor cores although - # cuda graph is not enabled. input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, pad=0, dtype=torch.long, @@ -260,29 +258,44 @@ def _prepare_prompt( device=self.device, ) - # Cumulative index of each prompt. [prompt_lens + 1] - # [0, 0+1th, 0+1th+2nd, ...] + # Query length can be shorter than key (i.e., prompt) when prefill + # is chunked or prefix cached. subquery_lens_tensor = torch.tensor(subquery_lens, dtype=torch.long, device=self.device) - start_loc_tensor = torch.zeros(subquery_lens_tensor.shape[0], + subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1, + dtype=torch.long, + device=self.device) + + seq_tensor = torch.add(subquery_lens_tensor, context_lens_tensor) + prompt_lens_tensor = torch.tensor(prompt_lens, + dtype=torch.long, + device=self.device) + seq_start_loc = torch.zeros(seq_tensor.shape[0] + 1, dtype=torch.long, device=self.device) - torch.cumsum(subquery_lens_tensor[:-1], + torch.cumsum(subquery_lens_tensor, dim=0, - dtype=start_loc_tensor.dtype, - out=start_loc_tensor[1:]) + dtype=subquery_start_loc.dtype, + out=subquery_start_loc[1:]) + + torch.cumsum(seq_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping, prompt_lens=prompt_lens, + prompt_lens_tensor=prompt_lens_tensor, num_prompt_tokens=num_prompt_tokens, num_generation_tokens=0, - max_seq_len=max_prompt_len, - start_loc=start_loc_tensor, + max_subquery_len=max_subquery_len, max_context_len=None, + subquery_start_loc=subquery_start_loc, + seq_start_loc=seq_start_loc, context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -416,11 +429,13 @@ def _prepare_decode( is_prompt=False, slot_mapping=slot_mapping, prompt_lens=None, + prompt_lens_tensor=None, num_prompt_tokens=0, num_generation_tokens=len(input_tokens), - max_seq_len=None, - start_loc=None, + max_subquery_len=None, max_context_len=max_context_len, + subquery_start_loc=None, + seq_start_loc=None, context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, @@ -696,15 +711,15 @@ def list_loras(self) -> Set[int]: def capture_model(self, kv_caches: List[KVCache]) -> None: """Cuda graph capture a model. - Note that cuda graph performance gain is negligible if number - of batched tokens are less than 200. And since Cuda graph + Note that CUDA graph's performance gain is negligible if number + of batched tokens are larger than 200. And since CUDA graph requires fixed sized tensors, supporting large/variable batch size requires high GPU memory overhead. Thus, vLLM only captures decoding requests. Mixed batch (chunked prefill + decoding) or prefill requests are not captured. - this API assumes the shape is N batches of tokens flattened to 1D - tensor, where is token's seqlen is 1. + Since it is used for decoding-only, it assumes there's only 1 token + per sequence in the batch. """ # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never # deleted before the CUDA graphs. @@ -728,7 +743,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) - context_lens = torch.zeros(max_batch_size, dtype=torch.int32).cuda() + context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() graph_batch_size = _get_graph_batch_size( @@ -752,11 +767,13 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: is_prompt=False, slot_mapping=slot_mapping[:batch_size], prompt_lens=None, + prompt_lens_tensor=None, num_prompt_tokens=0, - num_generation_tokens=0, - max_seq_len=None, - start_loc=None, + num_generation_tokens=batch_size, + max_subquery_len=None, max_context_len=self.max_context_len_to_capture, + subquery_start_loc=None, + seq_start_loc=None, context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size], use_cuda_graph=True, @@ -908,6 +925,7 @@ def _make_tensor_with_pad_for_alignment( device: Optional[Union[str, torch.device]], ) -> torch.Tensor: """Create a tensor of a given list x with padding. + It adds paddings to align with graph batch size. See _get_graph_batch_size for more details. """ @@ -930,7 +948,7 @@ def _make_tensor_with_pad( def _get_graph_batch_size(batch_size: int) -> int: """Returns the padded batch size given actual batch size. - + Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... """ From e619c4ec2a0907e7d6cc835b7b502e778f7c718d Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 02:48:52 -0700 Subject: [PATCH 48/88] Flash attn works now --- tests/models/test_models.py | 2 +- vllm/model_executor/input_metadata.py | 2 + .../layers/attention/attention.py | 12 ++-- .../layers/attention/backends/flash_attn.py | 24 +++++--- vllm/worker/model_runner.py | 60 +++++++++---------- 5 files changed, 51 insertions(+), 49 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index cc7c6f5b1e41..070eb7017dd5 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -24,7 +24,7 @@ @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) def test_models( hf_runner, diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 1d1764332b98..70841974ed42 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -46,6 +46,8 @@ class InputMetadata: max_subquery_len: Optional[int] # Maximum context length in the batch. max_context_len: Optional[int] + # Maximum sequence length in the batch. + max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in the batch, used to index into q. # E.g., if the subquery length is [4, 6], it is [0, 4, 10]. subquery_start_loc: Optional[torch.Tensor] diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 640a47ad3fdc..8f7c6af9b6f2 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -19,17 +19,15 @@ class Attention(nn.Module): can either contain prompt tokens or generation tokens. If the input tensors contain prompt tokens, the layout is as follows: - |<---------------------- num_valid_tokens ---------------------->| |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| Otherwise, the layout is as follows: - |<------------------ num_valid_tokens -------------------------->| - |<---------- num_generation_tokens (M) ----------->| + |<------------------ num_generation_tokens (M) ----------------->| |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| - Both prompt and generation can contain padding for cuda-graph (currently - decoding only) or to be aligned with length 8 (so that it can utilize tensor cores). + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. The prompts might have different lengths, while the generation tokens always have length 1. The paddings are appended to make the input length a multiple @@ -52,7 +50,7 @@ def __init__( sliding_window: Optional[int] = None, ) -> None: super().__init__() - if False and _use_flash_attn(): + if _use_flash_attn(): from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501 self.backend = FlashAttentionBackend(num_heads, head_size, scale, num_kv_heads, alibi_slopes, diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index a610188a9de2..8a3d3dc66a3a 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -1,7 +1,7 @@ """Attention layer with Flash and PagedAttention.""" from typing import List, Optional -from flash_attn import flash_attn_func +from flash_attn import flash_attn_func, flash_attn_varlen_func import torch from vllm.model_executor.input_metadata import InputMetadata @@ -77,19 +77,25 @@ def forward( PagedAttentionImpl.reshape_and_cache(key, value, key_cache, value_cache, input_metadata) + if input_metadata.is_prompt: # Prompt run. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # normal attention - assert False - output = flash_attn_func(query, - key, - value, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes) + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=input_metadata.subquery_start_loc, + cu_seqlens_k=input_metadata.seq_start_loc, + max_seqlen_q=input_metadata.max_subquery_len, + max_seqlen_k=input_metadata.max_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) else: # prefix-enabled attention output = PagedAttentionImpl.forward_prefix( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6090e2d96a3b..f194e3678913 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -123,7 +123,7 @@ def set_block_size(self, block_size: int) -> None: (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), dtype=np.int32) - def get_max_block_per_batch(self): + def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_context_len_to_capture + block_size - 1) // block_size @@ -227,19 +227,14 @@ def _prepare_prompt( slot_mapping.append(slot) max_subquery_len = max(subquery_lens) + max_seq_len = max(prompt_lens) num_prompt_tokens = len(input_tokens) assert max_subquery_len > 0 - input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, - pad=0, - dtype=torch.long, - device=self.device) - input_positions = _make_tensor_with_pad_for_alignment( + input_tokens = _make_tensor(input_tokens, pad=0, dtype=torch.long, device=self.device) + input_positions = _make_tensor( input_positions, pad=0, dtype=torch.long, device=self.device) - slot_mapping = _make_tensor_with_pad_for_alignment(slot_mapping, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) + slot_mapping = _make_tensor(slot_mapping, pad=_PAD_SLOT_ID, dtype=torch.long, device=self.device) lora_index_mapping = _pad_to_alignment(lora_index_mapping, _get_graph_batch_size( len(lora_index_mapping)), @@ -264,15 +259,14 @@ def _prepare_prompt( dtype=torch.long, device=self.device) subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1, - dtype=torch.long, + dtype=torch.int32, device=self.device) - seq_tensor = torch.add(subquery_lens_tensor, context_lens_tensor) prompt_lens_tensor = torch.tensor(prompt_lens, dtype=torch.long, device=self.device) - seq_start_loc = torch.zeros(seq_tensor.shape[0] + 1, - dtype=torch.long, + seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, + dtype=torch.int32, device=self.device) torch.cumsum(subquery_lens_tensor, @@ -280,7 +274,7 @@ def _prepare_prompt( dtype=subquery_start_loc.dtype, out=subquery_start_loc[1:]) - torch.cumsum(seq_tensor, + torch.cumsum(prompt_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) @@ -294,6 +288,7 @@ def _prepare_prompt( num_generation_tokens=0, max_subquery_len=max_subquery_len, max_context_len=None, + max_seq_len=max_seq_len, subquery_start_loc=subquery_start_loc, seq_start_loc=seq_start_loc, context_lens=context_lens_tensor, @@ -380,16 +375,10 @@ def _prepare_decode( # Pad tokens to better utilize tensor cores although # cuda graph is not enabled. - input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, - pad=0, - dtype=torch.long, - device=self.device) - input_positions = _make_tensor_with_pad_for_alignment( + input_tokens = _make_tensor(input_tokens, pad=0, dtype=torch.long, device=self.device) + input_positions = _make_tensor( input_positions, pad=0, dtype=torch.long, device=self.device) - slot_mapping = _make_tensor_with_pad_for_alignment(slot_mapping, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) + slot_mapping = _make_tensor(slot_mapping, pad=_PAD_SLOT_ID, dtype=torch.long, device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -434,6 +423,7 @@ def _prepare_decode( num_generation_tokens=len(input_tokens), max_subquery_len=None, max_context_len=max_context_len, + max_seq_len=None, subquery_start_loc=None, seq_start_loc=None, context_lens=context_lens, @@ -772,6 +762,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: num_generation_tokens=batch_size, max_subquery_len=None, max_context_len=self.max_context_len_to_capture, + max_seq_len=None, subquery_start_loc=None, seq_start_loc=None, context_lens=context_lens[:batch_size], @@ -918,21 +909,22 @@ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: return x + [pad] * (max_len - len(x)) -def _make_tensor_with_pad_for_alignment( +def _make_tensor( x: List[int], pad: int, dtype: torch.dtype, device: Optional[Union[str, torch.device]], + align: bool = False, ) -> torch.Tensor: - """Create a tensor of a given list x with padding. + """Create a tensor of a given list. - It adds paddings to align with graph batch size. See - _get_graph_batch_size for more details. + If `align` is True, it creates a tensor with a padding with a given `pad`. """ - batch_size = len(x) - batch_size = _get_graph_batch_size(batch_size) - padded_x = _pad_to_alignment(x, batch_size, pad) - return torch.tensor(padded_x, dtype=dtype, device=device) + if align: + batch_size = len(x) + batch_size = _get_graph_batch_size(batch_size) + x = _pad_to_alignment(x, batch_size, pad) + return torch.tensor(x, dtype=dtype, device=device) def _make_tensor_with_pad( @@ -942,6 +934,10 @@ def _make_tensor_with_pad( dtype: torch.dtype, device: Optional[Union[str, torch.device]], ) -> torch.Tensor: + """Make a padded tensor of a 2D inputs. + + The padding is applied to the end of each inner list until it reaches `max_len`. + """ padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] return torch.tensor(padded_x, dtype=dtype, device=device) From 9c86aa34d96034c565d09b080c356be907a9421c Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 02:59:58 -0700 Subject: [PATCH 49/88] Linting --- .../test_basic_correctness.py | 4 +- tests/basic_correctness/test_cuda_graph.py | 41 ------------- tests/models/test_models.py | 30 +++++----- tests/prompts/example.txt | 2 +- tests/worker/test_model_runner.py | 39 ++++++------ vllm/model_executor/input_metadata.py | 19 +++--- .../layers/attention/attention.py | 5 +- .../layers/attention/backends/flash_attn.py | 3 +- .../layers/attention/backends/xformers.py | 3 +- .../layers/attention/ops/paged_attn.py | 2 +- .../layers/attention/ops/prefix_prefill.py | 2 +- vllm/worker/model_runner.py | 59 +++++++++++-------- 12 files changed, 87 insertions(+), 122 deletions(-) delete mode 100644 tests/basic_correctness/test_cuda_graph.py diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index fe67e0f2f480..da0176306b4e 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -13,6 +13,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [False, True]) def test_models( hf_runner, vllm_runner, @@ -20,12 +21,13 @@ def test_models( model: str, dtype: str, max_tokens: int, + enforce_eager: bool, ) -> None: hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype) + vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/basic_correctness/test_cuda_graph.py b/tests/basic_correctness/test_cuda_graph.py deleted file mode 100644 index 69ecb153eb47..000000000000 --- a/tests/basic_correctness/test_cuda_graph.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Compare the outputs of HF and vLLM when using greedy sampling. - -Make sure both cuda graph & eager mode works. - -Run `pytest tests/models/test_models.py --forked`. -""" -import pytest - -MODELS = [ - "facebook/opt-125m", -] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("enforce_eager", [False, True]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - enforce_eager: bool, -) -> None: - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - del hf_model - - vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - del vllm_model - - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 070eb7017dd5..fb567e837d28 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -5,26 +5,26 @@ import pytest MODELS = [ - # "facebook/opt-125m", - # "meta-llama/Llama-2-7b-hf", - # "mistralai/Mistral-7B-v0.1", - # "Deci/DeciLM-7b", - # "tiiuae/falcon-7b", - # "gpt2", - # "bigcode/tiny_starcoder_py", - # "EleutherAI/gpt-j-6b", - # "EleutherAI/pythia-70m", + "facebook/opt-125m", + "meta-llama/Llama-2-7b-hf", + "mistralai/Mistral-7B-v0.1", + "Deci/DeciLM-7b", + "tiiuae/falcon-7b", + "gpt2", + "bigcode/tiny_starcoder_py", + "EleutherAI/gpt-j-6b", + "EleutherAI/pythia-70m", "bigscience/bloom-560m", - # "mosaicml/mpt-7b", - # "microsoft/phi-2", - # "stabilityai/stablelm-3b-4e1t", - # "allenai/OLMo-1B", - # "bigcode/starcoder2-3b", + "mosaicml/mpt-7b", + "microsoft/phi-2", + "stabilityai/stablelm-3b-4e1t", + "allenai/OLMo-1B", + "bigcode/starcoder2-3b", ] @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [128]) def test_models( hf_runner, diff --git a/tests/prompts/example.txt b/tests/prompts/example.txt index cef4d1d76873..e1b97bc6eee7 100644 --- a/tests/prompts/example.txt +++ b/tests/prompts/example.txt @@ -5,4 +5,4 @@ Describe the basic components of a neural network and how it can be trained. Write a short story about a robot that dreams for the first time. Analyze the impact of the COVID-19 pandemic on global economic structures and future business models. Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies. -Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.' \ No newline at end of file +Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.' diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index f26ad16fa41b..af921b00f7df 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -4,12 +4,11 @@ from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT from vllm.config import ModelConfig +from vllm.utils import _get_aligned_size -# Make sure the result is aligned. -def round_up_to_next_multiple_of_batch_size(batch_size): - return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // - _BATCH_SIZE_ALIGNMENT) * _BATCH_SIZE_ALIGNMENT +def get_aligned_size(batch_size): + return _get_aligned_size(batch_size, _BATCH_SIZE_ALIGNMENT) def test_prepare_prompt(): @@ -40,8 +39,8 @@ def test_prepare_prompt(): expected_selected_token_indices.append(selected_token_start_idx + prompt_len - 1) selected_token_start_idx += prompt_len - input_tokens, input_positions, input_metadata, return_prompt_lens, _, _, _, _ = ( - model_runner._prepare_prompt(seq_group_metadata_list)) + (input_tokens, input_positions, input_metadata, return_prompt_lens, _, _, + _, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens # Verify input metadata is correct for prompts. @@ -79,22 +78,18 @@ def test_prepare_prompt(): # Cuda graph should not be used for prerill. assert input_metadata.use_cuda_graph is False assert input_metadata.kv_cache_dtype == "auto" - assert input_metadata.num_valid_tokens == round_up_to_next_multiple_of_batch_size( + assert input_metadata.num_valid_tokens == _get_aligned_size( sum(prompt_lens)) - assert input_tokens.shape == (round_up_to_next_multiple_of_batch_size( - sum(prompt_lens)), ) - assert input_positions.shape == (round_up_to_next_multiple_of_batch_size( - sum(prompt_lens)), ) + assert input_tokens.shape == (get_aligned_size(sum(prompt_lens)), ) + assert input_positions.shape == (get_aligned_size(sum(prompt_lens)), ) torch.testing.assert_close(input_tokens, input_positions) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - assert input_tokens.shape == (round_up_to_next_multiple_of_batch_size( - sum(prompt_lens)), ) - assert input_positions.shape == (round_up_to_next_multiple_of_batch_size( - sum(prompt_lens)), ) + assert input_tokens.shape == (get_aligned_size(sum(prompt_lens)), ) + assert input_positions.shape == (get_aligned_size(sum(prompt_lens)), ) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, @@ -126,7 +121,7 @@ def test_prepare_decode_cuda_graph(): model_runner.set_block_size(16) # Make sure the result is aligned. - def round_up_to_next_multiple_of_batch_size(n): + def get_aligned_size(n): batch_size = _BATCH_SIZE_ALIGNMENT return ((n + 7) // batch_size) * batch_size @@ -155,8 +150,8 @@ def round_up_to_next_multiple_of_batch_size(n): assert input_metadata.is_prompt is False assert input_metadata.prompt_lens is None assert input_metadata.num_prompt_tokens == 0 - assert input_metadata.num_generation_tokens == ( - round_up_to_next_multiple_of_batch_size(len(seq_group_metadata_list))) + assert input_metadata.num_generation_tokens == (get_aligned_size( + len(seq_group_metadata_list))) assert input_metadata.max_seq_len is None assert input_metadata.start_loc is None assert input_metadata.max_context_len == max(prompt_lens) @@ -174,12 +169,12 @@ def round_up_to_next_multiple_of_batch_size(n): # Cuda graph should not be used for prerill. assert input_metadata.use_cuda_graph is True assert input_metadata.kv_cache_dtype == "auto" - assert input_metadata.num_valid_tokens == ( - round_up_to_next_multiple_of_batch_size(len(seq_group_metadata_list))) + assert input_metadata.num_valid_tokens == (get_aligned_size( + len(seq_group_metadata_list))) - assert input_tokens.shape == (round_up_to_next_multiple_of_batch_size( + assert input_tokens.shape == (get_aligned_size( len(seq_group_metadata_list)), ) - assert input_positions.shape == (round_up_to_next_multiple_of_batch_size( + assert input_positions.shape == (get_aligned_size( len(seq_group_metadata_list)), ) torch.testing.assert_close(input_tokens, input_positions) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 70841974ed42..b11c49a592f0 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -17,10 +17,10 @@ class InputMetadata: # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (num_tokens,). The indices of the token slots that input tokens will be stored into. - # E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the three tokens - # are stored in the 3rd slot in block 2, 2nd slot in block 0, and 1st slot in block 1, - # respectively. + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor # (batch_size,). The prompt length per sequence. None if it is a decoding. prompt_lens: Optional[List] @@ -30,7 +30,6 @@ class InputMetadata: num_prompt_tokens: int # The number of generation tokens. Doesn't include padding. num_generation_tokens: int - """ Definition of context_len, subquery_len, and seqlen. |---------- N-1 iteration --------| @@ -48,11 +47,13 @@ class InputMetadata: max_context_len: Optional[int] # Maximum sequence length in the batch. max_seq_len: Optional[int] - # (batch_size + 1,). The cumulative subquery lengths of the sequences in the batch, used to index into q. - # E.g., if the subquery length is [4, 6], it is [0, 4, 10]. + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. subquery_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in the batch, used to index into k. - # E.g., if the sequence length is [4, 6], it is [0, 4, 10]. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] # (batch_size,). The length of context (tokens stored in KV cache) per # sequence. It doesn't include the length of new tokens. diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 8f7c6af9b6f2..2e62c02c1d8c 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -29,9 +29,8 @@ class Attention(nn.Module): Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. - The prompts might have different lengths, while the generation tokens always - have length 1. The paddings are appended to make the input length a multiple - of 8, which is desirable for Tensor Cores. + The prompts might have different lengths, while the generation tokens + always have length 1. The class does the following: diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index 8a3d3dc66a3a..ff2cc4bbfd42 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -1,7 +1,7 @@ """Attention layer with Flash and PagedAttention.""" from typing import List, Optional -from flash_attn import flash_attn_func, flash_attn_varlen_func +from flash_attn import flash_attn_varlen_func import torch from vllm.model_executor.input_metadata import InputMetadata @@ -77,7 +77,6 @@ def forward( PagedAttentionImpl.reshape_and_cache(key, value, key_cache, value_cache, input_metadata) - if input_metadata.is_prompt: # Prompt run. if (key_cache is None or value_cache is None diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/model_executor/layers/attention/backends/xformers.py index 444717229291..610ae13374af 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/model_executor/layers/attention/backends/xformers.py @@ -236,7 +236,8 @@ def _make_alibi_bias( # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. - # Calculate a matrix where each element represents ith element - jth element. + # Calculate a matrix where each element represents ith element- jth + # element. bias = bias[None, :] - bias[:, None] padded_len = _get_aligned_size(prompt_len, 8) diff --git a/vllm/model_executor/layers/attention/ops/paged_attn.py b/vllm/model_executor/layers/attention/ops/paged_attn.py index a9623d95e39c..3105ba37b983 100644 --- a/vllm/model_executor/layers/attention/ops/paged_attn.py +++ b/vllm/model_executor/layers/attention/ops/paged_attn.py @@ -129,7 +129,7 @@ def forward_prefix( key_cache, value_cache, input_metadata.block_tables, - # subquery_start_loc is (batch_size + 1,) + # subquery_start_loc is (batch_size + 1,) input_metadata.subquery_start_loc[:-1], input_metadata.prompt_lens_tensor, input_metadata.context_lens, diff --git a/vllm/model_executor/layers/attention/ops/prefix_prefill.py b/vllm/model_executor/layers/attention/ops/prefix_prefill.py index 541c0aae44c7..70f09224f1cf 100644 --- a/vllm/model_executor/layers/attention/ops/prefix_prefill.py +++ b/vllm/model_executor/layers/attention/ops/prefix_prefill.py @@ -632,7 +632,7 @@ def context_attention_fwd(q, alibi_slopes=None): cap = torch.cuda.get_device_capability() - BLOCK = 32 if cap[0] >= 8 else 64 + BLOCK = 128 if cap[0] >= 8 else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f194e3678913..684ed8df6058 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -155,7 +155,7 @@ def _prepare_prompt( prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) computed_len = 0 - import os + # NOTE: This only works for oooooooxxx style attention. computed_block_nums = seq_group_metadata.computed_block_nums if computed_block_nums is not None and len( @@ -165,14 +165,6 @@ def _prepare_prompt( prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) context_len = computed_len - if os.getenv("ENABLE") is not None: - if seq_group_metadata.block_tables is None: - prefix_block_tables.append([]) - else: - block_table = seq_group_metadata.block_tables[seq_id] - prefix_block_tables.append(block_table) - # prefix_block_tables.append([]) - context_len = 0 else: prefix_block_tables.append([]) context_len = 0 @@ -231,10 +223,18 @@ def _prepare_prompt( num_prompt_tokens = len(input_tokens) assert max_subquery_len > 0 - input_tokens = _make_tensor(input_tokens, pad=0, dtype=torch.long, device=self.device) - input_positions = _make_tensor( - input_positions, pad=0, dtype=torch.long, device=self.device) - slot_mapping = _make_tensor(slot_mapping, pad=_PAD_SLOT_ID, dtype=torch.long, device=self.device) + input_tokens = _make_tensor(input_tokens, + pad=0, + dtype=torch.long, + device=self.device) + input_positions = _make_tensor(input_positions, + pad=0, + dtype=torch.long, + device=self.device) + slot_mapping = _make_tensor(slot_mapping, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) lora_index_mapping = _pad_to_alignment(lora_index_mapping, _get_graph_batch_size( len(lora_index_mapping)), @@ -259,15 +259,15 @@ def _prepare_prompt( dtype=torch.long, device=self.device) subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) + dtype=torch.int32, + device=self.device) prompt_lens_tensor = torch.tensor(prompt_lens, - dtype=torch.long, - device=self.device) + dtype=torch.long, + device=self.device) seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) + dtype=torch.int32, + device=self.device) torch.cumsum(subquery_lens_tensor, dim=0, @@ -375,10 +375,18 @@ def _prepare_decode( # Pad tokens to better utilize tensor cores although # cuda graph is not enabled. - input_tokens = _make_tensor(input_tokens, pad=0, dtype=torch.long, device=self.device) - input_positions = _make_tensor( - input_positions, pad=0, dtype=torch.long, device=self.device) - slot_mapping = _make_tensor(slot_mapping, pad=_PAD_SLOT_ID, dtype=torch.long, device=self.device) + input_tokens = _make_tensor(input_tokens, + pad=0, + dtype=torch.long, + device=self.device) + input_positions = _make_tensor(input_positions, + pad=0, + dtype=torch.long, + device=self.device) + slot_mapping = _make_tensor(slot_mapping, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -702,7 +710,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: """Cuda graph capture a model. Note that CUDA graph's performance gain is negligible if number - of batched tokens are larger than 200. And since CUDA graph + of batched tokens are larger than 200. And since CUDA graph requires fixed sized tensors, supporting large/variable batch size requires high GPU memory overhead. Thus, vLLM only captures decoding requests. Mixed batch (chunked prefill + decoding) or @@ -936,7 +944,8 @@ def _make_tensor_with_pad( ) -> torch.Tensor: """Make a padded tensor of a 2D inputs. - The padding is applied to the end of each inner list until it reaches `max_len`. + The padding is applied to the end of each inner list until it reaches + `max_len`. """ padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] return torch.tensor(padded_x, dtype=dtype, device=device) From 5b4aa095f50b935d173618332cd56109f40ad589 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 03:32:36 -0700 Subject: [PATCH 50/88] temporary --- tests/spec_decode/test_multi_step_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 45b43ec59ee8..5f788549d44d 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -92,8 +92,8 @@ def test_same_output_for_single_step(): num_gpu_blocks, seed, ) - multi_step_worker.model_runner = worker.model_runner - multi_step_worker.cache_engine = worker.cache_engine + # multi_step_worker.model_runner = worker.model_runner + # multi_step_worker.cache_engine = worker.cache_engine num_steps = 1 From 4cced78f2e2101fa76bb97231f2fe6b391cde817 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 03:44:25 -0700 Subject: [PATCH 51/88] fix tests --- vllm/model_executor/input_metadata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 47163b20fd5a..8b2a110d39df 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -70,7 +70,6 @@ class InputMetadata: # Cuda-graph is currently enabled for decoding only. use_cuda_graph: bool kv_cache_dtype: str ->>>>>>> 1dquery def __post_init__(self): # Set during the execution of the first attention op. From cdb7a2c91065365aed34eea63707d83e9db037f1 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 04:36:25 -0700 Subject: [PATCH 52/88] Fixed --- vllm/worker/model_runner.py | 75 +++++++++++++------------------------ 1 file changed, 25 insertions(+), 50 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 684ed8df6058..e9565846937e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -35,6 +35,10 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) ] +# True if inputs should be aligned. It is currently disabled. +# Aligning inputs can better utilize tensor cores. +# https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/ +SHOULD_ALIGN = False class ModelRunner: @@ -223,22 +227,18 @@ def _prepare_prompt( num_prompt_tokens = len(input_tokens) assert max_subquery_len > 0 - input_tokens = _make_tensor(input_tokens, - pad=0, + input_tokens = torch.tensor(_align_if_necessary(input_tokens, pad=0), dtype=torch.long, device=self.device) - input_positions = _make_tensor(input_positions, - pad=0, + input_positions = torch.tensor(_align_if_necessary(input_positions, + pad=0), dtype=torch.long, device=self.device) - slot_mapping = _make_tensor(slot_mapping, - pad=_PAD_SLOT_ID, + slot_mapping = torch.tensor(_align_if_necessary(slot_mapping, + pad=_PAD_SLOT_ID), dtype=torch.long, device=self.device) - lora_index_mapping = _pad_to_alignment(lora_index_mapping, - _get_graph_batch_size( - len(lora_index_mapping)), - pad=0) + lora_index_mapping = _align_if_necessary(lora_index_mapping, pad=0) context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, @@ -360,36 +360,26 @@ def _prepare_decode( not self.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_context_len <= self.max_context_len_to_capture) - if use_captured_graph: - # Pad the input tokens, positions, and slot mapping to match the - # batch size of the captured graph. - graph_batch_size = _get_graph_batch_size(batch_size) - assert graph_batch_size >= batch_size - for _ in range(graph_batch_size - batch_size): - input_tokens.append(0) - input_positions.append(0) - slot_mapping.append(_PAD_SLOT_ID) - context_lens.append(0) - block_tables.append([]) - batch_size = graph_batch_size # Pad tokens to better utilize tensor cores although # cuda graph is not enabled. - input_tokens = _make_tensor(input_tokens, - pad=0, + input_tokens = torch.tensor(_align_if_necessary( + input_tokens, pad=0, should_align=use_captured_graph), dtype=torch.long, device=self.device) - input_positions = _make_tensor(input_positions, - pad=0, + input_positions = torch.tensor(_align_if_necessary( + input_positions, pad=0, should_align=use_captured_graph), dtype=torch.long, device=self.device) - slot_mapping = _make_tensor(slot_mapping, - pad=_PAD_SLOT_ID, + slot_mapping = torch.tensor(_align_if_necessary( + slot_mapping, pad=_PAD_SLOT_ID, should_align=use_captured_graph), dtype=torch.long, device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) + lora_index_mapping = _align_if_necessary( + lora_index_mapping, pad=0, should_align=use_captured_graph) if use_captured_graph: # When using cuda-graph all these tensors should be @@ -398,7 +388,6 @@ def _prepare_decode( assert context_lens.shape[0] == input_positions.shape[0] assert context_lens.shape[0] == slot_mapping.shape[0] - if use_captured_graph: # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.graph_block_tables[:batch_size] @@ -417,11 +406,6 @@ def _prepare_decode( device=self.device, ) - lora_index_mapping = _pad_to_alignment(lora_index_mapping, - _get_graph_batch_size( - len(lora_index_mapping)), - pad=0) - input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping, @@ -917,22 +901,13 @@ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: return x + [pad] * (max_len - len(x)) -def _make_tensor( - x: List[int], - pad: int, - dtype: torch.dtype, - device: Optional[Union[str, torch.device]], - align: bool = False, -) -> torch.Tensor: - """Create a tensor of a given list. - - If `align` is True, it creates a tensor with a padding with a given `pad`. - """ - if align: - batch_size = len(x) - batch_size = _get_graph_batch_size(batch_size) - x = _pad_to_alignment(x, batch_size, pad) - return torch.tensor(x, dtype=dtype, device=device) +def _align_if_necessary(x: List[int], pad: int, should_align=SHOULD_ALIGN): + """Align flattened 1D inputs by a fixed alignment size.""" + if not should_align: + return x + batch_size = len(x) + batch_size = _get_graph_batch_size(batch_size) + return _pad_to_alignment(x, batch_size, pad) def _make_tensor_with_pad( From d87b651781d61fc01c6c7f8e24f6a466d98c28ea Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 06:14:39 -0700 Subject: [PATCH 53/88] Pass unit tests. --- tests/worker/test_model_runner.py | 46 ++++++++++++++++--------------- vllm/worker/model_runner.py | 8 +++++- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index af921b00f7df..e462ea7935eb 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -46,25 +46,35 @@ def test_prepare_prompt(): # Verify input metadata is correct for prompts. device = model_runner.device assert input_metadata.is_prompt is True - assert torch.allclose(input_metadata.prompt_lens, + assert torch.allclose(input_metadata.prompt_lens_tensor, torch.tensor(prompt_lens, device=device)) + assert input_metadata.prompt_lens == prompt_lens assert input_metadata.num_prompt_tokens == sum(prompt_lens) assert input_metadata.num_generation_tokens == 0 assert input_metadata.max_seq_len == max(prompt_lens) - # build start_loc + + # Test subquery start locs. start_idx = 0 start_loc = [start_idx] - # start_loc is padded. for prompt_len in prompt_lens: start_idx += prompt_len start_loc.append(start_idx) assert torch.allclose( - input_metadata.start_loc, - torch.tensor(start_loc, dtype=torch.long, device=device)) + input_metadata.subquery_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + + # Test seq start locs. Note that for normal prefill it is + # equivalent to subquery_start_loc. + start_idx = 0 + seq_start_loc = [start_idx] + for prompt_len in prompt_lens: + start_idx += prompt_len + seq_start_loc.append(start_idx) + + assert torch.allclose( + input_metadata.seq_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) assert input_metadata.max_context_len is None - # TODO(sang): The current definition of context_lens is the - # number of k/v that are already cached (before this run). - # It is inconsistent with decoding. assert torch.allclose( input_metadata.context_lens, torch.zeros(input_metadata.context_lens.shape[0], @@ -78,18 +88,16 @@ def test_prepare_prompt(): # Cuda graph should not be used for prerill. assert input_metadata.use_cuda_graph is False assert input_metadata.kv_cache_dtype == "auto" - assert input_metadata.num_valid_tokens == _get_aligned_size( - sum(prompt_lens)) - assert input_tokens.shape == (get_aligned_size(sum(prompt_lens)), ) - assert input_positions.shape == (get_aligned_size(sum(prompt_lens)), ) + assert input_tokens.shape == (sum(prompt_lens), ) + assert input_positions.shape == (sum(prompt_lens), ) torch.testing.assert_close(input_tokens, input_positions) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - assert input_tokens.shape == (get_aligned_size(sum(prompt_lens)), ) - assert input_positions.shape == (get_aligned_size(sum(prompt_lens)), ) + assert input_tokens.shape == (sum(prompt_lens), ) + assert input_positions.shape == (sum(prompt_lens), ) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, @@ -120,11 +128,6 @@ def test_prepare_decode_cuda_graph(): model_runner = ModelRunner(model_config, None, None, None, None) model_runner.set_block_size(16) - # Make sure the result is aligned. - def get_aligned_size(n): - batch_size = _BATCH_SIZE_ALIGNMENT - return ((n + 7) // batch_size) * batch_size - batch_size = random.randint(1, 256) prompt_lens = [] seq_group_metadata_list = [] @@ -153,7 +156,8 @@ def get_aligned_size(n): assert input_metadata.num_generation_tokens == (get_aligned_size( len(seq_group_metadata_list))) assert input_metadata.max_seq_len is None - assert input_metadata.start_loc is None + assert input_metadata.subquery_start_loc is None + assert input_metadata.seq_start_loc is None assert input_metadata.max_context_len == max(prompt_lens) assert torch.allclose( input_metadata.context_lens[:len(prompt_lens)], @@ -169,8 +173,6 @@ def get_aligned_size(n): # Cuda graph should not be used for prerill. assert input_metadata.use_cuda_graph is True assert input_metadata.kv_cache_dtype == "auto" - assert input_metadata.num_valid_tokens == (get_aligned_size( - len(seq_group_metadata_list))) assert input_tokens.shape == (get_aligned_size( len(seq_group_metadata_list)), ) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e9565846937e..63defb17f0d6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -360,6 +360,8 @@ def _prepare_decode( not self.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_context_len <= self.max_context_len_to_capture) + if use_captured_graph: + batch_size = _get_graph_batch_size(batch_size) # Pad tokens to better utilize tensor cores although # cuda graph is not enabled. @@ -375,9 +377,13 @@ def _prepare_decode( slot_mapping, pad=_PAD_SLOT_ID, should_align=use_captured_graph), dtype=torch.long, device=self.device) - context_lens = torch.tensor(context_lens, + context_lens = torch.tensor(_align_if_necessary( + context_lens, pad=0, should_align=use_captured_graph), dtype=torch.int, device=self.device) + block_tables = _align_if_necessary(block_tables, + pad=[], + should_align=use_captured_graph) lora_index_mapping = _align_if_necessary( lora_index_mapping, pad=0, should_align=use_captured_graph) From 2c18896731209b7d9e981749b5bdcc8c7cefc4e8 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 06:51:12 -0700 Subject: [PATCH 54/88] experiment --- tests/models/test_models.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index fb567e837d28..b8ee03759284 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -6,20 +6,20 @@ MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", - "mistralai/Mistral-7B-v0.1", - "Deci/DeciLM-7b", - "tiiuae/falcon-7b", - "gpt2", - "bigcode/tiny_starcoder_py", - "EleutherAI/gpt-j-6b", - "EleutherAI/pythia-70m", - "bigscience/bloom-560m", - "mosaicml/mpt-7b", - "microsoft/phi-2", - "stabilityai/stablelm-3b-4e1t", - "allenai/OLMo-1B", - "bigcode/starcoder2-3b", + # "meta-llama/Llama-2-7b-hf", + # "mistralai/Mistral-7B-v0.1", + # "Deci/DeciLM-7b", + # "tiiuae/falcon-7b", + # "gpt2", + # "bigcode/tiny_starcoder_py", + # "EleutherAI/gpt-j-6b", + # "EleutherAI/pythia-70m", + # "bigscience/bloom-560m", + # "mosaicml/mpt-7b", + # "microsoft/phi-2", + # "stabilityai/stablelm-3b-4e1t", + # "allenai/OLMo-1B", + # "bigcode/starcoder2-3b", ] From b46f902b89d96e6a0103ab3fdbe9390a73ddc55d Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 07:00:30 -0700 Subject: [PATCH 55/88] . --- tests/models/test_models.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index b8ee03759284..fb567e837d28 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -6,20 +6,20 @@ MODELS = [ "facebook/opt-125m", - # "meta-llama/Llama-2-7b-hf", - # "mistralai/Mistral-7B-v0.1", - # "Deci/DeciLM-7b", - # "tiiuae/falcon-7b", - # "gpt2", - # "bigcode/tiny_starcoder_py", - # "EleutherAI/gpt-j-6b", - # "EleutherAI/pythia-70m", - # "bigscience/bloom-560m", - # "mosaicml/mpt-7b", - # "microsoft/phi-2", - # "stabilityai/stablelm-3b-4e1t", - # "allenai/OLMo-1B", - # "bigcode/starcoder2-3b", + "meta-llama/Llama-2-7b-hf", + "mistralai/Mistral-7B-v0.1", + "Deci/DeciLM-7b", + "tiiuae/falcon-7b", + "gpt2", + "bigcode/tiny_starcoder_py", + "EleutherAI/gpt-j-6b", + "EleutherAI/pythia-70m", + "bigscience/bloom-560m", + "mosaicml/mpt-7b", + "microsoft/phi-2", + "stabilityai/stablelm-3b-4e1t", + "allenai/OLMo-1B", + "bigcode/starcoder2-3b", ] From 07b22f8d6c207ef5d07c5292163d772ed8f323c3 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 07:14:05 -0700 Subject: [PATCH 56/88] . --- .buildkite/test-pipeline.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2c7dd9f304b9..f599857243eb 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -37,10 +37,10 @@ steps: command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 -- label: Models Test - commands: - - pytest -v -s models --forked - soft_fail: true +# - label: Models Test +# commands: +# - pytest -v -s models --forked +# soft_fail: true - label: Prefix Caching Test commands: From 9bd7ea16fff9f5f24adefd36f0b1a94ae3878fa7 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 07:17:00 -0700 Subject: [PATCH 57/88] . --- .buildkite/test-pipeline.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f599857243eb..2c7dd9f304b9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -37,10 +37,10 @@ steps: command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 -# - label: Models Test -# commands: -# - pytest -v -s models --forked -# soft_fail: true +- label: Models Test + commands: + - pytest -v -s models --forked + soft_fail: true - label: Prefix Caching Test commands: From c55402f1d6fae190ee7d0262c58e77eebce51ebd Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 07:24:43 -0700 Subject: [PATCH 58/88] trial --- .buildkite/test-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2c7dd9f304b9..42e77f2f7bd1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -47,7 +47,7 @@ steps: - pytest -v -s prefix_caching - label: Samplers Test - command: pytest -v -s samplers --forked + command: pytest -v -s samplers - label: Worker Test command: pytest -v -s worker From a13cf7ebe482cce2d77ba54b38cba84464a7fe0e Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 07:45:50 -0700 Subject: [PATCH 59/88] remove --fork --- .buildkite/test-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 42e77f2f7bd1..451760cbb046 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -56,7 +56,7 @@ steps: command: pytest -v -s spec_decode - label: LoRA Test %N - command: pytest -v -s lora --forked --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 - label: Metrics Test From ec91304a0bc543c996db4fdb8205f19a0fb872ac Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 18 Mar 2024 17:31:34 -0700 Subject: [PATCH 60/88] fixed --- vllm/model_executor/input_metadata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index f0e34eee44ad..c595fefa115e 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -30,7 +30,6 @@ class InputMetadata: num_prompt_tokens: int # The number of generation tokens. Doesn't include padding. num_generation_tokens: int - """ Definition of context_len, subquery_len, and seqlen. |---------- N-1 iteration --------| From 2e6e9199fee04329c3c75e4f1cb2ea2dce2e203e Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 19 Mar 2024 03:09:57 -0700 Subject: [PATCH 61/88] Addressed code review. --- tests/worker/test_model_runner.py | 13 ++--- vllm/model_executor/input_metadata.py | 17 ++++-- .../layers/attention/attention.py | 16 +----- .../layers/attention/backends/flash_attn.py | 21 ++++++- .../layers/attention/backends/xformers.py | 55 ++++++++++++------ vllm/utils.py | 5 -- vllm/worker/model_runner.py | 56 +++++++++---------- 7 files changed, 102 insertions(+), 81 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index e462ea7935eb..44b22c2bd8a2 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,14 +1,13 @@ import random import torch +from vllm.config import ModelConfig from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT -from vllm.config import ModelConfig -from vllm.utils import _get_aligned_size -def get_aligned_size(batch_size): - return _get_aligned_size(batch_size, _BATCH_SIZE_ALIGNMENT) +def get_aligned_size(batch_size: int, alignment: int): + return ((batch_size + alignment - 1) // alignment * alignment) def test_prepare_prompt(): @@ -154,7 +153,7 @@ def test_prepare_decode_cuda_graph(): assert input_metadata.prompt_lens is None assert input_metadata.num_prompt_tokens == 0 assert input_metadata.num_generation_tokens == (get_aligned_size( - len(seq_group_metadata_list))) + len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT)) assert input_metadata.max_seq_len is None assert input_metadata.subquery_start_loc is None assert input_metadata.seq_start_loc is None @@ -175,9 +174,9 @@ def test_prepare_decode_cuda_graph(): assert input_metadata.kv_cache_dtype == "auto" assert input_tokens.shape == (get_aligned_size( - len(seq_group_metadata_list)), ) + len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), ) assert input_positions.shape == (get_aligned_size( - len(seq_group_metadata_list)), ) + len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), ) torch.testing.assert_close(input_tokens, input_positions) # Verify Sampling diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index c595fefa115e..8f7061795711 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,8 +1,8 @@ from dataclasses import dataclass, fields from typing import Optional, List, Any, Dict -from xformers.ops.fmha.attn_bias import AttentionBias import torch +from xformers.ops.fmha.attn_bias import AttentionBias @dataclass @@ -23,7 +23,7 @@ class InputMetadata: # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List] + prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] # The number of prompt tokens. Doesn't include padding. @@ -38,10 +38,13 @@ class InputMetadata: |---------- context_len ----------| |-------------------- seqlen ----------------------| |- subquery_len -| - + + WARNING: context_len has different definition depending on if it is + prefill vs decoding. When it is prefill, it doesn't include new + tokens. When it is for decoding, it includes a new token. """ - # Maximum sequence length in the batch. + # Maximum subquery length in the batch. max_subquery_len: Optional[int] # Maximum context length in the batch. max_context_len: Optional[int] @@ -56,13 +59,15 @@ class InputMetadata: # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] # (batch_size,). The length of context (tokens stored in KV cache) per - # sequence. It doesn't include the length of new tokens. + # sequence. WARNING: When it is a prefill request, it doesn't include new + # tokens. When it is for decoding, it includes a new token. context_lens: Optional[torch.Tensor] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks # in the kv cache. Each block can contain up to block_size tokens. - # The first dimension is padded if it is cuda-graph captured. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. block_tables: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 2e62c02c1d8c..ae598b029a00 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -18,25 +18,11 @@ class Attention(nn.Module): This class takes query, key, and value tensors as input. The input tensors can either contain prompt tokens or generation tokens. - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| - - Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - The class does the following: 1. Store the input key and value tensors in the KV cache. 2. Perform (multi-head/multi-query/grouped-query) attention. - 3. Output a flattened 1D tensor. + 3. Output the output tensor. """ def __init__( diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index ff2cc4bbfd42..9ce5851f3650 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -10,6 +10,21 @@ class FlashAttentionBackend: + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ def __init__( self, @@ -82,13 +97,15 @@ def forward( if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. output = flash_attn_varlen_func( q=query, k=key, v=value, - cu_seqlens_q=input_metadata.subquery_start_loc, + cu_seqlens_q=input_metadata.seq_start_loc, cu_seqlens_k=input_metadata.seq_start_loc, - max_seqlen_q=input_metadata.max_subquery_len, + max_seqlen_q=input_metadata.max_seq_len, max_seqlen_k=input_metadata.max_seq_len, softmax_scale=self.scale, causal=True, diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/model_executor/layers/attention/backends/xformers.py index 610ae13374af..9e053e52687e 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/model_executor/layers/attention/backends/xformers.py @@ -10,10 +10,25 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention.ops.paged_attn import ( PagedAttentionImpl) -from vllm.utils import is_hip, _get_aligned_size +from vllm.utils import is_hip class XFormersBackend: + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ def __init__( self, @@ -81,7 +96,7 @@ def forward( if input_metadata.is_prompt: # Prompt run. - # key_cache and value_cache is None when it is a profiling run. + # key_cache and value_cache are None when it is a profiling run. # block tables are empty if the prompt has never been computed. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): @@ -104,22 +119,30 @@ def forward( value.shape[-1]) if self.use_ref_attention: - output = _ref_masked_attention( - query, - key, - value, - self.num_heads, - self.num_kv_heads, - self.head_size, - self.scale, - ) + output = torch.empty_like(query) + start = 0 + for _, prompt_len in enumerate(input_metadata.prompt_lens): + end = start + prompt_len + out = _ref_masked_attention( + query[None, start:end], + key[None, start:end], + value[None, start:end], + self.num_heads, + self.num_kv_heads, + self.head_size, + self.scale, + ) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out) + start += prompt_len + # Using view got RuntimeError: view size is not compatible # with input tensor's size and stride (at least one # dimension spans across two contiguous subspaces). # Use reshape instead. return output.reshape(num_tokens, hidden_size) - output = self._multi_query_kv_attention( + output = self._run_memory_efficient_xformer_forward( query, key, value, input_metadata) else: # prefix-enabled attention @@ -147,7 +170,7 @@ def forward( # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) - def _multi_query_kv_attention( + def _run_memory_efficient_xformer_forward( self, query: torch.Tensor, key: torch.Tensor, @@ -240,7 +263,7 @@ def _make_alibi_bias( # element. bias = bias[None, :] - bias[:, None] - padded_len = _get_aligned_size(prompt_len, 8) + padded_len = (prompt_len + 7) // 8 * 8 num_heads = alibi_slopes.shape[0] bias = torch.empty( 1, # batch size @@ -275,10 +298,6 @@ def _ref_masked_attention( head_size: int, scale: float, ) -> torch.Tensor: - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - seq_len, _, _ = query.shape attn_mask = torch.triu(torch.ones(seq_len, seq_len, diff --git a/vllm/utils.py b/vllm/utils.py index 4f6ad4a17fdc..729a4332af96 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -337,11 +337,6 @@ def create_kv_caches_with_random( return key_caches, value_caches -def _get_aligned_size(batch_size: int, alignment: int) -> int: - """Returns the padded batch based on an alignment.""" - return ((batch_size + alignment - 1) // alignment * alignment) - - class measure_cuda_memory: def __init__(self, device=None): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 67ea4edfdc87..066052440dc9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -21,7 +21,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest -from vllm.utils import in_wsl, measure_cuda_memory, _get_aligned_size +from vllm.utils import in_wsl, measure_cuda_memory logger = init_logger(__name__) @@ -34,10 +34,6 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) ] -# True if inputs should be aligned. It is currently disabled. -# Aligning inputs can better utilize tensor cores. -# https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/ -SHOULD_ALIGN = False class ModelRunner: @@ -226,18 +222,16 @@ def _prepare_prompt( num_prompt_tokens = len(input_tokens) assert max_subquery_len > 0 - input_tokens = torch.tensor(_align_if_necessary(input_tokens, pad=0), + input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) - input_positions = torch.tensor(_align_if_necessary(input_positions, - pad=0), + input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) - slot_mapping = torch.tensor(_align_if_necessary(slot_mapping, - pad=_PAD_SLOT_ID), + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) - lora_index_mapping = _align_if_necessary(lora_index_mapping, pad=0) + lora_index_mapping = lora_index_mapping context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, @@ -364,27 +358,28 @@ def _prepare_decode( # Pad tokens to better utilize tensor cores although # cuda graph is not enabled. - input_tokens = torch.tensor(_align_if_necessary( - input_tokens, pad=0, should_align=use_captured_graph), + input_tokens = torch.tensor(_pad_for_cuda_graph( + input_tokens, pad=0, use_captured_graph=use_captured_graph), dtype=torch.long, device=self.device) - input_positions = torch.tensor(_align_if_necessary( - input_positions, pad=0, should_align=use_captured_graph), + input_positions = torch.tensor(_pad_for_cuda_graph( + input_positions, pad=0, use_captured_graph=use_captured_graph), dtype=torch.long, device=self.device) - slot_mapping = torch.tensor(_align_if_necessary( - slot_mapping, pad=_PAD_SLOT_ID, should_align=use_captured_graph), + slot_mapping = torch.tensor(_pad_for_cuda_graph( + slot_mapping, + pad=_PAD_SLOT_ID, + use_captured_graph=use_captured_graph), dtype=torch.long, device=self.device) - context_lens = torch.tensor(_align_if_necessary( - context_lens, pad=0, should_align=use_captured_graph), + context_lens = torch.tensor(_pad_for_cuda_graph( + context_lens, pad=0, use_captured_graph=use_captured_graph), dtype=torch.int, device=self.device) - block_tables = _align_if_necessary(block_tables, - pad=[], - should_align=use_captured_graph) - lora_index_mapping = _align_if_necessary( - lora_index_mapping, pad=0, should_align=use_captured_graph) + block_tables = _pad_for_cuda_graph( + block_tables, pad=[], use_captured_graph=use_captured_graph) + lora_index_mapping = _pad_for_cuda_graph( + lora_index_mapping, pad=0, use_captured_graph=use_captured_graph) if use_captured_graph: # When using cuda-graph all these tensors should be @@ -906,10 +901,14 @@ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: return x + [pad] * (max_len - len(x)) -def _align_if_necessary(x: List[int], pad: int, should_align=SHOULD_ALIGN): - """Align flattened 1D inputs by a fixed alignment size.""" - if not should_align: +def _pad_for_cuda_graph(x: List[int], pad: int, use_captured_graph: bool): + """Pad flattened 1D inputs by a fixed alignment size for cuda graph. + + This function is no-op if use_captured_graph is False. + """ + if not use_captured_graph: return x + batch_size = len(x) batch_size = _get_graph_batch_size(batch_size) return _pad_to_alignment(x, batch_size, pad) @@ -942,7 +941,8 @@ def _get_graph_batch_size(batch_size: int) -> int: elif batch_size <= 4: return 4 else: - return _get_aligned_size(batch_size, _BATCH_SIZE_ALIGNMENT) + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) def _async_h2d( From ac7828cefc0c2e11d54361503fb10d4bf955b8a2 Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 19 Mar 2024 03:18:34 -0700 Subject: [PATCH 62/88] revert removing forked --- .buildkite/test-pipeline.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 17f4c3367082..6ae351130f20 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -47,7 +47,7 @@ steps: - pytest -v -s prefix_caching - label: Samplers Test - command: pytest -v -s samplers + command: pytest -v -s samplers --forked - label: Worker Test command: pytest -v -s worker @@ -56,7 +56,7 @@ steps: command: pytest -v -s spec_decode - label: LoRA Test %N - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + command: pytest -v -s lora --forked --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 - label: Metrics Test From 3d7f1a1e56b194aeddc81622cb00f0ff706a8796 Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 19 Mar 2024 05:23:07 -0700 Subject: [PATCH 63/88] done --- .buildkite/test-pipeline.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6ae351130f20..17f4c3367082 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -47,7 +47,7 @@ steps: - pytest -v -s prefix_caching - label: Samplers Test - command: pytest -v -s samplers --forked + command: pytest -v -s samplers - label: Worker Test command: pytest -v -s worker @@ -56,7 +56,7 @@ steps: command: pytest -v -s spec_decode - label: LoRA Test %N - command: pytest -v -s lora --forked --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 - label: Metrics Test From fa3ce4e90b3c26cb74e30845009a3770e8ca7199 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 20 Mar 2024 02:54:32 -0700 Subject: [PATCH 64/88] final code review. --- vllm/model_executor/input_metadata.py | 2 + .../layers/attention/backends/xformers.py | 4 ++ vllm/worker/model_runner.py | 48 ++++++------------- 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 8f7061795711..35245865fb1b 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -48,12 +48,14 @@ class InputMetadata: max_subquery_len: Optional[int] # Maximum context length in the batch. max_context_len: Optional[int] + # FIXME: It is for flash attn. # Maximum sequence length in the batch. max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. subquery_start_loc: Optional[torch.Tensor] + # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/model_executor/layers/attention/backends/xformers.py index 9e053e52687e..d3508bbb21df 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/model_executor/layers/attention/backends/xformers.py @@ -119,6 +119,7 @@ def forward( value.shape[-1]) if self.use_ref_attention: + print("ref attention used.") output = torch.empty_like(query) start = 0 for _, prompt_len in enumerate(input_metadata.prompt_lens): @@ -298,6 +299,9 @@ def _ref_masked_attention( head_size: int, scale: float, ) -> torch.Tensor: + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) seq_len, _, _ = query.shape attn_mask = torch.triu(torch.ones(seq_len, seq_len, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 066052440dc9..04348aa79bfc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -354,32 +354,29 @@ def _prepare_decode( and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_context_len <= self.max_context_len_to_capture) if use_captured_graph: - batch_size = _get_graph_batch_size(batch_size) + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + for _ in range(graph_batch_size - batch_size): + input_tokens.append(0) + input_positions.append(0) + slot_mapping.append(_PAD_SLOT_ID) + context_lens.append(1) + block_tables.append([]) + lora_index_mapping.append(0) + batch_size = graph_batch_size - # Pad tokens to better utilize tensor cores although - # cuda graph is not enabled. - input_tokens = torch.tensor(_pad_for_cuda_graph( - input_tokens, pad=0, use_captured_graph=use_captured_graph), + input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) - input_positions = torch.tensor(_pad_for_cuda_graph( - input_positions, pad=0, use_captured_graph=use_captured_graph), + input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) - slot_mapping = torch.tensor(_pad_for_cuda_graph( - slot_mapping, - pad=_PAD_SLOT_ID, - use_captured_graph=use_captured_graph), + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) - context_lens = torch.tensor(_pad_for_cuda_graph( - context_lens, pad=0, use_captured_graph=use_captured_graph), + context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) - block_tables = _pad_for_cuda_graph( - block_tables, pad=[], use_captured_graph=use_captured_graph) - lora_index_mapping = _pad_for_cuda_graph( - lora_index_mapping, pad=0, use_captured_graph=use_captured_graph) if use_captured_graph: # When using cuda-graph all these tensors should be @@ -892,28 +889,11 @@ def _maybe_cupy_nccl(): yield -def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]: - return x + [pad] * ((-len(x)) % multiple_of) - - def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: assert len(x) <= max_len return x + [pad] * (max_len - len(x)) -def _pad_for_cuda_graph(x: List[int], pad: int, use_captured_graph: bool): - """Pad flattened 1D inputs by a fixed alignment size for cuda graph. - - This function is no-op if use_captured_graph is False. - """ - if not use_captured_graph: - return x - - batch_size = len(x) - batch_size = _get_graph_batch_size(batch_size) - return _pad_to_alignment(x, batch_size, pad) - - def _make_tensor_with_pad( x: List[List[int]], max_len: int, From 8bc0af5174a807a82e589cda065d283ae402a26c Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 20 Mar 2024 18:47:50 -0700 Subject: [PATCH 65/88] . --- tests/test_sequence.py | 9 +- vllm/engine/arg_utils.py | 11 +- .../einops-0.7.0.dist-info/INSTALLER | 1 - .../einops-0.7.0.dist-info/METADATA | 360 ----- .../einops-0.7.0.dist-info/RECORD | 45 - .../einops-0.7.0.dist-info/REQUESTED | 0 .../einops-0.7.0.dist-info/WHEEL | 4 - .../einops-0.7.0.dist-info/licenses/LICENSE | 21 - vllm/thirdparty_files/einops/__init__.py | 15 - vllm/thirdparty_files/einops/_backends.py | 662 --------- .../einops/_torch_specific.py | 127 -- vllm/thirdparty_files/einops/array_api.py | 119 -- vllm/thirdparty_files/einops/einops.py | 901 ------------ .../einops/experimental/__init__.py | 0 .../einops/experimental/data_api_packing.py | 137 -- .../einops/experimental/indexing.py | 393 ------ .../einops/layers/__init__.py | 106 -- .../thirdparty_files/einops/layers/_einmix.py | 176 --- .../thirdparty_files/einops/layers/chainer.py | 53 - vllm/thirdparty_files/einops/layers/flax.py | 80 -- vllm/thirdparty_files/einops/layers/keras.py | 9 - .../thirdparty_files/einops/layers/oneflow.py | 53 - vllm/thirdparty_files/einops/layers/paddle.py | 59 - .../einops/layers/tensorflow.py | 85 -- vllm/thirdparty_files/einops/layers/torch.py | 68 - vllm/thirdparty_files/einops/packing.py | 191 --- vllm/thirdparty_files/einops/parsing.py | 149 -- vllm/thirdparty_files/einops/py.typed | 0 .../flash_attn-2.5.6.dist-info/AUTHORS | 1 - .../flash_attn-2.5.6.dist-info/INSTALLER | 1 - .../flash_attn-2.5.6.dist-info/LICENSE | 29 - .../flash_attn-2.5.6.dist-info/METADATA | 430 ------ .../flash_attn-2.5.6.dist-info/RECORD | 103 -- .../flash_attn-2.5.6.dist-info/REQUESTED | 0 .../flash_attn-2.5.6.dist-info/WHEEL | 5 - .../flash_attn-2.5.6.dist-info/top_level.txt | 2 - vllm/thirdparty_files/flash_attn/__init__.py | 11 - .../flash_attn/bert_padding.py | 213 --- .../flash_attn/flash_attn_interface.py | 1209 ----------------- .../flash_attn/flash_attn_triton.py | 1160 ---------------- .../flash_attn/flash_attn_triton_og.py | 365 ----- .../flash_attn/flash_blocksparse_attention.py | 197 --- .../flash_blocksparse_attn_interface.py | 200 --- .../flash_attn/fused_softmax.py | 201 --- .../flash_attn/layers/__init__.py | 0 .../flash_attn/layers/patch_embed.py | 67 - .../flash_attn/layers/rotary.py | 481 ------- .../flash_attn/losses/__init__.py | 0 .../flash_attn/losses/cross_entropy.py | 84 -- .../flash_attn/models/__init__.py | 0 .../flash_attn/models/baichuan.py | 151 -- .../flash_attn/models/bert.py | 764 ----------- .../flash_attn/models/bigcode.py | 233 ---- .../flash_attn/models/btlm.py | 102 -- .../flash_attn/models/falcon.py | 143 -- .../thirdparty_files/flash_attn/models/gpt.py | 1080 --------------- .../flash_attn/models/gpt_neox.py | 124 -- .../flash_attn/models/gptj.py | 109 -- .../flash_attn/models/llama.py | 422 ------ .../thirdparty_files/flash_attn/models/opt.py | 116 -- .../thirdparty_files/flash_attn/models/vit.py | 373 ----- .../flash_attn/modules/__init__.py | 0 .../flash_attn/modules/block.py | 397 ------ .../flash_attn/modules/embedding.py | 216 --- .../flash_attn/modules/mha.py | 1016 -------------- .../flash_attn/modules/mlp.py | 191 --- .../flash_attn/ops/__init__.py | 0 .../flash_attn/ops/activations.py | 135 -- .../flash_attn/ops/fused_dense.py | 688 ---------- .../flash_attn/ops/layer_norm.py | 800 ----------- .../flash_attn/ops/rms_norm.py | 174 --- .../flash_attn/ops/triton/__init__.py | 1 - .../flash_attn/ops/triton/cross_entropy.py | 320 ----- .../flash_attn/ops/triton/k_activations.py | 162 --- .../flash_attn/ops/triton/layer_norm.py | 1086 --------------- .../flash_attn/ops/triton/linear.py | 594 -------- .../flash_attn/ops/triton/mlp.py | 149 -- .../flash_attn/ops/triton/rotary.py | 240 ---- .../flash_attn/utils/__init__.py | 0 .../flash_attn/utils/benchmark.py | 268 ---- .../flash_attn/utils/distributed.py | 144 -- .../flash_attn/utils/generation.py | 735 ---------- .../flash_attn/utils/pretrained.py | 79 -- 83 files changed, 9 insertions(+), 19596 deletions(-) delete mode 100644 vllm/thirdparty_files/einops-0.7.0.dist-info/INSTALLER delete mode 100644 vllm/thirdparty_files/einops-0.7.0.dist-info/METADATA delete mode 100644 vllm/thirdparty_files/einops-0.7.0.dist-info/RECORD delete mode 100644 vllm/thirdparty_files/einops-0.7.0.dist-info/REQUESTED delete mode 100644 vllm/thirdparty_files/einops-0.7.0.dist-info/WHEEL delete mode 100644 vllm/thirdparty_files/einops-0.7.0.dist-info/licenses/LICENSE delete mode 100644 vllm/thirdparty_files/einops/__init__.py delete mode 100644 vllm/thirdparty_files/einops/_backends.py delete mode 100644 vllm/thirdparty_files/einops/_torch_specific.py delete mode 100644 vllm/thirdparty_files/einops/array_api.py delete mode 100644 vllm/thirdparty_files/einops/einops.py delete mode 100644 vllm/thirdparty_files/einops/experimental/__init__.py delete mode 100644 vllm/thirdparty_files/einops/experimental/data_api_packing.py delete mode 100644 vllm/thirdparty_files/einops/experimental/indexing.py delete mode 100644 vllm/thirdparty_files/einops/layers/__init__.py delete mode 100644 vllm/thirdparty_files/einops/layers/_einmix.py delete mode 100644 vllm/thirdparty_files/einops/layers/chainer.py delete mode 100644 vllm/thirdparty_files/einops/layers/flax.py delete mode 100644 vllm/thirdparty_files/einops/layers/keras.py delete mode 100644 vllm/thirdparty_files/einops/layers/oneflow.py delete mode 100644 vllm/thirdparty_files/einops/layers/paddle.py delete mode 100644 vllm/thirdparty_files/einops/layers/tensorflow.py delete mode 100644 vllm/thirdparty_files/einops/layers/torch.py delete mode 100644 vllm/thirdparty_files/einops/packing.py delete mode 100644 vllm/thirdparty_files/einops/parsing.py delete mode 100644 vllm/thirdparty_files/einops/py.typed delete mode 100644 vllm/thirdparty_files/flash_attn-2.5.6.dist-info/AUTHORS delete mode 100644 vllm/thirdparty_files/flash_attn-2.5.6.dist-info/INSTALLER delete mode 100644 vllm/thirdparty_files/flash_attn-2.5.6.dist-info/LICENSE delete mode 100644 vllm/thirdparty_files/flash_attn-2.5.6.dist-info/METADATA delete mode 100644 vllm/thirdparty_files/flash_attn-2.5.6.dist-info/RECORD delete mode 100644 vllm/thirdparty_files/flash_attn-2.5.6.dist-info/REQUESTED delete mode 100644 vllm/thirdparty_files/flash_attn-2.5.6.dist-info/WHEEL delete mode 100644 vllm/thirdparty_files/flash_attn-2.5.6.dist-info/top_level.txt delete mode 100644 vllm/thirdparty_files/flash_attn/__init__.py delete mode 100644 vllm/thirdparty_files/flash_attn/bert_padding.py delete mode 100644 vllm/thirdparty_files/flash_attn/flash_attn_interface.py delete mode 100644 vllm/thirdparty_files/flash_attn/flash_attn_triton.py delete mode 100644 vllm/thirdparty_files/flash_attn/flash_attn_triton_og.py delete mode 100644 vllm/thirdparty_files/flash_attn/flash_blocksparse_attention.py delete mode 100644 vllm/thirdparty_files/flash_attn/flash_blocksparse_attn_interface.py delete mode 100644 vllm/thirdparty_files/flash_attn/fused_softmax.py delete mode 100644 vllm/thirdparty_files/flash_attn/layers/__init__.py delete mode 100644 vllm/thirdparty_files/flash_attn/layers/patch_embed.py delete mode 100644 vllm/thirdparty_files/flash_attn/layers/rotary.py delete mode 100644 vllm/thirdparty_files/flash_attn/losses/__init__.py delete mode 100644 vllm/thirdparty_files/flash_attn/losses/cross_entropy.py delete mode 100644 vllm/thirdparty_files/flash_attn/models/__init__.py delete mode 100644 vllm/thirdparty_files/flash_attn/models/baichuan.py delete mode 100644 vllm/thirdparty_files/flash_attn/models/bert.py delete mode 100644 vllm/thirdparty_files/flash_attn/models/bigcode.py delete mode 100644 vllm/thirdparty_files/flash_attn/models/btlm.py delete mode 100644 vllm/thirdparty_files/flash_attn/models/falcon.py delete mode 100644 vllm/thirdparty_files/flash_attn/models/gpt.py delete mode 100644 vllm/thirdparty_files/flash_attn/models/gpt_neox.py delete mode 100644 vllm/thirdparty_files/flash_attn/models/gptj.py delete mode 100644 vllm/thirdparty_files/flash_attn/models/llama.py delete mode 100644 vllm/thirdparty_files/flash_attn/models/opt.py delete mode 100644 vllm/thirdparty_files/flash_attn/models/vit.py delete mode 100644 vllm/thirdparty_files/flash_attn/modules/__init__.py delete mode 100644 vllm/thirdparty_files/flash_attn/modules/block.py delete mode 100644 vllm/thirdparty_files/flash_attn/modules/embedding.py delete mode 100644 vllm/thirdparty_files/flash_attn/modules/mha.py delete mode 100644 vllm/thirdparty_files/flash_attn/modules/mlp.py delete mode 100644 vllm/thirdparty_files/flash_attn/ops/__init__.py delete mode 100644 vllm/thirdparty_files/flash_attn/ops/activations.py delete mode 100644 vllm/thirdparty_files/flash_attn/ops/fused_dense.py delete mode 100644 vllm/thirdparty_files/flash_attn/ops/layer_norm.py delete mode 100644 vllm/thirdparty_files/flash_attn/ops/rms_norm.py delete mode 100644 vllm/thirdparty_files/flash_attn/ops/triton/__init__.py delete mode 100644 vllm/thirdparty_files/flash_attn/ops/triton/cross_entropy.py delete mode 100644 vllm/thirdparty_files/flash_attn/ops/triton/k_activations.py delete mode 100644 vllm/thirdparty_files/flash_attn/ops/triton/layer_norm.py delete mode 100644 vllm/thirdparty_files/flash_attn/ops/triton/linear.py delete mode 100644 vllm/thirdparty_files/flash_attn/ops/triton/mlp.py delete mode 100644 vllm/thirdparty_files/flash_attn/ops/triton/rotary.py delete mode 100644 vllm/thirdparty_files/flash_attn/utils/__init__.py delete mode 100644 vllm/thirdparty_files/flash_attn/utils/benchmark.py delete mode 100644 vllm/thirdparty_files/flash_attn/utils/distributed.py delete mode 100644 vllm/thirdparty_files/flash_attn/utils/generation.py delete mode 100644 vllm/thirdparty_files/flash_attn/utils/pretrained.py diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 0953ccb593cc..c948abe437a1 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,9 +1,6 @@ import pytest -from vllm.sequence import ( - SequenceData, Sequence, - SequenceGroupOutput, SamplerOutput, - SequenceOutput -) +from vllm.sequence import (SequenceData, Sequence, SequenceGroupOutput, + SamplerOutput, SequenceOutput) @pytest.fixture(name="sequence") @@ -84,4 +81,4 @@ def test_sequence_data_prefill(): assert seq_data.get_prefill_range() == (4, 4) # append tokens and reset, simulating recompute - seq_data.append_token_id(1, logprob=0.0) \ No newline at end of file + seq_data.append_token_id(1, logprob=0.0) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f7e5d05a23c8..8dad83fe33ef 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -353,11 +353,12 @@ def create_engine_configs( self.tokenizer_pool_type, self.tokenizer_pool_extra_config, ), self.ray_workers_use_nsight) - scheduler_config = SchedulerConfig(self.max_num_batched_tokens, - self.max_num_seqs, - model_config.max_model_len, - max_chunked_prefill_len=self.max_chunked_prefill_len, - max_num_prompt_seqs=self.max_num_prompt_seqs) + scheduler_config = SchedulerConfig( + self.max_num_batched_tokens, + self.max_num_seqs, + model_config.max_model_len, + max_chunked_prefill_len=self.max_chunked_prefill_len, + max_num_prompt_seqs=self.max_num_prompt_seqs) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/thirdparty_files/einops-0.7.0.dist-info/INSTALLER b/vllm/thirdparty_files/einops-0.7.0.dist-info/INSTALLER deleted file mode 100644 index a1b589e38a32..000000000000 --- a/vllm/thirdparty_files/einops-0.7.0.dist-info/INSTALLER +++ /dev/null @@ -1 +0,0 @@ -pip diff --git a/vllm/thirdparty_files/einops-0.7.0.dist-info/METADATA b/vllm/thirdparty_files/einops-0.7.0.dist-info/METADATA deleted file mode 100644 index 9cbf7ce09d84..000000000000 --- a/vllm/thirdparty_files/einops-0.7.0.dist-info/METADATA +++ /dev/null @@ -1,360 +0,0 @@ -Metadata-Version: 2.1 -Name: einops -Version: 0.7.0 -Summary: A new flavour of deep learning operations -Project-URL: Homepage, https://github.com/arogozhnikov/einops -Author: Alex Rogozhnikov -License: MIT -License-File: LICENSE -Keywords: deep learning,einops,machine learning,neural networks,scientific computations,tensor manipulation -Classifier: Intended Audience :: Science/Research -Classifier: License :: OSI Approved :: MIT License -Classifier: Programming Language :: Python :: 3 -Requires-Python: >=3.8 -Description-Content-Type: text/markdown - - - - - - -https://user-images.githubusercontent.com/6318811/177030658-66f0eb5d-e136-44d8-99c9-86ae298ead5b.mp4 - - - - -# einops -[![Run tests](https://github.com/arogozhnikov/einops/actions/workflows/run_tests.yml/badge.svg)](https://github.com/arogozhnikov/einops/actions/workflows/run_tests.yml) -[![PyPI version](https://badge.fury.io/py/einops.svg)](https://badge.fury.io/py/einops) -[![Documentation](https://img.shields.io/badge/documentation-link-blue.svg)](https://einops.rocks/) -![Supported python versions](https://raw.githubusercontent.com/arogozhnikov/einops/master/docs/resources/python_badge.svg) - - -Flexible and powerful tensor operations for readable and reliable code.
-Supports numpy, pytorch, tensorflow, jax, and [others](#supported-frameworks). - -## Recent updates: - -- 0.7.0: no-hassle `torch.compile`, support of [array api standard](https://data-apis.org/array-api/latest/API_specification/index.html) and more -- 10'000🎉: github reports that more than 10k project use einops -- see how to use einops with [torch.compile](https://github.com/arogozhnikov/einops/wiki/Using-torch.compile-with-einops) -- einops 0.6.1: paddle backend added -- einops 0.6 introduces [packing and unpacking](https://github.com/arogozhnikov/einops/blob/master/docs/4-pack-and-unpack.ipynb) -- einops 0.5: einsum is now a part of einops -- [Einops paper](https://openreview.net/pdf?id=oapKSVM2bcj) is accepted for oral presentation at ICLR 2022 (yes, it worth reading). - Talk recordings are [available](https://iclr.cc/virtual/2022/oral/6603) - - -
-Previous updates -- flax and oneflow backend added -- torch.jit.script is supported for pytorch layers -- powerful EinMix added to einops. [Einmix tutorial notebook](https://github.com/arogozhnikov/einops/blob/master/docs/3-einmix-layer.ipynb) -
- - - -## Tweets - -> In case you need convincing arguments for setting aside time to learn about einsum and einops... -[Tim Rocktäschel, FAIR](https://twitter.com/_rockt/status/1230818967205425152) - -> Writing better code with PyTorch and einops 👌 -[Andrej Karpathy, AI at Tesla](https://twitter.com/karpathy/status/1290826075916779520) - -> Slowly but surely, einops is seeping in to every nook and cranny of my code. If you find yourself shuffling around bazillion dimensional tensors, this might change your life -[Nasim Rahaman, MILA (Montreal)](https://twitter.com/nasim_rahaman/status/1216022614755463169) - -[More testimonials](https://einops.rocks/pages/testimonials/) - - - -## Contents - -- [Installation](#Installation) -- [Documentation](https://einops.rocks/) -- [Tutorial](#Tutorials) -- [API micro-reference](#API) -- [Why using einops](#Why-using-einops-notation) -- [Supported frameworks](#Supported-frameworks) -- [Citing](#Citing) -- [Repository](https://github.com/arogozhnikov/einops) and [discussions](https://github.com/arogozhnikov/einops/discussions) - -## Installation - -Plain and simple: -```bash -pip install einops -``` - - - -## Tutorials - -Tutorials are the most convenient way to see `einops` in action - -- part 1: [einops fundamentals](https://github.com/arogozhnikov/einops/blob/master/docs/1-einops-basics.ipynb) -- part 2: [einops for deep learning](https://github.com/arogozhnikov/einops/blob/master/docs/2-einops-for-deep-learning.ipynb) -- part 3: [packing and unpacking](https://github.com/arogozhnikov/einops/blob/master/docs/4-pack-and-unpack.ipynb) -- part 4: [improve pytorch code with einops](http://einops.rocks/pytorch-examples.html) - -Kapil Sachdeva recorded a small [intro to einops](https://www.youtube.com/watch?v=xGy75Pjsqzo). - -## API - -`einops` has a minimalistic yet powerful API. - -Three core operations provided ([einops tutorial](https://github.com/arogozhnikov/einops/blob/master/docs/) -shows those cover stacking, reshape, transposition, squeeze/unsqueeze, repeat, tile, concatenate, view and numerous reductions) - -```python -from einops import rearrange, reduce, repeat -# rearrange elements according to the pattern -output_tensor = rearrange(input_tensor, 't b c -> b c t') -# combine rearrangement and reduction -output_tensor = reduce(input_tensor, 'b c (h h2) (w w2) -> b h w c', 'mean', h2=2, w2=2) -# copy along a new axis -output_tensor = repeat(input_tensor, 'h w -> h w c', c=3) -``` - -Later additions to the family are `pack` and `unpack` functions (better than stack/split/concatenate): - -```python -from einops import pack, unpack -# pack and unpack allow reversibly 'packing' multiple tensors into one. -# Packed tensors may be of different dimensionality: -packed, ps = pack([class_token_bc, image_tokens_bhwc, text_tokens_btc], 'b * c') -class_emb_bc, image_emb_bhwc, text_emb_btc = unpack(transformer(packed), ps, 'b * c') -``` - -Finally, einops provides einsum with a support of multi-lettered names: - -```python -from einops import einsum, pack, unpack -# einsum is like ... einsum, generic and flexible dot-product -# but 1) axes can be multi-lettered 2) pattern goes last 3) works with multiple frameworks -C = einsum(A, B, 'b t1 head c, b t2 head c -> b head t1 t2') -``` - -### EinMix - -`EinMix` is a generic linear layer, perfect for MLP Mixers and similar architectures. - -### Layers - -Einops provides layers (`einops` keeps a separate version for each framework) that reflect corresponding functions - -```python -from einops.layers.torch import Rearrange, Reduce -from einops.layers.tensorflow import Rearrange, Reduce -from einops.layers.flax import Rearrange, Reduce -from einops.layers.paddle import Rearrange, Reduce -from einops.layers.keras import Rearrange, Reduce -from einops.layers.chainer import Rearrange, Reduce -``` - -
-Example of using layers within a pytorch model -Example given for pytorch, but code in other frameworks is almost identical - -```python -from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU -from einops.layers.torch import Rearrange - -model = Sequential( - ..., - Conv2d(6, 16, kernel_size=5), - MaxPool2d(kernel_size=2), - # flattening without need to write forward - Rearrange('b c h w -> b (c h w)'), - Linear(16*5*5, 120), - ReLU(), - Linear(120, 10), -) -``` - -No more flatten needed! - -Additionally, torch users will benefit from layers as those are script-able and compile-able. -
- - - - -## Naming - -`einops` stands for Einstein-Inspired Notation for operations -(though "Einstein operations" is more attractive and easier to remember). - -Notation was loosely inspired by Einstein summation (in particular by `numpy.einsum` operation). - -## Why use `einops` notation?! - - -### Semantic information (being verbose in expectations) - -```python -y = x.view(x.shape[0], -1) -y = rearrange(x, 'b c h w -> b (c h w)') -``` -While these two lines are doing the same job in *some* context, -the second one provides information about the input and output. -In other words, `einops` focuses on interface: *what is the input and output*, not *how* the output is computed. - -The next operation looks similar: - -```python -y = rearrange(x, 'time c h w -> time (c h w)') -``` -but it gives the reader a hint: -this is not an independent batch of images we are processing, -but rather a sequence (video). - -Semantic information makes the code easier to read and maintain. - -### Convenient checks - -Reconsider the same example: - -```python -y = x.view(x.shape[0], -1) # x: (batch, 256, 19, 19) -y = rearrange(x, 'b c h w -> b (c h w)') -``` -The second line checks that the input has four dimensions, -but you can also specify particular dimensions. -That's opposed to just writing comments about shapes since comments don't prevent mistakes, not tested, and without code review tend to be outdated -```python -y = x.view(x.shape[0], -1) # x: (batch, 256, 19, 19) -y = rearrange(x, 'b c h w -> b (c h w)', c=256, h=19, w=19) -``` - -### Result is strictly determined - -Below we have at least two ways to define the depth-to-space operation -```python -# depth-to-space -rearrange(x, 'b c (h h2) (w w2) -> b (c h2 w2) h w', h2=2, w2=2) -rearrange(x, 'b c (h h2) (w w2) -> b (h2 w2 c) h w', h2=2, w2=2) -``` -There are at least four more ways to do it. Which one is used by the framework? - -These details are ignored, since *usually* it makes no difference, -but it can make a big difference (e.g. if you use grouped convolutions in the next stage), -and you'd like to specify this in your code. - - -### Uniformity - -```python -reduce(x, 'b c (x dx) -> b c x', 'max', dx=2) -reduce(x, 'b c (x dx) (y dy) -> b c x y', 'max', dx=2, dy=3) -reduce(x, 'b c (x dx) (y dy) (z dz) -> b c x y z', 'max', dx=2, dy=3, dz=4) -``` -These examples demonstrated that we don't use separate operations for 1d/2d/3d pooling, -those are all defined in a uniform way. - -Space-to-depth and depth-to space are defined in many frameworks but how about width-to-height? Here you go: - -```python -rearrange(x, 'b c h (w w2) -> b c (h w2) w', w2=2) -``` - -### Framework independent behavior - -Even simple functions are defined differently by different frameworks - -```python -y = x.flatten() # or flatten(x) -``` - -Suppose `x`'s shape was `(3, 4, 5)`, then `y` has shape ... - -- numpy, pytorch, cupy, chainer: `(60,)` -- keras, tensorflow.layers, gluon: `(3, 20)` - -`einops` works the same way in all frameworks. - -### Independence of framework terminology - -Example: `tile` vs `repeat` causes lots of confusion. To copy image along width: -```python -np.tile(image, (1, 2)) # in numpy -image.repeat(1, 2) # pytorch's repeat ~ numpy's tile -``` - -With einops you don't need to decipher which axis was repeated: -```python -repeat(image, 'h w -> h (tile w)', tile=2) # in numpy -repeat(image, 'h w -> h (tile w)', tile=2) # in pytorch -repeat(image, 'h w -> h (tile w)', tile=2) # in tf -repeat(image, 'h w -> h (tile w)', tile=2) # in jax -repeat(image, 'h w -> h (tile w)', tile=2) # in cupy -... (etc.) -``` - -[Testimonials](https://einops.rocks/pages/testimonials/) provide users' perspective on the same question. - -## Supported frameworks - -Einops works with ... - -- [numpy](http://www.numpy.org/) -- [pytorch](https://pytorch.org/) -- [tensorflow](https://www.tensorflow.org/) -- [jax](https://github.com/google/jax) -- [cupy](https://cupy.chainer.org/) -- [chainer](https://chainer.org/) -- [tf.keras](https://www.tensorflow.org/guide/keras) -- [oneflow](https://github.com/Oneflow-Inc/oneflow) (experimental) -- [flax](https://github.com/google/flax) (experimental) -- [paddle](https://github.com/PaddlePaddle/Paddle) (experimental) - -Additionally, starting from einops 0.7.0 einops can be used with any framework that supports [Python array API standard](https://data-apis.org/array-api/latest/API_specification/index.html) - -## Citing einops - -Please use the following bibtex record - -```text -@inproceedings{ - rogozhnikov2022einops, - title={Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation}, - author={Alex Rogozhnikov}, - booktitle={International Conference on Learning Representations}, - year={2022}, - url={https://openreview.net/forum?id=oapKSVM2bcj} -} -``` - - -## Supported python versions - -`einops` works with python 3.8 or later. diff --git a/vllm/thirdparty_files/einops-0.7.0.dist-info/RECORD b/vllm/thirdparty_files/einops-0.7.0.dist-info/RECORD deleted file mode 100644 index bb822a67f816..000000000000 --- a/vllm/thirdparty_files/einops-0.7.0.dist-info/RECORD +++ /dev/null @@ -1,45 +0,0 @@ -einops-0.7.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -einops-0.7.0.dist-info/METADATA,sha256=iiGuU5-2fSwfbC4q8Rm0bEvaA1WN3qslFMQJxnPUSKg,13078 -einops-0.7.0.dist-info/RECORD,, -einops-0.7.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -einops-0.7.0.dist-info/WHEEL,sha256=9QBuHhg6FNW7lppboF2vKVbCGTVzsFykgRQjjlajrhA,87 -einops-0.7.0.dist-info/licenses/LICENSE,sha256=MNmENkKW9R_67K1LAe4SfpUlDFBokY1LZvyWIGcj5DQ,1073 -einops/__init__.py,sha256=kmiFcK59IAAR7yoWBeIzca6vQqU0X_ITwt_Or6XakRo,388 -einops/__pycache__/__init__.cpython-39.pyc,, -einops/__pycache__/_backends.cpython-39.pyc,, -einops/__pycache__/_torch_specific.cpython-39.pyc,, -einops/__pycache__/array_api.cpython-39.pyc,, -einops/__pycache__/einops.cpython-39.pyc,, -einops/__pycache__/packing.cpython-39.pyc,, -einops/__pycache__/parsing.cpython-39.pyc,, -einops/_backends.py,sha256=e7faFZ1DrYEGHGC-JhPIBM1_npCnYgruwzM8pG6BcqE,19693 -einops/_torch_specific.py,sha256=c9V_pqU_ayd1UE8i1CaCTkOSMeRdlDLm3WDD5q0gsvY,4138 -einops/array_api.py,sha256=rhtT1_nMDj9IYMmvdGZd0oktYvV9r1OG7w-dAiGQeUg,5211 -einops/einops.py,sha256=1Y4AW4htn2rk0jAyBxoW-CtAxZLAbxmEeDZQFx81jN0,36808 -einops/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -einops/experimental/__pycache__/__init__.cpython-39.pyc,, -einops/experimental/__pycache__/data_api_packing.cpython-39.pyc,, -einops/experimental/__pycache__/indexing.cpython-39.pyc,, -einops/experimental/data_api_packing.py,sha256=CKCxtqMek1T6BlBM6OkG4YvaynFmdWgVEvbPTDZ6DqE,4690 -einops/experimental/indexing.py,sha256=f4d3lVVuS0bzGKYczlpk3Y5U_kk6_AWt887i_2qzjv4,14777 -einops/layers/__init__.py,sha256=n5FI4v2Zs2qLeI1FuuK_7AfyauRiVtsnQTw0c_mr0IM,3747 -einops/layers/__pycache__/__init__.cpython-39.pyc,, -einops/layers/__pycache__/_einmix.cpython-39.pyc,, -einops/layers/__pycache__/chainer.cpython-39.pyc,, -einops/layers/__pycache__/flax.cpython-39.pyc,, -einops/layers/__pycache__/keras.cpython-39.pyc,, -einops/layers/__pycache__/oneflow.cpython-39.pyc,, -einops/layers/__pycache__/paddle.cpython-39.pyc,, -einops/layers/__pycache__/tensorflow.cpython-39.pyc,, -einops/layers/__pycache__/torch.cpython-39.pyc,, -einops/layers/_einmix.py,sha256=szayQWYvJzCwOhdfXJY72A7EtBYrNc91nzvOm-uMREs,8578 -einops/layers/chainer.py,sha256=BPvzqGV9-3xpoXHsVfkIR5wbsvmN0QnPJXY_mdsD2Dk,1981 -einops/layers/flax.py,sha256=rTgGl4HvCwhIoJYLhjpCG_nB3aWC3xGHV7v7q8a2qe8,2621 -einops/layers/keras.py,sha256=RTsR-aim1Sco5VXI2W1Qs639hJRJ0hWIilTZCs3Ftn4,212 -einops/layers/oneflow.py,sha256=XlZugzWGgAp3eGZLM_UsusoU-5Mws2Qfwwn8yywLYWs,2046 -einops/layers/paddle.py,sha256=-nvq9jJitZ8QWN31SdcEPXtjzh2xjpJGLdifOSOC_48,2063 -einops/layers/tensorflow.py,sha256=xEFbNU681Jva6Qy4AjHwwlmxdJfE28sTXDfhYxpF5jg,3299 -einops/layers/torch.py,sha256=Gj3tZ8sSnrnC-RvlK_6zTpgS89jdQofils30vikUVnE,2595 -einops/packing.py,sha256=z9VfP8SWgNTvi76atmYslw608ATkf8vKZHSTUa9srrI,7668 -einops/parsing.py,sha256=ChR0sKBmN2z--lyEpZvExN7gZdjTgm4V3cera0Yf4AM,6717 -einops/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 diff --git a/vllm/thirdparty_files/einops-0.7.0.dist-info/REQUESTED b/vllm/thirdparty_files/einops-0.7.0.dist-info/REQUESTED deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/thirdparty_files/einops-0.7.0.dist-info/WHEEL b/vllm/thirdparty_files/einops-0.7.0.dist-info/WHEEL deleted file mode 100644 index ba1a8af28bcc..000000000000 --- a/vllm/thirdparty_files/einops-0.7.0.dist-info/WHEEL +++ /dev/null @@ -1,4 +0,0 @@ -Wheel-Version: 1.0 -Generator: hatchling 1.18.0 -Root-Is-Purelib: true -Tag: py3-none-any diff --git a/vllm/thirdparty_files/einops-0.7.0.dist-info/licenses/LICENSE b/vllm/thirdparty_files/einops-0.7.0.dist-info/licenses/LICENSE deleted file mode 100644 index 3a654e906619..000000000000 --- a/vllm/thirdparty_files/einops-0.7.0.dist-info/licenses/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2018 Alex Rogozhnikov - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/vllm/thirdparty_files/einops/__init__.py b/vllm/thirdparty_files/einops/__init__.py deleted file mode 100644 index a24af7ac7ac3..000000000000 --- a/vllm/thirdparty_files/einops/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -__author__ = 'Alex Rogozhnikov' -__version__ = '0.7.0' - - -class EinopsError(RuntimeError): - """ Runtime error thrown by einops """ - pass - - -__all__ = ['rearrange', 'reduce', 'repeat', 'einsum', - 'pack', 'unpack', - 'parse_shape', 'asnumpy', 'EinopsError'] - -from .einops import rearrange, reduce, repeat, einsum, parse_shape, asnumpy -from .packing import pack, unpack \ No newline at end of file diff --git a/vllm/thirdparty_files/einops/_backends.py b/vllm/thirdparty_files/einops/_backends.py deleted file mode 100644 index 40e0502ae76b..000000000000 --- a/vllm/thirdparty_files/einops/_backends.py +++ /dev/null @@ -1,662 +0,0 @@ -""" -Backends in `einops` are organized to meet the following requirements -- backends are not imported unless those are actually needed, because - - backends may not be installed - - importing all available backends will drive to significant memory footprint - - backends may be present but installed with errors (but never used), - importing may drive to crashes -- backend should be either symbolic or imperative - - this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined -- if backend can't provide symbols for shape dimensions, UnknownSize objects are used -""" - -import sys - -__author__ = "Alex Rogozhnikov" - -_loaded_backends: dict = {} -_type2backend: dict = {} -_debug_importing = False - - -def get_backend(tensor) -> "AbstractBackend": - """ - Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor. - If needed, imports package and creates backend - """ - _type = type(tensor) - _result = _type2backend.get(_type, None) - if _result is not None: - return _result - - for framework_name, backend in list(_loaded_backends.items()): - if backend.is_appropriate_type(tensor): - _type2backend[_type] = backend - return backend - - # Find backend subclasses recursively - backend_subclasses = [] - backends = AbstractBackend.__subclasses__() - while backends: - backend = backends.pop() - backends += backend.__subclasses__() - backend_subclasses.append(backend) - - for BackendSubclass in backend_subclasses: - if _debug_importing: - print("Testing for subclass of ", BackendSubclass) - if BackendSubclass.framework_name not in _loaded_backends: - # check that module was already imported. Otherwise it can't be imported - if BackendSubclass.framework_name in sys.modules: - if _debug_importing: - print("Imported backend for ", BackendSubclass.framework_name) - backend = BackendSubclass() - _loaded_backends[backend.framework_name] = backend - if backend.is_appropriate_type(tensor): - _type2backend[_type] = backend - return backend - - raise RuntimeError("Tensor type unknown to einops {}".format(type(tensor))) - - -class AbstractBackend: - """Base backend class, major part of methods are only for debugging purposes.""" - - framework_name: str - - def is_appropriate_type(self, tensor): - """helper method should recognize tensors it can handle""" - raise NotImplementedError() - - def from_numpy(self, x): - raise NotImplementedError("framework doesn't support imperative execution") - - def to_numpy(self, x): - raise NotImplementedError("framework doesn't support imperative execution") - - def create_symbol(self, shape): - raise NotImplementedError("framework doesn't support symbolic computations") - - def eval_symbol(self, symbol, input_dict): - raise NotImplementedError("framework doesn't support symbolic computations") - - def arange(self, start, stop): - # supplementary method used only in testing, so should implement CPU version - raise NotImplementedError("framework doesn't implement arange") - - def shape(self, x): - """shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)""" - return x.shape - - def reshape(self, x, shape): - return x.reshape(shape) - - def transpose(self, x, axes): - return x.transpose(axes) - - def reduce(self, x, operation, axes): - return getattr(x, operation)(axis=axes) - - def stack_on_zeroth_dimension(self, tensors: list): - raise NotImplementedError() - - def add_axis(self, x, new_position): - raise NotImplementedError() - - def add_axes(self, x, n_axes, pos2len): - repeats = [1] * n_axes - for axis_position, axis_length in pos2len.items(): - x = self.add_axis(x, axis_position) - repeats[axis_position] = axis_length - return self.tile(x, tuple(repeats)) - - def tile(self, x, repeats): - """repeats - same lengths as x.shape""" - raise NotImplementedError() - - def concat(self, tensors, axis: int): - """concatenates tensors along axis. - Assume identical across tensors: devices, dtypes and shapes except selected axis.""" - raise NotImplementedError() - - def is_float_type(self, x): - # some backends (torch) can't compute average for non-floating types. - # Decided to drop average for all backends if type is not floating - raise NotImplementedError() - - def layers(self): - raise NotImplementedError("backend does not provide layers") - - def __repr__(self): - return "".format(self.framework_name) - - def einsum(self, pattern, *x): - raise NotImplementedError("backend does not support einsum") - - -class UnknownSize: - """pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements""" - - def __floordiv__(self, other): - return self - - def __eq__(self, other): - return True # we don't know actual size - - def __mul__(self, other): - return self - - def __rmul__(self, other): - return self - - def __hash__(self): - return hash(None) - - -class NumpyBackend(AbstractBackend): - framework_name = "numpy" - - def __init__(self): - import numpy - - self.np = numpy - - def is_appropriate_type(self, tensor): - return isinstance(tensor, self.np.ndarray) - - def from_numpy(self, x): - return x - - def to_numpy(self, x): - return x - - def arange(self, start, stop): - return self.np.arange(start, stop) - - def stack_on_zeroth_dimension(self, tensors: list): - return self.np.stack(tensors) - - def tile(self, x, repeats): - return self.np.tile(x, repeats) - - def concat(self, tensors, axis: int): - return self.np.concatenate(tensors, axis=axis) - - def is_float_type(self, x): - return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") - - def add_axis(self, x, new_position): - return self.np.expand_dims(x, new_position) - - def einsum(self, pattern, *x): - return self.np.einsum(pattern, *x) - - -class JaxBackend(NumpyBackend): - framework_name = "jax" - - def __init__(self): - super(JaxBackend, self).__init__() - self.onp = self.np - - import jax.numpy - - self.np = jax.numpy - - def from_numpy(self, x): - return self.np.asarray(x) - - def to_numpy(self, x): - return self.onp.asarray(x) - - -class TorchBackend(AbstractBackend): - framework_name = "torch" - - def __init__(self): - import torch - - self.torch = torch - # importing would register operations in torch._dynamo for torch.compile - from . import _torch_specific # noqa - - def is_appropriate_type(self, tensor): - return isinstance(tensor, self.torch.Tensor) - - def from_numpy(self, x): - variable = self.torch.from_numpy(x) - if self.is_float_type(variable): - # attach grad only to floating types - variable.requires_grad = True - return variable - - def to_numpy(self, x): - return x.detach().cpu().numpy() - - def arange(self, start, stop): - return self.torch.arange(start, stop, dtype=self.torch.int64) - - def reduce(self, x, operation, reduced_axes): - if operation == "min": - return x.amin(dim=reduced_axes) - elif operation == "max": - return x.amax(dim=reduced_axes) - elif operation == "sum": - return x.sum(dim=reduced_axes) - elif operation == "mean": - return x.mean(dim=reduced_axes) - elif operation in ("any", "all", "prod"): - # pytorch supports reducing only one operation at a time - for i in list(sorted(reduced_axes))[::-1]: - x = getattr(x, operation)(dim=i) - return x - else: - raise NotImplementedError("Unknown reduction ", operation) - - def transpose(self, x, axes): - return x.permute(axes) - - def stack_on_zeroth_dimension(self, tensors: list): - return self.torch.stack(tensors) - - def add_axes(self, x, n_axes, pos2len): - repeats = [-1] * n_axes - for axis_position, axis_length in pos2len.items(): - x = self.add_axis(x, axis_position) - repeats[axis_position] = axis_length - return x.expand(repeats) - - def tile(self, x, repeats): - return x.repeat(repeats) - - def concat(self, tensors, axis: int): - return self.torch.cat(tensors, dim=axis) - - def add_axis(self, x, new_position): - return self.torch.unsqueeze(x, new_position) - - def is_float_type(self, x): - return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64, self.torch.bfloat16] - - def layers(self): - from .layers import torch - - return torch - - def einsum(self, pattern, *x): - return self.torch.einsum(pattern, *x) - - -class CupyBackend(AbstractBackend): - framework_name = "cupy" - - def __init__(self): - import cupy - - self.cupy = cupy - - def is_appropriate_type(self, tensor): - return isinstance(tensor, self.cupy.ndarray) - - def from_numpy(self, x): - return self.cupy.asarray(x) - - def to_numpy(self, x): - return self.cupy.asnumpy(x) - - def arange(self, start, stop): - return self.cupy.arange(start, stop) - - def stack_on_zeroth_dimension(self, tensors: list): - return self.cupy.stack(tensors) - - def tile(self, x, repeats): - return self.cupy.tile(x, repeats) - - def concat(self, tensors, axis: int): - return self.cupy.concatenate(tensors, axis=axis) - - def add_axis(self, x, new_position): - return self.cupy.expand_dims(x, new_position) - - def is_float_type(self, x): - return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") - - def einsum(self, pattern, *x): - return self.cupy.einsum(pattern, *x) - - -class ChainerBackend(AbstractBackend): - framework_name = "chainer" - - def __init__(self): - import chainer - import numpy - - self.numpy = numpy - self.chainer = chainer - - def is_appropriate_type(self, tensor): - return isinstance(tensor, self.chainer.Variable) - - def from_numpy(self, x): - return self.chainer.Variable(x.astype("float32")) - - def to_numpy(self, x): - if isinstance(x, self.chainer.Variable): - x = x.data - return x - - def arange(self, start, stop): - return self.numpy.arange(start, stop) - - def reduce(self, x, operation, axes): - return getattr(self.chainer.functions, operation)(x, axis=axes) - - def stack_on_zeroth_dimension(self, tensors: list): - return self.chainer.functions.stack(tensors) - - def tile(self, x, repeats): - return self.chainer.functions.tile(x, repeats) - - def concat(self, tensors, axis: int): - return self.chainer.functions.concat(tensors, axis=axis) - - def add_axis(self, x, new_position): - return self.chainer.functions.expand_dims(x, new_position) - - def is_float_type(self, x): - return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") - - def layers(self): - from .layers import chainer - - return chainer - - def einsum(self, pattern, *x): - return self.chainer.functions.einsum(pattern, *x) - - -class HashableTuple: - """Overcomes non-hashability of symbolic elements""" - - def __init__(self, elements: tuple): - self.elements = elements - - def __iter__(self): - for x in self.elements: - yield x - - def __len__(self): - return len(self.elements) - - def __getitem__(self, item): - return self.elements[item] - - # default equality and hash is used (True only with itself, hash taken of id) - - -class TensorflowBackend(AbstractBackend): - framework_name = "tensorflow" - - def __init__(self): - import tensorflow - - self.tf = tensorflow - - def is_appropriate_type(self, tensor): - return isinstance(tensor, (self.tf.Tensor, self.tf.Variable)) - - def from_numpy(self, x): - assert self.tf.executing_eagerly() - return self.tf.convert_to_tensor(x) - - def to_numpy(self, x): - assert self.tf.executing_eagerly() - return x.numpy() - - def arange(self, start, stop): - return self.tf.range(start, stop) - - def shape(self, x): - if self.tf.executing_eagerly(): - return tuple(UnknownSize() if d is None else int(d) for d in x.shape) - else: - static_shape = x.shape.as_list() - tf_shape = self.tf.shape(x) - # use the static shape where known, otherwise use the TF shape components - shape = tuple([s or tf_shape[dim] for dim, s in enumerate(static_shape)]) - try: - hash(shape) - return shape - except: - # unhashable symbols in shape. Wrap tuple to be hashable. - return HashableTuple(shape) - - def reduce(self, x, operation, axes): - return getattr(self.tf, "reduce_" + operation)(x, axis=axes) - - def reshape(self, x, shape): - return self.tf.reshape(x, shape) - - def transpose(self, x, axes): - return self.tf.transpose(x, axes) - - def stack_on_zeroth_dimension(self, tensors: list): - return self.tf.stack(tensors) - - def tile(self, x, repeats): - return self.tf.tile(x, repeats) - - def concat(self, tensors, axis: int): - return self.tf.concat(tensors, axis=axis) - - def add_axis(self, x, new_position): - return self.tf.expand_dims(x, new_position) - - def is_float_type(self, x): - return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") - - def layers(self): - from .layers import tensorflow - - return tensorflow - - def einsum(self, pattern, *x): - return self.tf.einsum(pattern, *x) - - -class KerasBackend(AbstractBackend): - framework_name = "tensorflow.keras" - - def __init__(self): - import tensorflow as tf - - self.tf = tf - self.keras = tf.keras - self.K = tf.keras.backend - - def is_appropriate_type(self, tensor): - return self.tf.is_tensor(tensor) and self.K.is_keras_tensor(tensor) - - def create_symbol(self, shape): - return self.keras.Input(batch_shape=shape) - - def eval_symbol(self, symbol, input_dict): - model = self.keras.models.Model([var for (var, _) in input_dict], symbol) - return model.predict_on_batch([val for (_, val) in input_dict]) - - def arange(self, start, stop): - return self.K.arange(start, stop) - - def shape(self, x): - shape = self.K.shape(x) # tf tensor - return HashableTuple(tuple(shape)) - - def reduce(self, x, operation, axes): - return getattr(self.K, operation)(x, axis=axes) - - def reshape(self, x, shape): - return self.K.reshape(x, shape) - - def transpose(self, x, axes): - return self.K.permute_dimensions(x, axes) - - def stack_on_zeroth_dimension(self, tensors: list): - return self.K.stack(tensors) - - def tile(self, x, repeats): - return self.K.tile(x, repeats) - - def concat(self, tensors, axis: int): - return self.K.concatenate(tensors, axis=axis) - - def add_axis(self, x, new_position): - return self.K.expand_dims(x, new_position) - - def is_float_type(self, x): - return "float" in self.K.dtype(x) - - def layers(self): - from .layers import keras - - return keras - - -class OneFlowBackend(AbstractBackend): - framework_name = "oneflow" - - def __init__(self): - import oneflow as flow - - self.flow = flow - - def is_appropriate_type(self, tensor): - return isinstance(tensor, self.flow.Tensor) - - def from_numpy(self, x): - variable = self.flow.from_numpy(x) - if self.is_float_type(variable): - # attach grad only to floating types - variable.requires_grad = True - return variable - - def to_numpy(self, x): - return x.detach().cpu().numpy() - - def arange(self, start, stop): - return self.flow.arange(start, stop, dtype=self.flow.int64) - - def reduce(self, x, operation, reduced_axes): - for axis in sorted(reduced_axes, reverse=True): - if operation == "min": - x, _ = x.min(dim=axis) - elif operation == "max": - x, _ = x.max(dim=axis) - elif operation in ["sum", "mean", "prod", "any", "all"]: - x = getattr(x, operation)(dim=axis) - else: - raise NotImplementedError("Unknown reduction ", operation) - return x - - def transpose(self, x, axes): - return x.permute(axes) - - def stack_on_zeroth_dimension(self, tensors: list): - return self.flow.stack(tensors) - - def add_axes(self, x, n_axes, pos2len): - repeats = [-1] * n_axes - for axis_position, axis_length in pos2len.items(): - x = self.add_axis(x, axis_position) - repeats[axis_position] = axis_length - return x.expand(*repeats) - - def tile(self, x, repeats): - return x.repeat(repeats) - - def concat(self, tensors, axis: int): - return self.flow.concat(tensors, dim=axis) - - def add_axis(self, x, new_position): - return self.flow.unsqueeze(x, new_position) - - def is_float_type(self, x): - return x.dtype in [self.flow.float16, self.flow.float32, self.flow.float64] - - def layers(self): - from .layers import oneflow - - return oneflow - - def einsum(self, pattern, *x): - return self.flow.einsum(pattern, *x) - - -class PaddleBackend(AbstractBackend): - framework_name = "paddle" - - def __init__(self): - import paddle - - self.paddle = paddle - - def is_appropriate_type(self, tensor): - return isinstance(tensor, (self.paddle.Tensor, self.paddle.static.Variable)) - - def from_numpy(self, x): - tensor = self.paddle.to_tensor(x) - tensor.stop_gradient = False - return tensor - - def to_numpy(self, x): - return x.detach().numpy() - - def arange(self, start, stop): - return self.paddle.arange(start, stop, dtype=self.paddle.int64) - - def reduce(self, x, operation, axes): - if len(axes) == x.ndim: - # currently paddle returns 1d tensor instead of 0d - return super().reduce(x, operation, axes).squeeze(0) - else: - return super().reduce(x, operation, axes) - - def transpose(self, x, axes): - return x.transpose(axes) - - def add_axes(self, x, n_axes, pos2len): - repeats = [-1] * n_axes - for axis_position, axis_length in pos2len.items(): - x = self.add_axis(x, axis_position) - repeats[axis_position] = axis_length - return x.expand(repeats) - - def stack_on_zeroth_dimension(self, tensors: list): - return self.paddle.stack(tensors) - - def reshape(self, x, shape): - return x.reshape(shape) - - def tile(self, x, repeats): - return x.tile(repeats) - - def concat(self, tensors, axis: int): - return self.paddle.concat(tensors, axis=axis) - - def add_axis(self, x, new_position): - return x.unsqueeze(new_position) - - def is_float_type(self, x): - return x.dtype in [self.paddle.float16, self.paddle.float32, self.paddle.float64] - - def layers(self): - from .layers import paddle - - return paddle - - def einsum(self, pattern, *x): - return self.paddle.einsum(pattern, *x) - - def shape(self, x): - return tuple(x.shape) diff --git a/vllm/thirdparty_files/einops/_torch_specific.py b/vllm/thirdparty_files/einops/_torch_specific.py deleted file mode 100644 index 670459c27d79..000000000000 --- a/vllm/thirdparty_files/einops/_torch_specific.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -Specialization of einops for torch. - -Unfortunately, torch's jit scripting mechanism isn't strong enough, -and to have scripting supported at least for layers, -a number of additional moves is needed. - -Design of main operations (dynamic resolution by lookup) is unlikely -to be implemented by torch.jit.script, -but torch.compile seems to work with operations just fine. -""" -import warnings -from typing import Dict, List, Tuple - -import torch -from einops.einops import TransformRecipe, _reconstruct_from_shape_uncached - - -class TorchJitBackend: - """ - Completely static backend that mimics part of normal backend functionality - but restricted to be within torchscript. - """ - - @staticmethod - def reduce(x: torch.Tensor, operation: str, reduced_axes: List[int]): - if operation == "min": - return x.amin(dim=reduced_axes) - elif operation == "max": - return x.amax(dim=reduced_axes) - elif operation == "sum": - return x.sum(dim=reduced_axes) - elif operation == "mean": - return x.mean(dim=reduced_axes) - elif operation == "prod": - for i in list(sorted(reduced_axes))[::-1]: - x = x.prod(dim=i) - return x - else: - raise NotImplementedError("Unknown reduction ", operation) - - @staticmethod - def transpose(x, axes: List[int]): - return x.permute(axes) - - @staticmethod - def stack_on_zeroth_dimension(tensors: List[torch.Tensor]): - return torch.stack(tensors) - - @staticmethod - def tile(x, repeats: List[int]): - return x.repeat(repeats) - - @staticmethod - def add_axes(x, n_axes: int, pos2len: Dict[int, int]): - repeats = [-1] * n_axes - for axis_position, axis_length in pos2len.items(): - x = torch.unsqueeze(x, axis_position) - repeats[axis_position] = axis_length - return x.expand(repeats) - - @staticmethod - def is_float_type(x): - return x.dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16] - - @staticmethod - def shape(x): - return x.shape - - @staticmethod - def reshape(x, shape: List[int]): - return x.reshape(shape) - - -# mirrors einops.einops._apply_recipe -def apply_for_scriptable_torch( - recipe: TransformRecipe, tensor: torch.Tensor, reduction_type: str, axes_dims: List[Tuple[str, int]] -) -> torch.Tensor: - backend = TorchJitBackend - ( - init_shapes, - axes_reordering, - reduced_axes, - added_axes, - final_shapes, - n_axes_w_added, - ) = _reconstruct_from_shape_uncached(recipe, backend.shape(tensor), axes_dims=axes_dims) - if init_shapes is not None: - tensor = backend.reshape(tensor, init_shapes) - if axes_reordering is not None: - tensor = backend.transpose(tensor, axes_reordering) - if len(reduced_axes) > 0: - tensor = backend.reduce(tensor, operation=reduction_type, reduced_axes=reduced_axes) - if len(added_axes) > 0: - tensor = backend.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes) - if final_shapes is not None: - tensor = backend.reshape(tensor, final_shapes) - return tensor - - -def allow_ops_in_compiled_graph(): - if hasattr(torch, "__version__") and torch.__version__[0] < "2": - # torch._dynamo and torch.compile appear in pytorch 2.0 - return - try: - from torch._dynamo import allow_in_graph - except ImportError: - warnings.warn("allow_ops_in_compiled_graph failed to import torch: ensure pytorch >=2.0", ImportWarning) - return - - from .einops import rearrange, reduce, repeat, einsum - from .packing import pack, unpack - - allow_in_graph(rearrange) - allow_in_graph(reduce) - allow_in_graph(repeat) - allow_in_graph(einsum) - allow_in_graph(pack) - allow_in_graph(unpack) - - # CF: https://github.com/pytorch/pytorch/blob/2df939aacac68e9621fbd5d876c78d86e72b41e2/torch/_dynamo/__init__.py#L222 - global _ops_were_registered_in_torchdynamo - _ops_were_registered_in_torchdynamo = True - - -# module import automatically registers ops in torchdynamo -allow_ops_in_compiled_graph() diff --git a/vllm/thirdparty_files/einops/array_api.py b/vllm/thirdparty_files/einops/array_api.py deleted file mode 100644 index e150d07def54..000000000000 --- a/vllm/thirdparty_files/einops/array_api.py +++ /dev/null @@ -1,119 +0,0 @@ -from typing import List, Tuple, Sequence -from .einops import Tensor, Reduction, EinopsError, _prepare_transformation_recipe, _apply_recipe_array_api -from .packing import analyze_pattern, prod - - -def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor: - if isinstance(tensor, list): - if len(tensor) == 0: - raise TypeError("Einops can't be applied to an empty list") - xp = tensor[0].__array_namespace__() - tensor = xp.stack(tensor) - else: - xp = tensor.__array_namespace__() - try: - hashable_axes_lengths = tuple(axes_lengths.items()) - recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=tensor.ndim) - return _apply_recipe_array_api( - xp, - recipe=recipe, tensor=tensor, reduction_type=reduction, axes_lengths=hashable_axes_lengths, - ) - except EinopsError as e: - message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern) - if not isinstance(tensor, list): - message += "\n Input tensor shape: {}. ".format(tensor.shape) - else: - message += "\n Input is list. " - message += "Additional info: {}.".format(axes_lengths) - raise EinopsError(message + "\n {}".format(e)) - - - -def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor: - return reduce(tensor, pattern, reduction="repeat", **axes_lengths) - - -def rearrange(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor: - return reduce(tensor, pattern, reduction="rearrange", **axes_lengths) - - -def asnumpy(tensor: Tensor): - import numpy as np - return np.from_dlpack(tensor) - -Shape = Tuple - -def pack(tensors: Sequence[Tensor], pattern: str) -> Tuple[Tensor, List[Shape]]: - n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, 'pack') - xp = tensors[0].__array_namespace__() - - reshaped_tensors: List[Tensor] = [] - packed_shapes: List[Shape] = [] - for i, tensor in enumerate(tensors): - shape = tensor.shape - if len(shape) < min_axes: - raise EinopsError(f'packed tensor #{i} (enumeration starts with 0) has shape {shape}, ' - f'while pattern {pattern} assumes at least {min_axes} axes') - axis_after_packed_axes = len(shape) - n_axes_after - packed_shapes.append(shape[n_axes_before:axis_after_packed_axes]) - reshaped_tensors.append(xp.reshape(tensor, (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:]))) - - return xp.concat(reshaped_tensors, axis=n_axes_before), packed_shapes - - - -def unpack(tensor: Tensor, packed_shapes: List[Shape], pattern: str) -> List[Tensor]: - xp = tensor.__array_namespace__() - n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, opname='unpack') - - # backend = get_backend(tensor) - input_shape = tensor.shape - if len(input_shape) != n_axes_before + 1 + n_axes_after: - raise EinopsError(f'unpack(..., {pattern}) received input of wrong dim with shape {input_shape}') - - unpacked_axis: int = n_axes_before - - lengths_of_composed_axes: List[int] = [ - -1 if -1 in p_shape else prod(p_shape) - for p_shape in packed_shapes - ] - - n_unknown_composed_axes = sum(x == -1 for x in lengths_of_composed_axes) - if n_unknown_composed_axes > 1: - raise EinopsError( - f"unpack(..., {pattern}) received more than one -1 in {packed_shapes} and can't infer dimensions" - ) - - # following manipulations allow to skip some shape verifications - # and leave it to backends - - # [[], [2, 3], [4], [-1, 5], [6]] < examples of packed_axis - # split positions when computed should be - # [0, 1, 7, 11, N-6 , N ], where N = length of axis - split_positions = [0] * len(packed_shapes) + [input_shape[unpacked_axis]] - if n_unknown_composed_axes == 0: - for i, x in enumerate(lengths_of_composed_axes[:-1]): - split_positions[i + 1] = split_positions[i] + x - else: - unknown_composed_axis: int = lengths_of_composed_axes.index(-1) - for i in range(unknown_composed_axis): - split_positions[i + 1] = split_positions[i] + lengths_of_composed_axes[i] - for j in range(unknown_composed_axis + 1, len(lengths_of_composed_axes))[::-1]: - split_positions[j] = split_positions[j + 1] - lengths_of_composed_axes[j] - - shape_start = input_shape[:unpacked_axis] - shape_end = input_shape[unpacked_axis + 1:] - slice_filler = (slice(None, None),) * unpacked_axis - try: - return [ - xp.reshape( - # shortest way slice arbitrary axis - tensor[(*slice_filler, slice(split_positions[i], split_positions[i + 1]), ...)], - (*shape_start, *element_shape, *shape_end) - ) - for i, element_shape in enumerate(packed_shapes) - ] - except BaseException: - # this hits if there is an error during reshapes, which means passed shapes were incorrect - raise RuntimeError(f'Error during unpack(..., "{pattern}"): could not split axis of size {split_positions[-1]}' - f' into requested {packed_shapes}') diff --git a/vllm/thirdparty_files/einops/einops.py b/vllm/thirdparty_files/einops/einops.py deleted file mode 100644 index bb9b39e8aa45..000000000000 --- a/vllm/thirdparty_files/einops/einops.py +++ /dev/null @@ -1,901 +0,0 @@ -import functools -import itertools -import string -import typing -from collections import OrderedDict -from typing import Set, Tuple, List, Dict, Union, Callable, Optional, TypeVar, cast, Any - -if typing.TYPE_CHECKING: - # for docstrings in pycharm - import numpy as np - -from . import EinopsError -from ._backends import get_backend -from .parsing import ParsedExpression, _ellipsis, AnonymousAxis - -Tensor = TypeVar("Tensor") -ReductionCallable = Callable[[Tensor, Tuple[int, ...]], Tensor] -Reduction = Union[str, ReductionCallable] - -_reductions = ("min", "max", "sum", "mean", "prod", "any", "all") - -# magic integers are required to stay within -# traceable subset of language -_unknown_axis_length = -999999 -_expected_axis_length = -99999 - - -def _product(sequence: List[int]) -> int: - """minimalistic product that works both with numbers and symbols. Supports empty lists""" - result = 1 - for element in sequence: - result *= element - return result - - -def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], backend): - if callable(reduction_type): - # custom callable - return reduction_type(tensor, tuple(reduced_axes)) - else: - # one of built-in operations - assert reduction_type in _reductions - if reduction_type == "mean": - if not backend.is_float_type(tensor): - raise NotImplementedError("reduce_mean is not available for non-floating tensors") - return backend.reduce(tensor, reduction_type, tuple(reduced_axes)) - - -def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes): - # 'collapses' neighboring axes if those participate in the result pattern in the same order - # TODO add support for added_axes - assert len(axes_reordering) + len(reduced_axes) == len(init_shapes) - # joining consecutive axes that will be reduced - # possibly we can skip this if all backends can optimize this (not sure) - reduced_axes = tuple(sorted(reduced_axes)) - for i in range(len(reduced_axes) - 1)[::-1]: - if reduced_axes[i] + 1 == reduced_axes[i + 1]: - removed_axis = reduced_axes[i + 1] - removed_length = init_shapes[removed_axis] - init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1 :] - init_shapes[removed_axis - 1] *= removed_length - reduced_axes = reduced_axes[: i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2 :]) - - # removing axes that are moved together during reshape - def build_mapping(): - init_to_final = {} - for axis in range(len(init_shapes)): - if axis in reduced_axes: - init_to_final[axis] = None - else: - after_reduction = sum(x is not None for x in init_to_final.values()) - init_to_final[axis] = list(axes_reordering).index(after_reduction) - return init_to_final - - init_axis_to_final_axis = build_mapping() - - for init_axis in range(len(init_shapes) - 1)[::-1]: - if init_axis_to_final_axis[init_axis] is None: - continue - if init_axis_to_final_axis[init_axis + 1] is None: - continue - if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]: - removed_axis = init_axis + 1 - removed_length = init_shapes[removed_axis] - removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis)) - - reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes) - init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1 :] - init_shapes[removed_axis - 1] *= removed_length - old_reordering = axes_reordering - axes_reordering = [] - for axis in old_reordering: - if axis == removed_axis_after_reduction: - pass - elif axis < removed_axis_after_reduction: - axes_reordering.append(axis) - else: - axes_reordering.append(axis - 1) - init_axis_to_final_axis = build_mapping() - - return init_shapes, reduced_axes, axes_reordering, final_shapes - - -CookedRecipe = Tuple[Optional[List[int]], Optional[List[int]], List[int], Dict[int, int], Optional[List[int]], int] - -# Actual type is tuple[tuple[str, int], ...] -# However torch.jit.script does not "understand" the correct type, -# and torch_specific will use list version. -HashableAxesLengths = Tuple[Tuple[str, int], ...] -FakeHashableAxesLengths = List[Tuple[str, int]] - - -class TransformRecipe: - """ - Recipe describes actual computation pathway. - Recipe can be applied to a tensor or variable. - """ - - # structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+) - # update: pytorch 2.0 torch.jit.script seems to have problems with dataclasses unless they were explicitly provided - - def __init__( - self, - # list of sizes (or just sizes) for elementary axes as they appear in left expression. - # this is what (after computing unknown parts) will be a shape after first transposition. - # This does not include any ellipsis dimensions. - elementary_axes_lengths: List[int], - # if additional axes are provided, they should be set in prev array - # This shows mapping from name to position - axis_name2elementary_axis: Dict[str, int], - # each dimension in input can help to reconstruct length of one elementary axis - # or verify one of dimensions. Each element points to element of elementary_axes_lengths. - input_composition_known_unknown: List[Tuple[List[int], List[int]]], - # permutation applied to elementary axes, if ellipsis is absent - axes_permutation: List[int], - # permutation puts reduced axes in the end, we only need to know the first position. - first_reduced_axis: int, - # at which positions which of elementary axes should appear. Axis position -> axis index. - added_axes: Dict[int, int], - # ids of axes as they appear in result, again pointers to elementary_axes_lengths, - # only used to infer result dimensions - output_composite_axes: List[List[int]], - ): - self.elementary_axes_lengths: List[int] = elementary_axes_lengths - self.axis_name2elementary_axis: Dict[str, int] = axis_name2elementary_axis - self.input_composition_known_unknown: List[Tuple[List[int], List[int]]] = input_composition_known_unknown - self.axes_permutation: List[int] = axes_permutation - - self.first_reduced_axis: int = first_reduced_axis - self.added_axes: Dict[int, int] = added_axes - self.output_composite_axes: List[List[int]] = output_composite_axes - - -def _reconstruct_from_shape_uncached( - self: TransformRecipe, shape: List[int], axes_dims: FakeHashableAxesLengths -) -> CookedRecipe: - """ - Reconstruct all actual parameters using shape. - Shape is a tuple that may contain integers, shape symbols (tf, theano) and UnknownSize (tf, previously mxnet) - known axes can be integers or symbols, but not Nones. - """ - # magic number - need_init_reshape = False - - # last axis is allocated for collapsed ellipsis - axes_lengths: List[int] = list(self.elementary_axes_lengths) - for axis, dim in axes_dims: - axes_lengths[self.axis_name2elementary_axis[axis]] = dim - - for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composition_known_unknown): - length = shape[input_axis] - if len(known_axes) == 0 and len(unknown_axes) == 1: - # shortcut for the most common case - axes_lengths[unknown_axes[0]] = length - continue - - known_product = 1 - for axis in known_axes: - known_product *= axes_lengths[axis] - - if len(unknown_axes) == 0: - if isinstance(length, int) and isinstance(known_product, int) and length != known_product: - raise EinopsError(f"Shape mismatch, {length} != {known_product}") - else: - # assert len(unknown_axes) == 1, 'this is enforced when recipe is created, so commented out' - if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0: - raise EinopsError(f"Shape mismatch, can't divide axis of length {length} in chunks of {known_product}") - - unknown_axis = unknown_axes[0] - inferred_length: int = length // known_product - axes_lengths[unknown_axis] = inferred_length - - if len(known_axes) + len(unknown_axes) != 1: - need_init_reshape = True - - # at this point all axes_lengths are computed (either have values or variables, but not Nones) - - # elementary axes are ordered as they appear in input, then all added axes - init_shapes: Optional[List[int]] = axes_lengths[: len(self.axes_permutation)] if need_init_reshape else None - - need_final_reshape = False - final_shapes: List[int] = [] - for grouping in self.output_composite_axes: - lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping] - final_shapes.append(_product(lengths)) - if len(lengths) != 1: - need_final_reshape = True - - added_axes: Dict[int, int] = { - pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items() - } - - # this list can be empty - reduced_axes = list(range(self.first_reduced_axis, len(self.axes_permutation))) - - n_axes_after_adding_axes = len(added_axes) + len(self.axes_permutation) - - axes_reordering: Optional[List[int]] = self.axes_permutation - if self.axes_permutation == list(range(len(self.axes_permutation))): - axes_reordering = None - - _final_shapes = final_shapes if need_final_reshape else None - return init_shapes, axes_reordering, reduced_axes, added_axes, _final_shapes, n_axes_after_adding_axes - - -_reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached) - - -def _apply_recipe( - backend, recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths -) -> Tensor: - # this method implements actual work for all backends for 3 operations - try: - init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape( - recipe, backend.shape(tensor), axes_lengths - ) - except TypeError: - # shape or one of passed axes lengths is not hashable (i.e. they are symbols) - _result = _reconstruct_from_shape_uncached(recipe, backend.shape(tensor), axes_lengths) - (init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added) = _result - if init_shapes is not None: - tensor = backend.reshape(tensor, init_shapes) - if axes_reordering is not None: - tensor = backend.transpose(tensor, axes_reordering) - if len(reduced_axes) > 0: - tensor = _reduce_axes(tensor, reduction_type=reduction_type, reduced_axes=reduced_axes, backend=backend) - if len(added_axes) > 0: - tensor = backend.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes) - if final_shapes is not None: - tensor = backend.reshape(tensor, final_shapes) - return tensor - - -def _apply_recipe_array_api( - xp, recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths -) -> Tensor: - # completely-inline implementation - init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape( - recipe, tensor.shape, axes_lengths - ) - if init_shapes is not None: - tensor = xp.reshape(tensor, init_shapes) - if axes_reordering is not None: - tensor = xp.permute_dims(tensor, axes_reordering) - if len(reduced_axes) > 0: - if callable(reduction_type): - # custom callable - tensor = reduction_type(tensor, tuple(reduced_axes)) - else: - # one of built-in operations - assert reduction_type in _reductions - tensor = getattr(xp, reduction_type)(tensor, axis=tuple(reduced_axes)) - if len(added_axes) > 0: - # we use broadcasting - for axis_position, axis_length in added_axes.items(): - tensor = xp.expand_dims(tensor, axis=axis_position) - - final_shape = list(tensor.shape) - for axis_position, axis_length in added_axes.items(): - final_shape[axis_position] = axis_length - - tensor = xp.broadcast_to(tensor, final_shape) - if final_shapes is not None: - tensor = xp.reshape(tensor, final_shapes) - return tensor - - -@functools.lru_cache(256) -def _prepare_transformation_recipe( - pattern: str, - operation: Reduction, - axes_names: Tuple[str, ...], - ndim: int, -) -> TransformRecipe: - """Perform initial parsing of pattern and provided supplementary info - axes_lengths is a tuple of tuples (axis_name, axis_length) - """ - left_str, rght_str = pattern.split("->") - left = ParsedExpression(left_str) - rght = ParsedExpression(rght_str) - - # checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction - if not left.has_ellipsis and rght.has_ellipsis: - raise EinopsError("Ellipsis found in right side, but not left side of a pattern {}".format(pattern)) - if left.has_ellipsis and left.has_ellipsis_parenthesized: - raise EinopsError("Ellipsis inside parenthesis in the left side is not allowed: {}".format(pattern)) - if operation == "rearrange": - if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes: - raise EinopsError("Non-unitary anonymous axes are not supported in rearrange (exception is length 1)") - difference = set.symmetric_difference(left.identifiers, rght.identifiers) - if len(difference) > 0: - raise EinopsError("Identifiers only on one side of expression (should be on both): {}".format(difference)) - elif operation == "repeat": - difference = set.difference(left.identifiers, rght.identifiers) - if len(difference) > 0: - raise EinopsError("Unexpected identifiers on the left side of repeat: {}".format(difference)) - axes_without_size = set.difference( - {ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)}, - {*left.identifiers, *axes_names}, - ) - if len(axes_without_size) > 0: - raise EinopsError("Specify sizes for new axes in repeat: {}".format(axes_without_size)) - elif operation in _reductions or callable(operation): - difference = set.difference(rght.identifiers, left.identifiers) - if len(difference) > 0: - raise EinopsError("Unexpected identifiers on the right side of reduce {}: {}".format(operation, difference)) - else: - raise EinopsError("Unknown reduction {}. Expect one of {}.".format(operation, _reductions)) - - if left.has_ellipsis: - n_other_dims = len(left.composition) - 1 - if ndim < n_other_dims: - raise EinopsError(f"Wrong shape: expected >={n_other_dims} dims. Received {ndim}-dim tensor.") - ellipsis_ndim = ndim - n_other_dims - ell_axes = [_ellipsis + str(i) for i in range(ellipsis_ndim)] - left_composition = [] - for composite_axis in left.composition: - if composite_axis == _ellipsis: - for axis in ell_axes: - left_composition.append([axis]) - else: - left_composition.append(composite_axis) - - rght_composition = [] - for composite_axis in rght.composition: - if composite_axis == _ellipsis: - for axis in ell_axes: - rght_composition.append([axis]) - else: - group = [] - for axis in composite_axis: - if axis == _ellipsis: - group.extend(ell_axes) - else: - group.append(axis) - rght_composition.append(group) - - left.identifiers.update(ell_axes) - left.identifiers.remove(_ellipsis) - if rght.has_ellipsis: - rght.identifiers.update(ell_axes) - rght.identifiers.remove(_ellipsis) - else: - if ndim != len(left.composition): - raise EinopsError(f"Wrong shape: expected {len(left.composition)} dims. Received {ndim}-dim tensor.") - left_composition = left.composition - rght_composition = rght.composition - - # parsing all dimensions to find out lengths - axis_name2known_length: Dict[Union[str, AnonymousAxis], int] = OrderedDict() - for composite_axis in left_composition: - for axis_name in composite_axis: - if isinstance(axis_name, AnonymousAxis): - axis_name2known_length[axis_name] = axis_name.value - else: - axis_name2known_length[axis_name] = _unknown_axis_length - - # axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point - - repeat_axes_names = [] - for axis_name in rght.identifiers: - if axis_name not in axis_name2known_length: - if isinstance(axis_name, AnonymousAxis): - axis_name2known_length[axis_name] = axis_name.value - else: - axis_name2known_length[axis_name] = _unknown_axis_length - repeat_axes_names.append(axis_name) - - axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)} - - # axes provided as kwargs - for elementary_axis in axes_names: - if not ParsedExpression.check_axis_name(elementary_axis): - raise EinopsError("Invalid name for an axis", elementary_axis) - if elementary_axis not in axis_name2known_length: - raise EinopsError("Axis {} is not used in transform".format(elementary_axis)) - axis_name2known_length[elementary_axis] = _expected_axis_length - - input_axes_known_unknown = [] - # some shapes are inferred later - all information is prepared for faster inference - for i, composite_axis in enumerate(left_composition): - known: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length} - unknown: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length} - if len(unknown) > 1: - raise EinopsError("Could not infer sizes for {}".format(unknown)) - assert len(unknown) + len(known) == len(composite_axis) - input_axes_known_unknown.append( - ([axis_name2position[axis] for axis in known], [axis_name2position[axis] for axis in unknown]) - ) - - axis_position_after_reduction: Dict[str, int] = {} - for axis_name in itertools.chain(*left_composition): - if axis_name in rght.identifiers: - axis_position_after_reduction[axis_name] = len(axis_position_after_reduction) - - result_axes_grouping: List[List[int]] = [ - [axis_name2position[axis] for axis in composite_axis] for i, composite_axis in enumerate(rght_composition) - ] - - ordered_axis_left = list(itertools.chain(*left_composition)) - ordered_axis_rght = list(itertools.chain(*rght_composition)) - reduced_axes = [axis for axis in ordered_axis_left if axis not in rght.identifiers] - order_after_transposition = [axis for axis in ordered_axis_rght if axis in left.identifiers] + reduced_axes - axes_permutation = [ordered_axis_left.index(axis) for axis in order_after_transposition] - added_axes = { - i: axis_name2position[axis_name] - for i, axis_name in enumerate(ordered_axis_rght) - if axis_name not in left.identifiers - } - - first_reduced_axis = len(order_after_transposition) - len(reduced_axes) - - return TransformRecipe( - elementary_axes_lengths=list(axis_name2known_length.values()), - axis_name2elementary_axis={axis: axis_name2position[axis] for axis in axes_names}, - input_composition_known_unknown=input_axes_known_unknown, - axes_permutation=axes_permutation, - first_reduced_axis=first_reduced_axis, - added_axes=added_axes, - output_composite_axes=result_axes_grouping, - ) - - -def _prepare_recipes_for_all_dims( - pattern: str, operation: Reduction, axes_names: Tuple[str, ...] -) -> Dict[int, TransformRecipe]: - """ - Internal function, used in layers. - Layer makes all recipe creation when it is initialized, thus to keep recipes simple we pre-compute for all dims - """ - left_str, rght_str = pattern.split("->") - left = ParsedExpression(left_str) - dims = [len(left.composition)] - if left.has_ellipsis: - dims = [len(left.composition) - 1 + ellipsis_dims for ellipsis_dims in range(8)] - return {ndim: _prepare_transformation_recipe(pattern, operation, axes_names, ndim=ndim) for ndim in dims} - - -def reduce(tensor: Union[Tensor, List[Tensor]], pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor: - """ - einops.reduce provides combination of reordering and reduction using reader-friendly notation. - - Examples for reduce operation: - - ```python - >>> x = np.random.randn(100, 32, 64) - - # perform max-reduction on the first axis - >>> y = reduce(x, 't b c -> b c', 'max') - - # same as previous, but with clearer axes meaning - >>> y = reduce(x, 'time batch channel -> batch channel', 'max') - - >>> x = np.random.randn(10, 20, 30, 40) - - # 2d max-pooling with kernel size = 2 * 2 for image processing - >>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2) - - # if one wants to go back to the original height and width, depth-to-space trick can be applied - >>> y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2) - >>> assert parse_shape(x, 'b _ h w') == parse_shape(y2, 'b _ h w') - - # Adaptive 2d max-pooling to 3 * 4 grid - >>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape - (10, 20, 3, 4) - - # Global average pooling - >>> reduce(x, 'b c h w -> b c', 'mean').shape - (10, 20) - - # Subtracting mean over batch for each channel - >>> y = x - reduce(x, 'b c h w -> () c () ()', 'mean') - - # Subtracting per-image mean for each channel - >>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean') - - ``` - - Parameters: - tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). - list of tensors is also accepted, those should be of the same type and shape - pattern: string, reduction pattern - reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive - alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided. - This allows using various reductions, examples: np.max, tf.reduce_logsumexp, torch.var, etc. - axes_lengths: any additional specifications for dimensions - - Returns: - tensor of the same type as input - """ - try: - if isinstance(tensor, list): - if len(tensor) == 0: - raise TypeError("Rearrange/Reduce/Repeat can't be applied to an empty list") - backend = get_backend(tensor[0]) - tensor = backend.stack_on_zeroth_dimension(tensor) - else: - backend = get_backend(tensor) - - hashable_axes_lengths = tuple(axes_lengths.items()) - shape = backend.shape(tensor) - recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=len(shape)) - return _apply_recipe( - backend, recipe, cast(Tensor, tensor), reduction_type=reduction, axes_lengths=hashable_axes_lengths - ) - except EinopsError as e: - message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern) - if not isinstance(tensor, list): - message += "\n Input tensor shape: {}. ".format(shape) - else: - message += "\n Input is list. " - message += "Additional info: {}.".format(axes_lengths) - raise EinopsError(message + "\n {}".format(e)) - - -def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor: - """ - einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors. - This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, - stack, concatenate and other operations. - - Examples for rearrange operation: - - ```python - # suppose we have a set of 32 images in "h w c" format (height-width-channel) - >>> images = [np.random.randn(30, 40, 3) for _ in range(32)] - - # stack along first (batch) axis, output is a single array - >>> rearrange(images, 'b h w c -> b h w c').shape - (32, 30, 40, 3) - - # concatenate images along height (vertical axis), 960 = 32 * 30 - >>> rearrange(images, 'b h w c -> (b h) w c').shape - (960, 40, 3) - - # concatenated images along horizontal axis, 1280 = 32 * 40 - >>> rearrange(images, 'b h w c -> h (b w) c').shape - (30, 1280, 3) - - # reordered axes to "b c h w" format for deep learning - >>> rearrange(images, 'b h w c -> b c h w').shape - (32, 3, 30, 40) - - # flattened each image into a vector, 3600 = 30 * 40 * 3 - >>> rearrange(images, 'b h w c -> b (c h w)').shape - (32, 3600) - - # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2 - >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape - (128, 15, 20, 3) - - # space-to-depth operation - >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape - (32, 15, 20, 12) - - ``` - - When composing axes, C-order enumeration used (consecutive elements have different last axis) - Find more examples in einops tutorial. - - Parameters: - tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). - list of tensors is also accepted, those should be of the same type and shape - pattern: string, rearrangement pattern - axes_lengths: any additional specifications for dimensions - - Returns: - tensor of the same type as input. If possible, a view to the original tensor is returned. - - """ - return reduce(tensor, pattern, reduction="rearrange", **axes_lengths) - - -def repeat(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor: - """ - einops.repeat allows reordering elements and repeating them in arbitrary combinations. - This operation includes functionality of repeat, tile, broadcast functions. - - Examples for repeat operation: - - ```python - # a grayscale image (of shape height x width) - >>> image = np.random.randn(30, 40) - - # change it to RGB format by repeating in each channel - >>> repeat(image, 'h w -> h w c', c=3).shape - (30, 40, 3) - - # repeat image 2 times along height (vertical axis) - >>> repeat(image, 'h w -> (repeat h) w', repeat=2).shape - (60, 40) - - # repeat image 2 time along height and 3 times along width - >>> repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape - (60, 120) - - # convert each pixel to a small square 2x2. Upsample image by 2x - >>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape - (60, 80) - - # pixelate image first by downsampling by 2x, then upsampling - >>> downsampled = reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2) - >>> repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape - (30, 40) - - ``` - - When composing axes, C-order enumeration used (consecutive elements have different last axis) - Find more examples in einops tutorial. - - Parameters: - tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). - list of tensors is also accepted, those should be of the same type and shape - pattern: string, rearrangement pattern - axes_lengths: any additional specifications for dimensions - - Returns: - Tensor of the same type as input. If possible, a view to the original tensor is returned. - - """ - return reduce(tensor, pattern, reduction="repeat", **axes_lengths) - - -def parse_shape(x, pattern: str) -> dict: - """ - Parse a tensor shape to dictionary mapping axes names to their lengths. - - ```python - # Use underscore to skip the dimension in parsing. - >>> x = np.zeros([2, 3, 5, 7]) - >>> parse_shape(x, 'batch _ h w') - {'batch': 2, 'h': 5, 'w': 7} - - # `parse_shape` output can be used to specify axes_lengths for other operations: - >>> y = np.zeros([700]) - >>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape - (2, 10, 5, 7) - - ``` - - For symbolic frameworks may return symbols, not integers. - - Parameters: - x: tensor of any supported framework - pattern: str, space separated names for axes, underscore means skip axis - - Returns: - dict, maps axes names to their lengths - """ - exp = ParsedExpression(pattern, allow_underscore=True) - shape = get_backend(x).shape(x) - if exp.has_composed_axes(): - raise RuntimeError(f"Can't parse shape with composite axes: {pattern} {shape}") - if len(shape) != len(exp.composition): - if exp.has_ellipsis: - if len(shape) < len(exp.composition) - 1: - raise RuntimeError(f"Can't parse shape with this number of dimensions: {pattern} {shape}") - else: - raise RuntimeError(f"Can't parse shape with different number of dimensions: {pattern} {shape}") - if exp.has_ellipsis: - ellipsis_idx = exp.composition.index(_ellipsis) - composition = ( - exp.composition[:ellipsis_idx] - + ["_"] * (len(shape) - len(exp.composition) + 1) - + exp.composition[ellipsis_idx + 1 :] - ) - else: - composition = exp.composition - result = {} - for (axis_name,), axis_length in zip(composition, shape): # type: ignore - if axis_name != "_": - result[axis_name] = axis_length - return result - - -# _enumerate_directions is not exposed in the public API -def _enumerate_directions(x): - """ - For an n-dimensional tensor, returns tensors to enumerate each axis. - ```python - x = np.zeros([2, 3, 4]) # or any other tensor - i, j, k = _enumerate_directions(x) - result = i + 2*j + 3*k - ``` - - `result[i, j, k] = i + 2j + 3k`, and also has the same shape as result - Works very similarly to numpy.ogrid (open indexing grid) - """ - backend = get_backend(x) - shape = backend.shape(x) - result = [] - for axis_id, axis_length in enumerate(shape): - shape = [1] * len(shape) - shape[axis_id] = axis_length - result.append(backend.reshape(backend.arange(0, axis_length), shape)) - return result - - -# to avoid importing numpy -np_ndarray = Any - - -def asnumpy(tensor) -> np_ndarray: - """ - Convert a tensor of an imperative framework (i.e. numpy/cupy/torch/jax/etc.) to `numpy.ndarray` - - Parameters: - tensor: tensor of any known imperative framework - - Returns: - `numpy.ndarray`, converted to numpy - """ - return get_backend(tensor).to_numpy(tensor) - - -def _validate_einsum_axis_name(axis_name): - if len(axis_name) == 0: - raise NotImplementedError("Singleton () axes are not yet supported in einsum.") - if len(axis_name) > 1: - raise NotImplementedError("Shape rearrangement is not yet supported in einsum.") - - axis_name = axis_name[0] - - if isinstance(axis_name, AnonymousAxis): - raise NotImplementedError("Anonymous axes are not yet supported in einsum.") - if len(axis_name) == 0: - raise RuntimeError("Encountered empty axis name in einsum.") - if not isinstance(axis_name, str): - raise RuntimeError("Axis name in einsum must be a string.") - - -@functools.lru_cache(256) -def _compactify_pattern_for_einsum(pattern: str) -> str: - if "->" not in pattern: - # numpy allows this, so make sure users - # don't accidentally do something like this. - raise ValueError("Einsum pattern must contain '->'.") - lefts_str, right_str = pattern.split("->") - - lefts = [ParsedExpression(left, allow_underscore=True, allow_duplicates=True) for left in lefts_str.split(",")] - - right = ParsedExpression(right_str, allow_underscore=True) - - # Start from 'a' and go up to 'Z' - output_axis_names = string.ascii_letters - i = 0 - axis_name_mapping = {} - - left_patterns = [] - for left in lefts: - left_pattern = "" - for raw_axis_name in left.composition: - if raw_axis_name == _ellipsis: - left_pattern += "..." - continue - - _validate_einsum_axis_name(raw_axis_name) - axis_name = raw_axis_name[0] - if axis_name not in axis_name_mapping: - if i >= len(output_axis_names): - raise RuntimeError("Too many axes in einsum.") - axis_name_mapping[axis_name] = output_axis_names[i] - i += 1 - - left_pattern += axis_name_mapping[axis_name] - left_patterns.append(left_pattern) - - compact_pattern = ",".join(left_patterns) + "->" - - for raw_axis_name in right.composition: - if raw_axis_name == _ellipsis: - compact_pattern += "..." - continue - - _validate_einsum_axis_name(raw_axis_name) - axis_name = raw_axis_name[0] - - if axis_name not in axis_name_mapping: - raise EinopsError(f"Unknown axis {axis_name} on right side of einsum {pattern}.") - - compact_pattern += axis_name_mapping[axis_name] - - return compact_pattern - - -@typing.overload -def einsum(tensor: Tensor, pattern: str, /) -> Tensor: - ... - - -@typing.overload -def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str, /) -> Tensor: - ... - - -@typing.overload -def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str, /) -> Tensor: - ... - - -@typing.overload -def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str, /) -> Tensor: - ... - - -def einsum(*tensors_and_pattern: Union[Tensor, str]) -> Tensor: - """ - einops.einsum calls einsum operations with einops-style named - axes indexing, computing tensor products with an arbitrary - number of tensors. Unlike typical einsum syntax, here you must - pass tensors first, and then the pattern. - - Also, note that rearrange operations such as `"(batch chan) out"`, - or singleton axes `()`, are not currently supported. - - Examples: - - For a given pattern such as: - ```python - >>> x, y, z = np.random.randn(3, 20, 20, 20) - >>> output = einsum(x, y, z, "a b c, c b d, a g k -> a b k") - - ``` - the following formula is computed: - ```tex - output[a, b, k] = - \sum_{c, d, g} x[a, b, c] * y[c, b, d] * z[a, g, k] - ``` - where the summation over `c`, `d`, and `g` is performed - because those axes names do not appear on the right-hand side. - - Let's see some additional examples: - ```python - # Filter a set of images: - >>> batched_images = np.random.randn(128, 16, 16) - >>> filters = np.random.randn(16, 16, 30) - >>> result = einsum(batched_images, filters, - ... "batch h w, h w channel -> batch channel") - >>> result.shape - (128, 30) - - # Matrix multiplication, with an unknown input shape: - >>> batch_shape = (50, 30) - >>> data = np.random.randn(*batch_shape, 20) - >>> weights = np.random.randn(10, 20) - >>> result = einsum(weights, data, - ... "out_dim in_dim, ... in_dim -> ... out_dim") - >>> result.shape - (50, 30, 10) - - # Matrix trace on a single tensor: - >>> matrix = np.random.randn(10, 10) - >>> result = einsum(matrix, "i i ->") - >>> result.shape - () - - ``` - - Parameters: - tensors_and_pattern: - tensors: tensors of any supported library (numpy, tensorflow, pytorch, jax). - pattern: string, einsum pattern, with commas - separating specifications for each tensor. - pattern should be provided after all tensors. - - Returns: - Tensor of the same type as input, after processing with einsum. - - """ - if len(tensors_and_pattern) <= 1: - raise ValueError( - "`einops.einsum` takes at minimum two arguments: the tensors (at least one), followed by the pattern." - ) - pattern = tensors_and_pattern[-1] - if not isinstance(pattern, str): - raise ValueError( - "The last argument passed to `einops.einsum` must be a string, representing the einsum pattern." - ) - tensors = tensors_and_pattern[:-1] - pattern = _compactify_pattern_for_einsum(pattern) - return get_backend(tensors[0]).einsum(pattern, *tensors) diff --git a/vllm/thirdparty_files/einops/experimental/__init__.py b/vllm/thirdparty_files/einops/experimental/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/thirdparty_files/einops/experimental/data_api_packing.py b/vllm/thirdparty_files/einops/experimental/data_api_packing.py deleted file mode 100644 index 5e3e04c58c4a..000000000000 --- a/vllm/thirdparty_files/einops/experimental/data_api_packing.py +++ /dev/null @@ -1,137 +0,0 @@ -from typing import List, TypeVar, Tuple, Sequence - -from einops import EinopsError - -T = TypeVar('T') - -Shape = Tuple[int, ...] - - -def pack(pattern: str, tensors: Sequence[T]) -> Tuple[T, List[Shape]]: - axes = pattern.split() - if len(axes) != len(set(axes)): - raise EinopsError(f'Duplicates in axes names in pack("{pattern}", ...)') - if '*' not in axes: - raise EinopsError(f'No *-axis in pack("{pattern}", ...)') - - # need some validation of identifiers - - n_axes_before = axes.index('*') - n_axes_after = len(axes) - n_axes_before - 1 - min_axes = n_axes_before + n_axes_after - - xp = tensors[0].__array_namespace__() - - reshaped_tensors: List[T] = [] - packed_shapes: List[Shape] = [] - for i, tensor in enumerate(tensors): - shape = tensor.shape - if len(shape) < min_axes: - raise EinopsError(f'packed tensor #{i} (enumeration starts with 0) has shape {shape}, ' - f'while pattern {pattern} assumes at least {min_axes} axes') - axis_after_packed_axes = len(shape) - n_axes_after - packed_shapes.append(shape[n_axes_before:]) - reshaped_tensors.append( - xp.reshape(tensor, (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:])) - ) - - return xp.concat(reshaped_tensors, axis=n_axes_before), packed_shapes - - -def prod(x: Shape) -> int: - result = 1 - for i in x: - result *= i - return result - - -def unpack(pattern: str, tensor: T, packed_shapes: List[Shape]) -> List[T]: - axes = pattern.split() - if len(axes) != len(set(axes)): - raise EinopsError(f'Duplicates in axes names in unpack("{pattern}", ...)') - if '*' not in axes: - raise EinopsError(f'No *-axis in unpack("{pattern}", ...)') - - # need some validation of identifiers - - input_shape = tensor.shape - if len(input_shape) != len(axes): - raise EinopsError(f'unpack({pattern}, ...) received input of wrong dim with shape {input_shape}') - - unpacked_axis = axes.index('*') - - lengths_of_composed_axes: List[int] = [ - -1 if -1 in p_shape else prod(p_shape) - for p_shape in packed_shapes - ] - - n_unknown_composed_axes = sum(x == -1 for x in lengths_of_composed_axes) - if n_unknown_composed_axes > 1: - raise EinopsError( - f"unpack({pattern}, ...) received more than one -1 in {packed_shapes} and can't infer dimensions" - ) - - # following manipulations allow to skip some shape verifications - # and leave them to backends - - # [[], [2, 3], [4], [-1, 5], [6]] < examples of packed_axis - # split positions when computed should be - # [0, 1, 7, 11, N-6 , N ], where N = length of axis - split_positions = [0] * len(packed_shapes) + [input_shape[unpacked_axis]] - if n_unknown_composed_axes == 0: - for i, x in enumerate(lengths_of_composed_axes[:-1]): - split_positions[i + 1] = split_positions[i] + x - else: - unknown_composed_axis: int = lengths_of_composed_axes.index(-1) - for i in range(unknown_composed_axis): - split_positions[i + 1] = split_positions[i] + lengths_of_composed_axes[i] - for j in range(unknown_composed_axis + 1, len(lengths_of_composed_axes))[::-1]: - split_positions[j] = split_positions[j + 1] + lengths_of_composed_axes[j] - - xp = tensor.__array_namespace__() - shape_start = input_shape[:unpacked_axis] - shape_end = input_shape[unpacked_axis + 1:] - slice_filler = (slice(None, None),) * unpacked_axis - return [ - xp.reshape( - # shortest way slice arbitrary axis - tensor[(*slice_filler, slice(split_positions[i], split_positions[i + 1]))], - (*shape_start, *element_shape, *shape_end) - ) - for i, element_shape in enumerate(packed_shapes) - ] - - -if __name__ == '__main__': - import numpy.array_api as np - - H = 100 - W = 101 - C = 3 - - r = np.zeros((H, W)) - g = np.zeros((H, W)) - b = np.zeros((H, W)) - embeddings = np.zeros((H, W, 32)) - - im = np.stack([r, g, b], axis=-1) - print(im.shape) - - image, shapes = pack('h w *', [r, g, b]) - print(image.shape, shapes) - - print(type(image)) - print(type(im)) - assert np.all(np.equal(image, im)) - - images_and_embedding, shapes = pack('h w *', [r, g, b, embeddings]) - print(images_and_embedding.shape, shapes) - r2, g2, b2, embeddings2 = unpack('h w *', images_and_embedding, shapes) - assert np.all(np.equal(r, r2)) - assert np.all(np.equal(g, g2)) - assert np.all(np.equal(b, b2)) - assert np.all(np.equal(embeddings, embeddings2)) - - print([x.shape for x in unpack('h w *', images_and_embedding, shapes[1:])]) - - print('all is fine') diff --git a/vllm/thirdparty_files/einops/experimental/indexing.py b/vllm/thirdparty_files/einops/experimental/indexing.py deleted file mode 100644 index 35cef4b304e5..000000000000 --- a/vllm/thirdparty_files/einops/experimental/indexing.py +++ /dev/null @@ -1,393 +0,0 @@ -""" - -Indexing one array with the other(s). - -Concept for discussion. - -Notation targets hard cases, not simple ones, like indexing of 1d-array with another 1d-array -(notation supports that, but you can't simplify arr[ind], and there is no reason to) - -Examples - -1. query for every token in sequence a token in the image. Images and sequences are paired - einindex('b t c <- b h w c, [h, w] b t', arr_bhwc, [h_indices_bt, w_indices_bt]) - - this is equivalent, so you can pass indexers idependently or together - einindex('b t c <- b h w c, [h, w] b t', arr_bhwc, np.asarray([h_indices_bt, w_indices_bt])) - - after some thinking I decided that having first axis for indexing variable is not too restrictive, - but should simplify mapping of such cases. - For this reason [...] part should always go first in indexer. - - This makes the largest difference with einindex https://github.com/malmaud/einindex, - which has almost identical grammar, but puts special dimension last, while we put it first. - This trick allows naturally decomposing multiindex into individual dimensions or visa versa. - - -2. query for every token in the video the most suitable word in a (matching) sentence - einindex('b t h w <- seq b, [seq] t b h w', arr_tbc, [t_indices_bhw]) - - note, that only one indexer is used, but still it has to be enclosed in the list. - That's a price for being generic. Alternatively leading singleton dimension can be added. - - -3. (not supported now, future planning) - for every timeframe in a video, find the token with the highest norm (across h and w), and compose a new stack of them - indices_2bt = argmax(x_bthwc.norm(dim=-1), 'b t h w -> [h, w] b t') - selected_embeddings_btc = einindex('b t c <- b t h w c, [h, w] b t', x_bthwc, indices_2bt) - - while currently question is around 'how do we index', - it is important to pre-align that with a question 'what are natural ways to get indices'. - Most common are min/max. less common options: topk (works here), random sampling. - - - -Some important properties of this notation: -- support for multiple indexers, including using a single tensor to keep multiple indexers -- 'batch' indexing, when some axes of indexer and array should be matched -- universal (one-indexing-to-rule-them-all) -- extensible for (named) ellipses, including variadic number of indexers -- extensible for einops-style compositions and decompositions -- extensible for outer indexing when indexers are not aligned - -Current implementation based on python array api and uses loops, -because no appropriate indexing available in the standard. - -""" - -from typing import List, Union, TypeVar, Tuple - -from einops import EinopsError - -T = TypeVar('T') - - -class CompositionDecomposition: - def __init__( - self, - decomposed_shape: List[str], - composed_shape: List[List[str]], - ): - flat_shape = [] - for x in composed_shape: - flat_shape.extend(x) - - self.compose_transposition: Tuple[int, ...] = tuple([decomposed_shape.index(x) for x in flat_shape]) - self.decompose_transposition: Tuple[int, ...] = tuple([flat_shape.index(x) for x in decomposed_shape]) - self.composed_shape = composed_shape - self.decomposed_shape = decomposed_shape - - def decompose(self, x, known_axes_lengths: dict[str, int]): - xp = x.__array_namespace__() - shape = x.shape - - flat_shape = [] - - for i, axis_group in enumerate(self.composed_shape): - unknown_axis_name = None - known_sizes_prod = 1 - for axis_name in axis_group: - if axis_name in known_axes_lengths: - known_sizes_prod *= known_axes_lengths[axis_name] - else: - if unknown_axis_name is None: - unknown_axis_name = axis_name - else: - raise EinopsError("Can't infer the size") - - if unknown_axis_name is None: - assert shape[i] == known_sizes_prod - else: - known_axes_lengths[unknown_axis_name] = shape[i] // known_sizes_prod - - for axis in axis_group: - flat_shape.append(known_axes_lengths[axis]) - - x = xp.reshape(x, flat_shape) - return xp.permute_dims(x, self.decompose_transposition) - - def compose(self, x, known_axes_lengths: dict[str, int]): - xp = x.__array_namespace__() - - for axis_len, axis_name in zip(x.shape, self.decomposed_shape): - if axis_name in known_axes_lengths: - assert known_axes_lengths[axis_name] == axis_len - else: - known_axes_lengths[axis_name] = axis_len - - x = xp.permute_dims(x, self.compose_transposition) - new_shape = [] - for axis_group in self.composed_shape: - composed_axis_size = 1 - for axis_name in axis_group: - composed_axis_size *= known_axes_lengths[axis_name] - new_shape.append(composed_axis_size) - - return xp.reshape(x, tuple(new_shape)) - - -def arange_at_position(xp, n_axes, axis, axis_len, device=None): - x = xp.arange(axis_len, dtype=xp.int64, device=device) - shape = [1] * n_axes - shape[axis] = axis_len - x = xp.reshape(x, shape) - return x - - -class IndexingFormula: - - def __init__(self, pattern: str): - """ - :param pattern: example 'b t c <- b hsel wsel c, [hsel, wsel] b t' - """ - self.pattern = pattern - left, right = pattern.split('<-') - arg_split = right.index(',') - arr_pattern, ind_pattern = right[:arg_split], right[arg_split + 1:] - ind_pattern = ind_pattern.strip() - # print( - # arr_pattern, '\n', - # ind_pattern, - # ) - assert ind_pattern.startswith('['), 'composition axis should go first in indexer (second argument) [h w] i j k' - composition_start = ind_pattern.index('[') - composition_end = ind_pattern.index(']') - composition = ind_pattern[composition_start + 1: composition_end] - ind_other_axes = ind_pattern[composition_end + 1:] - - self.result_axes_names = left.split() - self.array_axes_names = arr_pattern.split() - self.indexing_axes_names = [x.strip() for x in composition.split(',')] - self.indexer_other_axes_names = ind_other_axes.split() - - for group_name, group in [ - ('result', self.result_axes_names), - ('array', self.array_axes_names), - ('indexer', self.indexing_axes_names + self.indexer_other_axes_names), - ]: - if len(set(group)) != len(group): - # need more verbosity, which axis, raise - raise EinopsError(f'{group_name} pattern ({group}) contains a duplicated axis') - - axis_groups = [ - self.result_axes_names, - self.array_axes_names, - self.indexing_axes_names, - self.indexer_other_axes_names, - ] - - all_axes = set() - for group in axis_groups: - all_axes.update(group) - - self.indexer_axes = [] - self.batch_axes = [] - self.result_and_index_axes = [] - self.result_and_array_axes = [] - - for axis in all_axes: - presence = tuple(axis in g for g in axis_groups) - # want match-case here. sweet dreams - if presence == (False, True, True, False): - self.indexer_axes.append(axis) - elif presence[2]: - raise EinopsError(f'Wrong usage of indexer variable {axis}') - elif presence == (True, True, False, True): - self.batch_axes.append(axis) - elif presence == (True, False, False, True): - self.result_and_index_axes.append(axis) - elif presence == (True, True, False, False): - self.result_and_array_axes.append(axis) - else: - # TODO better categorization of wrong usage patterns - raise EinopsError(f'{axis} is used incorrectly in {pattern}') - - assert set(self.indexer_axes) == set(self.indexing_axes_names) - # order of these variables matters, since we can't lose mapping here - self.indexer_axes = self.indexing_axes_names - - self.array_composition = CompositionDecomposition( - decomposed_shape=self.array_axes_names, - composed_shape=[self.batch_axes + self.indexer_axes, self.result_and_array_axes], - ) - - self.index_composition = CompositionDecomposition( - decomposed_shape=self.indexer_other_axes_names, - # single axis after composition - composed_shape=[self.batch_axes + self.result_and_index_axes], - ) - - self.result_composition = CompositionDecomposition( - decomposed_shape=self.result_axes_names, - composed_shape=[self.batch_axes + self.result_and_index_axes, self.result_and_array_axes], - ) - - def apply_to_array_api(self, arr: T, ind: Union[T, List[T]]): - known_axes_sizes: dict[str, int] = {} - xp = arr.__array_namespace__() - - if not isinstance(ind, list): - ind = [ind[i, ...] for i in range(ind.shape[0])] - - for indexer in ind: - assert len(indexer.shape) == len(self.indexer_other_axes_names) - - # step 1. transpose, reshapes of arr; learn its dimensions - arr_2d = self.array_composition.compose(arr, known_axes_sizes) - - # step 2. compute shifts and create an actual indexing array - shift = 1 - full_index = xp.zeros([1] * len(ind[0].shape), dtype=xp.int64, device=arr.device) - - # original order: [*batch-like axes, *indexing_axes,] - # now we need to traverse them in the opposite direction - - for axis_name, indexer in list(zip(self.indexing_axes_names, ind))[::-1]: - full_index = full_index + shift * (indexer % known_axes_sizes[axis_name]) - shift *= known_axes_sizes[axis_name] - - for axis_name in self.batch_axes[::-1]: - axis_id = self.indexer_other_axes_names.index(axis_name) - full_index = full_index + arange_at_position( - xp, len(self.indexer_other_axes_names), axis=axis_id, axis_len=known_axes_sizes[axis_name], - device=arr.device, - ) * shift - shift *= known_axes_sizes[axis_name] - - assert shift == arr_2d.shape[0] - - # step 3. Flatten index - full_index = self.index_composition.compose(full_index, known_axes_sizes) - - # step 4. indexing - # python array api lacks any integer indexing, so... I use loops. - # did you know that there is conceptual programming ... just like art? - # result_2d = arr_2d[full_index] - result_2d = xp.stack([arr_2d[full_index[i], :] for i in range(full_index.shape[0])]) - - # step 5. doing resulting - result = self.result_composition.decompose(result_2d, known_axes_sizes) - return result - - -def einindex(pattern: str, arr: T, /, ind: Union[T, List[T]]): - """ - Demonstrates how einindex should work. - Supports data-api compliant arrays. - """ - formula = IndexingFormula(pattern) - return formula.apply_to_array_api(arr, ind) - - -def test_composition_and_decomposition(): - import numpy.array_api as np - x = np.arange(2 * 3 * 5 * 7) - x = np.reshape(x, (2, 3, 5, 7)) - comp = CompositionDecomposition( - decomposed_shape=['a', 'b', 'c', 'd'], - composed_shape=[['a', 'b'], ['c', 'd']], - ) - assert comp.compose(x, known_axes_lengths={}).shape == (2 * 3, 5 * 7) - - y = CompositionDecomposition( - decomposed_shape=['a', 'b', 'c', 'd'], - composed_shape=[['a', 'b'], [], ['c', 'd']], - ).compose(x, {}) - assert y.shape == (2 * 3, 1, 5 * 7) - assert np.all(np.reshape(x, (-1,)) == np.reshape(y, (-1,))) - - comp = CompositionDecomposition( - decomposed_shape=['a', 'b', 'e', 'c', 'd'], - composed_shape=[['e', 'c'], ['b'], ['a', 'd']], - ) - x = np.arange(2 * 3 * 5 * 7 * 3) - x = np.reshape(x, (2, 3, 5, 7, 3)) - - axes = {} - y = comp.compose(x, axes) - x2 = comp.decompose(y, axes) - assert np.all(x == x2) - - -def test_simple_indexing(): - import numpy.array_api as np - - # simple 2d test - arr = np.reshape(np.arange(5 * 7), (5, 7)) - ind = np.arange(7) % 5 - x = einindex('j <- i j, [i] j', arr, [ind]) - for j, i in enumerate(ind): - assert arr[i, j] == x[j] - - y = einindex('j <- j i, [i] j', np.permute_dims(arr, (1, 0)), [ind]) - for j, i in enumerate(ind): - assert arr[i, j] == y[j] - - -def test_multidimensional_indexing(): - import numpy.array_api as np - - embedding_bhwc = ( - + arange_at_position(np, 4, 0, 2) * 1000 - + arange_at_position(np, 4, 1, 3) * 100 - + arange_at_position(np, 4, 2, 5) * 10 - + arange_at_position(np, 4, 3, 7) * 1 - ) - - hindices_bt = np.reshape(np.arange(6), (2, 3)) % 3 - windices_bt = np.reshape(np.arange(6), (2, 3)) % 5 - - # imagine that you have pairs of image <> sentence - # your goal is to get most suitable token from image for every token in sentence - # thus for every token in sentence you compute best k and v - - result = einindex('c t b <- b h w c, [h, w] b t', embedding_bhwc, [hindices_bt, windices_bt]) - # example of using a single array for indexing multiple axes - hw_indices_bt = np.stack([hindices_bt, windices_bt]) - result2 = einindex('c t b <- b h w c, [h, w] b t', embedding_bhwc, hw_indices_bt) - assert np.all(result == result2) - - # check vs manual element computation - result_manual = result * 0 - for b in range(2): - for t in range(3): - for c in range(7): - h = hindices_bt[b, t] - w = windices_bt[b, t] - result_manual[c, t, b] = embedding_bhwc[b, h, w, c] - - assert np.all(result == result_manual) - - -def test_reverse_indexing(): - import numpy.array_api as np - - C, T, B = 2, 3, 5 - # G = GPU, batch-like varaible - G = 4 - H = 7 - W = 9 - - arr_gtbc = ( - + arange_at_position(np, 4, 0, G) * 1000 - + arange_at_position(np, 4, 1, T) * 100 - + arange_at_position(np, 4, 2, B) * 10 - + arange_at_position(np, 4, 3, C) * 1 - ) - - t_indices_gbhw = np.reshape(np.arange(G * B * H * W), (G, B, H, W)) % T - - result = einindex('g b c h w <- g t b c, [t] g b h w', arr_gtbc, [t_indices_gbhw]) - - result_manual = result * 0 - for g in range(G): - for b in range(B): - for c in range(C): - for h in range(H): - for w in range(W): - t = t_indices_gbhw[g, b, h, w] - result_manual[g, b, c, h, w] = arr_gtbc[g, t, b, c] - - assert np.all(result == result_manual) - - diff --git a/vllm/thirdparty_files/einops/layers/__init__.py b/vllm/thirdparty_files/einops/layers/__init__.py deleted file mode 100644 index cc3118636e76..000000000000 --- a/vllm/thirdparty_files/einops/layers/__init__.py +++ /dev/null @@ -1,106 +0,0 @@ -__author__ = 'Alex Rogozhnikov' - -from typing import Any, Dict - - -from ..einops import TransformRecipe, _apply_recipe, _prepare_recipes_for_all_dims, get_backend -from .. import EinopsError - - -class RearrangeMixin: - """ - Rearrange layer behaves identically to einops.rearrange operation. - - :param pattern: str, rearrangement pattern - :param axes_lengths: any additional specification of dimensions - - See einops.rearrange for source_examples. - """ - - def __init__(self, pattern: str, **axes_lengths: Any) -> None: - super().__init__() - self.pattern = pattern - self.axes_lengths = axes_lengths - # self._recipe = self.recipe() # checking parameters - self._multirecipe = self.multirecipe() - self._axes_lengths = tuple(self.axes_lengths.items()) - - def __repr__(self) -> str: - params = repr(self.pattern) - for axis, length in self.axes_lengths.items(): - params += ', {}={}'.format(axis, length) - return '{}({})'.format(self.__class__.__name__, params) - - def multirecipe(self) -> Dict[int, TransformRecipe]: - try: - return _prepare_recipes_for_all_dims( - self.pattern, operation='rearrange', axes_names=tuple(self.axes_lengths) - ) - except EinopsError as e: - raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e)) - - def _apply_recipe(self, x): - backend = get_backend(x) - return _apply_recipe( - backend=backend, - recipe=self._multirecipe[len(x.shape)], - tensor=x, - reduction_type='rearrange', - axes_lengths=self._axes_lengths, - ) - - def __getstate__(self): - return {'pattern': self.pattern, 'axes_lengths': self.axes_lengths} - - def __setstate__(self, state): - self.__init__(pattern=state['pattern'], **state['axes_lengths']) - - -class ReduceMixin: - """ - Reduce layer behaves identically to einops.reduce operation. - - :param pattern: str, rearrangement pattern - :param reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive - :param axes_lengths: any additional specification of dimensions - - See einops.reduce for source_examples. - """ - - def __init__(self, pattern: str, reduction: str, **axes_lengths: Any): - super().__init__() - self.pattern = pattern - self.reduction = reduction - self.axes_lengths = axes_lengths - self._multirecipe = self.multirecipe() - self._axes_lengths = tuple(self.axes_lengths.items()) - - def __repr__(self): - params = '{!r}, {!r}'.format(self.pattern, self.reduction) - for axis, length in self.axes_lengths.items(): - params += ', {}={}'.format(axis, length) - return '{}({})'.format(self.__class__.__name__, params) - - def multirecipe(self) -> Dict[int, TransformRecipe]: - try: - return _prepare_recipes_for_all_dims( - self.pattern, operation=self.reduction, axes_names=tuple(self.axes_lengths) - ) - except EinopsError as e: - raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e)) - - def _apply_recipe(self, x): - backend = get_backend(x) - return _apply_recipe( - backend=backend, - recipe=self._multirecipe[len(x.shape)], - tensor=x, - reduction_type=self.reduction, - axes_lengths=self._axes_lengths, - ) - - def __getstate__(self): - return {'pattern': self.pattern, 'reduction': self.reduction, 'axes_lengths': self.axes_lengths} - - def __setstate__(self, state): - self.__init__(pattern=state['pattern'], reduction=state['reduction'], **state['axes_lengths']) diff --git a/vllm/thirdparty_files/einops/layers/_einmix.py b/vllm/thirdparty_files/einops/layers/_einmix.py deleted file mode 100644 index 5d0d4bcf740b..000000000000 --- a/vllm/thirdparty_files/einops/layers/_einmix.py +++ /dev/null @@ -1,176 +0,0 @@ -from typing import Any, List, Optional, Dict - -from einops import EinopsError -from einops.parsing import ParsedExpression -import warnings -import string -from ..einops import _product - - -def _report_axes(axes: set, report_message: str): - if len(axes) > 0: - raise EinopsError(report_message.format(axes)) - - -class _EinmixMixin: - def __init__(self, pattern: str, weight_shape: str, bias_shape: Optional[str]=None, **axes_lengths: Any): - """ - EinMix - Einstein summation with automated tensor management and axis packing/unpacking. - - EinMix is an advanced tool, helpful tutorial: - https://github.com/arogozhnikov/einops/blob/master/docs/3-einmix-layer.ipynb - - Imagine taking einsum with two arguments, one of each input, and one - tensor with weights - >>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight) - - This layer manages weights for you, syntax highlights separate role of weight matrix - >>> EinMix('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out') - But otherwise it is the same einsum under the hood. - - Simple linear layer with bias term (you have one like that in your framework) - >>> EinMix('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20) - There is no restriction to mix the last axis. Let's mix along height - >>> EinMix('h w c-> hout w c', weight_shape='h hout', bias_shape='hout', h=32, hout=32) - Channel-wise multiplication (like one used in normalizations) - >>> EinMix('t b c -> t b c', weight_shape='c', c=128) - Multi-head linear layer (each head is own linear layer): - >>> EinMix('t b (head cin) -> t b (head cout)', weight_shape='head cin cout', ...) - - ... ah yes, you need to specify all dimensions of weight shape/bias shape in parameters. - - Use cases: - - when channel dimension is not last, use EinMix, not transposition - - patch/segment embeddings - - when need only within-group connections to reduce number of weights and computations - - perfect as a part of sequential models - - next-gen MLPs (follow tutorial to learn more) - - Uniform He initialization is applied to weight tensor and encounters for number of elements mixed. - - Parameters - :param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output - :param weight_shape: axes of weight. A tensor of this shape is created, stored, and optimized in a layer - :param bias_shape: axes of bias added to output. Weights of this shape are created and stored. If `None` (the default), no bias is added. - :param axes_lengths: dimensions of weight tensor - """ - super().__init__() - self.pattern = pattern - self.weight_shape = weight_shape - self.bias_shape = bias_shape - self.axes_lengths = axes_lengths - self.initialize_einmix(pattern=pattern, weight_shape=weight_shape, bias_shape=bias_shape, axes_lengths=axes_lengths) - - def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optional[str], axes_lengths: dict): - left_pattern, right_pattern = pattern.split('->') - left = ParsedExpression(left_pattern) - right = ParsedExpression(right_pattern) - weight = ParsedExpression(weight_shape) - _report_axes( - set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}), - 'Unrecognized identifiers on the right side of EinMix {}' - ) - - if left.has_ellipsis or right.has_ellipsis or weight.has_ellipsis: - raise EinopsError('Ellipsis is not supported in EinMix (right now)') - if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]): - raise EinopsError('Anonymous axes (numbers) are not allowed in EinMix') - if '(' in weight_shape or ')' in weight_shape: - raise EinopsError(f'Parenthesis is not allowed in weight shape: {weight_shape}') - - pre_reshape_pattern = None - pre_reshape_lengths = None - post_reshape_pattern = None - if any(len(group) != 1 for group in left.composition): - names: List[str] = [] - for group in left.composition: - names += group - composition = ' '.join(names) - pre_reshape_pattern = f'{left_pattern}->{composition}' - pre_reshape_lengths = {name: length for name, length in axes_lengths.items() if name in names} - - if any(len(group) != 1 for group in right.composition): - names = [] - for group in right.composition: - names += group - composition = ' '.join(names) - post_reshape_pattern = f'{composition}->{right_pattern}' - - self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths, post_reshape_pattern, {}) - - for axis in weight.identifiers: - if axis not in axes_lengths: - raise EinopsError('Dimension {} of weight should be specified'.format(axis)) - _report_axes( - set.difference(set(axes_lengths), {*left.identifiers, *weight.identifiers}), - 'Axes {} are not used in pattern', - ) - _report_axes( - set.difference(weight.identifiers, {*left.identifiers, *right.identifiers}), - 'Weight axes {} are redundant' - ) - if len(weight.identifiers) == 0: - warnings.warn('EinMix: weight has no dimensions (means multiplication by a number)') - - _weight_shape = [axes_lengths[axis] for axis, in weight.composition] - # single output element is a combination of fan_in input elements - _fan_in = _product([axes_lengths[axis] for axis, in weight.composition if axis not in right.identifiers]) - if bias_shape is not None: - if not isinstance(bias_shape, str): - raise EinopsError('bias shape should be string specifying which axes bias depends on') - bias = ParsedExpression(bias_shape) - _report_axes( - set.difference(bias.identifiers, right.identifiers), - 'Bias axes {} not present in output' - ) - _report_axes( - set.difference(bias.identifiers, set(axes_lengths)), - 'Sizes not provided for bias axes {}', - ) - - _bias_shape = [] - for axes in right.composition: - for axis in axes: - if axis in bias.identifiers: - _bias_shape.append(axes_lengths[axis]) - else: - _bias_shape.append(1) - else: - _bias_shape = None - - weight_bound = (3 / _fan_in) ** 0.5 - bias_bound = (1 / _fan_in) ** 0.5 - self._create_parameters(_weight_shape, weight_bound, _bias_shape, bias_bound) - - # rewrite einsum expression with single-letter latin identifiers so that - # expression will be understood by any framework - mapped_identifiers = {*left.identifiers, *right.identifiers, *weight.identifiers} - mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapped_identifiers)} - - def write_flat(axes: list): - return ''.join(mapping2letters[axis] for axis in axes) - - self.einsum_pattern: str = '{},{}->{}'.format( - write_flat(left.flat_axes_order()), - write_flat(weight.flat_axes_order()), - write_flat(right.flat_axes_order()), - ) - - def _create_rearrange_layers(self, - pre_reshape_pattern: Optional[str], - pre_reshape_lengths: Optional[Dict], - post_reshape_pattern: Optional[str], - post_reshape_lengths: Optional[Dict]): - raise NotImplementedError('Should be defined in framework implementations') - - def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): - """ Shape and implementations """ - raise NotImplementedError('Should be defined in framework implementations') - - def __repr__(self): - params = repr(self.pattern) - params += f", '{self.weight_shape}'" - if self.bias_shape is not None: - params += f", '{self.bias_shape}'" - for axis, length in self.axes_lengths.items(): - params += ', {}={}'.format(axis, length) - return '{}({})'.format(self.__class__.__name__, params) diff --git a/vllm/thirdparty_files/einops/layers/chainer.py b/vllm/thirdparty_files/einops/layers/chainer.py deleted file mode 100644 index 0214bc323c4e..000000000000 --- a/vllm/thirdparty_files/einops/layers/chainer.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Optional, Dict, cast - -import chainer - -from . import RearrangeMixin, ReduceMixin -from ._einmix import _EinmixMixin - -__author__ = 'Alex Rogozhnikov' - - -class Rearrange(RearrangeMixin, chainer.Link): - def __call__(self, x): - return self._apply_recipe(x) - - -class Reduce(ReduceMixin, chainer.Link): - def __call__(self, x): - return self._apply_recipe(x) - - -class EinMix(_EinmixMixin, chainer.Link): - def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): - uniform = chainer.variable.initializers.Uniform - with self.init_scope(): - self.weight = chainer.variable.Parameter(uniform(weight_bound), weight_shape) - if bias_shape is not None: - self.bias = chainer.variable.Parameter(uniform(bias_bound), bias_shape) - else: - self.bias = None - - def _create_rearrange_layers(self, - pre_reshape_pattern: Optional[str], - pre_reshape_lengths: Optional[Dict], - post_reshape_pattern: Optional[str], - post_reshape_lengths: Optional[Dict], - ): - self.pre_rearrange = None - if pre_reshape_pattern is not None: - self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths)) - - self.post_rearrange = None - if post_reshape_pattern is not None: - self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths)) - - def __call__(self, input): - if self.pre_rearrange is not None: - input = self.pre_rearrange(input) - result = chainer.functions.einsum(self.einsum_pattern, input, self.weight) - if self.bias is not None: - result = result + self.bias - if self.post_rearrange is not None: - result = self.post_rearrange(result) - return result diff --git a/vllm/thirdparty_files/einops/layers/flax.py b/vllm/thirdparty_files/einops/layers/flax.py deleted file mode 100644 index abd4ec5b4224..000000000000 --- a/vllm/thirdparty_files/einops/layers/flax.py +++ /dev/null @@ -1,80 +0,0 @@ -from dataclasses import field -from typing import Optional, Dict, cast - -import flax.linen as nn -import jax -import jax.numpy as jnp - -from . import RearrangeMixin, ReduceMixin -from ._einmix import _EinmixMixin - -__author__ = 'Alex Rogozhnikov' - - -class Reduce(nn.Module): - pattern: str - reduction: str - sizes: dict = field(default_factory=lambda: {}) - - def setup(self): - self.reducer = ReduceMixin(self.pattern, self.reduction, **self.sizes) - - def __call__(self, input): - return self.reducer._apply_recipe(input) - - -class Rearrange(nn.Module): - pattern: str - sizes: dict = field(default_factory=lambda: {}) - - def setup(self): - self.rearranger = RearrangeMixin(self.pattern, **self.sizes) - - def __call__(self, input): - return self.rearranger._apply_recipe(input) - - -class EinMix(nn.Module, _EinmixMixin): - pattern: str - weight_shape: str - bias_shape: Optional[str] = None - sizes: dict = field(default_factory=lambda: {}) - - def setup(self): - self.initialize_einmix( - pattern=self.pattern, - weight_shape=self.weight_shape, - bias_shape=self.bias_shape, - axes_lengths=self.sizes, - ) - - def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): - self.weight = self.param("weight", jax.nn.initializers.uniform(weight_bound), weight_shape) - - if bias_shape is not None: - self.bias = self.param("bias", jax.nn.initializers.uniform(bias_bound), bias_shape) - else: - self.bias = None - - def _create_rearrange_layers(self, - pre_reshape_pattern: Optional[str], - pre_reshape_lengths: Optional[Dict], - post_reshape_pattern: Optional[str], - post_reshape_lengths: Optional[Dict]): - self.pre_rearrange = None - if pre_reshape_pattern is not None: - self.pre_rearrange = Rearrange(pre_reshape_pattern, sizes=cast(dict, pre_reshape_lengths)) - - self.post_rearrange = None - if post_reshape_pattern is not None: - self.post_rearrange = Rearrange(post_reshape_pattern, sizes=cast(dict, post_reshape_lengths)) - - def __call__(self, input): - if self.pre_rearrange is not None: - input = self.pre_rearrange(input) - result = jnp.einsum(self.einsum_pattern, input, self.weight) - if self.bias is not None: - result += self.bias - if self.post_rearrange is not None: - result = self.post_rearrange(result) - return result diff --git a/vllm/thirdparty_files/einops/layers/keras.py b/vllm/thirdparty_files/einops/layers/keras.py deleted file mode 100644 index e2533a2f7743..000000000000 --- a/vllm/thirdparty_files/einops/layers/keras.py +++ /dev/null @@ -1,9 +0,0 @@ -__author__ = 'Alex Rogozhnikov' - -from ..layers.tensorflow import Rearrange, Reduce, EinMix - -keras_custom_objects = { - Rearrange.__name__: Rearrange, - Reduce.__name__: Reduce, - EinMix.__name__: EinMix, -} diff --git a/vllm/thirdparty_files/einops/layers/oneflow.py b/vllm/thirdparty_files/einops/layers/oneflow.py deleted file mode 100644 index 2885404db2c1..000000000000 --- a/vllm/thirdparty_files/einops/layers/oneflow.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Optional, Dict, cast - -import oneflow as flow - -from . import RearrangeMixin, ReduceMixin -from ._einmix import _EinmixMixin - -__author__ = 'Tianhe Ren & Depeng Liang' - - -class Rearrange(RearrangeMixin, flow.nn.Module): - def forward(self, input): - return self._apply_recipe(input) - - -class Reduce(ReduceMixin, flow.nn.Module): - def forward(self, input): - return self._apply_recipe(input) - - -class EinMix(_EinmixMixin, flow.nn.Module): - def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): - self.weight = flow.nn.Parameter(flow.zeros(weight_shape).uniform_(-weight_bound, weight_bound), - requires_grad=True) - if bias_shape is not None: - self.bias = flow.nn.Parameter(flow.zeros(bias_shape).uniform_(-bias_bound, bias_bound), - requires_grad=True) - else: - self.bias = None - - def _create_rearrange_layers(self, - pre_reshape_pattern: Optional[str], - pre_reshape_lengths: Optional[Dict], - post_reshape_pattern: Optional[str], - post_reshape_lengths: Optional[Dict], - ): - self.pre_rearrange = None - if pre_reshape_pattern is not None: - self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths)) - - self.post_rearrange = None - if post_reshape_pattern is not None: - self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths)) - - def forward(self, input): - if self.pre_rearrange is not None: - input = self.pre_rearrange(input) - result = flow.einsum(self.einsum_pattern, input, self.weight) - if self.bias is not None: - result += self.bias - if self.post_rearrange is not None: - result = self.post_rearrange(result) - return result diff --git a/vllm/thirdparty_files/einops/layers/paddle.py b/vllm/thirdparty_files/einops/layers/paddle.py deleted file mode 100644 index c3335604a4b8..000000000000 --- a/vllm/thirdparty_files/einops/layers/paddle.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import Optional, Dict, cast - -import paddle - -from . import RearrangeMixin, ReduceMixin -from ._einmix import _EinmixMixin - -__author__ = 'PaddlePaddle' - - -class Rearrange(RearrangeMixin, paddle.nn.Layer): - def forward(self, input): - return self._apply_recipe(input) - - -class Reduce(ReduceMixin, paddle.nn.Layer): - def forward(self, input): - return self._apply_recipe(input) - - -class EinMix(_EinmixMixin, paddle.nn.Layer): - def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): - self.weight = self.create_parameter( - weight_shape, - default_initializer=paddle.nn.initializer.Uniform(-weight_bound, weight_bound) - ) - - if bias_shape is not None: - self.bias = self.create_parameter( - bias_shape, - default_initializer=paddle.nn.initializer.Uniform(-bias_bound, bias_bound) - ) - else: - self.bias = None - - def _create_rearrange_layers(self, - pre_reshape_pattern: Optional[str], - pre_reshape_lengths: Optional[Dict], - post_reshape_pattern: Optional[str], - post_reshape_lengths: Optional[Dict], - ): - self.pre_rearrange = None - if pre_reshape_pattern is not None: - self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths)) - - self.post_rearrange = None - if post_reshape_pattern is not None: - self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths)) - - def forward(self, input): - if self.pre_rearrange is not None: - input = self.pre_rearrange(input) - - result = paddle.einsum(self.einsum_pattern, input, self.weight) - if self.bias is not None: - result += self.bias - if self.post_rearrange is not None: - result = self.post_rearrange(result) - return result \ No newline at end of file diff --git a/vllm/thirdparty_files/einops/layers/tensorflow.py b/vllm/thirdparty_files/einops/layers/tensorflow.py deleted file mode 100644 index c89a71ad6048..000000000000 --- a/vllm/thirdparty_files/einops/layers/tensorflow.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import List, Optional, Dict, cast - -import tensorflow as tf -from tensorflow.keras.layers import Layer - -from .._backends import UnknownSize -from . import RearrangeMixin, ReduceMixin -from ._einmix import _EinmixMixin -from ..einops import TransformRecipe, _reconstruct_from_shape_uncached - -__author__ = 'Alex Rogozhnikov' - - -def _compute_output_shape(recipe: TransformRecipe, input_shape) -> List[Optional[int]]: - input_shape = [UnknownSize() if d is None else int(d) for d in input_shape] - init_shapes, reduced_axes, axes_reordering, added_axes, final_shape = \ - _reconstruct_from_shape_uncached(recipe, input_shape) - output_shape: List[Optional[int]] = [None if isinstance(d, UnknownSize) else int(d) for d in final_shape] - return output_shape - - -class Rearrange(RearrangeMixin, Layer): - def compute_output_shape(self, input_shape): - return _compute_output_shape(self.recipe(), input_shape) - - def call(self, inputs): - return self._apply_recipe(inputs) - - def get_config(self): - return {'pattern': self.pattern, **self.axes_lengths} - - -class Reduce(ReduceMixin, Layer): - def compute_output_shape(self, input_shape): - return _compute_output_shape(self.recipe(), input_shape) - - def call(self, inputs): - return self._apply_recipe(inputs) - - def get_config(self): - return {'pattern': self.pattern, 'reduction': self.reduction, **self.axes_lengths} - - -class EinMix(_EinmixMixin, Layer): - def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): - self.weight = tf.Variable(tf.random_uniform_initializer(-weight_bound, weight_bound)(shape=weight_shape), - trainable=True) - if bias_shape is not None: - self.bias = tf.Variable(tf.random_uniform_initializer(-bias_bound, bias_bound)(shape=bias_shape), - trainable=True) - else: - self.bias = None - - def _create_rearrange_layers(self, - pre_reshape_pattern: Optional[str], - pre_reshape_lengths: Optional[Dict], - post_reshape_pattern: Optional[str], - post_reshape_lengths: Optional[Dict], - ): - self.pre_rearrange = None - if pre_reshape_pattern is not None: - self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths)) - - self.post_rearrange = None - if post_reshape_pattern is not None: - self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths)) - - def build(self, input_shape): - pass - - def call(self, inputs): - if self.pre_rearrange is not None: - inputs = self.pre_rearrange(inputs) - result = tf.einsum(self.einsum_pattern, inputs, self.weight) - if self.bias is not None: - result = result + self.bias - if self.post_rearrange is not None: - result = self.post_rearrange(result) - return result - - def get_config(self): - return {'pattern': self.pattern, - 'weight_shape': self.weight_shape, - 'bias_shape': self.bias_shape, - **self.axes_lengths} diff --git a/vllm/thirdparty_files/einops/layers/torch.py b/vllm/thirdparty_files/einops/layers/torch.py deleted file mode 100644 index 83abb98f779d..000000000000 --- a/vllm/thirdparty_files/einops/layers/torch.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import Optional, Dict, cast - -import torch - -from . import RearrangeMixin, ReduceMixin -from ._einmix import _EinmixMixin -from .._torch_specific import apply_for_scriptable_torch - -__author__ = 'Alex Rogozhnikov' - - -class Rearrange(RearrangeMixin, torch.nn.Module): - def forward(self, input): - recipe = self._multirecipe[input.ndim] - return apply_for_scriptable_torch( - recipe, input, reduction_type='rearrange', axes_dims=self._axes_lengths - ) - - def _apply_recipe(self, x): - # overriding parent method to prevent it's scripting - pass - - -class Reduce(ReduceMixin, torch.nn.Module): - def forward(self, input): - recipe = self._multirecipe[input.ndim] - return apply_for_scriptable_torch( - recipe, input, reduction_type=self.reduction, axes_dims=self._axes_lengths - ) - - def _apply_recipe(self, x): - # overriding parent method to prevent it's scripting - pass - - -class EinMix(_EinmixMixin, torch.nn.Module): - def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): - self.weight = torch.nn.Parameter(torch.zeros(weight_shape).uniform_(-weight_bound, weight_bound), - requires_grad=True) - if bias_shape is not None: - self.bias = torch.nn.Parameter(torch.zeros(bias_shape).uniform_(-bias_bound, bias_bound), - requires_grad=True) - else: - self.bias = None - - def _create_rearrange_layers(self, - pre_reshape_pattern: Optional[str], - pre_reshape_lengths: Optional[Dict], - post_reshape_pattern: Optional[str], - post_reshape_lengths: Optional[Dict], - ): - self.pre_rearrange = None - if pre_reshape_pattern is not None: - self.pre_rearrange = Rearrange(pre_reshape_pattern, **cast(dict, pre_reshape_lengths)) - - self.post_rearrange = None - if post_reshape_pattern is not None: - self.post_rearrange = Rearrange(post_reshape_pattern, **cast(dict, post_reshape_lengths)) - - def forward(self, input): - if self.pre_rearrange is not None: - input = self.pre_rearrange(input) - result = torch.einsum(self.einsum_pattern, input, self.weight) - if self.bias is not None: - result += self.bias - if self.post_rearrange is not None: - result = self.post_rearrange(result) - return result diff --git a/vllm/thirdparty_files/einops/packing.py b/vllm/thirdparty_files/einops/packing.py deleted file mode 100644 index d47a1bdb0797..000000000000 --- a/vllm/thirdparty_files/einops/packing.py +++ /dev/null @@ -1,191 +0,0 @@ -from functools import lru_cache -from typing import List, Union, TypeVar, Tuple, Sequence - -from einops import EinopsError - -from einops._backends import get_backend -from einops.parsing import ParsedExpression - -Tensor = TypeVar('Tensor') - -Shape = Union[Tuple[int, ...], List[int]] - - -@lru_cache(maxsize=128) -def analyze_pattern(pattern: str, opname: str) -> Tuple[int, int, int]: - # Maybe some validation of identifiers? - axes = pattern.split() - axes_set = set(axes) - if len(axes) != len(axes_set): - raise EinopsError(f'Duplicates in axes names in {opname}(..., "{pattern}")') - if '*' not in axes_set: - raise EinopsError(f'No *-axis in {opname}(..., "{pattern}")') - for axis in axes: - if axis != '*': - is_valid, reason = ParsedExpression.check_axis_name_return_reason(axis) - if not is_valid: - raise EinopsError(f'Invalid axis name {axis} in {opname}(..., "{pattern}")') - n_axes_before = axes.index('*') - n_axes_after = len(axes) - n_axes_before - 1 - min_axes = n_axes_before + n_axes_after - return n_axes_before, n_axes_after, min_axes - - -def pack(tensors: Sequence[Tensor], pattern: str) -> Tuple[Tensor, List[Shape]]: - """ - Packs several tensors into one. - See einops tutorial for introduction into packing (and how it replaces stack and concatenation). - - Parameters: - tensors: tensors to be packed, can be of different dimensionality - pattern: pattern that is shared for all inputs and output, e.g. "i j * k" or "batch seq *" - - Returns: - (packed_tensor, packed_shapes aka PS) - - Example: - ```python - >>> from numpy import zeros as Z - >>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])] - >>> packed, ps = pack(inputs, 'i j * k') - >>> packed.shape, ps - ((2, 3, 71, 5), [(), (7,), (7, 9)]) - ``` - - In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last). - All other axes were 'packed' and concatenated. - PS (packed shapes) contains information about axes that were matched to '*' in every input. - Resulting tensor has as many elements as all inputs in total. - - Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order. - - ```python - >>> inputs_unpacked = unpack(packed, ps, 'i j * k') - >>> [x.shape for x in inputs_unpacked] - [(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)] - ``` - - Read the tutorial for introduction and application scenarios. - """ - n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, 'pack') - - # packing zero tensors is illegal - backend = get_backend(tensors[0]) - - reshaped_tensors: List[Tensor] = [] - packed_shapes: List[Shape] = [] - for i, tensor in enumerate(tensors): - shape = backend.shape(tensor) - if len(shape) < min_axes: - raise EinopsError(f'packed tensor #{i} (enumeration starts with 0) has shape {shape}, ' - f'while pattern {pattern} assumes at least {min_axes} axes') - axis_after_packed_axes = len(shape) - n_axes_after - packed_shapes.append(shape[n_axes_before:axis_after_packed_axes]) - reshaped_tensors.append( - backend.reshape(tensor, (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:])) - ) - - return backend.concat(reshaped_tensors, axis=n_axes_before), packed_shapes - - -def prod(x: Shape) -> int: - result = 1 - for i in x: - result *= i - return result - - -def unpack(tensor: Tensor, packed_shapes: List[Shape], pattern: str) -> List[Tensor]: - """ - Unpacks a single tensor into several by splitting over a selected axes. - See einops tutorial for introduction into packing (and how it replaces stack and concatenation). - - Parameters: - tensor: tensor to be unpacked - packed_shapes: packed_shapes (aka PS) is a list of shapes that take place of '*' in each output. - output will contain a single tensor for every provided shape - pattern: pattern that is shared for input and all outputs, e.g. "i j * k" or "batch seq *", - where * designates an axis to be unpacked - - Returns: - list of tensors - - If framework supports views, results are views to the original tensor. - - Example: - ```python - >>> from numpy import zeros as Z - >>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])] - >>> packed, ps = pack(inputs, 'i j * k') - >>> packed.shape, ps - ((2, 3, 71, 5), [(), (7,), (7, 9)]) - ``` - - In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last). - All other axes were 'packed' and concatenated. - PS (packed shapes) contains information about axes that were matched to '*' in every input. - Resulting tensor has as many elements as all inputs in total. - - Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order. - - ```python - >>> inputs_unpacked = unpack(packed, ps, 'i j * k') - >>> [x.shape for x in inputs_unpacked] - [(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)] - ``` - - Read the tutorial for introduction and application scenarios. - """ - n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, opname='unpack') - - backend = get_backend(tensor) - input_shape = backend.shape(tensor) - if len(input_shape) != n_axes_before + 1 + n_axes_after: - raise EinopsError(f'unpack(..., {pattern}) received input of wrong dim with shape {input_shape}') - - unpacked_axis: int = n_axes_before - - lengths_of_composed_axes: List[int] = [ - -1 if -1 in p_shape else prod(p_shape) - for p_shape in packed_shapes - ] - - n_unknown_composed_axes = sum(x == -1 for x in lengths_of_composed_axes) - if n_unknown_composed_axes > 1: - raise EinopsError( - f"unpack(..., {pattern}) received more than one -1 in {packed_shapes} and can't infer dimensions" - ) - - # following manipulations allow to skip some shape verifications - # and leave it to backends - - # [[], [2, 3], [4], [-1, 5], [6]] < examples of packed_axis - # split positions when computed should be - # [0, 1, 7, 11, N-6 , N ], where N = length of axis - split_positions = [0] * len(packed_shapes) + [input_shape[unpacked_axis]] - if n_unknown_composed_axes == 0: - for i, x in enumerate(lengths_of_composed_axes[:-1]): - split_positions[i + 1] = split_positions[i] + x - else: - unknown_composed_axis: int = lengths_of_composed_axes.index(-1) - for i in range(unknown_composed_axis): - split_positions[i + 1] = split_positions[i] + lengths_of_composed_axes[i] - for j in range(unknown_composed_axis + 1, len(lengths_of_composed_axes))[::-1]: - split_positions[j] = split_positions[j + 1] - lengths_of_composed_axes[j] - - shape_start = input_shape[:unpacked_axis] - shape_end = input_shape[unpacked_axis + 1:] - slice_filler = (slice(None, None),) * unpacked_axis - try: - return [ - backend.reshape( - # shortest way slice arbitrary axis - tensor[(*slice_filler, slice(split_positions[i], split_positions[i + 1]))], - (*shape_start, *element_shape, *shape_end) - ) - for i, element_shape in enumerate(packed_shapes) - ] - except BaseException: - # this hits if there is an error during reshapes, which means passed shapes were incorrect - raise RuntimeError(f'Error during unpack(..., "{pattern}"): could not split axis of size {split_positions[-1]}' - f' into requested {packed_shapes}') diff --git a/vllm/thirdparty_files/einops/parsing.py b/vllm/thirdparty_files/einops/parsing.py deleted file mode 100644 index df0f4c53032f..000000000000 --- a/vllm/thirdparty_files/einops/parsing.py +++ /dev/null @@ -1,149 +0,0 @@ -from einops import EinopsError -import keyword -import warnings -from typing import List, Optional, Set, Tuple, Union - -_ellipsis: str = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated - - -class AnonymousAxis(object): - """Important thing: all instances of this class are not equal to each other """ - - def __init__(self, value: str): - self.value = int(value) - if self.value <= 1: - if self.value == 1: - raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue') - else: - raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value)) - - def __repr__(self): - return "{}-axis".format(str(self.value)) - - -class ParsedExpression: - """ - non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)') - and keeps some information important for downstream - """ - def __init__(self, expression: str, *, allow_underscore: bool = False, - allow_duplicates: bool = False): - self.has_ellipsis: bool = False - self.has_ellipsis_parenthesized: Optional[bool] = None - self.identifiers: Set[str] = set() - # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition - self.has_non_unitary_anonymous_axes: bool = False - # composition keeps structure of composite axes, see how different corner cases are handled in tests - self.composition: List[Union[List[str], str]] = [] - if '.' in expression: - if '...' not in expression: - raise EinopsError('Expression may contain dots only inside ellipsis (...)') - if str.count(expression, '...') != 1 or str.count(expression, '.') != 3: - raise EinopsError( - 'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ') - expression = expression.replace('...', _ellipsis) - self.has_ellipsis = True - - bracket_group: Optional[List[str]] = None - - def add_axis_name(x): - if x in self.identifiers: - if not (allow_underscore and x == "_") and not allow_duplicates: - raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x)) - if x == _ellipsis: - self.identifiers.add(_ellipsis) - if bracket_group is None: - self.composition.append(_ellipsis) - self.has_ellipsis_parenthesized = False - else: - bracket_group.append(_ellipsis) - self.has_ellipsis_parenthesized = True - else: - is_number = str.isdecimal(x) - if is_number and int(x) == 1: - # handling the case of anonymous axis of length 1 - if bracket_group is None: - self.composition.append([]) - else: - pass # no need to think about 1s inside parenthesis - return - is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore) - if not (is_number or is_axis_name): - raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason)) - if is_number: - x = AnonymousAxis(x) - self.identifiers.add(x) - if is_number: - self.has_non_unitary_anonymous_axes = True - if bracket_group is None: - self.composition.append([x]) - else: - bracket_group.append(x) - - current_identifier = None - for char in expression: - if char in '() ': - if current_identifier is not None: - add_axis_name(current_identifier) - current_identifier = None - if char == '(': - if bracket_group is not None: - raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)") - bracket_group = [] - elif char == ')': - if bracket_group is None: - raise EinopsError('Brackets are not balanced') - self.composition.append(bracket_group) - bracket_group = None - elif str.isalnum(char) or char in ['_', _ellipsis]: - if current_identifier is None: - current_identifier = char - else: - current_identifier += char - else: - raise EinopsError("Unknown character '{}'".format(char)) - - if bracket_group is not None: - raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression)) - if current_identifier is not None: - add_axis_name(current_identifier) - - def flat_axes_order(self) -> List: - result = [] - for composed_axis in self.composition: - assert isinstance(composed_axis, list), 'does not work with ellipsis' - for axis in composed_axis: - result.append(axis) - return result - - def has_composed_axes(self) -> bool: - # this will ignore 1 inside brackets - for axes in self.composition: - if isinstance(axes, list) and len(axes) > 1: - return True - return False - - @staticmethod - def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]: - if not str.isidentifier(name): - return False, 'not a valid python identifier' - elif name[0] == '_' or name[-1] == '_': - if name == '_' and allow_underscore: - return True, '' - return False, 'axis name should should not start or end with underscore' - else: - if keyword.iskeyword(name): - warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning) - if name in ['axis']: - warnings.warn("It is discouraged to use 'axis' as an axis name " - "and will raise an error in future", FutureWarning) - return True, '' - - @staticmethod - def check_axis_name(name: str) -> bool: - """ - Valid axes names are python identifiers except keywords, - and additionally should not start or end with underscore - """ - is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name) - return is_valid diff --git a/vllm/thirdparty_files/einops/py.typed b/vllm/thirdparty_files/einops/py.typed deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/AUTHORS b/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/AUTHORS deleted file mode 100644 index e35a781665ea..000000000000 --- a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/AUTHORS +++ /dev/null @@ -1 +0,0 @@ -Tri Dao, trid@cs.stanford.edu \ No newline at end of file diff --git a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/INSTALLER b/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/INSTALLER deleted file mode 100644 index a1b589e38a32..000000000000 --- a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/INSTALLER +++ /dev/null @@ -1 +0,0 @@ -pip diff --git a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/LICENSE b/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/LICENSE deleted file mode 100644 index 5860e4b33f3d..000000000000 --- a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/LICENSE +++ /dev/null @@ -1,29 +0,0 @@ -BSD 3-Clause License - -Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -* Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/METADATA b/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/METADATA deleted file mode 100644 index 6642859bec85..000000000000 --- a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/METADATA +++ /dev/null @@ -1,430 +0,0 @@ -Metadata-Version: 2.1 -Name: flash-attn -Version: 2.5.6 -Summary: Flash Attention: Fast and Memory-Efficient Exact Attention -Home-page: https://github.com/Dao-AILab/flash-attention -Author: Tri Dao -Author-email: trid@cs.stanford.edu -Classifier: Programming Language :: Python :: 3 -Classifier: License :: OSI Approved :: BSD License -Classifier: Operating System :: Unix -Requires-Python: >=3.7 -Description-Content-Type: text/markdown -License-File: LICENSE -License-File: AUTHORS -Requires-Dist: torch -Requires-Dist: einops -Requires-Dist: packaging -Requires-Dist: ninja - -# FlashAttention -This repository provides the official implementation of FlashAttention and -FlashAttention-2 from the -following papers. - -**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness** -Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré -Paper: https://arxiv.org/abs/2205.14135 -IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention. -![FlashAttention](assets/flashattn_banner.jpg) - -**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning** -Tri Dao - -Paper: https://tridao.me/publications/flash2/flash2.pdf - -![FlashAttention-2](assets/flashattention_logo.png) - - -## Usage - -We've been very happy to see FlashAttention being widely adopted in such a short -time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md) -contains a partial list of places where FlashAttention is being used. - -FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). -Please cite and credit FlashAttention if you use it. - -## Installation and features - -Requirements: -- CUDA 11.6 and above. -- PyTorch 1.12 and above. -- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue. - -We recommend the -[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) -container from Nvidia, which has all the required tools to install FlashAttention. - -To install: -1. Make sure that PyTorch is installed. -2. Make sure that `packaging` is installed (`pip install packaging`) -3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja ---version` then `echo $?` should return exit code 0). If not (sometimes `ninja ---version` then `echo $?` returns a nonzero exit code), uninstall then reinstall -`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`, -compiling can take a very long time (2h) since it does not use multiple CPU -cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine. -4. Then: -```sh -pip install flash-attn --no-build-isolation -``` -Alternatively you can compile from source: -```sh -python setup.py install -``` - -If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might -run too many parallel compilation jobs that could exhaust the amount of RAM. To -limit the number of parallel compilation jobs, you can set the environment -variable `MAX_JOBS`: -```sh -MAX_JOBS=4 pip install flash-attn --no-build-isolation -``` - -Interface: `src/flash_attention_interface.py` - -FlashAttention-2 currently supports: -1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing - GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing - GPUs for now. -2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs). -3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5. - - -## How to use FlashAttention - -The main functions implement scaled dot product attention (softmax(Q @ K^T * -softmax_scale) @ V): -```python -from flash_attn import flash_attn_qkvpacked_func, flash_attn_func -``` - -```python -flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, - window_size=(-1, -1), alibi_slopes=None, deterministic=False): -"""dropout_p should be set to 0.0 during evaluation -If Q, K, V are already stacked into 1 tensor, this function will be faster than -calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation -of the gradients of Q, K, V. -If window_size != (-1, -1), implements sliding window local attention. Query at position i -will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. -Arguments: - qkv: (batch_size, seqlen, 3, nheads, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to - the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. -Return: - out: (batch_size, seqlen, nheads, headdim). -""" -``` - -```python -flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, - window_size=(-1, -1), alibi_slopes=None, deterministic=False): -"""dropout_p should be set to 0.0 during evaluation -Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads -than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. -For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head -0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. -If window_size != (-1, -1), implements sliding window local attention. Query at position i -will only attend to keys between -[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - -Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. -Return: - out: (batch_size, seqlen, nheads, headdim). -""" -``` - -```python -def flash_attn_with_kvcache( - q, - k_cache, - v_cache, - k=None, - v=None, - rotary_cos=None, - rotary_sin=None, - cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, - cache_batch_idx: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - rotary_interleaved=True, - alibi_slopes=None, -): - """ - If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from - k and v. This is useful for incremental decoding: you can pass in the cached keys/values from - the previous step, and update them with the new keys/values from the current step, and do - attention with the updated cache, all in 1 kernel. - - If you pass in k / v, you must make sure that the cache is large enough to hold the new values. - For example, the KV cache could be pre-allocated with the max sequence length, and you can use - cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. - - Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be - rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos - and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at - indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). - - See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. - - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Note: Does not support backward pass. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) - page_block_size must be a multiple of 256. - v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) - k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate - k with k_cache, starting at the indices specified by cache_seqlens. - v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. - rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding - to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. - rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. - cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the - KV cache. - block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. - cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. - If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. - If the indices are not distinct, and k and v are provided, the values updated in the cache - might come from any of the duplicate indices. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. - If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, - rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 - (i.e. GPT-NeoX style). - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - - Return: - out: (batch_size, seqlen, nheads, headdim). - """ -``` - -To see how these functions are used in a multi-head attention layer (which -includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py). - -## Changelog - -### 2.0: Complete rewrite, 2x faster -Upgrading from FlashAttention (1.x) to FlashAttention-2 - -These functions have been renamed: -- `flash_attn_unpadded_func` -> `flash_attn_varlen_func` -- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func` -- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func` - -If the inputs have the same sequence lengths in the same batch, it is simpler -and faster to use these functions: -```python -flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False) -``` -```python -flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) -``` -### 2.1: Change behavior of causal flag - -If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the -bottom right corner of the attention matrix, instead of the top-left corner. - -For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = -masked out) is: -v2.0: - 1 0 0 0 0 - 1 1 0 0 0 -v2.1: - 1 1 1 1 0 - 1 1 1 1 1 - -If seqlen_q = 5 and seqlen_k = 2, the causal mask is: -v2.0: - 1 0 - 1 1 - 1 1 - 1 1 - 1 1 -v2.1: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 -If the row of the mask is all zero, the output will be zero. - -### 2.2: Optimize for inference - -Optimize for inference (iterative decoding) when query has very small sequence -length (e.g., query sequence length = 1). The bottleneck here is to load KV -cache as fast as possible, and we split the loading across different thread -blocks, with a separate kernel to combine results. - -See the function `flash_attn_with_kvcache` with more features for inference -(perform rotary embedding, updating KV cache inplace). - -Thanks to the xformers team, and in particular Daniel Haziza, for this -collaboration. - -### 2.3: Local (i.e., sliding window) attention - -Implement sliding window attention (i.e., local attention). Thanks to [Mistral -AI](https://mistral.ai/) and in particular Timothée Lacroix for this -contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model. - -### 2.4: ALiBi (attention with linear bias), deterministic backward pass. - -Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution. - -Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution. - -### 2.5: Paged KV cache. - -Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)). -Thanks to @beginlner for this contribution. - -## Performance - -We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). - -We currently have benchmarks for these GPUs: -* [A100](#a100) -* [H100](#h100) - - - -### A100 - -We display FlashAttention speedup using these parameters: -* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads). -* Sequence length 512, 1k, 2k, 4k, 8k, 16k. -* Batch size set to 16k / seqlen. - -#### Speedup - -![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png) - -#### Memory - -![FlashAttention memory](assets/flashattn_memory.jpg) - -We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking). -Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length. -We see 10X memory savings at sequence length 2K, and 20X at 4K. -As a result, FlashAttention can scale to much longer sequence lengths. - -### H100 - -![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png) - -## Full model code and training script - -We have released the full GPT model -[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py). -We also provide optimized implementations of other layers (e.g., MLP, LayerNorm, -cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x -compared to the baseline implementation from Huggingface, reaching up to 225 -TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need -any activation checkpointing). - -We also include a training -[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to -train GPT2 on Openwebtext and GPT3 on The Pile. - -## Triton implementation of FlashAttention - -Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton: -https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py - -As Triton is a higher-level language than CUDA, it might be easier to understand -and experiment with. The notations in the Triton implementation are also closer -to what's used in our paper. - -We also have an experimental implementation in Triton that support attention -bias (e.g. ALiBi): -https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py - - -## Tests -We test that FlashAttention produces the same output and gradient as a reference -implementation, up to some numerical tolerance. In particular, we check that the -maximum numerical error of FlashAttention is at most twice the numerical error -of a baseline implementation in Pytorch (for different head dimensions, input -dtype, sequence length, causal / non-causal). - -To run the tests: -```sh -pytest -q -s tests/test_flash_attn.py -``` -## When you encounter issues - -This new release of FlashAttention-2 has been tested on several GPT-style -models, mostly on A100 GPUs. - -If you encounter bugs, please open a GitHub Issue! - -## Citation -If you use this codebase, or otherwise found our work valuable, please cite: -``` -@inproceedings{dao2022flashattention, - title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, - author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, - booktitle={Advances in Neural Information Processing Systems}, - year={2022} -} -@article{dao2023flashattention2, - title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, - author={Dao, Tri}, - year={2023} -} -``` diff --git a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/RECORD b/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/RECORD deleted file mode 100644 index a10839fa6068..000000000000 --- a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/RECORD +++ /dev/null @@ -1,103 +0,0 @@ -flash_attn-2.5.6.dist-info/AUTHORS,sha256=879BRIJqYoQbf5rrxQV_ddotMqZSpXPtxnJQ7JSjd6c,29 -flash_attn-2.5.6.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -flash_attn-2.5.6.dist-info/LICENSE,sha256=jJzLlsBl5wYTW2y60nm3IdphVuUfOl8nxrMymvlBbXM,1558 -flash_attn-2.5.6.dist-info/METADATA,sha256=lQWMph0JsxH5Bol92fSflE-J1M1V3QNbC-553b7-lYw,19145 -flash_attn-2.5.6.dist-info/RECORD,, -flash_attn-2.5.6.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -flash_attn-2.5.6.dist-info/WHEEL,sha256=cPiEulY4lHJMdGCxx29HxfDkwhV9C6172sJgZUA9dSs,103 -flash_attn-2.5.6.dist-info/top_level.txt,sha256=M0iiwJuMya9VMy0DnzgZYGe6v-YSy4TY5LNv2xY3fVQ,29 -flash_attn/__init__.py,sha256=mNeJ6pmb-Y7gqafbzyqRcT1ILzDQpX7GgaA95CYxBLA,285 -flash_attn/__pycache__/__init__.cpython-39.pyc,, -flash_attn/__pycache__/bert_padding.cpython-39.pyc,, -flash_attn/__pycache__/flash_attn_interface.cpython-39.pyc,, -flash_attn/__pycache__/flash_attn_triton.cpython-39.pyc,, -flash_attn/__pycache__/flash_attn_triton_og.cpython-39.pyc,, -flash_attn/__pycache__/flash_blocksparse_attention.cpython-39.pyc,, -flash_attn/__pycache__/flash_blocksparse_attn_interface.cpython-39.pyc,, -flash_attn/__pycache__/fused_softmax.cpython-39.pyc,, -flash_attn/bert_padding.py,sha256=MYMu_Dg9AcnM4-D56X0QGxp5WieqJ045RAvC4kPFI5w,9535 -flash_attn/flash_attn_interface.py,sha256=fOU_b14DTcWpGpNHxnSB1lvnhdcA9ZuaO3BBMhWDdig,45128 -flash_attn/flash_attn_triton.py,sha256=Du81zbh8Ls70ExEsm00opziGvjGFfcZCoZDUO2zut9Q,41112 -flash_attn/flash_attn_triton_og.py,sha256=LmvDju7LJG-wOYhoR6Zc2AmdPK2oWyB1VJpMjRhnWnE,11328 -flash_attn/flash_blocksparse_attention.py,sha256=aJlttNZVxVaktCNYAfP5AdqeZDu8jv42_ZbTkRnDkWg,7469 -flash_attn/flash_blocksparse_attn_interface.py,sha256=2qK2KvVCt851_j8ZzHvjS-aMfdgVDu1yne67-iScWfo,7265 -flash_attn/fused_softmax.py,sha256=0-XbXo7R1a5h4-EpUzPy--lwlGytfTDW34WGM5nmBAY,7793 -flash_attn/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -flash_attn/layers/__pycache__/__init__.cpython-39.pyc,, -flash_attn/layers/__pycache__/patch_embed.cpython-39.pyc,, -flash_attn/layers/__pycache__/rotary.cpython-39.pyc,, -flash_attn/layers/patch_embed.py,sha256=H58CgME_qSOPTZLOG08wFgrQS1j34pvNwMPrkTj3Ek4,2136 -flash_attn/layers/rotary.py,sha256=RmDtuIpbFY-dqLATKwaPTjuVswcGJgL21_LvHwn2uw8,18874 -flash_attn/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -flash_attn/losses/__pycache__/__init__.cpython-39.pyc,, -flash_attn/losses/__pycache__/cross_entropy.cpython-39.pyc,, -flash_attn/losses/cross_entropy.py,sha256=XmyE7jGX5SE9Etuz1_BSRSHlZH2oi9leE9cJZF3Lyj8,3133 -flash_attn/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -flash_attn/models/__pycache__/__init__.cpython-39.pyc,, -flash_attn/models/__pycache__/baichuan.cpython-39.pyc,, -flash_attn/models/__pycache__/bert.cpython-39.pyc,, -flash_attn/models/__pycache__/bigcode.cpython-39.pyc,, -flash_attn/models/__pycache__/btlm.cpython-39.pyc,, -flash_attn/models/__pycache__/falcon.cpython-39.pyc,, -flash_attn/models/__pycache__/gpt.cpython-39.pyc,, -flash_attn/models/__pycache__/gpt_neox.cpython-39.pyc,, -flash_attn/models/__pycache__/gptj.cpython-39.pyc,, -flash_attn/models/__pycache__/llama.cpython-39.pyc,, -flash_attn/models/__pycache__/opt.cpython-39.pyc,, -flash_attn/models/__pycache__/vit.cpython-39.pyc,, -flash_attn/models/baichuan.py,sha256=eFNWwoRQ02AIeQP0OoK8pNvYw0dqnHOshLigCQPkAEc,5730 -flash_attn/models/bert.py,sha256=-y6wVYzAfDqWWeO6n-dLapT1scn0lIsadKJKFzn48Vg,33241 -flash_attn/models/bigcode.py,sha256=mkYeItoJtmWVf2wKkUs5oXjwdbTdGSo5eHxi0-1maZ8,9383 -flash_attn/models/btlm.py,sha256=d8YDjYTa2G1DutYu-YuVf15S_Dn6oKn8-HzERoersLA,4631 -flash_attn/models/falcon.py,sha256=mA3wGv1a4zhbrUSlFNVVmTgVjiXc1sFTOi55eYpgSPo,6033 -flash_attn/models/gpt.py,sha256=_Eu0Kh0RQoXUVRSsVZQEKCLD1etHDi7w6Dc0_yrbN3I,47663 -flash_attn/models/gpt_neox.py,sha256=_704a9KQ2PcnID8uMV7yZ4ggjGlh1zZH5gszue6D1bI,5159 -flash_attn/models/gptj.py,sha256=k2eqMNyMbU7CJVM_BHBjlKt0ByFz6ITSETqS1mJa89g,4436 -flash_attn/models/llama.py,sha256=bDRI308iRpeJngZLrQlLTGYAmwYotqzUxnjBMirfn-k,16581 -flash_attn/models/opt.py,sha256=L0ZIWKpSP44lcEbiVCzVT9un_5gFMAW6cvnS3KHcb-A,5164 -flash_attn/models/vit.py,sha256=7i0WUI_jZvQ5TMoSKPPzf77ZcyMDfDJuQaINzXN_iQU,14074 -flash_attn/modules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -flash_attn/modules/__pycache__/__init__.cpython-39.pyc,, -flash_attn/modules/__pycache__/block.cpython-39.pyc,, -flash_attn/modules/__pycache__/embedding.cpython-39.pyc,, -flash_attn/modules/__pycache__/mha.cpython-39.pyc,, -flash_attn/modules/__pycache__/mlp.cpython-39.pyc,, -flash_attn/modules/block.py,sha256=WLi7JKj9_Zpk89ppzC7WTIoykJJ7TLOJbUSZePNnW1E,17349 -flash_attn/modules/embedding.py,sha256=RCVeeiomlGNkLeQD8G6Udvex-NDI_xKD45hXjgZ2lbQ,8693 -flash_attn/modules/mha.py,sha256=_x3QP5zAWBdeZwowDJ-Qq4c6a1HNjio6B2CE5H7HnYA,43075 -flash_attn/modules/mlp.py,sha256=G6KPQagfKq1DRn7hQRJ3OHznFJLZHj_PiidZE_zcLgg,6033 -flash_attn/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -flash_attn/ops/__pycache__/__init__.cpython-39.pyc,, -flash_attn/ops/__pycache__/activations.cpython-39.pyc,, -flash_attn/ops/__pycache__/fused_dense.cpython-39.pyc,, -flash_attn/ops/__pycache__/layer_norm.cpython-39.pyc,, -flash_attn/ops/__pycache__/rms_norm.cpython-39.pyc,, -flash_attn/ops/activations.py,sha256=4f9iruZ2SKJSmOlNQ9L3t5EpQ2tKJVlyy-iBBF6sMgs,3936 -flash_attn/ops/fused_dense.py,sha256=ACJKqkIfxZibxI3nb5ycb3pXBKaL_CM63rUUyQYNAUE,27907 -flash_attn/ops/layer_norm.py,sha256=zr7NXIm-2mtEynTp1CS0fbFGI2Mqdp41dY4AfDWF6EQ,22443 -flash_attn/ops/rms_norm.py,sha256=XEnihcj0a4aSz4LO55m5iKGVn4HKTeKN8TIyHjuDgxI,3988 -flash_attn/ops/triton/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1 -flash_attn/ops/triton/__pycache__/__init__.cpython-39.pyc,, -flash_attn/ops/triton/__pycache__/cross_entropy.cpython-39.pyc,, -flash_attn/ops/triton/__pycache__/k_activations.cpython-39.pyc,, -flash_attn/ops/triton/__pycache__/layer_norm.cpython-39.pyc,, -flash_attn/ops/triton/__pycache__/linear.cpython-39.pyc,, -flash_attn/ops/triton/__pycache__/mlp.cpython-39.pyc,, -flash_attn/ops/triton/__pycache__/rotary.cpython-39.pyc,, -flash_attn/ops/triton/cross_entropy.py,sha256=XxjYoQIe8v40gKUp5kdrrOGPfrKjcYTwZqHITu4OfzI,12546 -flash_attn/ops/triton/k_activations.py,sha256=-Z3vIyO4JkqBMipKsPvhzmxljtBdIhJCsl_M-_ESqBo,4034 -flash_attn/ops/triton/layer_norm.py,sha256=7pyChANqCgLJGnuXtlGB78-kj4zPdSlk2Sm5zfYT9Fc,34966 -flash_attn/ops/triton/linear.py,sha256=OtRvKz8xdpl-7v3q_ZTaS9fdBt9XrzMyapgRr50uBbM,20841 -flash_attn/ops/triton/mlp.py,sha256=_5lbZJFZg_pXeXYITGt4V_6LkB_yddClB_jt-diCOdw,6068 -flash_attn/ops/triton/rotary.py,sha256=WNol7_u1QJs3SL7RBoewRvGM4jiH_H9_hSFmO-ljioY,8990 -flash_attn/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -flash_attn/utils/__pycache__/__init__.cpython-39.pyc,, -flash_attn/utils/__pycache__/benchmark.cpython-39.pyc,, -flash_attn/utils/__pycache__/distributed.cpython-39.pyc,, -flash_attn/utils/__pycache__/generation.cpython-39.pyc,, -flash_attn/utils/__pycache__/pretrained.cpython-39.pyc,, -flash_attn/utils/benchmark.py,sha256=JDtzdVhFyMIQqs3edbcXdXnmDf-O7RVpmZmn2ZFCvI0,7369 -flash_attn/utils/distributed.py,sha256=qhcybRXtslssuV9LYaQy37haPaPtklM4YUMDx9UvnnQ,5825 -flash_attn/utils/generation.py,sha256=4rh4XRDXN3xCfmPt4dtQz4m3StTIjyCg8L2VNZwdaVo,30466 -flash_attn/utils/pretrained.py,sha256=VZ6qk90sBJA7M86gRzPsNc_CkQXkj5HyrJvwl0I355k,3246 -flash_attn_2_cuda.cpython-39-x86_64-linux-gnu.so,sha256=DUSvDx6EXrT4-TYAiRKAoCpB1zg3_N2izLPxSsWc-7Q,403107440 diff --git a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/REQUESTED b/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/REQUESTED deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/WHEEL b/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/WHEEL deleted file mode 100644 index 704037a08e9e..000000000000 --- a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/WHEEL +++ /dev/null @@ -1,5 +0,0 @@ -Wheel-Version: 1.0 -Generator: bdist_wheel (0.42.0) -Root-Is-Purelib: false -Tag: cp39-cp39-linux_x86_64 - diff --git a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/top_level.txt b/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/top_level.txt deleted file mode 100644 index 459764cff235..000000000000 --- a/vllm/thirdparty_files/flash_attn-2.5.6.dist-info/top_level.txt +++ /dev/null @@ -1,2 +0,0 @@ -flash_attn -flash_attn_2_cuda diff --git a/vllm/thirdparty_files/flash_attn/__init__.py b/vllm/thirdparty_files/flash_attn/__init__.py deleted file mode 100644 index 756253685eb8..000000000000 --- a/vllm/thirdparty_files/flash_attn/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -__version__ = "2.5.6" - -from flash_attn.flash_attn_interface import ( - flash_attn_func, - flash_attn_kvpacked_func, - flash_attn_qkvpacked_func, - flash_attn_varlen_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, - flash_attn_with_kvcache, -) diff --git a/vllm/thirdparty_files/flash_attn/bert_padding.py b/vllm/thirdparty_files/flash_attn/bert_padding.py deleted file mode 100644 index 1d447d3f660e..000000000000 --- a/vllm/thirdparty_files/flash_attn/bert_padding.py +++ /dev/null @@ -1,213 +0,0 @@ -# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py - -import torch -import torch.nn.functional as F -from einops import rearrange, repeat - - -class IndexFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices): - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - # return input[indices] - return torch.gather( - rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) - ).reshape(-1, *other_shape) - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - grad_output = rearrange(grad_output, "b ... -> b (...)") - grad_input = torch.zeros( - [ctx.first_axis_dim, grad_output.shape[1]], - device=grad_output.device, - dtype=grad_output.dtype, - ) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - # grad_input[indices] = grad_output - grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -index_first_axis = IndexFirstAxis.apply - - -class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, values, indices, first_axis_dim): - ctx.save_for_backward(indices) - assert indices.ndim == 1 - assert values.ndim >= 2 - output = torch.zeros( - first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype - ) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - output[indices] = values - # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) - return output - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - grad_values = grad_output[indices] - # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) - return grad_values, None, None - - -index_put_first_axis = IndexPutFirstAxis.apply - - -class IndexFirstAxisResidual(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices): - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - output = input[indices] - # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last - # memory format to channel_first. In other words, input might not be contiguous. - # If we don't detach, Pytorch complains about output being a view and is being modified inplace - return output, input.detach() - - @staticmethod - def backward(ctx, grad_output, grad_residual): - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - assert grad_residual.shape[1:] == other_shape - grad_input = grad_residual - # grad_input[indices] += grad_output - indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) - indices = indices.expand_as(grad_output) - grad_input.scatter_add_(0, indices, grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -index_first_axis_residual = IndexFirstAxisResidual.apply - - -def unpad_input(hidden_states, attention_mask): - """ - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, - # so we write custom forward and backward to make it a bit faster. - return ( - index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): - """ - Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). - The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). - - For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: - ``` - [ - [2, 3, 0, 0, 0, 0], - [3, 2, 0, 0, 0, 0], - [6, 0, 0, 0, 0, 0] - ] - ``` - , which refers to the 3D-attention mask: - ``` - [ - [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0], - [0, 0, 1, 1, 0, 0], - [0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 1] - ], - [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 1] - ], - [ - [1, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1] - ] - ] - ```. - - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - """ - length = attention_mask_in_length.sum(dim=-1) - seqlen = attention_mask_in_length.size(-1) - attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1) - real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() - seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] - indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, - # so we write custom forward and backward to make it a bit faster. - return ( - index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def pad_input(hidden_states, indices, batch, seqlen): - """ - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. - batch: int, batch size for the padded sequence. - seqlen: int, maximum sequence length for the padded sequence. - Return: - hidden_states: (batch, seqlen, ...) - """ - dim = hidden_states.shape[-1] - # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) - # output[indices] = hidden_states - output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, "(b s) ... -> b s ...", b=batch) diff --git a/vllm/thirdparty_files/flash_attn/flash_attn_interface.py b/vllm/thirdparty_files/flash_attn/flash_attn_interface.py deleted file mode 100644 index a1ef865dd2b1..000000000000 --- a/vllm/thirdparty_files/flash_attn/flash_attn_interface.py +++ /dev/null @@ -1,1209 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -from typing import Optional, Union - -import torch -import torch.nn as nn - -# isort: off -# We need to import the CUDA kernels after importing torch -import flash_attn_2_cuda as flash_attn_cuda - -# isort: on - - -def _get_block_size_n(device, head_dim, is_dropout, is_causal): - # This should match the block sizes in the CUDA kernel - assert head_dim <= 256 - major, minor = torch.cuda.get_device_capability(device) - is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) - is_sm80 = major == 8 and minor == 0 - is_sm90 = major == 9 and minor == 0 - if head_dim <= 32: - return 128 - if head_dim <= 64: - return 128 if not is_dropout else 64 - elif head_dim <= 96: - return 64 - elif head_dim <= 128: - if is_sm8x: - return 64 if (not is_dropout and is_causal) else 32 - else: - return 64 if not is_dropout else 32 - elif head_dim <= 160: - if is_sm8x: - return 64 - else: - return 32 - elif head_dim <= 192: - return 64 - elif head_dim <= 224: - return 64 - elif head_dim <= 256: - return 64 - - -def _flash_attn_forward( - q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax -): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( - q, - k, - v, - None, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size[0], - window_size[1], - return_softmax, - None, - ) - return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state - - -def _flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - return_softmax, -): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( - q, - k, - v, - None, - cu_seqlens_q, - cu_seqlens_k, - None, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - False, - causal, - window_size[0], - window_size[1], - return_softmax, - None, - ) - # if out.isnan().any() or softmax_lse.isnan().any(): - # breakpoint() - return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state - - -def _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, -): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x - # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size[0], - window_size[1], - deterministic, - None, - rng_state, - ) - return dq, dk, dv, softmax_d - - -def _flash_attn_varlen_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - rng_state=None, -): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x - # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - False, - causal, - window_size[0], - window_size[1], - deterministic, - None, - rng_state, - ) - # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): - # breakpoint() - return dq, dk, dv, softmax_d - - -class FlashAttnQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - ): - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) - dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dqkv[:, :, 0], - dqkv[:, :, 1], - dqkv[:, :, 2], - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None - - -class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - ): - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) - ctx.dropout_p = dropout_p - ctx.max_seqlen = max_seqlen - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors - qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) - dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) - _flash_attn_varlen_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dqkv[:, 0], - dqkv[:, 1], - dqkv[:, 2], - cu_seqlens, - cu_seqlens, - ctx.max_seqlen, - ctx.max_seqlen, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None - - -class FlashAttnKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - kv, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( - q, - kv[:, :, 0], - kv[:, :, 1], - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - dq = torch.empty_like(q) - kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) - dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dkv[:, :, 0], - dkv[:, :, 1], - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None - - -class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( - q, - kv[:, 0], - kv[:, 1], - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state - ) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors - dq = torch.empty_like(q) - kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) - dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) - _flash_attn_varlen_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dkv[:, 0], - dkv[:, 1], - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None, None, None - - -class FlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors - dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - _flash_attn_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None - - -class FlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_softmax, - ): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal=causal, - window_size=window_size, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax and dropout_p > 0, - ) - ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state - ) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors - dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - _flash_attn_varlen_backward( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size, - ctx.alibi_slopes, - ctx.deterministic, - rng_state=rng_state, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If Q, K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of Q, K, V. - For multi-query and grouped-query attention (MQA/GQA), please see - flash_attn_kvpacked_func and flash_attn_func. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. - - Arguments: - qkv: (batch_size, seqlen, 3, nheads, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to - the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnQKVPackedFunc.apply( - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - ) - - -def flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of K, V. - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - kv: (batch_size, seqlen, 2, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnKVPackedFunc.apply( - q, - kv, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - ) - - -def flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - ) - - -def flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If Q, K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of Q, K, V. - For multi-query and grouped-query attention (MQA/GQA), please see - flash_attn_varlen_kvpacked_func and flash_attn_varlen_func. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. - - Arguments: - qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into qkv. - max_seqlen: int. Maximum sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenQKVPackedFunc.apply( - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - ) - - -def flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - If K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of K, V. - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenKVPackedFunc.apply( - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - ) - - -def flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_attn_probs, - ) - - -def flash_attn_with_kvcache( - q, - k_cache, - v_cache, - k=None, - v=None, - rotary_cos=None, - rotary_sin=None, - cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, - cache_batch_idx: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - rotary_interleaved=True, - alibi_slopes=None, - num_splits=0, -): - """ - If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from - k and v. This is useful for incremental decoding: you can pass in the cached keys/values from - the previous step, and update them with the new keys/values from the current step, and do - attention with the updated cache, all in 1 kernel. - - If you pass in k / v, you must make sure that the cache is large enough to hold the new values. - For example, the KV cache could be pre-allocated with the max sequence length, and you can use - cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. - - Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be - rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos - and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at - indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). - - See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. - - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Note: Does not support backward pass. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) - page_block_size must be a multiple of 256. - v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) - k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate - k with k_cache, starting at the indices specified by cache_seqlens. - v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. - rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding - to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. - rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. - cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the - KV cache. - block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. - cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. - If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. - If the indices are not distinct, and k and v are provided, the values updated in the cache - might come from any of the duplicate indices. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. - If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, - rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 - (i.e. GPT-NeoX style). - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - num_splits: int. If > 1, split the key/value into this many chunks along the sequence. - If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic - to automatically determine the number of splits. - Don't change this unless you know what you are doing. - - Return: - out: (batch_size, seqlen, nheads, headdim). - """ - assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" - assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" - maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - if cache_seqlens is not None and isinstance(cache_seqlens, int): - cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device - ) - cache_seqlens = maybe_contiguous(cache_seqlens) - cache_batch_idx = maybe_contiguous(cache_batch_idx) - block_table = maybe_contiguous(block_table) - out, softmax_lse = flash_attn_cuda.fwd_kvcache( - q, - k_cache, - v_cache, - k, - v, - cache_seqlens, - rotary_cos, - rotary_sin, - cache_batch_idx, - block_table, - alibi_slopes, - None, - softmax_scale, - causal, - window_size[0], - window_size[1], - rotary_interleaved, - num_splits, - ) - return out diff --git a/vllm/thirdparty_files/flash_attn/flash_attn_triton.py b/vllm/thirdparty_files/flash_attn/flash_attn_triton.py deleted file mode 100644 index 30420c057adf..000000000000 --- a/vllm/thirdparty_files/flash_attn/flash_attn_triton.py +++ /dev/null @@ -1,1160 +0,0 @@ -""" -*Experimental* implementation of FlashAttention in Triton. -Tested with triton==2.0.0.dev20221202. -Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions -other than 64: -https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 -We'll update this implementation with the new Triton backend once this is fixed. - -We use the FlashAttention implementation from Phil Tillet a starting point. -https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py - -Changes: -- Implement both causal and non-causal attention. -- Implement both self-attention and cross-attention. -- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. -- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. -- Support attention bias. -- Speed up the forward pass a bit, and only store the LSE instead of m and l. -- Make the backward for d=128 much faster by reducing register spilling. -- Optionally parallelize the backward pass across seqlen_k, to deal with the case of -small batch size * nheads. - -Caution: -- This is an *experimental* implementation. The forward pass should be quite robust but -I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). -- This implementation has only been tested on A100. -- If you plan to use headdim other than 64 and 128, you should test for race conditions -(due to the Triton compiler), as done in tests/test_flash_attn.py -"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions -for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident -that there are none left for other head dimensions. - -Differences between this Triton version and the CUDA version: -- Triton version doesn't support dropout. -- Triton forward is generally faster than CUDA forward, while Triton backward is -generally slower than CUDA backward. Overall Triton forward + backward is slightly slower -than CUDA forward + backward. -- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). -- Triton version supports attention bias, while CUDA version doesn't. -""" - -import math - -import torch -import triton -import triton.language as tl - - -# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 -# @triton.autotune( -# configs=[ -# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), -# # This config has a race condition when EVEN_M == False, disabling it for now. -# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), -# ], -# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] -# ) -@triton.heuristics( - { - "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, - "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, - "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], - } -) -@triton.jit -def _fwd_kernel( - Q, - K, - V, - Bias, - Out, - Lse, - TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug - softmax_scale, - stride_qb, - stride_qh, - stride_qm, - stride_kb, - stride_kh, - stride_kn, - stride_vb, - stride_vh, - stride_vn, - stride_bb, - stride_bh, - stride_bm, - stride_ob, - stride_oh, - stride_om, - nheads, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - headdim, - CACHE_KEY_SEQLEN_Q, - CACHE_KEY_SEQLEN_K, - BIAS_TYPE: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - EVEN_M: tl.constexpr, - EVEN_N: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - start_m = tl.program_id(0) - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads - # off_b = tl.program_id(1) - # off_h = tl.program_id(2) - # off_hb = off_b * nheads + off_h - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_HEADDIM) - # Initialize pointers to Q, K, V - # Adding parenthesis around indexing might use int32 math instead of int64 math? - # https://github.com/openai/triton/issues/741 - # I'm seeing a tiny bit of difference (5-7us) - q_ptrs = ( - Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) - ) - k_ptrs = ( - K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) - ) - v_ptrs = ( - V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) - ) - if BIAS_TYPE == "vector": - b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n - elif BIAS_TYPE == "matrix": - b_ptrs = ( - Bias - + off_b * stride_bb - + off_h * stride_bh - + (offs_m[:, None] * stride_bm + offs_n[None, :]) - ) - # initialize pointer to m and l - t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m - lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) - # load q: it will stay in SRAM throughout - # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call - # tl.load(q_ptrs), we get the wrong output! - if EVEN_M & EVEN_N: - if EVEN_HEADDIM: - q = tl.load(q_ptrs) - else: - q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) - else: - q = tl.load( - q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 - ) - # loop over k, v and update accumulator - end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) - for start_n in range(0, end_n, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - k = tl.load(k_ptrs + start_n * stride_kn) - else: - k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) - else: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, - ) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, trans_b=True) - # Trying to combine the two masks seem to make the result wrong - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) - if IS_CAUSAL: - qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) - if BIAS_TYPE != "none": - if BIAS_TYPE == "vector": - if EVEN_N: - bias = tl.load(b_ptrs + start_n).to(tl.float32) - else: - bias = tl.load( - b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0 - ).to(tl.float32) - bias = bias[None, :] - elif BIAS_TYPE == "matrix": - if EVEN_M & EVEN_N: - bias = tl.load(b_ptrs + start_n).to(tl.float32) - else: - bias = tl.load( - b_ptrs + start_n, - mask=(offs_m[:, None] < seqlen_q) - & ((start_n + offs_n)[None, :] < seqlen_k), - other=0.0, - ).to(tl.float32) - # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler - # can then fuse the mult and add into an fma instruction. But if we have bias we need to - # to multiply with softmax_scale here. - qk = qk * softmax_scale + bias - m_ij = tl.maximum(tl.max(qk, 1), lse_i) - p = tl.exp(qk - m_ij[:, None]) - else: - m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) - p = tl.exp(qk * softmax_scale - m_ij[:, None]) - l_ij = tl.sum(p, 1) - - # scale acc_o - acc_o_scale = tl.exp(m_i - m_ij) - - # # -- update output accumulator -- - # BUG: have to store and immediately load - tl.store(t_ptrs, acc_o_scale) - acc_o_scale = tl.load(t_ptrs) - acc_o = acc_o * acc_o_scale[:, None] - # update acc_o - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - v = tl.load(v_ptrs + start_n * stride_vn) - else: - v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) - else: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, - ) - p = p.to(v.dtype) - acc_o += tl.dot(p, v) - - # -- update statistics - m_i = m_ij - l_i_new = tl.exp(lse_i - m_ij) + l_ij - lse_i = m_ij + tl.log(l_i_new) - - o_scale = tl.exp(m_i - lse_i) - # BUG: have to store and immediately load - tl.store(t_ptrs, o_scale) - o_scale = tl.load(t_ptrs) - acc_o = acc_o * o_scale[:, None] - # rematerialize offsets to save registers - start_m = tl.program_id(0) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # write back l and m - lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m - tl.store(lse_ptrs, lse_i) - # initialize pointers to output - offs_d = tl.arange(0, BLOCK_HEADDIM) - out_ptrs = ( - Out - + off_b * stride_ob - + off_h * stride_oh - + (offs_m[:, None] * stride_om + offs_d[None, :]) - ) - if EVEN_M: - if EVEN_HEADDIM: - tl.store(out_ptrs, acc_o) - else: - tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) - else: - tl.store( - out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) - ) - - -@triton.jit -def _bwd_preprocess_do_o_dot( - Out, - DO, - Delta, - stride_ob, - stride_oh, - stride_om, - stride_dob, - stride_doh, - stride_dom, - nheads, - seqlen_q, - seqlen_q_rounded, - headdim, - BLOCK_M: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, -): - start_m = tl.program_id(0) - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_d = tl.arange(0, BLOCK_HEADDIM) - # load - o = tl.load( - Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ).to(tl.float32) - do = tl.load( - DO - + off_b * stride_dob - + off_h * stride_doh - + offs_m[:, None] * stride_dom - + offs_d[None, :], - mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ).to(tl.float32) - delta = tl.sum(o * do, axis=1) - # write-back - tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) - - -@triton.jit -def _bwd_store_dk_dv( - dk_ptrs, - dv_ptrs, - dk, - dv, - offs_n, - offs_d, - seqlen_k, - headdim, - EVEN_M: tl.constexpr, - EVEN_N: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, -): - # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, - # if we just call tl.store(dv_ptrs), there's a race condition - if EVEN_N & EVEN_M: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) - else: - tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) - tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) - tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) - else: - tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - - -@triton.jit -def _bwd_kernel_one_col_block( - start_n, - Q, - K, - V, - Bias, - DO, - DQ, - DK, - DV, - LSE, - D, - softmax_scale, - stride_qm, - stride_kn, - stride_vn, - stride_bm, - stride_dom, - stride_dqm, - stride_dkn, - stride_dvn, - seqlen_q, - seqlen_k, - headdim, - ATOMIC_ADD: tl.constexpr, - BIAS_TYPE: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - EVEN_M: tl.constexpr, - EVEN_N: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) - begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M - # initialize row/col offsets - offs_qm = begin_m + tl.arange(0, BLOCK_M) - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_m = tl.arange(0, BLOCK_M) - offs_d = tl.arange(0, BLOCK_HEADDIM) - # initialize pointers to value-like data - q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) - v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) - do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) - dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) - if BIAS_TYPE == "vector": - b_ptrs = Bias + offs_n - elif BIAS_TYPE == "matrix": - b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) - # initialize dv and dk - dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) - dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) - # There seems to be some problem with Triton pipelining that makes results wrong for - # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop - # may have zero step, and pipelining with the bias matrix could screw it up. - # So we just exit early. - if begin_m >= seqlen_q: - dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) - dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - _bwd_store_dk_dv( - dk_ptrs, - dv_ptrs, - dk, - dv, - offs_n, - offs_d, - seqlen_k, - headdim, - EVEN_M=EVEN_M, - EVEN_N=EVEN_N, - EVEN_HEADDIM=EVEN_HEADDIM, - ) - return - # k and v stay in SRAM throughout - # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, - # if we just call tl.load(k_ptrs), we get the wrong output! - if EVEN_N & EVEN_M: - if EVEN_HEADDIM: - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - else: - k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) - v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) - else: - k = tl.load( - k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 - ) - v = tl.load( - v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 - ) - # loop over rows - num_block_m = tl.cdiv(seqlen_q, BLOCK_M) - for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): - start_m = tl.multiple_of(start_m, BLOCK_M) - offs_m_curr = start_m + offs_m - # load q, k, v, do on-chip - # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) - if EVEN_M & EVEN_HEADDIM: - q = tl.load(q_ptrs) - else: - if EVEN_HEADDIM: - q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) - else: - q = tl.load( - q_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ) - # recompute p = softmax(qk, dim=-1).T - qk = tl.dot(q, k, trans_b=True) - # Trying to combine the two masks seem to make the result wrong - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) - if IS_CAUSAL: - qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) - if BIAS_TYPE != "none": - tl.debug_barrier() # Race condition otherwise - if BIAS_TYPE == "vector": - if EVEN_N: - bias = tl.load(b_ptrs).to(tl.float32) - else: - bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) - bias = bias[None, :] - elif BIAS_TYPE == "matrix": - if EVEN_M & EVEN_N: - bias = tl.load(b_ptrs).to(tl.float32) - else: - bias = tl.load( - b_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), - other=0.0, - ).to(tl.float32) - qk = qk * softmax_scale + bias - # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. - # Also wrong for headdim=64. - if not (EVEN_M & EVEN_HEADDIM): - tl.debug_barrier() - lse_i = tl.load(LSE + offs_m_curr) - if BIAS_TYPE == "none": - p = tl.exp(qk * softmax_scale - lse_i[:, None]) - else: - p = tl.exp(qk - lse_i[:, None]) - # compute dv - # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call - # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs - # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, - # the output is correct. - if EVEN_M & EVEN_HEADDIM: - do = tl.load(do_ptrs) - else: - # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. - do = tl.load( - do_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ) - # if EVEN_M: - # if EVEN_HEADDIM: - # do = tl.load(do_ptrs) - # else: - # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - # else: - # if EVEN_HEADDIM: - # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) - # else: - # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) - # & (offs_d[None, :] < headdim), other=0.0) - dv += tl.dot(p.to(do.dtype), do, trans_a=True) - # compute dp = dot(v, do) - # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. - # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True - # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False - if not (EVEN_M & EVEN_HEADDIM): - tl.debug_barrier() - dp = tl.dot(do, v, trans_b=True) - # There's a race condition for headdim=48 - if not EVEN_HEADDIM: - tl.debug_barrier() - # compute ds = p * (dp - delta[:, None]) - # Putting the subtraction after the dp matmul (instead of before) is slightly faster - Di = tl.load(D + offs_m_curr) - # Converting ds to q.dtype here reduces register pressure and makes it much faster - # for BLOCK_HEADDIM=128 - ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) - # compute dk = dot(ds.T, q) - dk += tl.dot(ds, q, trans_a=True) - # compute dq - if not ( - EVEN_M & EVEN_HEADDIM - ): # Otherewise there's a race condition when BIAS_TYPE='matrix' - tl.debug_barrier() - if not ATOMIC_ADD: - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds, k) - tl.store(dq_ptrs, dq, eviction_policy="evict_last") - else: - if EVEN_HEADDIM: - dq = tl.load( - dq_ptrs, - mask=offs_m_curr[:, None] < seqlen_q, - other=0.0, - eviction_policy="evict_last", - ) - dq += tl.dot(ds, k) - tl.store( - dq_ptrs, - dq, - mask=offs_m_curr[:, None] < seqlen_q, - eviction_policy="evict_last", - ) - else: - dq = tl.load( - dq_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - eviction_policy="evict_last", - ) - dq += tl.dot(ds, k) - tl.store( - dq_ptrs, - dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - eviction_policy="evict_last", - ) - else: # If we're parallelizing across the seqlen_k dimension - dq = tl.dot(ds, k) - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - tl.atomic_add(dq_ptrs, dq) - else: - if EVEN_HEADDIM: - tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) - else: - tl.atomic_add( - dq_ptrs, - dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - ) - # increment pointers - dq_ptrs += BLOCK_M * stride_dqm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_dom - if BIAS_TYPE == "matrix": - b_ptrs += BLOCK_M * stride_bm - # write-back - dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) - dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - _bwd_store_dk_dv( - dk_ptrs, - dv_ptrs, - dk, - dv, - offs_n, - offs_d, - seqlen_k, - headdim, - EVEN_M=EVEN_M, - EVEN_N=EVEN_N, - EVEN_HEADDIM=EVEN_HEADDIM, - ) - - -def init_to_zero(name): - return lambda nargs: nargs[name].zero_() - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, - num_warps=8, - num_stages=1, - pre_hook=init_to_zero("DQ"), - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, - num_warps=8, - num_stages=1, - pre_hook=init_to_zero("DQ"), - ), - # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now - # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), - ], - key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"], -) -@triton.heuristics( - { - "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, - "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, - "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], - } -) -@triton.jit -def _bwd_kernel( - Q, - K, - V, - Bias, - DO, - DQ, - DK, - DV, - LSE, - D, - softmax_scale, - stride_qb, - stride_qh, - stride_qm, - stride_kb, - stride_kh, - stride_kn, - stride_vb, - stride_vh, - stride_vn, - stride_bb, - stride_bh, - stride_bm, - stride_dob, - stride_doh, - stride_dom, - stride_dqb, - stride_dqh, - stride_dqm, - stride_dkb, - stride_dkh, - stride_dkn, - stride_dvb, - stride_dvh, - stride_dvn, - nheads, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - headdim, - CACHE_KEY_SEQLEN_Q, - CACHE_KEY_SEQLEN_K, - BIAS_TYPE: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - SEQUENCE_PARALLEL: tl.constexpr, - EVEN_M: tl.constexpr, - EVEN_N: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads - # offset pointers for batch/head - Q += off_b * stride_qb + off_h * stride_qh - K += off_b * stride_kb + off_h * stride_kh - V += off_b * stride_vb + off_h * stride_vh - DO += off_b * stride_dob + off_h * stride_doh - DQ += off_b * stride_dqb + off_h * stride_dqh - DK += off_b * stride_dkb + off_h * stride_dkh - DV += off_b * stride_dvb + off_h * stride_dvh - if BIAS_TYPE != "none": - Bias += off_b * stride_bb + off_h * stride_bh - # pointer to row-wise quantities in value-like data - D += off_hb * seqlen_q_rounded - LSE += off_hb * seqlen_q_rounded - if not SEQUENCE_PARALLEL: - num_block_n = tl.cdiv(seqlen_k, BLOCK_N) - for start_n in range(0, num_block_n): - _bwd_kernel_one_col_block( - start_n, - Q, - K, - V, - Bias, - DO, - DQ, - DK, - DV, - LSE, - D, - softmax_scale, - stride_qm, - stride_kn, - stride_vn, - stride_bm, - stride_dom, - stride_dqm, - stride_dkn, - stride_dvn, - seqlen_q, - seqlen_k, - headdim, - ATOMIC_ADD=False, - BIAS_TYPE=BIAS_TYPE, - IS_CAUSAL=IS_CAUSAL, - BLOCK_HEADDIM=BLOCK_HEADDIM, - EVEN_M=EVEN_M, - EVEN_N=EVEN_N, - EVEN_HEADDIM=EVEN_HEADDIM, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) - else: - start_n = tl.program_id(0) - _bwd_kernel_one_col_block( - start_n, - Q, - K, - V, - Bias, - DO, - DQ, - DK, - DV, - LSE, - D, - softmax_scale, - stride_qm, - stride_kn, - stride_vn, - stride_bm, - stride_dom, - stride_dqm, - stride_dkn, - stride_dvn, - seqlen_q, - seqlen_k, - headdim, - ATOMIC_ADD=True, - BIAS_TYPE=BIAS_TYPE, - IS_CAUSAL=IS_CAUSAL, - BLOCK_HEADDIM=BLOCK_HEADDIM, - EVEN_M=EVEN_M, - EVEN_N=EVEN_N, - EVEN_HEADDIM=EVEN_HEADDIM, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) - - -def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): - # shape constraints - batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, _, _ = k.shape - assert k.shape == (batch, seqlen_k, nheads, d) - assert v.shape == (batch, seqlen_k, nheads, d) - assert d <= 128, "FlashAttention only support head dimensions up to 128" - assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" - assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" - assert q.is_cuda and k.is_cuda and v.is_cuda - softmax_scale = softmax_scale or 1.0 / math.sqrt(d) - - has_bias = bias is not None - bias_type = "none" - if has_bias: - assert bias.dtype in [q.dtype, torch.float] - assert bias.is_cuda - assert bias.dim() == 4 - if bias.stride(-1) != 1: - bias = bias.contiguous() - if bias.shape[2:] == (1, seqlen_k): - bias_type = "vector" - elif bias.shape[2:] == (seqlen_q, seqlen_k): - bias_type = "matrix" - else: - raise RuntimeError( - "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" - ) - bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) - bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) - - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - o = torch.empty_like(q) - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - BLOCK = 128 - num_warps = 4 if d <= 64 else 8 - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - _fwd_kernel[grid]( - q, - k, - v, - bias, - o, - lse, - tmp, - softmax_scale, - q.stride(0), - q.stride(2), - q.stride(1), - k.stride(0), - k.stride(2), - k.stride(1), - v.stride(0), - v.stride(2), - v.stride(1), - *bias_strides, - o.stride(0), - o.stride(2), - o.stride(1), - nheads, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - d, - seqlen_q // 32, - seqlen_k // 32, # key for triton cache (limit number of compilations) - # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # IS_CAUSAL=causal, BLOCK_HEADDIM=d, - bias_type, - causal, - BLOCK_HEADDIM, - BLOCK_M=BLOCK, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return o, lse, softmax_scale # softmax_scale could have been updated - - -def _flash_attn_backward( - do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None -): - # Make sure that the last dimension is contiguous - if do.stride(-1) != 1: - do = do.contiguous() - batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, _, _ = k.shape - # assert d in {16, 32, 64, 128} - assert d <= 128 - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - assert lse.shape == (batch, nheads, seqlen_q_rounded) - assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 - assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 - softmax_scale = softmax_scale or 1.0 / math.sqrt(d) - # dq_accum = torch.zeros_like(q, dtype=torch.float32) - dq_accum = torch.empty_like(q, dtype=torch.float32) - delta = torch.empty_like(lse) - # delta = torch.zeros_like(lse) - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - _bwd_preprocess_do_o_dot[grid]( - o, - do, - delta, - o.stride(0), - o.stride(2), - o.stride(1), - do.stride(0), - do.stride(2), - do.stride(1), - nheads, - seqlen_q, - seqlen_q_rounded, - d, - BLOCK_M=128, - BLOCK_HEADDIM=BLOCK_HEADDIM, - ) - - has_bias = bias is not None - bias_type = "none" - if has_bias: - assert bias.dtype in [q.dtype, torch.float] - assert bias.is_cuda - assert bias.dim() == 4 - assert bias.stride(-1) == 1 - if bias.shape[2:] == (1, seqlen_k): - bias_type = "vector" - elif bias.shape[2:] == (seqlen_q, seqlen_k): - bias_type = "matrix" - else: - raise RuntimeError( - "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" - ) - bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) - bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) - - # BLOCK_M = 128 - # BLOCK_N = 64 - # num_warps = 4 - grid = lambda META: ( - triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, - batch * nheads, - ) - _bwd_kernel[grid]( - q, - k, - v, - bias, - do, - dq_accum, - dk, - dv, - lse, - delta, - softmax_scale, - q.stride(0), - q.stride(2), - q.stride(1), - k.stride(0), - k.stride(2), - k.stride(1), - v.stride(0), - v.stride(2), - v.stride(1), - *bias_strides, - do.stride(0), - do.stride(2), - do.stride(1), - dq_accum.stride(0), - dq_accum.stride(2), - dq_accum.stride(1), - dk.stride(0), - dk.stride(2), - dk.stride(1), - dv.stride(0), - dv.stride(2), - dv.stride(1), - nheads, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - d, - seqlen_q // 32, - seqlen_k // 32, # key for triton cache (limit number of compilations) - # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # IS_CAUSAL=causal, BLOCK_HEADDIM=d, - bias_type, - causal, - BLOCK_HEADDIM, - # SEQUENCE_PARALLEL=False, - # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - # num_warps=num_warps, - # num_stages=1, - ) - dq.copy_(dq_accum) - - -class FlashAttnQKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): - """ - qkv: (batch, seqlen, 3, nheads, headdim) - bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen). - For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). - ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen) - """ - # Make sure that the last dimension is contiguous - if qkv.stride(-1) != 1: - qkv = qkv.contiguous() - o, lse, ctx.softmax_scale = _flash_attn_forward( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - bias=bias, - causal=causal, - softmax_scale=softmax_scale, - ) - ctx.save_for_backward(qkv, o, lse, bias) - ctx.causal = causal - return o - - @staticmethod - def backward(ctx, do): - qkv, o, lse, bias = ctx.saved_tensors - assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet" - # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd - # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. - with torch.inference_mode(): - dqkv = torch.empty_like(qkv) - _flash_attn_backward( - do, - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - o, - lse, - dqkv[:, :, 0], - dqkv[:, :, 1], - dqkv[:, :, 2], - bias=bias, - causal=ctx.causal, - softmax_scale=ctx.softmax_scale, - ) - return dqkv, None, None, None - - -flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply - - -class FlashAttnKVPackedFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): - """ - q: (batch, seqlen_q, nheads, headdim) - kv: (batch, seqlen_k, 2, nheads, headdim) - bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). - For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). - ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) - """ - # Make sure that the last dimension is contiguous - q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] - o, lse, ctx.softmax_scale = _flash_attn_forward( - q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale - ) - ctx.save_for_backward(q, kv, o, lse, bias) - ctx.causal = causal - return o - - @staticmethod - def backward(ctx, do): - q, kv, o, lse, bias = ctx.saved_tensors - if len(ctx.needs_input_grad) >= 3: - assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet" - # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd - # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. - with torch.inference_mode(): - dq = torch.empty_like(q) - dkv = torch.empty_like(kv) - _flash_attn_backward( - do, - q, - kv[:, :, 0], - kv[:, :, 1], - o, - lse, - dq, - dkv[:, :, 0], - dkv[:, :, 1], - bias=bias, - causal=ctx.causal, - softmax_scale=ctx.softmax_scale, - ) - return dq, dkv, None, None, None - - -flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply - - -class FlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): - """ - q: (batch_size, seqlen_q, nheads, headdim) - k, v: (batch_size, seqlen_k, nheads, headdim) - bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). - For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). - ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) - """ - # Make sure that the last dimension is contiguous - q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] - o, lse, ctx.softmax_scale = _flash_attn_forward( - q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale - ) - ctx.save_for_backward(q, k, v, o, lse, bias) - ctx.causal = causal - return o - - @staticmethod - def backward(ctx, do): - q, k, v, o, lse, bias = ctx.saved_tensors - assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet" - # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd - # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. - with torch.inference_mode(): - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - _flash_attn_backward( - do, - q, - k, - v, - o, - lse, - dq, - dk, - dv, - bias=bias, - causal=ctx.causal, - softmax_scale=ctx.softmax_scale, - ) - return dq, dk, dv, None, None, None - - -flash_attn_func = FlashAttnFunc.apply diff --git a/vllm/thirdparty_files/flash_attn/flash_attn_triton_og.py b/vllm/thirdparty_files/flash_attn/flash_attn_triton_og.py deleted file mode 100644 index f2ddb99487b4..000000000000 --- a/vllm/thirdparty_files/flash_attn/flash_attn_triton_og.py +++ /dev/null @@ -1,365 +0,0 @@ -# [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py -# for benchmarking. -# We fixed a few dtype cast to make it work for bf16 - -""" -Fused Attention -=============== -This is a Triton implementation of the Flash Attention algorithm -(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) -""" - -import pytest -import torch -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - TMP, - L, - M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - Z, - H, - N_CTX, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk - off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk - # Initialize pointers to Q, K, V - q_ptrs = Q + off_q - k_ptrs = K + off_k - v_ptrs = V + off_v - # initialize pointer to m and l - t_ptrs = TMP + off_hz * N_CTX + offs_m - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # load q: it will stay in SRAM throughout - q = tl.load(q_ptrs) - # loop over k, v and update accumulator - for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + start_n * stride_kn) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, trans_b=True) - qk *= sm_scale - qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + start_n * stride_vk) - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # rematerialize offsets to save registers - start_m = tl.program_id(0) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # write back l and m - l_ptrs = L + off_hz * N_CTX + offs_m - m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(l_ptrs, l_i) - tl.store(m_ptrs, m_i) - # initialize pointers to output - offs_n = tl.arange(0, BLOCK_DMODEL) - off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - - -@triton.jit -def _bwd_preprocess( - Out, - DO, - L, - NewDO, - Delta, - BLOCK_M: tl.constexpr, - D_HEAD: tl.constexpr, -): - off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) - off_n = tl.arange(0, D_HEAD) - # load - o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - denom = tl.load(L + off_m).to(tl.float32) - # compute - do = do / denom[:, None] - delta = tl.sum(o * do, axis=1) - # write-back - tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) - tl.store(Delta + off_m, delta) - - -@triton.jit -def _bwd_kernel( - Q, - K, - V, - sm_scale, - Out, - DO, - DQ, - DK, - DV, - L, - M, - D, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - Z, - H, - N_CTX, - num_block, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - off_hz = tl.program_id(0) - off_z = off_hz // H - off_h = off_hz % H - # offset pointers for batch/head - Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_qz + off_h * stride_qh - V += off_z * stride_qz + off_h * stride_qh - DO += off_z * stride_qz + off_h * stride_qh - DQ += off_z * stride_qz + off_h * stride_qh - DK += off_z * stride_qz + off_h * stride_qh - DV += off_z * stride_qz + off_h * stride_qh - for start_n in range(0, num_block): - lo = start_n * BLOCK_M - # initialize row/col offsets - offs_qm = lo + tl.arange(0, BLOCK_M) - offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) - offs_m = tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_DMODEL) - # initialize pointers to value-like data - q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - # pointer to row-wise quantities in value-like data - D_ptrs = D + off_hz * N_CTX - m_ptrs = M + off_hz * N_CTX - # initialize dv amd dk - dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # k and v stay in SRAM throughout - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - # loop over rows - for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): - offs_m_curr = start_m + offs_m - # load q, k, v, do on-chip - q = tl.load(q_ptrs) - # recompute p = softmax(qk, dim=-1).T - # NOTE: `do` is pre-divided by `l`; no normalization here - qk = tl.dot(q, k, trans_b=True) - qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) - m = tl.load(m_ptrs + offs_m_curr) - p = tl.exp(qk * sm_scale - m[:, None]) - # compute dv - do = tl.load(do_ptrs) - dv += tl.dot(p.to(do.dtype), do, trans_a=True) - # compute dp = dot(v, do) - Di = tl.load(D_ptrs + offs_m_curr) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp += tl.dot(do, v, trans_b=True) - # compute ds = p * (dp - delta[:, None]) - ds = p * dp * sm_scale - # compute dk = dot(ds.T, q) - dk += tl.dot(ds.to(q.dtype), q, trans_a=True) - # # compute dq - dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds.to(k.dtype), k) - tl.store(dq_ptrs, dq, eviction_policy="evict_last") - # # increment pointers - dq_ptrs += BLOCK_M * stride_qm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_qm - # write-back - dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) - - -class _attention(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, sm_scale): - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - o = torch.empty_like(q) - grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) - tmp = torch.empty( - (q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 - ) - L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - num_warps = 4 if Lk <= 64 else 8 - - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - tmp, - L, - m, - o, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), - q.shape[0], - q.shape[1], - q.shape[2], - BLOCK_M=BLOCK, - BLOCK_N=BLOCK, - BLOCK_DMODEL=Lk, - num_warps=num_warps, - num_stages=1, - ) - ctx.save_for_backward(q, k, v, o, L, m) - ctx.BLOCK = BLOCK - ctx.grid = grid - ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = Lk - return o - - @staticmethod - def backward(ctx, do): - q, k, v, o, l, m = ctx.saved_tensors - do = do.contiguous() - dq = torch.zeros_like(q, dtype=torch.float32) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - do_scaled = torch.empty_like(do) - delta = torch.empty_like(l) - _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)]( - o, - do, - l, - do_scaled, - delta, - BLOCK_M=ctx.BLOCK, - D_HEAD=ctx.BLOCK_DMODEL, - ) - - # NOTE: kernel currently buggy for other values of `num_warps` - num_warps = 8 - _bwd_kernel[(ctx.grid[1],)]( - q, - k, - v, - ctx.sm_scale, - o, - do_scaled, - dq, - dk, - dv, - l, - m, - delta, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), - q.shape[0], - q.shape[1], - q.shape[2], - ctx.grid[0], - BLOCK_M=ctx.BLOCK, - BLOCK_N=ctx.BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, - num_warps=num_warps, - num_stages=1, - ) - return dq.to(q.dtype), dk, dv, None - - -attention = _attention.apply diff --git a/vllm/thirdparty_files/flash_attn/flash_blocksparse_attention.py b/vllm/thirdparty_files/flash_attn/flash_blocksparse_attention.py deleted file mode 100644 index 03798d16ffbb..000000000000 --- a/vllm/thirdparty_files/flash_attn/flash_blocksparse_attention.py +++ /dev/null @@ -1,197 +0,0 @@ -import math - -import hydra -import torch -import torch.nn as nn -from einops import rearrange - -from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input -from flash_attn.flash_blocksparse_attn_interface import ( - convert_blockmask, - flash_blocksparse_attn_func, -) - - -class FlashBlocksparseAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_temp: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.1) - """ - - def __init__( - self, - sparsity_config, - softmax_temp=None, - attention_dropout=0.0, - max_seq_length=2048, - device=None, - dtype=None, - ): - super().__init__() - self.sparsity_config = hydra.utils.instantiate(sparsity_config) - self.softmax_temp = softmax_temp - self.dropout_p = attention_dropout - - # initialize sparse layout and register as buffer - max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256 - layout = self.sparsity_config.make_layout(max_seq_length) - self.register_buffer("layout", layout) - blockmask_converted = convert_blockmask(self.layout, causal=False) - self.register_buffer("blockmask_converted", blockmask_converted) - # logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}') - - def forward( - self, - qkv, - attn_mask=None, - key_padding_mask=None, - causal=False, - cu_seqlens=None, - max_s=None, - need_weights=False, - convert_mask=True, - ): - """Implements the multihead softmax attention. - Arguments - --------- - qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None - attn_mask: An implementation of BaseMask that encodes where each - query can attend to - key_padding_mask: An implementation of BaseMask that encodes how - many query each sequence in the batch consists of - """ - assert not need_weights - assert attn_mask is None - assert qkv.dtype == torch.float16 - assert qkv.is_cuda - - if cu_seqlens is None: - batch_size = qkv.shape[0] - seqlen = qkv.shape[1] - # Convert mask to take a subset - seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256 - assert seqlen_rounded // 16 <= self.layout.shape[0], ( - seqlen_rounded // 256 <= self.layout.shape[1] - ) - blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256] - if key_padding_mask is None: - qkv = rearrange(qkv, "b s ... -> (b s) ...") - max_s = seqlen - cu_seqlens = torch.arange( - 0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device - ) - output = flash_blocksparse_attn_func( - qkv, - cu_seqlens, - blockmask, - self.dropout_p if self.training else 0.0, - max_s, - softmax_scale=self.softmax_temp, - causal=causal, - ) - output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - else: - key_padding_mask_bool = key_padding_mask.bool_matrix - nheads = qkv.shape[-2] - x = rearrange(qkv, "b s three h d -> b s (three h d)") - x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool) - x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) - output_unpad = flash_blocksparse_attn_func( - x_unpad, - cu_seqlens, - blockmask, - self.dropout_p if self.training else 0.0, - max_s, - softmax_scale=self.softmax_temp, - causal=causal, - ) - output = rearrange( - pad_input( - rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen - ), - "b s (h d) -> b s h d", - h=nheads, - ) - else: - assert max_s is not None - seqlen = max_s - # Convert mask to take a subset - seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256 - assert seqlen_rounded // 16 <= self.layout.shape[0], ( - seqlen_rounded // 256 <= self.layout.shape[1] - ) - blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256] - if convert_mask: - output = flash_blocksparse_attn_func( - qkv, - cu_seqlens, - blockmask, - self.dropout_p if self.training else 0.0, - max_s, - softmax_scale=self.softmax_temp, - causal=causal, - ) - else: - output = flash_blocksparse_attn_func( - qkv, - cu_seqlens, - self.blockmask_converted, - self.dropout_p if self.training else 0.0, - max_s, - softmax_scale=self.softmax_temp, - causal=causal, - convert_mask=False, - ) - - return output, None - - -class FlashBlocksparseMHA(nn.Module): - def __init__( - self, - embed_dim, - num_heads, - sparsity_config, - bias=True, - batch_first=True, - attention_dropout=0.0, - causal=False, - max_seq_length=2048, - device=None, - dtype=None, - **kwargs, - ) -> None: - assert batch_first - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.embed_dim = embed_dim - self.causal = causal - - self.num_heads = num_heads - assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" - self.head_dim = self.embed_dim // num_heads - assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64" - - self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) - self.inner_attn = FlashBlocksparseAttention( - sparsity_config, - attention_dropout=attention_dropout, - max_seq_length=max_seq_length, - **factory_kwargs, - ) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) - - def forward( - self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, need_weights=False - ): - qkv = self.Wqkv(x) - qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads) - context, attn_weights = self.inner_attn( - qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal - ) - return self.out_proj(rearrange(context, "b s h d -> b s (h d)")), attn_weights diff --git a/vllm/thirdparty_files/flash_attn/flash_blocksparse_attn_interface.py b/vllm/thirdparty_files/flash_attn/flash_blocksparse_attn_interface.py deleted file mode 100644 index 9ce3fe8c1344..000000000000 --- a/vllm/thirdparty_files/flash_attn/flash_blocksparse_attn_interface.py +++ /dev/null @@ -1,200 +0,0 @@ -# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py -import flash_attn_cuda -import torch -import torch.nn as nn - - -def convert_blockmask(blockmask, causal): - """Convert from the 0-1 format to the format used by the CUDA code. - 0 means the block is skipped. - nonzero means the block is not skipped. - Argument: - blockmask: (row, col): a 0-1 tensor - Return: - blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row - indices of the nonzero blocks, padded with -1 to reach length @row. - The indices are multiplied by 4, with the smallest bit used to encode whether - it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is - the last nonzero in its row.. - """ - assert not causal - # TD [2022-05-13]: The indexing and sorting is very tricky - nrow, ncol = blockmask.shape - # Sort does not support bool on CUDA - blockmask = blockmask.to(dtype=torch.uint8) - nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True) - nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0) - last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1] - last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[ - torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row - ] - first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0] - first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[ - torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row - ] - nonzero_idx = nonzero_sorted_rowidx * 4 - nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2 - nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1 - nonzero_idx[nonzero_val == 0] = -1 - return nonzero_idx.T.contiguous().to(dtype=torch.int32) - - -def _flash_blocksparse_attn_forward( - qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax -): - context, softmax_lse, *rest = flash_attn_cuda.fwd_block( - qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax, None - ) - # if context.isnan().any() or softmax_lse.isnan().any(): - # breakpoint() - S_dmask = rest[0] if return_softmax else None - return context, softmax_lse, S_dmask - - -def _flash_blocksparse_attn_backward( - dout, - qkv, - out, - S_dmask, - softmax_lse, - cu_seqlens, - blockmask, - dropout_p, - max_s, - softmax_scale, - causal, -): - dqkv, dp, softmax_d = flash_attn_cuda.bwd_block( - dout, - qkv, - out, - S_dmask, - softmax_lse, - cu_seqlens, - blockmask, - dropout_p, - softmax_scale, - max_s, - causal, - None, - ) - # if dqkv.isnan().any() or softmax_d.isnan().any(): - # breakpoint() - return dqkv - - -class FlashBlocksparseAttnFun(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): - # Save rng_state because the backward pass will regenerate the dropout mask - rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward( - qkv, - cu_seqlens, - blockmask, - dropout_p, - max_s, - softmax_scale, - causal=causal, - return_softmax=False, - ) - ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state) - ctx.dropout_p = dropout_p - ctx.max_s = max_s - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return context - - @staticmethod - def backward(ctx, dout): - qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors - if rng_state is not None: - cur_rng_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(rng_state) - # S_dmask is None, temporarily use another tensor just to get it running - dqkv = _flash_blocksparse_attn_backward( - dout, - qkv, - context, - context, - softmax_lse, - cu_seqlens, - blockmask, - ctx.dropout_p, - ctx.max_s, - ctx.softmax_scale, - ctx.causal, - ) - if rng_state is not None: - torch.cuda.set_rng_state(cur_rng_state) - return dqkv, None, None, None, None, None, None, None - - -# We duplicate code to return both the output and the softmax for testing -# Returning both makes backward a bit slower, so we want to keep using the other version for speed. -class FlashBlocksparseAttnFunWithS(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): - # Save rng_state because the backward pass is gonna regenerate the dropout mask - rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward( - qkv, - cu_seqlens, - blockmask, - dropout_p, - max_s, - softmax_scale, - causal=causal, - return_softmax=True, - ) - ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state) - ctx.dropout_p = dropout_p - ctx.max_s = max_s - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return context, S_dmask, softmax_lse - - @staticmethod - def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored): - qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors - if rng_state is not None: - cur_rng_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(rng_state) - dqkv = _flash_blocksparse_attn_backward( - dout, - qkv, - context, - S_dmask, - softmax_lse, - cu_seqlens, - blockmask, - ctx.dropout_p, - ctx.max_s, - ctx.softmax_scale, - ctx.causal, - ) - if rng_state is not None: - torch.cuda.set_rng_state(cur_rng_state) - return dqkv, None, None, None, None, None, None - - -def flash_blocksparse_attn_func( - qkv, - cu_seqlens, - blockmask, - dropout_p, - max_s, - softmax_scale=None, - causal=False, - return_attn_probs=False, - convert_mask=True, -): - """dropout_p should be set to 0.0 during evaluation""" - func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS - if convert_mask: - blockmask = convert_blockmask(blockmask, causal=causal) - return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal) diff --git a/vllm/thirdparty_files/flash_attn/fused_softmax.py b/vllm/thirdparty_files/flash_attn/fused_softmax.py deleted file mode 100644 index 382f94f092cd..000000000000 --- a/vllm/thirdparty_files/flash_attn/fused_softmax.py +++ /dev/null @@ -1,201 +0,0 @@ -# [2022-10-23] Copied from https://github.com/NVIDIA/apex/blob/master/apex/transformer/functional/fused_softmax.py -# for benchmarking. -# We added support for seqlen=2k and seqlen=4k - -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -from apex._autocast_utils import _cast_if_autocast_enabled -from apex.transformer.enums import AttnMaskType -from fused_softmax_lib import ( - scaled_masked_softmax_backward, - scaled_masked_softmax_forward, - scaled_masked_softmax_get_batch_per_block, - scaled_upper_triang_masked_softmax_backward, - scaled_upper_triang_masked_softmax_forward, -) - - -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - scale_t = torch.tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_upper_triang_masked_softmax_backward( - output_grads, softmax_results, scale_t[0] - ) - return input_grads, None - - -def scaled_upper_triang_masked_softmax(inputs, _, scale): - b, np, sq, sk = inputs.size() - assert sq == sk, "causal mask is only for self attention" - # Reshaping input to 3D tensor (attn_batches, sq, sk) - inputs = inputs.view(-1, sq, sk) - args = _cast_if_autocast_enabled(inputs, scale) - with torch.cuda.amp.autocast(enabled=False): - probs = ScaledUpperTriangMaskedSoftmax.apply(*args) - return probs.view(b, np, sq, sk) - - -# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`. -# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context. -# So I needed to manually write two `torch.autograd.Function` inheritances. -# Fused operation which performs following three operations in sequence -# 1. Scale the tensor. -# 2. Apply the mask. -# 3. Perform softmax. -class ScaledMaskedSoftmax(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs, mask, scale): - scale_t = torch.tensor([scale]) - softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -def scaled_masked_softmax(inputs, mask, scale): - # input is 4D tensor (b, np, sq, sk) - args = _cast_if_autocast_enabled(inputs, mask, scale) - with torch.cuda.amp.autocast(enabled=False): - return ScaledMaskedSoftmax.apply(*args) - - -class FusedScaleMaskSoftmax(torch.nn.Module): - """ - fused operation: scaling + mask + softmax - - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - input_in_bf16: flag to indicate if input in bf16 data format. - attn_mask_type: attention mask type (pad or causal) - scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. - """ - - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ): - super().__init__() - self.input_in_fp16 = input_in_fp16 - self.input_in_bf16 = input_in_bf16 - if self.input_in_fp16 and self.input_in_bf16: - raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.") - self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - - if not (self.scale is None or softmax_in_fp32): - raise RuntimeError("softmax should be in fp32 when scaled") - - if self.scaled_masked_softmax_fusion: - if self.attn_mask_type == AttnMaskType.causal: - self.fused_softmax_func = scaled_upper_triang_masked_softmax - elif self.attn_mask_type == AttnMaskType.padding: - self.fused_softmax_func = scaled_masked_softmax - else: - raise ValueError("Invalid attn_mask_type.") - - def forward(self, input, mask): - # [b, np, sq, sk] - assert input.dim() == 4 - - if self.is_kernel_available(mask, *input.size()): - return self.forward_fused_softmax(input, mask) - else: - return self.forward_torch_softmax(input, mask) - - def is_kernel_available(self, mask, b, np, sq, sk): - attn_batches = b * np - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and ( - self.attn_mask_type == AttnMaskType.causal - or (self.attn_mask_type == AttnMaskType.padding and mask is not None) - ) - and 16 < sk <= 8192 # sk must be 16 ~ 8192 - and sq % 4 == 0 # sq must be divisor of 4 - and sk % 4 == 0 # sk must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): - if 0 <= sk <= 8192: - batch_per_block = self.get_batch_per_block(sq, sk, b, np) - - if self.attn_mask_type == AttnMaskType.causal: - if attn_batches % batch_per_block == 0: - return True - else: - if sq % batch_per_block == 0: - return True - return False - - def forward_fused_softmax(self, input, mask): - # input.shape = [b, np, sq, sk] - scale = self.scale if self.scale is not None else 1.0 - return self.fused_softmax_func(input, mask, scale) - - def forward_torch_softmax(self, input, mask): - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() - - return probs - - @staticmethod - def get_batch_per_block(sq, sk, b, np): - return scaled_masked_softmax_get_batch_per_block(sq, sk, b, np) diff --git a/vllm/thirdparty_files/flash_attn/layers/__init__.py b/vllm/thirdparty_files/flash_attn/layers/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/thirdparty_files/flash_attn/layers/patch_embed.py b/vllm/thirdparty_files/flash_attn/layers/patch_embed.py deleted file mode 100644 index 05562f8e8bcd..000000000000 --- a/vllm/thirdparty_files/flash_attn/layers/patch_embed.py +++ /dev/null @@ -1,67 +0,0 @@ -# We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py -# But we use nn.Linear instead of Conv2d and it's about 8x faster. - -from functools import partial - -import torch.nn as nn -from einops import rearrange -from torch import _assert -from torch.nn.modules.utils import _pair - -try: - from flash_attn.ops.fused_dense import FusedDense -except ImportError: - FusedDense = None - - -class PatchEmbed(nn.Module): - """2D Image to Patch Embedding""" - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - norm_layer=None, - flatten=True, - bias=True, - fused_bias_fc=False, - ): - super().__init__() - img_size = _pair(img_size) - patch_size = _pair(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten - if fused_bias_fc and FusedDense is None: - raise ImportError("fused_dense is not installed") - - linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense - self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - def forward(self, x): - _, _, H, W = x.shape - _assert( - H == self.img_size[0], - f"Input image height ({H}) doesn't match model ({self.img_size[0]}).", - ) - _assert( - W == self.img_size[1], - f"Input image width ({W}) doesn't match model ({self.img_size[1]}).", - ) - x = self.proj( - rearrange( - x, - "b c (h p1) (w p2) -> b h w (c p1 p2)", - p1=self.patch_size[0], - p2=self.patch_size[1], - ) - ) - if self.flatten: - x = rearrange(x, "b h w c -> b (h w) c") - x = self.norm(x) - return x diff --git a/vllm/thirdparty_files/flash_attn/layers/rotary.py b/vllm/thirdparty_files/flash_attn/layers/rotary.py deleted file mode 100644 index 215f518ea7ca..000000000000 --- a/vllm/thirdparty_files/flash_attn/layers/rotary.py +++ /dev/null @@ -1,481 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -import math -from typing import Optional, Tuple, Union - -import torch -from einops import rearrange, repeat -from flash_attn.ops.triton.rotary import apply_rotary - - -def rotate_half(x, interleaved=False): - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) - - -def apply_rotary_emb_torch(x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") - sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") - return torch.cat( - [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], - dim=-1, - ) - - -class ApplyRotaryEmb(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - cos, - sin, - interleaved=False, - inplace=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - ): - out = apply_rotary( - x, - cos, - sin, - seqlen_offsets=seqlen_offsets, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - interleaved=interleaved, - inplace=inplace, - ) - if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - ctx.inplace = inplace - ctx.max_seqlen = max_seqlen - return out if not inplace else x - - @staticmethod - def backward(ctx, do): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors - else: - cos, sin, cu_seqlens = ctx.saved_tensors - # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with - # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. - if not ctx.interleaved and not ctx.inplace: - do = do.clone() - dx = apply_rotary( - do, - cos, - sin, - seqlen_offsets=seqlen_offsets, - cu_seqlens=cu_seqlens, - max_seqlen=ctx.max_seqlen, - interleaved=ctx.interleaved, - inplace=ctx.inplace, - conjugate=True, - ) - return dx, None, None, None, None, None, None, None - - -def apply_rotary_emb( - x, - cos, - sin, - interleaved=False, - inplace=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, -): - """ - Arguments: - x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim) - cos, sin: (seqlen_rotary, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - inplace: if True, apply rotary embedding in-place. - seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. - Most commonly used in inference when we have KV cache. - cu_seqlens: (batch + 1,) or None - max_seqlen: int - Return: - out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim) - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - return ApplyRotaryEmb.apply( - x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen - ) - - -# For backward compatibility -apply_rotary_emb_func = apply_rotary_emb - - -class ApplyRotaryEmbQKV_(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cos, - sin, - cos_k=None, - sin_k=None, - interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - ): - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - if cos_k is None and sin_k is None and qkv.is_contiguous(): - # Call 1 kernel instead of 2 kernels - # We need qkv to be contiguous so that when we reshape to combine (3, nheads) - # dimensions, we get the same tensor - # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") - qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) - apply_rotary( - qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True - ) - else: - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - q, k = qkv[:, :, 0], qkv[:, :, 1] - apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True) - apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True) - ctx.save_for_backward(cos, sin, cos_k, sin_k) - if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - return qkv - - @staticmethod - def backward(ctx, dqkv): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors - else: - cos, sin, cos_k, sin_k = ctx.saved_tensors - if cos_k is None and sin_k is None and dqkv.is_contiguous(): - # Call 1 kernel instead of 2 kernels - # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) - # dimensions, we get the same tensor - dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d") - apply_rotary( - dqk, - cos, - sin, - seqlen_offsets=seqlen_offsets, - interleaved=ctx.interleaved, - inplace=True, - conjugate=True, - ) - else: - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] - apply_rotary( - dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True - ) - apply_rotary( - dk, - cos_k, - sin_k, - seqlen_offsets, - interleaved=ctx.interleaved, - inplace=True, - conjugate=True, - ) - return dqkv, None, None, None, None, None, None - - -def apply_rotary_emb_qkv_( - qkv, - cos, - sin, - cos_k=None, - sin_k=None, - interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, -): - """ - Arguments: - qkv: (batch_size, seqlen, 3, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - cos_k, sin_k: (seqlen, rotary_dim / 2), optional - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. - Most commonly used in inference when we have KV cache. - Return: - qkv: (batch_size, seqlen, 3, nheads, headdim) - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of Q and K. - """ - return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets) - - -class ApplyRotaryEmbKV_(torch.autograd.Function): - @staticmethod - def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): - batch, seqlen, two, nheads, headdim = kv.shape - assert two == 2 - k = kv[:, :, 0] - apply_rotary( - k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True - ) - if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - return kv - - @staticmethod - def backward(ctx, dkv): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, seqlen_offsets = ctx.saved_tensors - else: - cos, sin = ctx.saved_tensors - apply_rotary( - dkv[:, :, 0], - cos, - sin, - seqlen_offsets=seqlen_offsets, - interleaved=ctx.interleaved, - inplace=True, - conjugate=True, - ) - return dkv, None, None, None, None - - -apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply - - -def apply_rotary_emb_kv_( - kv, - cos, - sin, - interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, -): - """ - Arguments: - kv: (batch_size, seqlen, 2, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. - Most commonly used in inference when we have KV cache. - Return: - kv: (batch_size, seqlen, 2, nheads, headdim) - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of K. - """ - return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets) - - -class RotaryEmbedding(torch.nn.Module): - """ - The rotary position embeddings from RoFormer_ (Su et. al). - A crucial insight from the method is that the query and keys are - transformed by rotation matrices which depend on the relative positions. - - Other implementations are available in the Rotary Transformer repo_ and in - GPT-NeoX_, GPT-NeoX was an inspiration - - .. _RoFormer: https://arxiv.org/abs/2104.09864 - .. _repo: https://github.com/ZhuiyiTechnology/roformer - .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox - - If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). - A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 - Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py - """ - - def __init__( - self, - dim: int, - base=10000.0, - interleaved=False, - scale_base=None, - pos_idx_in_fp32=True, - device=None, - ): - """ - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, - otherwise they might be in lower precision. - This option was added because previously (before 2023-07-02), when we construct - the position indices, we use the dtype of self.inv_freq. In most cases this would - be fp32, but if the model is trained in pure bf16 (not mixed precision), then - self.inv_freq would be bf16, and the position indices are also in bf16. - Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the - embeddings for some positions will coincide. - To maintain compatibility with models previously trained in pure bf16, - we add this option. - """ - super().__init__() - self.dim = dim - self.base = float(base) - self.pos_idx_in_fp32 = pos_idx_in_fp32 - # Generate and save the inverse frequency buffer (non trainable) - inv_freq = self._compute_inv_freq(device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.interleaved = interleaved - self.scale_base = scale_base - scale = ( - (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) - if scale_base is not None - else None - ) - self.register_buffer("scale", scale, persistent=False) - - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None - - def _compute_inv_freq(self, device=None): - return 1.0 / ( - self.base - ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) - ) - - def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): - # Reset the tables if the sequence length has changed, - # if we're on a new device (possibly due to tracing for instance), - # or if we're switching from inference mode to training - if ( - seqlen > self._seq_len_cached - or self._cos_cached is None - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - or (self.training and self._cos_cached.is_inference()) - ): - self._seq_len_cached = seqlen - # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 - # And the output of arange can be quite large, so bf16 would lose a lot of precision. - # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. - if self.pos_idx_in_fp32: - t = torch.arange(seqlen, device=device, dtype=torch.float32) - # We want fp32 here as well since inv_freq will be multiplied with t, and the output - # will be large. Having it in bf16 will lose a lot of precision and cause the - # cos & sin output to change significantly. - # We want to recompute self.inv_freq if it was not loaded in fp32 - if self.inv_freq.dtype != torch.float32: - inv_freq = self._compute_inv_freq(device=device) - else: - inv_freq = self.inv_freq - else: - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - inv_freq = self.inv_freq - # Don't do einsum, it converts fp32 to fp16 under AMP - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, inv_freq) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - else: - power = ( - torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - - seqlen // 2 - ) / self.scale_base - scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") - # We want the multiplication by scale to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) - - def forward( - self, - qkv: torch.Tensor, - kv: Optional[torch.Tensor] = None, - seqlen_offset: Union[int, torch.Tensor] = 0, - max_seqlen: Optional[int] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, - else it's just q of shape (batch, seqlen, nheads, headdim) - kv: (batch, seqlen, 2, nheads, headdim) - seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. - Most commonly used in inference when we have KV cache. - If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one - should pass in max_seqlen, which will update the cos / sin cache up to that length. - Apply rotary embedding *inplace* to qkv and / or kv. - """ - seqlen = qkv.shape[1] - if max_seqlen is not None: - self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) - elif isinstance(seqlen_offset, int): - self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) - if kv is None: - if self.scale is None: - return apply_rotary_emb_qkv_( - qkv, - self._cos_cached, - self._sin_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) - else: - return apply_rotary_emb_qkv_( - qkv, - self._cos_cached, - self._sin_cached, - self._cos_k_cached, - self._sin_k_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) - else: - q = qkv - q = apply_rotary_emb_func( - q, - self._cos_cached, - self._sin_cached, - interleaved=self.interleaved, - inplace=True, - seqlen_offsets=seqlen_offset, - ) - if self.scale is None: - kv = apply_rotary_emb_kv_( - kv, - self._cos_cached, - self._sin_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) - else: - kv = apply_rotary_emb_kv_( - kv, - self._cos_k_cached, - self._sin_k_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) - return q, kv diff --git a/vllm/thirdparty_files/flash_attn/losses/__init__.py b/vllm/thirdparty_files/flash_attn/losses/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/thirdparty_files/flash_attn/losses/cross_entropy.py b/vllm/thirdparty_files/flash_attn/losses/cross_entropy.py deleted file mode 100644 index 2a1b77a3495a..000000000000 --- a/vllm/thirdparty_files/flash_attn/losses/cross_entropy.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -import torch -import torch.nn as nn - -from flash_attn.ops.triton.cross_entropy import cross_entropy_loss - - -class CrossEntropyLoss(nn.Module): - def __init__( - self, - ignore_index=-100, - reduction="mean", - label_smoothing=0.0, - logit_scale=1.0, - lse_square_scale=0.0, - inplace_backward=False, - process_group=None, - return_z_loss=False, - ): - """ - Arguments: - ignored_index: int. If labels == ignored_index, the loss is set to 0.0. - label_smoothing: float - lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. - This is also referred to as "z-loss". - inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. - This saves memory. - process_group: if not None, we're doing Tensor Parallel: each process is responsible for - one part of the vocab. The loss will be aggregated across processes. - return_z_loss: bool. If True, we return the component of the loss contributed by - the lse_square_scale value. This value is only for logging and does not support - backprop. - """ - super().__init__() - if reduction not in ["mean", "none", "sum"]: - raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'") - self.ignore_index = ignore_index - self.reduction = reduction - self.label_smoothing = label_smoothing - self.logit_scale = logit_scale - self.lse_square_scale = lse_square_scale - self.inplace_backward = inplace_backward - self.process_group = process_group - self.return_z_loss = return_z_loss - - def forward(self, input, target): - """ - Arguments: - input: (batch, vocab_size) - target: (batch,) - Returns: - losses: (batch,) if reduction is 'none', else (1,), dtype float - z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss) - """ - assert input.is_cuda and target.is_cuda, "Only support CUDA tensors" - loss, z_loss = cross_entropy_loss( - input, - target, - label_smoothing=self.label_smoothing, - logit_scale=self.logit_scale, - lse_square_scale=self.lse_square_scale, - ignored_index=self.ignore_index, - inplace_backward=self.inplace_backward, - process_group=self.process_group, - ) - if self.reduction == "mean": - loss = loss.sum() / (target != self.ignore_index).sum() - elif self.reduction == "sum": - loss = loss.sum() - else: - loss = loss - - if not self.return_z_loss: - return loss - - if self.reduction == "mean": - z_loss = z_loss.sum() / (target != self.ignore_index).sum() - elif self.reduction == "sum": - z_loss = z_loss.sum() - else: - z_loss = z_loss - - return loss, z_loss diff --git a/vllm/thirdparty_files/flash_attn/models/__init__.py b/vllm/thirdparty_files/flash_attn/models/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/thirdparty_files/flash_attn/models/baichuan.py b/vllm/thirdparty_files/flash_attn/models/baichuan.py deleted file mode 100644 index 97d030782187..000000000000 --- a/vllm/thirdparty_files/flash_attn/models/baichuan.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) 2023, GGGGGGXY, Tri Dao. - -import math -import json -import re -from pathlib import Path - -from collections import OrderedDict - -import torch -import torch.nn.functional as F - -from einops import rearrange -from transformers import GPT2Config, AutoConfig, PretrainedConfig - - -def remap_state_dict_hf_baichuan(state_dict, config): - def key_mapping_layers(key): - return re.sub(r"^model.", "transformer.", key) - - state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) - - # Word embedding - def key_mapping_emb(key): - return re.sub( - r"^transformer.embed_tokens.", - "transformer.embeddings.word_embeddings.", - key, - ) - - state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = ( - math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) - * pad_vocab_size_multiple - ) - state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( - word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) - ) - if getattr(config, "tie_word_embeddings"): - state_dict["lm_head.weight"] = state_dict[ - "transformer.embeddings.word_embeddings.weight" - ] - else: - output_embeddings = state_dict.pop("lm_head.weight") - # Need to recompute vocab_size since Baichuan shards the word embeddings and output embeddings - # differently. - vocab_size = ( - math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) - * pad_vocab_size_multiple - ) - # It's possible that vocab_size is padded to be a multiple of 8, for example. - state_dict["lm_head.weight"] = F.pad( - output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) - ) - - # LayerNorm - def key_mapping_ln(key): - key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key) - key = re.sub( - r"^transformer.layers.(\d+).input_layernorm.", - r"transformer.layers.\1.norm1.", - key, - ) - key = re.sub( - r"^transformer.layers.(\d+).post_attention_layernorm.", - r"transformer.layers.\1.norm2.", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - # MLP - for l in range(config.n_layer): - w1 = state_dict.pop(f"transformer.layers.{l}.mlp.gate_proj.weight") - w3 = state_dict.pop(f"transformer.layers.{l}.mlp.up_proj.weight") - # Our ordering is different - state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat( - [w3, w1], dim=0 - ) - - def key_mapping_mlp(key): - return re.sub( - r"^transformer.layers.(\d+).mlp.down_proj.", - r"transformer.layers.\1.mlp.fc2.", - key, - ) - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # Attention - def key_mapping_attn(key): - key = re.sub( - r"^transformer.layers.(\d+).self_attn.W_pack.", - r"transformer.layers.\1.mixer.Wqkv.", - key, - ) - key = re.sub( - r"^transformer.layers.(\d+).self_attn.o_proj.", - r"transformer.layers.\1.mixer.out_proj.", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - for l in range(config.n_layer): - # pop rotary_emb.inv_freq from state dict - state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq", None) - return state_dict - - -def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config: - # HACK: the config doesn't have say whether it's rotary or alibi. - # So we have to infer from the hidden size (7B -> rotary, 13B -> alibi). - # HACK: the config doesn't have say whether it uses norm head. - # So we have to infer from the vocab size - # (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head). - use_rotary = baichuan_config.hidden_size < 5000 - return GPT2Config( - vocab_size=baichuan_config.vocab_size, - n_positions=0, # No absolute position embedding - n_embd=baichuan_config.hidden_size, - n_layer=baichuan_config.num_hidden_layers, - n_head=baichuan_config.num_attention_heads, - n_inner=baichuan_config.intermediate_size, - activation_function="swiglu", # Hardcode since HF calls it 'silu' - # baichuan doesn't have dropout, idk if it's because they only release the inference code - resid_pdrop=0.0, - embd_pdrop=0.0, - attn_pdrop=0.0, - layer_norm_epsilon=baichuan_config.rms_norm_eps, - initializer_range=baichuan_config.initializer_range, - bos_token_id=baichuan_config.bos_token_id, - eos_token_id=baichuan_config.eos_token_id, - # These are new arguments not in the original GPT2Config - pad_token_id=baichuan_config.pad_token_id, # Idk if this does anything - rms_norm=True, - rotary_emb_fraction=1.0 if use_rotary else 0.0, - rotary_emb_interleaved=False, - use_alibi=not use_rotary, - use_flash_attn=not use_rotary, # Alibi code path requires flash_attn - tie_word_embeddings=False, - norm_head=baichuan_config.vocab_size > 70000, - qkv_proj_bias=False, - out_proj_bias=False, - mlp_fc1_bias=False, - mlp_fc2_bias=False, - ) diff --git a/vllm/thirdparty_files/flash_attn/models/bert.py b/vllm/thirdparty_files/flash_attn/models/bert.py deleted file mode 100644 index 33d6935202a1..000000000000 --- a/vllm/thirdparty_files/flash_attn/models/bert.py +++ /dev/null @@ -1,764 +0,0 @@ -# Copyright (c) 2022, Tri Dao. -# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation. -# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py -# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py - -# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py - -import logging -import re -from collections import OrderedDict -from collections.abc import Sequence -from functools import partial -from typing import Any, Mapping - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from transformers import BertConfig, PretrainedConfig -from transformers.models.bert.modeling_bert import ( - BaseModelOutputWithPoolingAndCrossAttentions, - BertForPreTrainingOutput, -) - -from flash_attn.bert_padding import ( - index_first_axis, - index_first_axis_residual, - pad_input, - unpad_input, -) -from flash_attn.modules.block import Block -from flash_attn.modules.embedding import BertEmbeddings -from flash_attn.modules.mha import MHA -from flash_attn.modules.mlp import FusedMLP, Mlp -from flash_attn.utils.pretrained import state_dict_from_pretrained - -try: - from flash_attn.ops.fused_dense import FusedDense -except ImportError: - FusedDense = None - -try: - from flash_attn.ops.triton.layer_norm import layer_norm_fn -except ImportError: - layer_norm_fn = None - - -try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss -except ImportError: - CrossEntropyLoss = None - - -logger = logging.getLogger(__name__) - - -def create_mixer_cls(config, cross_attn=False, return_residual=False): - use_flash_attn = getattr(config, "use_flash_attn", False) - fused_bias_fc = getattr(config, "fused_bias_fc", False) - rotary_kwargs = {} - if config.position_embedding_type == "rotary": - rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size) - rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0) - rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None) - rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False) - mixer_cls = partial( - MHA, - num_heads=config.num_attention_heads, - cross_attn=cross_attn, - dropout=config.attention_probs_dropout_prob, - causal=False, - fused_bias_fc=fused_bias_fc, - use_flash_attn=use_flash_attn, - return_residual=return_residual, - **rotary_kwargs, - ) - return mixer_cls - - -def create_mlp_cls(config, layer_idx=None, return_residual=False): - inner_dim = config.intermediate_size - fused_mlp = getattr(config, "fused_mlp", False) - if fused_mlp: - assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], ( - "fused_mlp only " "supports approximate gelu" - ) - if not fused_mlp: - approximate = ( - "tanh" - if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] - else "none" - ) - mlp_cls = partial( - Mlp, - hidden_features=inner_dim, - activation=partial(F.gelu, approximate=approximate), - return_residual=return_residual, - ) - else: - if FusedMLP is None: - raise ImportError("fused_dense is not installed") - mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) - # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer - if isinstance(mlp_checkpoint_lvl, Sequence): - assert layer_idx is not None - mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] - mlp_cls = partial( - FusedMLP, - hidden_features=inner_dim, - checkpoint_lvl=mlp_checkpoint_lvl, - return_residual=return_residual, - ) - return mlp_cls - - -def create_block(config, layer_idx=None): - last_layer_subset = getattr(config, "last_layer_subset", False) - cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1 - # TD [2022-12-19]: For cross attention (last layer), we actually want to return the - # residual x_kv, not residual x. But it's annoying to change the API (and it only affects - # one layer) so we just choose not to return residual in this case. - return_residual = not cross_attn - mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual) - mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual) - norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps) - block = Block( - config.hidden_size, - mixer_cls, - mlp_cls, - norm_cls=norm_cls, - prenorm=False, - resid_dropout1=config.hidden_dropout_prob, - resid_dropout2=config.hidden_dropout_prob, - fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), - return_residual=return_residual, - ) - return block - - -# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748 -def _init_weights(module, initializer_range=0.02): - if isinstance(module, nn.Linear): - nn.init.normal_(module.weight, std=initializer_range) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=initializer_range) - if module.padding_idx is not None: - nn.init.zeros_(module.weight[module.padding_idx]) - - -class BertEncoder(nn.Module): - def __init__(self, config: BertConfig): - super().__init__() - self.use_flash_attn = getattr(config, "use_flash_attn", False) - self.layers = nn.ModuleList( - [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)] - ) - - def forward(self, hidden_states, key_padding_mask=None, subset_mask=None): - """If subset_mask is not None, we only want output for the subset of the sequence. - This means that we only compute the last layer output for these tokens. - subset_mask: (batch, seqlen), dtype=torch.bool - """ - if key_padding_mask is None or not self.use_flash_attn: - mixer_kwargs = ( - {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None - ) - for layer in self.layers: - hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) - if subset_mask is not None: - hidden_states = hidden_states[subset_mask] - else: - batch, seqlen = hidden_states.shape[:2] - hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input( - hidden_states, key_padding_mask - ) - mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch} - if subset_mask is None: - for layer in self.layers: - hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) - hidden_states = pad_input(hidden_states, indices, batch, seqlen) - else: - for layer in self.layers[:-1]: - hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) - if key_padding_mask is not None: - subset_idx = torch.nonzero( - subset_mask[key_padding_mask], as_tuple=False - ).flatten() - subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32) - subset_cu_seqlens = F.pad( - torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0) - ) - else: - subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten() - subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32) - subset_cu_seqlens = F.pad( - torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0) - ) - hidden_states_subset, hidden_states = index_first_axis_residual( - hidden_states, subset_idx - ) - # It's ok to set max_seqlen_q to be much larger - mixer_kwargs = { - "x_kv": hidden_states, - "cu_seqlens": subset_cu_seqlens, - "max_seqlen": max_seqlen_in_batch, - "cu_seqlens_k": cu_seqlens, - "max_seqlen_k": max_seqlen_in_batch, - } - hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs) - return hidden_states - - -class BertPooler(nn.Module): - def __init__(self, config): - super().__init__() - fused_bias_fc = getattr(config, "fused_bias_fc", False) - if fused_bias_fc and FusedDense is None: - raise ImportError("fused_dense is not installed") - linear_cls = nn.Linear if not fused_bias_fc else FusedDense - self.dense = linear_cls(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states, pool=True): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] if pool else hidden_states - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class BertPredictionHeadTransform(nn.Module): - def __init__(self, config): - super().__init__() - fused_bias_fc = getattr(config, "fused_bias_fc", False) - if fused_bias_fc and FusedDense is None: - raise ImportError("fused_dense is not installed") - self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) - if self.fused_dropout_add_ln and layer_norm_fn is None: - raise ImportError("Triton is not installed") - linear_cls = nn.Linear if not fused_bias_fc else FusedDense - self.dense = linear_cls(config.hidden_size, config.hidden_size) - approximate = ( - "tanh" - if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] - else "none" - ) - self.transform_act_fn = nn.GELU(approximate=approximate) - self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - if not self.fused_dropout_add_ln: - hidden_states = self.layer_norm(hidden_states) - else: - hidden_states = layer_norm_fn( - hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps - ) - return hidden_states - - -class BertLMPredictionHead(nn.Module): - def __init__(self, config): - super().__init__() - fused_bias_fc = getattr(config, "fused_bias_fc", False) - if fused_bias_fc and FusedDense is None: - raise ImportError("fused_dense is not installed") - linear_cls = nn.Linear if not fused_bias_fc else FusedDense - - self.transform = BertPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True) - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -class BertPreTrainingHeads(nn.Module): - def __init__(self, config): - super().__init__() - self.predictions = BertLMPredictionHead(config) - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, sequence_output, pooled_output): - prediction_scores = self.predictions(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - -class BertPreTrainedModel(nn.Module): - """An abstract class to handle weights initialization and - a simple interface for dowloading and loading pretrained models. - """ - - def __init__(self, config, *inputs, **kwargs): - super().__init__() - if not isinstance(config, BertConfig): - raise ValueError( - "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " - "To create a model from a Google pretrained model use " - "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( - self.__class__.__name__, self.__class__.__name__ - ) - ) - self.config = config - - @classmethod - def from_pretrained(cls, model_name, config, *inputs, **kwargs): - """ - Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. - Download and cache the pre-trained model file if needed. - - Params: - pretrained_model_name_or_path: either: - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `model.chkpt` a TensorFlow checkpoint - *inputs, **kwargs: additional input for the specific Bert class - (ex: num_labels for BertForSequenceClassification) - """ - # Instantiate model. - model = cls(config, *inputs, **kwargs) - load_return = model.load_state_dict( - remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False - ) - logger.info(load_return) - return model - - -class BertModel(BertPreTrainedModel): - def __init__(self, config: BertConfig, add_pooling_layer=True): - super().__init__(config) - self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - if config.vocab_size % self.pad_vocab_size_multiple != 0: - config.vocab_size += self.pad_vocab_size_multiple - ( - config.vocab_size % self.pad_vocab_size_multiple - ) - self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) - if self.fused_dropout_add_ln and layer_norm_fn is None: - raise ImportError("Triton is not installed") - assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"] - - self.embeddings = BertEmbeddings( - config.hidden_size, - config.vocab_size, - config.max_position_embeddings, - config.type_vocab_size, - padding_idx=config.pad_token_id, - ) - self.emb_drop = nn.Dropout(config.hidden_dropout_prob) - self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.encoder = BertEncoder(config) - self.pooler = BertPooler(config) if add_pooling_layer else None - - self.apply(partial(_init_weights, initializer_range=config.initializer_range)) - - def forward( - self, - input_ids, - position_ids=None, - token_type_ids=None, - attention_mask=None, - masked_tokens_mask=None, - ): - """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining), - we only want the output for the masked tokens. This means that we only compute the last - layer output for these tokens. - masked_tokens_mask: (batch, seqlen), dtype=torch.bool - """ - hidden_states = self.embeddings( - input_ids, position_ids=position_ids, token_type_ids=token_type_ids - ) - # TD [2022-12:18]: Don't need to force residual in fp32 - # BERT puts embedding LayerNorm before embedding dropout. - if not self.fused_dropout_add_ln: - hidden_states = self.emb_ln(hidden_states) - else: - hidden_states = layer_norm_fn( - hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps - ) - hidden_states = self.emb_drop(hidden_states) - - if masked_tokens_mask is not None: - batch_size, seqlen = input_ids.shape[:2] - # We also need the first column for the CLS token - first_col_mask = torch.zeros( - batch_size, seqlen, dtype=torch.bool, device=input_ids.device - ) - first_col_mask[:, 0] = True - subset_mask = masked_tokens_mask | first_col_mask - else: - subset_mask = None - - sequence_output = self.encoder( - hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask - ) - - if masked_tokens_mask is None: - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - else: - # TD [2022-03-01]: the indexing here is very tricky. - if attention_mask is not None: - subset_idx = subset_mask[attention_mask] - pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]] - sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]] - else: - pool_input = sequence_output[first_col_mask[subset_mask]] - sequence_output = sequence_output[masked_tokens_mask[subset_mask]] - pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None - - return BaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - ) - - -class BertForPreTraining(BertPreTrainedModel): - def __init__(self, config: BertConfig): - super().__init__(config) - # If dense_seq_output, we only need to pass the hidden states for the masked out tokens - # (around 15%) to the classifier heads. - self.dense_seq_output = getattr(config, "dense_seq_output", False) - # If last_layer_subset, we only need the compute the last layer for a subset of tokens - # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction). - self.last_layer_subset = getattr(config, "last_layer_subset", False) - if self.last_layer_subset: - assert self.dense_seq_output, "last_layer_subset requires dense_seq_output" - use_xentropy = getattr(config, "use_xentropy", False) - if use_xentropy and CrossEntropyLoss is None: - raise ImportError("xentropy_cuda is not installed") - loss_cls = ( - nn.CrossEntropyLoss - if not use_xentropy - else partial(CrossEntropyLoss, inplace_backward=True) - ) - - self.bert = BertModel(config) - self.cls = BertPreTrainingHeads(config) - self.mlm_loss = loss_cls(ignore_index=0) - self.nsp_loss = loss_cls(ignore_index=-1) - - # Initialize weights and apply final processing - self.apply(partial(_init_weights, initializer_range=config.initializer_range)) - self.tie_weights() - - def tie_weights(self): - self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight - - def forward( - self, - input_ids, - position_ids=None, - token_type_ids=None, - attention_mask=None, - labels=None, - next_sentence_label=None, - ): - """ - If labels are provided, they must be 0 for masked out tokens (as specified in the attention - mask). - Outputs: - if `labels` and `next_sentence_label` are not `None`: - Outputs the total_loss which is the sum of the masked language modeling loss and the next - sentence classification loss. - if `labels` or `next_sentence_label` is `None`: - Outputs a tuple comprising - - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and - - the next sentence classification logits of shape [batch_size, 2]. - - """ - masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None - outputs = self.bert( - input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask.bool() if attention_mask is not None else None, - masked_tokens_mask=masked_tokens_mask, - ) - sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output - if self.dense_seq_output and labels is not None: - masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten() - if not self.last_layer_subset: - sequence_output = index_first_axis( - rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx - ) - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) - - total_loss = None - if labels is not None and next_sentence_label is not None: - if ( - self.dense_seq_output and labels is not None - ): # prediction_scores are already flattened - masked_lm_loss = self.mlm_loss( - prediction_scores, labels.flatten()[masked_token_idx] - ) - else: - masked_lm_loss = self.mlm_loss( - rearrange(prediction_scores, "... v -> (...) v"), - rearrange(labels, "... -> (...)"), - ) - next_sentence_loss = self.nsp_loss( - rearrange(seq_relationship_score, "... t -> (...) t"), - rearrange(next_sentence_label, "... -> (...)"), - ) - total_loss = masked_lm_loss.float() + next_sentence_loss.float() - - return BertForPreTrainingOutput( - loss=total_loss, - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - ) - - -def remap_state_dict(state_dict, config: PretrainedConfig): - """ - Map the state_dict of a Huggingface BERT model to be flash_attn compatible. - """ - - # LayerNorm - def key_mapping_ln_gamma_beta(key): - key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key) - key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key) - return key - - state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()) - - # Layers - def key_mapping_layers(key): - return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key) - - state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) - - # LayerNorm - def key_mapping_ln(key): - key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key) - key = re.sub( - r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)", - r"bert.encoder.layers.\1.norm1.\2", - key, - ) - key = re.sub( - r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)", - r"bert.encoder.layers.\1.norm2.\2", - key, - ) - key = re.sub( - r"^cls.predictions.transform.LayerNorm.(weight|bias)", - r"cls.predictions.transform.layer_norm.\1", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - # MLP - def key_mapping_mlp(key): - key = re.sub( - r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)", - r"bert.encoder.layers.\1.mlp.fc1.\2", - key, - ) - key = re.sub( - r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)", - r"bert.encoder.layers.\1.mlp.fc2.\2", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # Attention - last_layer_subset = getattr(config, "last_layer_subset", False) - for d in range(config.num_hidden_layers): - Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight") - Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight") - Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight") - bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias") - bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias") - bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias") - if not (last_layer_subset and d == config.num_hidden_layers - 1): - state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat( - [Wq, Wk, Wv], dim=0 - ) - state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0) - else: - state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq - state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0) - state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq - state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0) - - def key_mapping_attn(key): - return re.sub( - r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)", - r"bert.encoder.layers.\1.mixer.out_proj.\2", - key, - ) - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - - def key_mapping_decoder_bias(key): - return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key) - - state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items()) - - # Word embedding - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - if pad_vocab_size_multiple > 1: - word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"] - state_dict["bert.embeddings.word_embeddings.weight"] = F.pad( - word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0]) - ) - decoder_weight = state_dict["cls.predictions.decoder.weight"] - state_dict["cls.predictions.decoder.weight"] = F.pad( - decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0]) - ) - # If the vocab was padded, we want to set the decoder bias for those padded indices to be - # strongly negative (i.e. the decoder shouldn't predict those indices). - # TD [2022-05-09]: I don't think it affects the MLPerf training. - decoder_bias = state_dict["cls.predictions.decoder.bias"] - state_dict["cls.predictions.decoder.bias"] = F.pad( - decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0 - ) - - return state_dict - - -def inv_remap_state_dict(state_dict, config: PretrainedConfig): - """ - Map the state_dict of a flash_attn model to be Huggingface BERT compatible. - - This function is meant to be the inverse of remap_state_dict. - """ - # Word embedding - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - if pad_vocab_size_multiple > 1: - word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"] - decoder_weight = state_dict["cls.predictions.decoder.weight"] - decoder_bias = state_dict["cls.predictions.decoder.bias"] - # unpad embeddings - state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[ - : config.orig_vocab_size, : - ] - state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :] - state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size] - - for d in range(config.num_hidden_layers): - last_layer_subset = getattr(config, "last_layer_subset", False) - if not last_layer_subset or d != (config.num_hidden_layers - 1): - Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight") - Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias") - state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[ - : Wqkv_weights.shape[0] // 3, : - ] - state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[ - Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, : - ] - state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[ - 2 * Wqkv_weights.shape[0] // 3 :, : - ] - state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[ - : Wqkv_biases.shape[0] // 3 - ] - state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[ - Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3 - ] - state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[ - 2 * Wqkv_biases.shape[0] // 3 : - ] - else: - Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight") - Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight") - Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias") - Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias") - state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight - state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[ - : Wkv_weights.shape[0] // 2, : - ] - state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[ - Wkv_weights.shape[0] // 2 :, : - ] - state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias - state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[ - : Wkv_biases.shape[0] // 2 - ] - state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[ - Wkv_biases.shape[0] // 2 : - ] - - def inv_key_mapping_ln(key): - key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key) - key = re.sub( - r"bert.encoder.layers.(\d+).norm1.(weight|bias)", - r"bert.encoder.layers.\1.attention.output.LayerNorm.\2", - key, - ) - key = re.sub( - r"bert.encoder.layers.(\d+).norm2.(weight|bias)", - r"bert.encoder.layers.\1.output.LayerNorm.\2", - key, - ) - key = re.sub( - r"cls.predictions.transform.layer_norm.(weight|bias)", - r"cls.predictions.transform.LayerNorm.\1", - key, - ) - return key - - def inv_key_mapping_ln_gamma_beta(key): - key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key) - key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key) - return key - - def inv_key_mapping_layers(key): - return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key) - - def inv_key_mapping_mlp(key): - key = re.sub( - r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)", - r"bert.encoder.layer.\1.intermediate.dense.\2", - key, - ) - key = re.sub( - r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)", - r"bert.encoder.layer.\1.output.dense.\2", - key, - ) - return key - - def inv_key_mapping_attn(key): - return re.sub( - r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)", - r"bert.encoder.layer.\1.attention.output.dense.\2", - key, - ) - - def inv_key_mapping_decoder_bias(key): - return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key) - - state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items()) - state_dict = OrderedDict( - (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items() - ) - state_dict = OrderedDict( - (inv_key_mapping_layers(key), value) for key, value in state_dict.items() - ) - state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items()) - state_dict = OrderedDict( - (inv_key_mapping_attn(key), value) for key, value in state_dict.items() - ) - state_dict = OrderedDict( - (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items() - ) - - return state_dict diff --git a/vllm/thirdparty_files/flash_attn/models/bigcode.py b/vllm/thirdparty_files/flash_attn/models/bigcode.py deleted file mode 100644 index 234944d4d690..000000000000 --- a/vllm/thirdparty_files/flash_attn/models/bigcode.py +++ /dev/null @@ -1,233 +0,0 @@ -import math -import re -from collections import OrderedDict - -import torch -import torch.nn.functional as F -from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig - - -def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig): - """ - Map the state_dict of a Huggingface BigCode model to be flash_attn compatible. - """ - - # Word embedding and position embedding - def key_mapping_pos_emb(key): - return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key) - - state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop("transformer.wte.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( - word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) - ) - state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] - - # LayerNorm - def key_mapping_ln(key): - key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key) - key = re.sub( - r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", - r"transformer.layers.\1.norm\2.\3", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - def key_mapping_mlp(key): - key = re.sub( - r"^transformer.h.(\d+).mlp.c_fc.weight", - r"transformer.layers.\1.mlp.fc1.weight", - key, - ) - key = re.sub( - r"^transformer.h.(\d+).mlp.c_proj.weight", - r"transformer.layers.\1.mlp.fc2.weight", - key, - ) - key = re.sub( - r"^transformer.h.(\d+).mlp.c_fc.bias", - r"transformer.layers.\1.mlp.fc1.bias", - key, - ) - key = re.sub( - r"^transformer.h.(\d+).mlp.c_proj.bias", - r"transformer.layers.\1.mlp.fc2.bias", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # TODO: add support for multi-head attention - assert config.multi_query, "Only multi-query attention is supported" - - # Attention - for d in range(config.num_hidden_layers): - embed_dim = config.n_embd - head_dim = embed_dim // config.n_head - - c_attn_weight = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight") - # with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim) - # see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112 - # see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183 - # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim) - q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0) - # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim) - k = torch.tile(k, (config.n_head, 1)) - v = torch.tile(v, (config.n_head, 1)) - state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = torch.cat((q, k, v), dim=0) - - # same deal with the bias - c_attn_bias = state_dict.pop(f"transformer.h.{d}.attn.c_attn.bias") - # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim) - q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0) - # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim) - k = torch.tile(k, (config.n_head,)) - v = torch.tile(v, (config.n_head,)) - state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = torch.cat((q, k, v), dim=0) - - def key_mapping_attn(key): - key = re.sub( - r"^transformer.h.(\d+).attn.c_proj.weight", - r"transformer.layers.\1.mixer.out_proj.weight", - key, - ) - key = re.sub( - r"^transformer.h.(\d+).attn.c_proj.bias", - r"transformer.layers.\1.mixer.out_proj.bias", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - - return state_dict - - -def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig): - """ - Map the state_dict of a flash_attn model to be Huggingface BigCode compatible. - - This function is meant to be the inverse of remap_state_dict_hf_bigcode. - """ - - # Word embedding and position embeddings - def inv_key_mapping_pos_emb(key): - return re.sub(r"^transformer.embeddings.position_embeddings.", "transformer.wpe.", key) - - state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") - - word_embeddings = word_embeddings[:, : config.vocab_size] - state_dict["transformer.wte.weight"] = word_embeddings - state_dict["lm_head.weight"] = word_embeddings - - # LayerNorm - def inv_key_mapping_ln(key): - key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key) - key = re.sub( - r"^transformer.layers.(\d+).norm(1|2).(weight|bias)", - r"transformer.h.\1.ln_\2.\3", - key, - ) - return key - - state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items()) - - # MLPs - def inv_key_mapping_mlp(key): - key = re.sub( - r"^transformer.layers.(\d+).mlp.fc1.weight", - r"transformer.h.\1.mlp.c_fc.weight", - key, - ) - key = re.sub( - r"^transformer.layers.(\d+).mlp.fc2.weight", - r"transformer.h.\1.mlp.c_proj.weight", - key, - ) - key = re.sub( - r"^transformer.layers.(\d+).mlp.fc1.bias", - r"transformer.h.\1.mlp.c_fc.bias", - key, - ) - key = re.sub( - r"^transformer.layers.(\d+).mlp.fc2.bias", - r"transformer.h.\1.mlp.c_proj.bias", - key, - ) - return key - - state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # Attention - for d in range(config.num_hidden_layers): - embed_dim = config.n_embd - head_dim = embed_dim // config.n_head - - Wqkv_weight = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight") - q, k, v = torch.split( - Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0 - ) - c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0) - state_dict[f"transformer.h.{d}.attn.c_attn.weight"] = c_attn_weight - - # Same deal with the bias - Wqkv_bias = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias") - q, k, v = torch.split( - Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0 - ) - c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0) - state_dict[f"transformer.h.{d}.attn.c_attn.bias"] = c_attn_bias - - def inv_key_mapping_attn(key): - key = re.sub( - r"^transformer.layers.(\d+).mixer.out_proj.weight", - r"transformer.h.\1.attn.c_proj.weight", - key, - ) - key = re.sub( - r"^transformer.layers.(\d+).mixer.out_proj.bias", - r"transformer.h.\1.attn.c_proj.bias", - key, - ) - return key - - state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items()) - - return state_dict - - -def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config: - return GPT2Config( - activation_function=bigcode_config.activation_function, - attn_pdrop=bigcode_config.attn_pdrop, - bos_token_id=bigcode_config.bos_token_id, - embd_pdrop=bigcode_config.embd_pdrop, - eos_token_id=bigcode_config.eos_token_id, - initializer_range=bigcode_config.initializer_range, - layer_norm_epsilon=bigcode_config.layer_norm_epsilon, - max_batch_size=bigcode_config.max_batch_size, - max_sequence_length=bigcode_config.max_sequence_length, - model_type=bigcode_config.model_type, - multi_query=bigcode_config.multi_query, - n_embd=bigcode_config.n_embd, - n_head=bigcode_config.n_head, - n_inner=bigcode_config.n_inner, - n_layer=bigcode_config.n_layer, - n_positions=bigcode_config.n_positions, - resid_pdrop=bigcode_config.resid_pdrop, - scale_attn_weights=bigcode_config.scale_attn_weights, - summary_activation=bigcode_config.summary_activation, - summary_first_dropout=bigcode_config.summary_first_dropout, - summary_proj_to_labels=bigcode_config.summary_proj_to_labels, - summary_type=bigcode_config.summary_type, - summary_use_proj=bigcode_config.summary_use_proj, - use_cache=bigcode_config.use_cache, - vocab_size=bigcode_config.vocab_size, - ) diff --git a/vllm/thirdparty_files/flash_attn/models/btlm.py b/vllm/thirdparty_files/flash_attn/models/btlm.py deleted file mode 100644 index 295e12062320..000000000000 --- a/vllm/thirdparty_files/flash_attn/models/btlm.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -import math -import json -import re -from pathlib import Path - -from collections import OrderedDict - -import torch -import torch.nn.functional as F - -from einops import rearrange -from transformers import GPT2Config, AutoConfig, PretrainedConfig - - -def remap_state_dict_hf_btlm(state_dict, config): - # Word embedding and position embedding - def key_mapping_pos_emb(key): - return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key) - - if "transformer.wpe.weight" in state_dict: - state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop("transformer.wte.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( - word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) - ) - state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] - - # LayerNorm - def key_mapping_ln(key): - key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key) - key = re.sub(r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key) - return key - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - # MLP - for d in range(config.num_hidden_layers): - W1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.weight") - W3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.weight") - state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = torch.cat([W1.t(), W3.t()], dim=0) - b1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.bias") - b3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.bias") - state_dict[f"transformer.layers.{d}.mlp.fc1.bias"] = torch.cat([b1, b3], dim=0) - W2 = state_dict.pop(f"transformer.h.{d}.mlp.c_proj.weight") - state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t() - - def key_mapping_mlp(key): - key = re.sub(r"^transformer.h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key) - return key - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # Attention - for d in range(config.num_hidden_layers): - Wqkv = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight") - state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t() - Wout = state_dict.pop(f"transformer.h.{d}.attn.c_proj.weight") - state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t() - state_dict.pop(f"transformer.relative_pe.slopes") # We don't store the Alibi slopes - - def key_mapping_attn(key): - key = re.sub(r"^transformer.h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key) - key = re.sub( - r"^transformer.h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key - ) - return key - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - - return state_dict - - -def btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config: - return GPT2Config( - vocab_size=btlm_config.vocab_size, - n_positions=0 if btlm_config.position_embedding_type == "alibi" else btlm_config.n_positions, - n_embd=btlm_config.hidden_size, - n_layer=btlm_config.num_hidden_layers, - n_head=btlm_config.num_attention_heads, - n_inner=btlm_config.n_inner, - activation_function=btlm_config.activation_function, - resid_pdrop=btlm_config.resid_pdrop, - embd_pdrop=btlm_config.embd_pdrop, - attn_pdrop=btlm_config.attn_pdrop, - layer_norm_epsilon=btlm_config.layer_norm_epsilon, - initializer_range=btlm_config.initializer_range, - bos_token_id=btlm_config.bos_token_id, - eos_token_id=btlm_config.eos_token_id, - # These are new arguments not in the original GPT2Config - use_alibi=btlm_config.position_embedding_type == "alibi", - use_flash_attn=btlm_config.position_embedding_type == "alibi", # Alibi code path requires flash_attn - mup_width_scale=btlm_config.mup_width_scale, - mup_embeddings_multiplier=btlm_config.mup_embeddings_scale, - mup_output_multiplier=btlm_config.mup_output_alpha, - mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d, - mlp_multiple_of=1, - ) diff --git a/vllm/thirdparty_files/flash_attn/models/falcon.py b/vllm/thirdparty_files/flash_attn/models/falcon.py deleted file mode 100644 index 4b02ec772774..000000000000 --- a/vllm/thirdparty_files/flash_attn/models/falcon.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -import math -import re -from collections import OrderedDict - -import torch -import torch.nn.functional as F -from einops import rearrange -from transformers import FalconConfig, GPT2Config - - -def remap_state_dict_hf_falcon(state_dict, config): - def key_mapping_layers(key): - return re.sub(r"^transformer.h.", "transformer.layers.", key) - - state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) - # Word embedding - def key_mapping_emb(key): - return re.sub( - r"^transformer.word_embeddings.", "transformer.embeddings.word_embeddings.", key - ) - - state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( - word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) - ) - if getattr(config, "tie_word_embeddings"): - state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] - else: - output_embeddings = state_dict.pop("lm_head.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - state_dict["lm_head.weight"] = F.pad( - output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) - ) - output_embeddings_bias = state_dict.pop("lm_head.bias") - state_dict["lm_head.bias"] = F.pad( - output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0]) - ) - - # LayerNorm - def key_mapping_ln(key): - key = re.sub( - r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key - ) - key = re.sub( - r"^transformer.layers.(\d+).post_attention_layernorm.", - r"transformer.layers.\1.norm2.", - key, - ) - key = re.sub(r"^transformer.layers.(\d+).ln_attn.", r"transformer.layers.\1.norm1.", key) - key = re.sub(r"^transformer.layers.(\d+).ln_mlp.", r"transformer.layers.\1.norm2.", key) - return key - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - # MLP - def key_mapping_mlp(key): - key = re.sub( - r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key - ) - key = re.sub( - r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key - ) - return key - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - def key_mapping_attn(key): - key = re.sub( - r"^transformer.layers.(\d+).self_attention.query_key_value.", - r"transformer.layers.\1.mixer.Wqkv.", - key, - ) - key = re.sub( - r"^transformer.layers.(\d+).self_attention.dense.", - r"transformer.layers.\1.mixer.out_proj.", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - n_head = config.n_head - n_head_kv = getattr(config, "n_head_kv", 1) - headdim = config.hidden_size // n_head - for l in range(config.n_layer): - # The weights are stored in a different layout compared to our implementation - Wqkv = rearrange( - state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight"), - "(group ratio headdim) ... -> group ratio headdim ...", - ratio=n_head // n_head_kv + 2, - headdim=headdim, - ) - Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...") - Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...") - Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...") - state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) - - return state_dict - - -def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config: - # The 40b config uses "n_head_kv" instead of "num_kv_heads" - n_head_kv = getattr( - falcon_config, - "n_head_kv", - 1 if getattr(falcon_config, "multi_query", False) else falcon_config.n_head, - ) - # HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config. - # So we have to infer it from the number of heads in the key/value block - parallel_block_tied_norm = n_head_kv == 1 - return GPT2Config( - vocab_size=falcon_config.vocab_size, - n_positions=0, # No absolute position embedding - n_embd=falcon_config.hidden_size, - n_layer=falcon_config.n_layer, - n_head=falcon_config.n_head, - n_inner=falcon_config.hidden_size * 4, - activation_function="gelu", - resid_pdrop=falcon_config.hidden_dropout, - embd_pdrop=0.0, # There doesn't seem to be any embedding dropout - attn_pdrop=falcon_config.attention_dropout, - layer_norm_epsilon=falcon_config.layer_norm_epsilon, - initializer_range=falcon_config.initializer_range, - bos_token_id=falcon_config.bos_token_id, - eos_token_id=falcon_config.eos_token_id, - # These are new arguments not in the original GPT2Config - parallel_block=falcon_config.parallel_attn, - n_head_kv=n_head_kv, - parallel_block_tied_norm=parallel_block_tied_norm, - rotary_emb_fraction=1.0, - rotary_emb_interleaved=False, - tie_word_embeddings=True, - qkv_proj_bias=falcon_config.bias, - out_proj_bias=falcon_config.bias, - mlp_fc1_bias=falcon_config.bias, - mlp_fc2_bias=falcon_config.bias, - lm_head_bias=False, - ) diff --git a/vllm/thirdparty_files/flash_attn/models/gpt.py b/vllm/thirdparty_files/flash_attn/models/gpt.py deleted file mode 100644 index 71540da954bb..000000000000 --- a/vllm/thirdparty_files/flash_attn/models/gpt.py +++ /dev/null @@ -1,1080 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -import logging -import math -import re -from collections import OrderedDict, namedtuple -from collections.abc import Sequence -from functools import partial -from typing import Dict, List - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from transformers import GPT2Config - -from flash_attn.models.bigcode import remap_state_dict_hf_bigcode -from flash_attn.models.falcon import remap_state_dict_hf_falcon -from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox -from flash_attn.models.gptj import remap_state_dict_hf_gptj -from flash_attn.models.llama import remap_state_dict_hf_llama -from flash_attn.models.opt import remap_state_dict_hf_opt -from flash_attn.modules.block import Block, ParallelBlock -from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings -from flash_attn.modules.mha import MHA, ParallelMHA -from flash_attn.modules.mlp import ( - FusedMLP, - GatedMlp, - Mlp, - ParallelFusedMLP, - ParallelGatedMlp, - ParallelMLP, -) -from flash_attn.ops.activations import sqrelu_fwd -from flash_attn.utils.distributed import ( - all_gather, - all_gather_raw, - get_dim_for_local_rank, - sync_shared_params, -) -from flash_attn.utils.generation import GenerationMixin -from flash_attn.utils.pretrained import state_dict_from_pretrained - -try: - from flash_attn.ops.fused_dense import ColumnParallelLinear -except ImportError: - ColumnParallelLinear = None - -try: - from flash_attn.ops.triton.mlp import FusedDenseSqreluDense -except ImportError: - FusedDenseSqreluDense = None - -try: - from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm -except ImportError: - layer_norm_fn, RMSNorm = None, None - -logger = logging.getLogger(__name__) - - -def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - attn_scale_power = 0.5 if not getattr(config, "mup_scale_qk_dot_by_d", False) else 1.0 - softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power)) - softmax_scale *= getattr(config, "mup_attn_multiplier", 1.0) - if config.scale_attn_by_inverse_layer_idx: - assert layer_idx is not None - softmax_scale /= float(layer_idx + 1) - dwconv = getattr(config, "attn_dwconv", False) - if dwconv: - assert process_group is None, "TensorParallel MHA does not support dwconv yet" - qkv_proj_bias = getattr(config, "qkv_proj_bias", True) - out_proj_bias = getattr(config, "out_proj_bias", True) - rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim) - rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0) - rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None) - rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False) - use_alibi = getattr(config, "use_alibi", False) - window_size = getattr(config, "window_size", (-1, -1)) - use_flash_attn = getattr(config, "use_flash_attn", False) - fused_bias_fc = getattr(config, "fused_bias_fc", False) - if not fused_bias_fc: - assert process_group is None, "TensorParallel MHA requires fused_bias_fc" - mha_cls = MHA if process_group is None else ParallelMHA - serial_kwargs = ( - {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {} - ) - parallel_kwargs = ( - { - "process_group": process_group, - "sequence_parallel": getattr(config, "sequence_parallel", True), - } - if process_group is not None - else {} - ) - num_heads_kv = getattr(config, "n_head_kv", None) - mixer_cls = partial( - mha_cls, - num_heads=config.num_attention_heads, - num_heads_kv=num_heads_kv, - qkv_proj_bias=qkv_proj_bias, - out_proj_bias=out_proj_bias, - dropout=config.attn_pdrop, - softmax_scale=softmax_scale, - causal=True, - layer_idx=layer_idx, - rotary_emb_dim=rotary_emb_dim, - rotary_emb_base=rotary_emb_base, - rotary_emb_scale_base=rotary_emb_scale_base, - rotary_emb_interleaved=rotary_emb_interleaved, - use_alibi=use_alibi, - window_size=window_size, - use_flash_attn=use_flash_attn, - **serial_kwargs, - **parallel_kwargs, - **factory_kwargs, - ) - return mixer_cls - - -def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True) - mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True) - fused_mlp = getattr(config, "fused_mlp", False) - if fused_mlp: - assert config.activation_function in [ - "gelu_new", - "gelu_fast", - "gelu_approx", - "gelu_pytorch_tanh", - "relu", - "sqrelu", - ] - fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False) - if fused_dense_sqrelu_dense: - assert config.activation_function == "sqrelu", ( - "fused_dense_sqrelu_dense only " "supports approximate activation_function sqrelu" - ) - assert not (fused_dense_sqrelu_dense and fused_mlp) - if not fused_mlp and not fused_dense_sqrelu_dense: - assert config.activation_function in [ - "gelu", - "gelu_new", - "gelu_fast", - "gelu_approx", - "gelu_pytorch_tanh", - "relu", - "sqrelu", - "glu", - "swiglu", - "geglu", - ] - if config.activation_function in ["glu", "swiglu", "geglu"]: - activation = ( - F.sigmoid - if config.activation_function == "glu" - else (F.silu if config.activation_function == "swiglu" else F.gelu) - ) - mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp - parallel_kwargs = ( - { - "process_group": process_group, - "sequence_parallel": getattr(config, "sequence_parallel", True), - } - if process_group is not None - else {} - ) - mlp_multiple_of = getattr(config, "mlp_multiple_of", 128) - mlp_cls = partial( - mlp_cls, - hidden_features=config.n_inner, - activation=activation, - bias1=mlp_fc1_bias, - bias2=mlp_fc2_bias, - multiple_of=mlp_multiple_of, - **parallel_kwargs, - **factory_kwargs, - ) - else: - if config.activation_function == "relu": - activation = partial(F.relu, inplace=True) - elif config.activation_function == "sqrelu": - activation = sqrelu_fwd - else: - approximate = ( - "tanh" - if config.activation_function - in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"] - else "none" - ) - activation = partial(F.gelu, approximate=approximate) - mlp_cls = Mlp if process_group is None else ParallelMLP - parallel_kwargs = ( - { - "process_group": process_group, - "sequence_parallel": getattr(config, "sequence_parallel", True), - } - if process_group is not None - else {} - ) - mlp_cls = partial( - mlp_cls, - hidden_features=config.n_inner, - activation=activation, - bias1=mlp_fc1_bias, - bias2=mlp_fc2_bias, - **parallel_kwargs, - **factory_kwargs, - ) - else: - mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) - # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer - if isinstance(mlp_checkpoint_lvl, Sequence): - assert layer_idx is not None - mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] - if fused_mlp: - if FusedMLP is None: - raise ImportError("fused_dense is not installed") - activation = ( - "gelu_approx" - if config.activation_function - in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"] - else config.activation_function - ) - mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP - parallel_kwargs = ( - { - "process_group": process_group, - "sequence_parallel": getattr(config, "sequence_parallel", True), - } - if process_group is not None - else {} - ) - mlp_cls = partial( - mlp_cls, - hidden_features=config.n_inner, - activation=activation, - checkpoint_lvl=mlp_checkpoint_lvl, - bias1=mlp_fc1_bias, - bias2=mlp_fc2_bias, - **parallel_kwargs, - **factory_kwargs, - ) - elif fused_dense_sqrelu_dense: - if process_group is not None: - assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense" - assert FusedDenseSqreluDense is not None - mlp_cls = partial( - FusedDenseSqreluDense, - hidden_features=config.n_inner, - checkpoint_lvl=mlp_checkpoint_lvl, - **factory_kwargs, - ) - else: - raise RuntimeError("MLP type not supported") - return mlp_cls - - -def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - sequence_parallel = getattr(config, "sequence_parallel", True) - mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs) - mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs) - use_rms_norm = getattr(config, "rms_norm", False) - norm_cls = partial( - nn.LayerNorm if not use_rms_norm else RMSNorm, - eps=config.layer_norm_epsilon, - **factory_kwargs, - ) - # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable - residual_in_fp32 = getattr(config, "residual_in_fp32", False) - resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop - prenorm = getattr(config, "prenorm", True) - parallel_block = getattr(config, "parallel_block", False) - if not parallel_block: - block = Block( - config.hidden_size, - mixer_cls, - mlp_cls, - norm_cls=norm_cls, - prenorm=prenorm, - resid_dropout1=resid_dropout1, - resid_dropout2=config.resid_pdrop, - fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), - residual_in_fp32=residual_in_fp32, - sequence_parallel=sequence_parallel and process_group is not None, - mark_shared_params=process_group is not None, - ) - else: - assert prenorm - block = ParallelBlock( - config.hidden_size, - mixer_cls, - mlp_cls, - norm_cls=norm_cls, - resid_dropout1=resid_dropout1, - resid_dropout2=config.resid_pdrop, - tied_norm=getattr(config, "parallel_block_tied_norm", False), - fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), - residual_in_fp32=residual_in_fp32, - sequence_parallel=sequence_parallel and process_group is not None, - mark_shared_params=process_group is not None, - ) - block.layer_idx = layer_idx - return block - - -class GPTPreTrainedModel(nn.Module): - """An abstract class to handle weights initialization and - a simple interface for dowloading and loading pretrained models. - """ - - def __init__(self, config, *inputs, **kwargs): - super().__init__() - if not isinstance(config, GPT2Config): - raise ValueError( - "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " - "To create a model from a Google pretrained model use " - "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( - self.__class__.__name__, self.__class__.__name__ - ) - ) - self.config = config - - @classmethod - def from_pretrained( - cls, - model_name, - config, - *args, - strict=True, - device=None, - dtype=None, - world_size=1, - rank=0, - **kwargs, - ): - """ - Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict. - Download and cache the pre-trained model file if needed. - """ - # Instantiate model. - model = cls(config, *args, device=device, dtype=dtype, **kwargs) - # Load state_dict in cpu because we already initialized the model in GPU, and we don't - # want extra stuff taking up more GPU memory - state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype) - if model_name.startswith("gpt2"): - state_dict = remap_state_dict_hf_gpt2(state_dict, config) - elif model_name.startswith("facebook/opt"): - state_dict = remap_state_dict_hf_opt(state_dict, config) - elif model_name.startswith("EleutherAI/gpt-j-") or model_name.startswith( - "togethercomputer/GPT-JT-" - ): - state_dict = remap_state_dict_hf_gptj(state_dict, config) - elif ( - model_name.startswith("EleutherAI/gpt-neox-") - or model_name.startswith("EleutherAI/pythia-") - or model_name.startswith("togethercomputer/RedPajama-INCITE-") - ): - state_dict = remap_state_dict_hf_gpt_neox(state_dict, config) - elif model_name.startswith("tiiuae/falcon-"): - state_dict = remap_state_dict_hf_falcon(state_dict, config) - elif model_name.startswith("meta-llama/Llama-"): - state_dict = remap_state_dict_hf_llama(state_dict, config) - elif model_name.startswith("bigcode/") or model_name.startswith("WizardLM/"): - state_dict = remap_state_dict_hf_bigcode(state_dict, config) - else: - raise NotImplementedError(f"Model {model_name} not supported") - if world_size > 1: - state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) - load_return = model.load_state_dict(state_dict, strict=strict) - logger.info(load_return) - return model - - -# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 -def _init_weights( - module, n_layer, initializer_range=0.02, mup_width_scale=1.0, rescale_prenorm_residual=True -): - mup_init_scale = math.sqrt(mup_width_scale) - if isinstance(module, nn.Linear): - nn.init.normal_(module.weight, std=initializer_range * mup_init_scale) - optim_cfg = getattr(module.weight, "_optim", {}) - optim_cfg.update({"lr_multiplier": mup_width_scale}) - setattr(module.weight, "_optim", optim_cfg) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=initializer_range) - - if rescale_prenorm_residual: - # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # - # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if name in ["out_proj.weight", "fc2.weight"]: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - nn.init.normal_( - p, mean=0.0, std=initializer_range * mup_init_scale / math.sqrt(2 * n_layer) - ) - - -class GPTModel(GPTPreTrainedModel): - def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): - super().__init__(config) - factory_kwargs = {"device": device, "dtype": dtype} - self.process_group = process_group - self.sequence_parallel = getattr(config, "sequence_parallel", True) - assert config.activation_function in [ - "gelu", - "gelu_new", - "gelu_fast", - "gelu_approx", - "gelu_pytorch_tanh", - "relu", - "sqrelu", - "glu", - "swiglu", - "geglu", - ] - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = ( - math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - ) - self.embeddings_multiplier = getattr(config, "mup_embeddings_multiplier", 1.0) - # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable - self.residual_in_fp32 = getattr(config, "residual_in_fp32", False) - # These 2 options are for OPT-350m - self.prenorm = getattr(config, "prenorm", True) - use_rms_norm = getattr(config, "rms_norm", False) - word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) - # For GPT-J, GPT-NeoX - self.parallel_block = getattr(config, "parallel_block", False) - - if process_group is None: - self.embeddings = GPT2Embeddings( - config.hidden_size, - vocab_size, - config.max_position_embeddings, - word_embed_proj_dim=word_embed_proj_dim, - **factory_kwargs, - ) - else: - self.embeddings = ParallelGPT2Embeddings( - config.hidden_size, - vocab_size, - config.max_position_embeddings, - process_group=process_group, - sequence_parallel=self.sequence_parallel, - **factory_kwargs, - ) - - # We change the order of dropout, residual and layer norm: - # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: - # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and - # the main branch (output of MLP). The model definition is unchanged, but the mapping of the - # nn.Dropout probabilities are changed. - # This is for performance reason: we can fuse dropout + add + layer_norm. - self.layers = nn.ModuleList( - [ - create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs) - for i in range(config.num_hidden_layers) - ] - ) - rotary_emb_fraction = getattr(config, "rotary_emb_fraction", 0.0) - if rotary_emb_fraction > 0.0: # Tie all the RotaryEmbedding modules to share the same cos/sin cache - for layer in self.layers[1:]: - layer.mixer.rotary_emb = self.layers[0].mixer.rotary_emb - - self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) - if self.fused_dropout_add_ln: - if layer_norm_fn is None: - raise ImportError("Triton is not installed") - if self.prenorm: - self.drop_f = nn.Dropout(config.resid_pdrop) - norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm - self.ln_f = norm_cls( - config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs - ) - if process_group is not None: - for p in self.ln_f.parameters(): - # Mark the norm parameters as "shared_params" so that we sync their values at init. - p._shared_params = True - # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads. - if self.sequence_parallel: - p._sequence_parallel = True - - self.apply( - partial( - _init_weights, - n_layer=config.num_hidden_layers, - initializer_range=config.initializer_range, - mup_width_scale=getattr(config, "mup_width_scale", 1.0), - ) - ) - self.tie_weights() - - def tie_weights(self): - if self.process_group is not None: - sync_shared_params(self, self.process_group) - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return { - i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - for i, layer in enumerate(self.layers) - } - - def forward(self, input_ids, position_ids=None, inference_params=None): - # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen - # dimensions so that we can split on it easily, in case of small batch size. - # Only the attention layers need to know the seqlen. - embedding_kwargs = ( - {"combine_batch_seqlen_dim": True} - if self.process_group is not None and self.sequence_parallel - else {} - ) - hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) - if self.embeddings_multiplier != 1.0: - hidden_states = hidden_states * self.embeddings_multiplier - if self.parallel_block: - hidden_states2 = None - residual = None - mixer_kwargs = ( - {"seqlen": input_ids.shape[1]} - if self.process_group is not None and self.sequence_parallel - else {} - ) - if inference_params is not None: - mixer_kwargs["inference_params"] = inference_params - for layer in self.layers: - if self.prenorm: - if not self.parallel_block: - hidden_states, residual = layer( - hidden_states, residual, mixer_kwargs=mixer_kwargs - ) - else: - hidden_states, hidden_states2, residual = layer( - hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs - ) - else: - hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) - if self.prenorm: - if not self.fused_dropout_add_ln: - dropped = self.drop_f(hidden_states) - if not self.parallel_block: - residual = (dropped + residual) if residual is not None else dropped - else: - dropped2 = self.drop_f(hidden_states2) - residual = ( - (residual + dropped + dropped2) - if residual is not None - else dropped + dropped2 - ) - hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) - else: - # Set prenorm=False here since we don't need the residual - hidden_states = layer_norm_fn( - hidden_states, - self.ln_f.weight, - self.ln_f.bias, - residual=residual, - x1=None if not self.parallel_block else hidden_states2, - eps=self.ln_f.eps, - dropout_p=self.drop_f.p if self.training else 0.0, - prenorm=False, - is_rms_norm=isinstance(self.ln_f, RMSNorm) - ) - return hidden_states - - -class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): - def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__(config) - self.process_group = process_group - self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs) - self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True) - lm_head_bias = getattr(config, "lm_head_bias", False) - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = ( - math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - ) - # This option is for OPT-350m - word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) - embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim - if word_embed_proj_dim is not None: - self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs) - else: - self.project_out = None - mup_width_scale = getattr(config, "mup_width_scale", 1.0) - mup_output_multiplier = getattr(config, "mup_output_multiplier", 1.0) - self.output_scale = mup_output_multiplier * mup_width_scale - if process_group is None: - self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs) - else: - if ColumnParallelLinear is None: - raise ImportError("fused_dense_lib is not installed") - self.lm_head = ColumnParallelLinear( - embed_dim, - vocab_size, - process_group, - bias=lm_head_bias, - sequence_parallel=getattr(config, "sequence_parallel", True), - **factory_kwargs, - ) - self.norm_head = getattr(config, "norm_head", False) - # Initialize weights and apply final processing - self.apply( - partial( - _init_weights, - n_layer=config.num_hidden_layers, - initializer_range=config.initializer_range, - mup_width_scale=mup_width_scale, - ) - ) - self.tie_weights() - - def tie_weights(self): - if self.tie_word_embeddings: - self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight - if self.process_group is not None: - sync_shared_params(self, self.process_group) - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.transformer.allocate_inference_cache( - batch_size, max_seqlen, dtype=dtype, **kwargs - ) - - def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): - """ - input_ids: (batch, seqlen) int tensor - inference_params: for generation. Adapted from Megatron-LM (and Apex) - https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 - num_last_tokens: if > 0, only return the logits for the last n tokens - """ - assert ( - input_ids.ndim == 2 - ), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}" - b, slen = input_ids.shape - hidden_states = self.transformer( - input_ids, position_ids=position_ids, inference_params=inference_params - ) - if inference_params is not None: - assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode" - if num_last_tokens > 0: - hidden_states = hidden_states[:, -num_last_tokens:] - if self.project_out is not None: - hidden_states = self.project_out(hidden_states) - if self.output_scale != 1.0: - hidden_states = hidden_states * self.output_scale - if not self.norm_head: - lm_logits = self.lm_head(hidden_states) - else: - lm_head_weight = F.normalize(self.lm_head.weight) - if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel: - hidden_states = all_gather(hidden_states, self.lm_head.process_group) - lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias) - # During inference, we want the full logit for sampling - if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: - lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) - lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b) - CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) - return CausalLMOutput(logits=lm_logits) - - def load_state_dict(self, state_dict, strict=True): - # Remapping from our checkpoints that used a different ordering of layers in the block - # Previous: Attn / MLP -> Dropout -> Add -> LN - # Current: Dropout -> Add -> LN -> Attn / MLP - if "transformer.ln_0.weight" in state_dict: - n_layers = len(self.transformer.layers) - ln_weight = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.weight") - ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias") - state_dict["transformer.ln_f.weight"] = ln_weight - state_dict["transformer.ln_f.bias"] = ln_bias - for l in reversed(range(n_layers)): - ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight") - ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias") - state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight - state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias - if l > 0: - ln_weight = state_dict.pop(f"transformer.layers.{l - 1}.norm2.weight") - ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias") - state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight - state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias - ln_weight = state_dict.pop("transformer.ln_0.weight") - ln_bias = state_dict.pop("transformer.ln_0.bias") - state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight - state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias - return super().load_state_dict(state_dict, strict=strict) - - -def shard_state_dict_tp(state_dict, config, world_size, rank): - """Convert the state_dict of a standard GPT model to the state_dict of a GPT model - with tensor parallel. - - This function modifies state_dict in place. - """ - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - assert vocab_size % world_size == 0 - assert config.hidden_size % world_size == 0 - inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size - assert inner_dim % world_size == 0 - - n_head = config.n_head - n_head_kv = getattr(config, "n_head_kv", n_head) - - embed_dim = config.hidden_size - head_dim = embed_dim // n_head - - def shard_first_dim(state_dict, key): - if key in state_dict: - x = state_dict[key] - dim = x.shape[0] // world_size - state_dict[key] = x[rank * dim : (rank + 1) * dim] - - def shard_last_dim(state_dict, key, multiple_of=1): - if key in state_dict: - x = state_dict[key] - dim_each_rank = [ - get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of) - for local_rank in range(world_size) - ] - beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1)) - state_dict[key] = x[..., beg:end] - - def shard_gatedmlp_fc1_dim(state_dict, key): - if key in state_dict: - x = state_dict[key] - dim = x.shape[0] // world_size // 2 - state_dict[key] = rearrange( - rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim], - "two o ... -> (two o) ...", - ) - - def shard_qkv_headdim(state_dict, key): - if key in state_dict: - n_head_each_rank = [ - get_dim_for_local_rank(n_head, world_size, local_rank) - for local_rank in range(world_size) - ] - n_head_kv_each_rank = [ - get_dim_for_local_rank(n_head_kv, world_size, local_rank) - for local_rank in range(world_size) - ] - - beg_n_head = sum(n_head_each_rank[:rank]) - end_n_head = sum(n_head_each_rank[: rank + 1]) - - beg_n_head_kv = sum(n_head_kv_each_rank[:rank]) - end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1]) - - if n_head_kv == n_head: - x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3) - state_dict[key] = rearrange( - x[:, beg_n_head * head_dim : end_n_head * head_dim], - "three d ... -> (three d) ...", - ) - else: - x = rearrange( - state_dict[key], - "(nheadqkv headdim) ... -> nheadqkv headdim ...", - nheadqkv=n_head + 2 * n_head_kv, - ) - state_dict[key] = rearrange( - torch.cat( - [ - x[beg_n_head:end_n_head], - x[n_head + beg_n_head_kv : n_head + end_n_head_kv], - x[ - n_head - + n_head_kv - + beg_n_head_kv : n_head - + n_head_kv - + end_n_head_kv - ], - ], - dim=0, - ), - "nheadqkv headdim ... -> (nheadqkv headdim) ...", - ) - - shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight") - if "lm_head.weight" in state_dict: - shard_first_dim(state_dict, "lm_head.weight") - if "transformer.embeddings.position_embeddings.weight" in state_dict: - shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight") - for i in range(config.num_hidden_layers): - shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") - shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") - shard_last_dim( - state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim - ) - if rank != 0: - state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None) - if config.activation_function in ["glu", "swiglu", "geglu"]: - shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight") - shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias") - else: - shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight") - shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias") - shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight") - if rank != 0: - state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias", None) - return state_dict - - -def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config): - """Convert the list of sharded state_dict of a GPT model with tensor parallel to - the state_dict of a standard GPT model. - - This function is meant to be the "reverse" of shard_state_dict_tp. - - Precondition: - - state_dicts should be ordered in the same way as the shards were created. - """ - world_size = len(state_dicts) - keys = state_dicts[0].keys() - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - assert vocab_size % world_size == 0 - assert config.hidden_size % world_size == 0 - inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size - assert inner_dim % world_size == 0 - assert config.hidden_size % config.n_head == 0 - headdim = config.hidden_size // config.n_head - - # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim. - # vocab_size // world_size coordinates are nonzero. - def combine_word_embeddings(state_dicts, state_dict, key): - dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1 - state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) - - def combine_dim(state_dicts, state_dict, key, dim=-1): - if key in state_dict: - state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) - - def combine_qkv_headdim(state_dicts, state_dict, key): - n_head = config.n_head - n_head_kv = getattr(config, "n_head_kv", n_head) - if key in state_dict: - if n_head_kv == n_head: - xs = [ - rearrange(s[key], "(three d) ... -> three d ...", three=3) for s in state_dicts - ] - state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...") - else: - n_head_each_rank = [ - get_dim_for_local_rank(n_head, world_size, local_rank) - for local_rank in range(world_size) - ] - n_head_kv_each_rank = [ - get_dim_for_local_rank(n_head_kv, world_size, local_rank) - for local_rank in range(world_size) - ] - xs = [ - rearrange( - s[key], - "(nheadqkv headdim) ... -> nheadqkv headdim ...", - nheadqkv=rank_n_head + 2 * rank_n_head_kv, - headdim=headdim, - ) - for s, rank_n_head, rank_n_head_kv in zip( - state_dicts, n_head_each_rank, n_head_kv_each_rank - ) - ] - wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0) - wk = torch.cat( - [ - x[ - n_head_each_rank[rank] : n_head_each_rank[rank] - + n_head_kv_each_rank[rank] - ] - for rank, x in enumerate(xs) - ], - dim=0, - ) - wv = torch.cat( - [ - x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :] - for rank, x in enumerate(xs) - ], - dim=0, - ) - wqkv = torch.cat( - [wq, wk, wv], - dim=0, - ) - state_dict[key] = rearrange( - wqkv, - "nheadqkv headdim ... -> (nheadqkv headdim) ...", - ) - - def combine_gated_mlp(state_dicts, state_dict, key): - if key in state_dict: - xs = [rearrange(s[key], "(two d) ... -> two d ...", two=2) for s in state_dicts] - state_dict[key] = rearrange(torch.cat(xs, dim=1), "two d ... -> (two d) ...") - - state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace - combine_word_embeddings( - state_dicts, state_dict, "transformer.embeddings.word_embeddings.weight" - ) - if "lm_head.weight" in state_dict: - combine_word_embeddings(state_dicts, state_dict, "lm_head.weight") - if "transformer.embeddings.position_embeddings.weight" in state_dict: - combine_dim( - state_dicts, state_dict, "transformer.embeddings.position_embeddings.weight", -1 - ) - mlp_combine_fn = ( - combine_gated_mlp - if config.activation_function in ["glu", "swiglu", "geglu"] - else partial(combine_dim, dim=0) - ) - for i in range(config.num_hidden_layers): - combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") - combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") - combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", -1) - mlp_combine_fn(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.weight") - combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.bias", 0) - combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc2.weight", -1) - return state_dict - - -def remap_state_dict_hf_gpt2(state_dict, config): - # Word embedding and position embedding - def key_mapping_pos_emb(key): - return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key) - - state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop("wte.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( - word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) - ) - state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] - - # LayerNorm - def key_mapping_ln(key): - key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key) - key = re.sub(r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key) - return key - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - # MLP - for d in range(config.num_hidden_layers): - W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight") - state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t() - W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight") - state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t() - - def key_mapping_mlp(key): - key = re.sub(r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key) - key = re.sub(r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key) - return key - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # Attention - for d in range(config.num_hidden_layers): - state_dict.pop(f"h.{d}.attn.bias") # We don't store this bias - Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight") - state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t() - Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight") - state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t() - - def key_mapping_attn(key): - key = re.sub(r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key) - key = re.sub( - r"^h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key - ) - return key - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - - return state_dict - - -def remap_state_dict_megatron(state_dict, config): - def key_mapping_transformer(key): - key = re.sub(r"^language_model.encoder.", "transformer.", key) - key = re.sub(r"^language_model.", "transformer.", key) - return key - - state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items()) - - # Word embedding and position embedding - def key_mapping_pos_emb(key): - return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key) - - state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop("transformer.embedding.word_embeddings.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = ( - math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple - ) - state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( - word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) - ) - state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] - - # LayerNorm - def key_mapping_ln(key): - key = re.sub(r"^transformer.final_layernorm.(weight|bias)", r"transformer.ln_f.\1", key) - key = re.sub( - r"^transformer.layers.(\d+).input_layernorm.(weight|bias)", - r"transformer.layers.\1.norm1.\2", - key, - ) - key = re.sub( - r"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)", - r"transformer.layers.\1.norm2.\2", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - # MLP - def key_mapping_mlp(key): - key = re.sub( - r"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)", - r"transformer.layers.\1.mlp.fc1.\2", - key, - ) - key = re.sub( - r"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)", - r"transformer.layers.\1.mlp.fc2.\2", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # Attention - def key_mapping_attn(key): - key = re.sub( - r"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq", - r"transformer.layers.\1.mixer.rotary_emb.inv_freq", - key, - ) - key = re.sub( - r"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)", - r"transformer.layers.\1.mixer.Wqkv.\2", - key, - ) - key = re.sub( - r"^transformer.layers.(\d+).self_attention.dense.(weight|bias)", - r"transformer.layers.\1.mixer.out_proj.\2", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim) - # while we store Wqkv as ((3 nheads headdim), hidden_dim) - headdim = config.hidden_size // config.num_attention_heads - for d in range(config.num_hidden_layers): - Wqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight") - state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = rearrange( - Wqkv, - "(nheads three headdim) ... -> (three nheads headdim) ...", - three=3, - headdim=headdim, - ) - bqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias") - state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = rearrange( - bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim - ) - - return state_dict diff --git a/vllm/thirdparty_files/flash_attn/models/gpt_neox.py b/vllm/thirdparty_files/flash_attn/models/gpt_neox.py deleted file mode 100644 index c38940441722..000000000000 --- a/vllm/thirdparty_files/flash_attn/models/gpt_neox.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -import math -import re -from collections import OrderedDict - -import torch -import torch.nn.functional as F -from einops import rearrange -from transformers import GPT2Config, GPTNeoXConfig - - -def remap_state_dict_hf_gpt_neox(state_dict, config): - def key_mapping_layers(key): - return re.sub(r"^gpt_neox.", "transformer.", key) - - state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) - # Word embedding - def key_mapping_emb(key): - return re.sub(r"^transformer.embed_in.", "transformer.embeddings.word_embeddings.", key) - - state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( - word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) - ) - if getattr(config, "tie_word_embeddings", False): - state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] - else: - output_embeddings = state_dict.pop("embed_out.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - state_dict["lm_head.weight"] = F.pad( - output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) - ) - - # LayerNorm - def key_mapping_ln(key): - key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key) - key = re.sub( - r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key - ) - key = re.sub( - r"^transformer.layers.(\d+).post_attention_layernorm.", - r"transformer.layers.\1.norm2.", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - # MLP - def key_mapping_mlp(key): - key = re.sub( - r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key - ) - key = re.sub( - r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key - ) - return key - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # Attention - for l in range(config.n_layer): - # We don't store these biases - state_dict.pop(f"transformer.layers.{l}.attention.bias") - state_dict.pop(f"transformer.layers.{l}.attention.masked_bias") - # We don't store these - state_dict.pop(f"transformer.layers.{l}.attention.rotary_emb.inv_freq", None) - # GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim) - # while we store Wqkv as ((3 nheads headdim), hidden_dim) - headdim = config.hidden_size // config.num_attention_heads - Wqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.weight") - state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = rearrange( - Wqkv, - "(nheads three headdim) ... -> (three nheads headdim) ...", - three=3, - headdim=headdim, - ) - bqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.bias") - state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = rearrange( - bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim - ) - - def key_mapping_attn(key): - key = re.sub( - r"^transformer.layers.(\d+).attention.dense.", - r"transformer.layers.\1.mixer.out_proj.", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - - return state_dict - - -def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config: - assert gpt_neox_config.rotary_emb_base == 10000 - return GPT2Config( - vocab_size=gpt_neox_config.vocab_size, - n_positions=0, # No absolute position embedding - n_embd=gpt_neox_config.hidden_size, - n_layer=gpt_neox_config.num_hidden_layers, - n_head=gpt_neox_config.num_attention_heads, - n_inner=gpt_neox_config.intermediate_size, - activation_function=gpt_neox_config.hidden_act, - resid_pdrop=0.0, # No dropout - embd_pdrop=0.0, - attn_pdrop=0.0, - layer_norm_epsilon=gpt_neox_config.layer_norm_eps, - initializer_range=gpt_neox_config.initializer_range, - bos_token_id=gpt_neox_config.bos_token_id, - eos_token_id=gpt_neox_config.eos_token_id, - # These are new arguments not in the original GPT2Config - prenorm=True, - parallel_block=gpt_neox_config.use_parallel_residual, - parallel_block_tied_norm=False, - rotary_emb_fraction=gpt_neox_config.rotary_pct, - tie_word_embeddings=gpt_neox_config.tie_word_embeddings, - ) diff --git a/vllm/thirdparty_files/flash_attn/models/gptj.py b/vllm/thirdparty_files/flash_attn/models/gptj.py deleted file mode 100644 index ca2330d79ce5..000000000000 --- a/vllm/thirdparty_files/flash_attn/models/gptj.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -import math -import re -from collections import OrderedDict - -import torch -import torch.nn.functional as F -from transformers import GPT2Config, GPTJConfig - - -def remap_state_dict_hf_gptj(state_dict, config): - def key_mapping_layers(key): - return re.sub(r"^transformer.h.", "transformer.layers.", key) - - state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) - # Word embedding - def key_mapping_emb(key): - return re.sub(r"^transformer.wte.", "transformer.embeddings.word_embeddings.", key) - - state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( - word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) - ) - if getattr(config, "tie_word_embeddings"): - state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] - else: - output_embeddings = state_dict.pop("lm_head.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - state_dict["lm_head.weight"] = F.pad( - output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) - ) - output_embeddings_bias = state_dict.pop("lm_head.bias") - state_dict["lm_head.bias"] = F.pad( - output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0]) - ) - - # LayerNorm - def key_mapping_ln(key): - return re.sub(r"^transformer.layers.(\d+).ln_1.", r"transformer.layers.\1.norm1.", key) - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - # MLP - def key_mapping_mlp(key): - key = re.sub( - r"^transformer.layers.(\d+).mlp.fc_in.", r"transformer.layers.\1.mlp.fc1.", key - ) - key = re.sub( - r"^transformer.layers.(\d+).mlp.fc_out.", r"transformer.layers.\1.mlp.fc2.", key - ) - return key - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # Attention - for l in range(config.n_layer): - Wq = state_dict.pop(f"transformer.layers.{l}.attn.q_proj.weight") - Wk = state_dict.pop(f"transformer.layers.{l}.attn.k_proj.weight") - Wv = state_dict.pop(f"transformer.layers.{l}.attn.v_proj.weight") - state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) - # We don't store these biases - state_dict.pop(f"transformer.layers.{l}.attn.bias") - state_dict.pop(f"transformer.layers.{l}.attn.masked_bias") - - def key_mapping_attn(key): - return re.sub( - r"^transformer.layers.(\d+).attn.out_proj.", - r"transformer.layers.\1.mixer.out_proj.", - key, - ) - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - - return state_dict - - -def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config: - headdim = gptj_config.n_embd // gptj_config.n_head - return GPT2Config( - vocab_size=gptj_config.vocab_size, - n_positions=0, # No absolute position embedding - n_embd=gptj_config.n_embd, - n_layer=gptj_config.n_layer, - n_head=gptj_config.n_head, - n_inner=gptj_config.n_inner, - activation_function=gptj_config.activation_function, - resid_pdrop=gptj_config.resid_pdrop, - embd_pdrop=gptj_config.embd_pdrop, - attn_pdrop=gptj_config.attn_pdrop, - layer_norm_epsilon=gptj_config.layer_norm_epsilon, - initializer_range=gptj_config.initializer_range, - bos_token_id=gptj_config.bos_token_id, - eos_token_id=gptj_config.eos_token_id, - # These are new arguments not in the original GPT2Config - prenorm=True, - parallel_block=True, - parallel_block_tied_norm=True, - rotary_emb_fraction=gptj_config.rotary_dim / headdim, - rotary_emb_interleaved=True, - tie_word_embeddings=False, - qkv_proj_bias=False, - out_proj_bias=False, - lm_head_bias=True, - ) diff --git a/vllm/thirdparty_files/flash_attn/models/llama.py b/vllm/thirdparty_files/flash_attn/models/llama.py deleted file mode 100644 index 3bfb51d17e27..000000000000 --- a/vllm/thirdparty_files/flash_attn/models/llama.py +++ /dev/null @@ -1,422 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -import json -import math -import os -import re -from collections import OrderedDict -from pathlib import Path -from typing import Dict, List, Union - -import torch -import torch.nn.functional as F -from sentencepiece import SentencePieceProcessor -from transformers import GPT2Config, LlamaConfig - -from einops import rearrange - - -def remap_state_dict_meta_llama( - state_dict: Dict[str, torch.Tensor], config: GPT2Config -) -> Dict[str, torch.Tensor]: - """Convert the state_dict in Meta format to standard GPT format. - - This function modifies state_dict in place. - """ - - def key_mapping_layers(key): - return f"transformer.{key}" if not key.startswith("output.") else key - - state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) - - # Word embedding - def key_mapping_emb(key): - return re.sub( - r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key - ) - - state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = ( - math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple - ) - state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( - word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) - ) - if getattr(config, "tie_word_embeddings"): - state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] - else: - output_embeddings = state_dict.pop("output.weight") - # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings - # differently. - vocab_size = ( - math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) - * pad_vocab_size_multiple - ) - # It's possible that vocab_size is padded to be a multiple of 8, for example. - state_dict["lm_head.weight"] = F.pad( - output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) - ) - - # LayerNorm - def key_mapping_ln(key): - key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key) - key = re.sub( - r"^transformer.layers.(\d+).attention_norm.", - r"transformer.layers.\1.norm1.", - key, - ) - key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key) - return key - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - # MLP - for l in range(config.n_layer): - w1 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w1.weight") - w3 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w3.weight") - # Our ordering is different - state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0) - - def key_mapping_mlp(key): - return re.sub( - r"^transformer.layers.(\d+).feed_forward.w2.", - r"transformer.layers.\1.mlp.fc2.", - key, - ) - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # Attention - for l in range(config.n_layer): - Wq = state_dict.pop(f"transformer.layers.{l}.attention.wq.weight") - Wk = state_dict.pop(f"transformer.layers.{l}.attention.wk.weight") - Wv = state_dict.pop(f"transformer.layers.{l}.attention.wv.weight") - state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) - # We don't store these - state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None) - - def key_mapping_attn(key): - return re.sub( - r"^transformer.layers.(\d+).attention.wo.", - r"transformer.layers.\1.mixer.out_proj.", - key, - ) - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - - state_dict.pop("transformer.rope.freqs", None) - - return state_dict - - -def remap_state_dict_hf_llama( - state_dict: Dict[str, torch.Tensor], config: GPT2Config -) -> Dict[str, torch.Tensor]: - """Convert the state_dict in Hugging Face format to standard GPT format. - - This function modifies state_dict in place. - """ - - # Embedding - def key_mapping_emb(key): - return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key) - - state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = ( - math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple - ) - state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( - word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) - ) - - # LM head - if getattr(config, "tie_word_embeddings"): - state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] - else: - output_embeddings = state_dict.pop("lm_head.weight") - # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings - # differently. - vocab_size = ( - math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) - * pad_vocab_size_multiple - ) - # It's possible that vocab_size is padded to be a multiple of 8, for example. - state_dict["lm_head.weight"] = F.pad( - output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) - ) - - # MLP - for l in range(config.n_layer): - # Fusing weights this way based on difference in the following: - # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220 - # https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115 - w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight") - w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight") - state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0) - - def key_mapping_mlp(key): - return re.sub( - r"^model.layers.(\d+).mlp.down_proj.", - r"transformer.layers.\1.mlp.fc2.", - key, - ) - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # LayerNorm - def key_mapping_ln(key): - key = re.sub(r"^model.norm.", r"transformer.ln_f.", key) - key = re.sub( - r"^model.layers.(\d+).input_layernorm.", - r"transformer.layers.\1.norm1.", - key, - ) - key = re.sub( - r"^model.layers.(\d+).post_attention_layernorm.", - r"transformer.layers.\1.norm2.", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - def inv_permute(w): - # Inverse of permute implemented in: - # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114 - return rearrange( - w, "(h two d) n -> (h d two) n", d=config.n_embd // config.n_head // 2, two=2 - ) - - # Attention - for l in range(config.n_layer): - Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight") - Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight") - Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight") - - state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat( - [inv_permute(Wq), inv_permute(Wk), Wv], dim=0 - ) - # We don't store these - state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None) - - def key_mapping_attn(key): - return re.sub( - r"^model.layers.(\d+).self_attn.o_proj.", - r"transformer.layers.\1.mixer.out_proj.", - key, - ) - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - return state_dict - - -def inv_remap_state_dict_hf_llama( - state_dict: Dict[str, torch.Tensor], config: GPT2Config -) -> Dict[str, torch.Tensor]: - """Convert the state_dict in standard GPT format to Hugging Face format. - - This function is meant to be the inverse of remap_state_dict_hf_llama, up to a - multiplier pad in the embedding and lm_head. That is if the original embedding - isn't a multiple of pad_vocab_size_multiple, then - inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict. - - This function modifies state_dict in place. - """ - - # Embedding - def key_mapping_emb(key): - return re.sub(r"^transformer.embeddings.word_embeddings.", "model.embed_tokens.", key) - - state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop("model.embed_tokens.weight") - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = ( - math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple - ) - state_dict["model.embed_tokens.weight"] = F.pad( - word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) - ) - - # LM head - if getattr(config, "tie_word_embeddings"): - state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"] - else: - output_embeddings = state_dict.pop("lm_head.weight") - vocab_size = ( - math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) - * pad_vocab_size_multiple - ) - state_dict["lm_head.weight"] = F.pad( - output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) - ) - - # MLP - for l in range(config.n_layer): - w3, w1 = torch.chunk( - state_dict.pop(f"transformer.layers.{l}.mlp.fc1.weight"), chunks=2, dim=0 - ) - state_dict[f"model.layers.{l}.mlp.gate_proj.weight"] = w1 - state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3 - - def key_mapping_mlp(key): - return re.sub( - r"^transformer.layers.(\d+).mlp.fc2.", - r"model.layers.\1.mlp.down_proj.", - key, - ) - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # LayerNorm - def key_mapping_ln(key): - key = re.sub(r"^transformer.ln_f.", r"model.norm.", key) - key = re.sub( - r"^transformer.layers.(\d+).norm1.", - r"model.layers.\1.input_layernorm.", - key, - ) - key = re.sub( - r"^transformer.layers.(\d+).norm2.", - r"model.layers.\1.post_attention_layernorm.", - key, - ) - return key - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - def permute(w): - return rearrange( - w, "(h d two) n -> (h two d) n", d=config.n_embd // config.n_head // 2, two=2 - ) - - n_head = config.n_head - n_head_kv = getattr(config, "n_head_kv", n_head) - - embed_dim = config.hidden_size - head_dim = embed_dim // n_head - - q_dim = n_head * head_dim - k_dim = v_dim = n_head_kv * head_dim - - # Attention - for l in range(config.n_layer): - Wqkv = state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight") - Wq = Wqkv[:q_dim] - Wk = Wqkv[q_dim : q_dim + k_dim] - Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim] - state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq) - state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk) - state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv - state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None) - - def key_mapping_attn(key): - return re.sub( - r"^transformer.layers.(\d+).mixer.out_proj.", - r"model.layers.\1.self_attn.o_proj.", - key, - ) - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - return state_dict - - -def config_from_meta_checkpoint( - checkpoint_path: Union[str, os.PathLike], model_name: str -) -> LlamaConfig: - """Load a LlamaConfig from a checkpoint path.""" - with open(Path(checkpoint_path) / model_name / "params.json") as f: - params = json.load(f) - config = LlamaConfig( - hidden_size=params["dim"], - intermediate_size=None, - num_attention_heads=params["n_heads"], - num_hidden_layers=params["n_layers"], - rms_norm_eps=params["norm_eps"], - num_key_value_heads=params.get("n_kv_heads", None), - ) - multiple_of = params.get("multiple_of", 1) - ffn_dim_multiplier = params.get("ffn_dim_multiplier", None) - - # Compute the hidden dimension of the MLP - # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224 - intermediate_size = 4 * config.hidden_size - # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199 - intermediate_size = int(2 * intermediate_size / 3) - # custom dim factor multiplier - if ffn_dim_multiplier is not None: - intermediate_size = int(ffn_dim_multiplier * intermediate_size) - intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of) - - config.intermediate_size = intermediate_size - if "rope_theta" in params: - config.rotary_emb_base = params["rope_theta"] - config.vocab_size = 32000 - # some CodeLLaMa have vocab_size 32000, some 32016 - # Sadly it's not specified in the `params.json` file :( - tokenizer = Path(checkpoint_path) / model_name / "tokenizer.model" - if tokenizer.is_file(): - config.vocab_size = SentencePieceProcessor(str(tokenizer)).vocab_size() - return config - - -def config_from_hf_checkpoint( - checkpoint_path: Union[str, os.PathLike], model_name: str -) -> LlamaConfig: - return LlamaConfig.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf" / "config.json") - - -def config_from_checkpoint( - checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta" -) -> LlamaConfig: - if checkpoint_format == "meta": - return config_from_meta_checkpoint(checkpoint_path, model_name) - else: - return config_from_hf_checkpoint(checkpoint_path, model_name) - - -def state_dicts_from_checkpoint( - checkpoint_path: Union[str, os.PathLike], model_name: str -) -> List[dict]: - # Need to sort, otherwise we mess up the ordering and the weights are wrong - return [ - torch.load(path, map_location="cpu") - for path in sorted((Path(checkpoint_path) / model_name).glob("consolidated.*.pth")) - ] - - -def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config: - return GPT2Config( - vocab_size=llama_config.vocab_size, - n_positions=0, # No absolute position embedding - n_embd=llama_config.hidden_size, - n_layer=llama_config.num_hidden_layers, - n_head=llama_config.num_attention_heads, - n_inner=llama_config.intermediate_size, - activation_function="swiglu", # Hardcode since HF calls it 'silu' - # Llama doesn't have dropout, idk if it's because they only release the inference code - resid_pdrop=0.0, - embd_pdrop=0.0, - attn_pdrop=0.0, - layer_norm_epsilon=llama_config.rms_norm_eps, - initializer_range=llama_config.initializer_range, - bos_token_id=llama_config.bos_token_id, - eos_token_id=llama_config.eos_token_id, - # These are new arguments not in the original GPT2Config - pad_token_id=llama_config.pad_token_id, # Idk if this does anything - rms_norm=True, - rotary_emb_fraction=1.0, - rotary_emb_interleaved=True, - tie_word_embeddings=False, - qkv_proj_bias=False, - out_proj_bias=False, - mlp_fc1_bias=False, - mlp_fc2_bias=False, - rotary_emb_base=getattr(llama_config, "rotary_emb_base", 10000.0), - n_head_kv=llama_config.num_key_value_heads, - ) diff --git a/vllm/thirdparty_files/flash_attn/models/opt.py b/vllm/thirdparty_files/flash_attn/models/opt.py deleted file mode 100644 index 501f9eb6cf44..000000000000 --- a/vllm/thirdparty_files/flash_attn/models/opt.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -import math -import re -from collections import OrderedDict - -import torch -import torch.nn.functional as F -from transformers import GPT2Config, OPTConfig - - -def remap_state_dict_hf_opt(state_dict, config): - def key_mapping_model(key): - key = re.sub(r"^model.decoder.", "transformer.", key) - # The OPT-350m model uses '^decoder' instead of '^model.decoder' - key = re.sub(r"^decoder.", "transformer.", key) - return key - - state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items()) - # Word embedding and position embedding - def key_mapping_emb(key): - key = re.sub(r"^transformer.embed_tokens.", "transformer.embeddings.word_embeddings.", key) - # The OPT-350m model uses has project_in and project_out - key = re.sub(r"^transformer.project_in.", "transformer.embeddings.project_in.", key) - key = re.sub(r"^transformer.project_out.", "project_out.", key) - key = re.sub( - r"^transformer.embed_positions.", "transformer.embeddings.position_embeddings.", key - ) - return key - - state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) - # OPT uses the first 2 indices of pos_emb for padding tokens - pos_embeddings = state_dict.pop("transformer.embeddings.position_embeddings.weight") - state_dict["transformer.embeddings.position_embeddings.weight"] = pos_embeddings[2:] - word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") - # It's possible that vocab_size is padded to be a multiple of 8, for example. - pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) - vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple - state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( - word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) - ) - state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] - - # LayerNorm - def key_mapping_ln(key): - key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key) - # The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm' - key = re.sub(r"^transformer.layer_norm.", r"transformer.ln_f.", key) - key = re.sub( - r"^transformer.layers.(\d+).self_attn_layer_norm.", r"transformer.layers.\1.norm1.", key - ) - key = re.sub( - r"^transformer.layers.(\d+).final_layer_norm.", r"transformer.layers.\1.norm2.", key - ) - return key - - state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - - # MLP - def key_mapping_mlp(key): - return re.sub( - r"^transformer.layers.(\d+).fc(1|2).", r"transformer.layers.\1.mlp.fc\2.", key - ) - - state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) - - # Attention - for l in range(config.n_layer): - Wq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.weight") - Wk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.weight") - Wv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.weight") - bq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.bias") - bk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.bias") - bv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.bias") - state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) - state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0) - - def key_mapping_attn(key): - return re.sub( - r"^transformer.layers.(\d+).self_attn.out_proj.", - r"transformer.layers.\1.mixer.out_proj.", - key, - ) - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - - return state_dict - - -def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config: - assert opt_config.layerdrop == 0.0 - assert opt_config.layer_norm_elementwise_affine - word_embed_proj_dim = ( - None - if opt_config.word_embed_proj_dim == opt_config.hidden_size - else opt_config.word_embed_proj_dim - ) - return GPT2Config( - vocab_size=opt_config.vocab_size, - n_positions=opt_config.max_position_embeddings, - n_embd=opt_config.hidden_size, - n_layer=opt_config.num_hidden_layers, - n_head=opt_config.num_attention_heads, - n_inner=opt_config.ffn_dim, - activation_function=opt_config.activation_function, - resid_pdrop=opt_config.dropout, - # HF's implementation of OPT doesn't seem to have embedding dropout - embd_pdrop=opt_config.dropout, - attn_pdrop=opt_config.attention_dropout, - initializer_range=opt_config.init_std, - bos_token_id=opt_config.bos_token_id, - eos_token_id=opt_config.eos_token_id, - # These are new arguments not in the original GPT2Config - prenorm=opt_config.do_layer_norm_before, - word_embed_proj_dim=word_embed_proj_dim, - ) diff --git a/vllm/thirdparty_files/flash_attn/models/vit.py b/vllm/thirdparty_files/flash_attn/models/vit.py deleted file mode 100644 index 4602fd7414d2..000000000000 --- a/vllm/thirdparty_files/flash_attn/models/vit.py +++ /dev/null @@ -1,373 +0,0 @@ -# Copyright (c) 2022, Tri Dao. -# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -import math -import re -from collections import OrderedDict -from copy import deepcopy -from functools import partial - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from timm.models.helpers import named_apply -from torch.nn.init import trunc_normal_ -from torchvision.ops import StochasticDepth - -from flash_attn.layers.patch_embed import PatchEmbed -from flash_attn.modules.block import Block -from flash_attn.modules.mha import MHA -from flash_attn.modules.mlp import FusedMLP, Mlp - -try: - from flash_attn.ops.triton.layer_norm import layer_norm_fn -except ImportError: - layer_norm_fn = None - - -def create_mixer_cls( - num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, cross_attn=False -): - mixer_cls = partial( - MHA, - num_heads=num_heads, - cross_attn=cross_attn, - qkv_proj_bias=qkv_bias, - dropout=attn_drop, - fused_bias_fc=fused_bias_fc, - use_flash_attn=use_flash_attn, - ) - return mixer_cls - - -def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp): - inner_dim = int(embed_dim * mlp_ratio) - if not fused_mlp: - mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer()) - else: - mlp_cls = partial(FusedMLP, hidden_features=inner_dim) - return mlp_cls - - -def create_block( - embed_dim, - num_heads, - mlp_ratio, - qkv_bias, - drop_rate, - attn_drop_rate, - drop_path1, - drop_path2, - norm_layer, - act_layer, - use_flash_attn, - fused_bias_fc, - fused_mlp, - fused_dropout_add_ln, - layer_idx=None, - n_layer=None, - last_layer_subset=False, -): - mixer_cls = create_mixer_cls( - num_heads, - qkv_bias, - attn_drop_rate, - use_flash_attn, - fused_bias_fc, - cross_attn=(last_layer_subset and layer_idx == n_layer - 1), - ) - mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp) - # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed - block = Block( - embed_dim, - mixer_cls, - mlp_cls, - norm_cls=norm_layer, - prenorm=True, - resid_dropout1=drop_rate, - resid_dropout2=drop_rate, - drop_path1=drop_path1, - drop_path2=drop_path2, - fused_dropout_add_ln=fused_dropout_add_ln, - residual_in_fp32=True, - ) - return block - - -class VisionTransformer(nn.Module): - """Vision Transformer - A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - - https://arxiv.org/abs/2010.11929 - """ - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - global_pool="token", - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=True, - init_values=None, - class_token=True, - no_embed_class=False, - pre_norm=False, - fc_norm=None, - drop_rate=0.0, - attn_drop_rate=0.0, - drop_path_rate=0.0, - weight_init="", - embed_layer=PatchEmbed, - norm_layer=None, - act_layer=None, - use_flash_attn=False, - fused_bias_fc=False, - fused_mlp=False, - fused_dropout_add_ln=False, - ): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - num_classes (int): number of classes for classification head - global_pool (str): type of global pooling for final sequence (default: 'token') - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - init_values: (float): layer-scale init values - class_token (bool): use class token - fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) - drop_rate (float): dropout rate - attn_drop_rate (float): attention dropout rate - drop_path_rate (float): stochastic depth rate - weight_init (str): weight init scheme - embed_layer (nn.Module): patch embedding layer - norm_layer: (nn.Module): normalization layer - act_layer: (nn.Module): MLP activation layer - """ - super().__init__() - assert global_pool == "token", "Only support pooling with CLS token" - assert class_token - assert init_values is None, "LayerScale is not supported yet" - assert weight_init == "" - assert fc_norm is None - # pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk - assert not pre_norm - use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - act_layer = act_layer or nn.GELU - - self.num_classes = num_classes - self.global_pool = global_pool - self.num_features = ( - self.embed_dim - ) = embed_dim # num_features for consistency with other models - self.num_prefix_tokens = 1 if class_token else 0 - self.no_embed_class = no_embed_class - - patch_embed_extra_kwargs = ( - {"fused_bias_fc": fused_bias_fc} if embed_layer is PatchEmbed else {} - ) - self.patch_embed = embed_layer( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) - **patch_embed_extra_kwargs, - ) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None - embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens - self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) - - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, depth) - ] # stochastic depth decay rule - - # We change the order of dropout, residual and layer norm: - # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: - # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and - # the main branch (output of MLP). The model definition is unchanged, but the mapping of the - # nn.Dropout probabilities are changed. - # This is for performance reason: we can fuse dropout + add + layer_norm. - self.blocks = nn.ModuleList( - [ - create_block( - embed_dim, - num_heads, - mlp_ratio, - qkv_bias, - drop_rate, - attn_drop_rate, - drop_path1=dpr[i - 1] if i > 0 else 0.0, - drop_path2=dpr[i], - norm_layer=norm_layer, - act_layer=act_layer, - use_flash_attn=use_flash_attn, - fused_bias_fc=fused_bias_fc, - fused_mlp=fused_mlp, - fused_dropout_add_ln=fused_dropout_add_ln, - layer_idx=i, - n_layer=depth, - last_layer_subset=(global_pool == "token"), - ) - for i in range(depth) - ] - ) - - self.dropout = nn.Dropout(p=drop_rate) - self.drop_path = StochasticDepth(p=dpr[-1], mode="row") - self.norm = norm_layer(embed_dim) - - self.fused_dropout_add_ln = fused_dropout_add_ln - if self.fused_dropout_add_ln and layer_norm_fn is None: - raise ImportError("Triton is not installed") - - # Classifier Head - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - - self.init_weights(weight_init) - - def init_weights(self, mode=""): - assert mode == "" - trunc_normal_(self.pos_embed, std=0.02) - if self.cls_token is not None: - nn.init.normal_(self.cls_token, std=1e-6) - named_apply(init_weights_vit_timm, self) - - def _init_weights(self, m): - # this fn left here for compat with downstream users - init_weights_vit_timm(m) - - @torch.jit.ignore - def no_weight_decay(self): - return {"pos_embed", "cls_token"} - - def _pos_embed(self, x): - if self.no_embed_class: - # deit-3, updated JAX (big vision) - # position embedding does not overlap with class token, add then concat - x = x + self.pos_embed - if self.cls_token is not None: - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - else: - # original timm, JAX, and deit vit impl - # pos_embed has entry for class token, concat then add - if self.cls_token is not None: - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = x + self.pos_embed - return x - - def forward_features(self, x, all_tokens=True): - """ - If all_tokens==False and self.global_pool == 'token', we only return the features for the - cls token. - """ - x = self.patch_embed(x) - hidden_states = self._pos_embed(x) - residual = None - if self.global_pool != "token" or all_tokens: - # if True: - for block in self.blocks: - hidden_states, residual = block(hidden_states, residual) - else: - for block in self.blocks[:-1]: - hidden_states, residual = block(hidden_states, residual) - # For the last layer, we only want the 1st token of the output. So we do cross-attention - # where the query is the 1st token and the key/value is the whole sequence. - hidden_states, residual = self.blocks[-1]( - hidden_states, residual, mixer_subset=slice(0, 1) - ) - if not self.fused_dropout_add_ln: - residual = self.drop_path(self.dropout(hidden_states)) + residual - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - else: - if self.drop_path.p == 0 or not self.training: - rowscale = None - else: - rowscale = self.drop_path( - torch.ones( - hidden_states.shape[:-1], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - ) - # Set prenorm=False here since we don't need to the residual - hidden_states = layer_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - eps=self.norm.eps, - dropout_p=self.dropout.p if self.training else 0.0, - rowscale=rowscale, - prenorm=False, - ) - return hidden_states - - def forward_head(self, x, pre_logits: bool = False): - if self.global_pool: - x = x[:, self.num_prefix_tokens :].mean(dim=1) if self.global_pool == "avg" else x[:, 0] - return x if pre_logits else self.head(x) - - def forward(self, x): - x = self.forward_features(x, all_tokens=False) - x = self.forward_head(x) - return x - - def load_state_dict(self, state_dict, strict=True): - patch_embed_weight = state_dict["patch_embed.proj.weight"] - if patch_embed_weight.dim() == 4: - # convert from Conv2d to Linear - state_dict["patch_embed.proj.weight"] = rearrange( - patch_embed_weight, "o c h w -> o (c h w)" - ) - - def key_mapping_attn(key): - key = re.sub(r"^blocks.(\d+).attn.qkv.", r"blocks.\1.mixer.Wqkv.", key) - key = re.sub(r"^blocks.(\d+).attn.proj.", r"blocks.\1.mixer.out_proj.", key) - return key - - state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) - n_layer = len(self.blocks) - # Convert from Wqkv to Wq and Wkv for cross attention (last layer) - if ( - self.blocks[-1].mixer.cross_attn - and f"blocks.{n_layer - 1}.mixer.Wqkv.weight" in state_dict - ): - Wqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.weight") - bqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.bias") - state_dict[f"blocks.{n_layer - 1}.mixer.Wq.weight"] = Wqkv[: self.embed_dim] - state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.weight"] = Wqkv[self.embed_dim :] - state_dict[f"blocks.{n_layer - 1}.mixer.Wq.bias"] = bqkv[: self.embed_dim] - state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.bias"] = bqkv[self.embed_dim :] - return super().load_state_dict(state_dict, strict=strict) - - -def init_weights_vit_timm(module: nn.Module, name: str = ""): - """ViT weight initialization, original timm impl (for reproducibility)""" - if isinstance(module, nn.Linear): - trunc_normal_(module.weight, std=0.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif hasattr(module, "init_weights"): - module.init_weights() - - -def vit_base_patch16_224(pretrained=False, **kwargs): - """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. - """ - assert not pretrained - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = VisionTransformer(**model_kwargs) - return model diff --git a/vllm/thirdparty_files/flash_attn/modules/__init__.py b/vllm/thirdparty_files/flash_attn/modules/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/thirdparty_files/flash_attn/modules/block.py b/vllm/thirdparty_files/flash_attn/modules/block.py deleted file mode 100644 index be8e8b864b60..000000000000 --- a/vllm/thirdparty_files/flash_attn/modules/block.py +++ /dev/null @@ -1,397 +0,0 @@ -# Copyright (c) 2024, Tri Dao. - -from functools import partial -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torchvision.ops import StochasticDepth - -from flash_attn.modules.mha import MHA -from flash_attn.modules.mlp import Mlp - -try: - from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm -except ImportError: - layer_norm_fn, RMSNorm = None, None - - -class Block(nn.Module): - def __init__( - self, - dim, - mixer_cls=None, - mlp_cls=None, - norm_cls=nn.LayerNorm, - dropout_cls=nn.Dropout, - prenorm=True, - resid_dropout1=0.0, - resid_dropout2=0.0, - drop_path1=0.0, - drop_path2=0.0, - fused_dropout_add_ln=False, - return_residual=False, - residual_in_fp32=False, - sequence_parallel=False, - mark_shared_params=False, - ): - """ - For prenorm=True, this Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both - the hidden_states (output of the MLP) and the residual. - This is for performance reasons, as we can fuse the dropout, add and LayerNorm. - The residual needs to be provided (except for the very first block). - - For prenorm=False, this Block has the same structure as a regular postnorm Transformer - block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. - - return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. - This is for performance reason: for post-norm architecture, returning the input allows us - to fuse the backward of nn.Linear with the residual connection. - """ - super().__init__() - self.prenorm = prenorm - self.fused_dropout_add_ln = fused_dropout_add_ln - self.return_residual = return_residual - self.residual_in_fp32 = residual_in_fp32 - if self.residual_in_fp32: - assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True" - if mixer_cls is None: - mixer_cls = partial(MHA, num_heads=dim // 64) - if mlp_cls is None: - mlp_cls = partial(Mlp, hidden_features=4 * dim) - self.mixer = mixer_cls(dim) - self.dropout1 = dropout_cls(resid_dropout1) - self.drop_path1 = StochasticDepth(drop_path1, mode="row") - self.norm1 = norm_cls(dim) - self.mlp = mlp_cls(dim) - if not isinstance(self.mlp, nn.Identity): - self.dropout2 = dropout_cls(resid_dropout2) - self.drop_path2 = StochasticDepth(drop_path2, mode="row") - self.norm2 = norm_cls(dim) - - if self.fused_dropout_add_ln: - assert layer_norm_fn is not None, "Triton is not installed" - assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( - self.dropout1, nn.Dropout - ) - - # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, - # then the input to each worker in the tensor parallel group will be different. - # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. - # For now this is not an issue because we always use sequence_parallel=True during training - # and only use sequence_parallel=False during inference. - - # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. - if sequence_parallel: - for p in self.norm1.parameters(): - p._sequence_parallel = True - if hasattr(self, "norm2"): - for p in self.norm2.parameters(): - p._sequence_parallel = True - # Mark the norm parameters as "shared_params" so that we sync their values at init. - if mark_shared_params: - for p in self.norm1.parameters(): - p._shared_params = True - if hasattr(self, "norm2"): - for p in self.norm2.parameters(): - p._shared_params = True - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - - def forward( - self, - hidden_states: Tensor, - residual: Optional[Tensor] = None, - mixer_subset=None, - mixer_kwargs=None, - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) - mixer_subset: for cross-attention only. If not None, will take a subset of x - before applying the query projection. Useful for e.g., ViT where we only care - about the CLS token in the last layer. - """ - if self.prenorm: - if not self.fused_dropout_add_ln: - dropped = self.drop_path1(self.dropout1(hidden_states)) - residual = (dropped + residual) if residual is not None else dropped - hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - if self.drop_path1.p == 0 or not self.training: - rowscale1 = None - else: - rowscale1 = self.drop_path1( - torch.ones( - hidden_states.shape[:-1], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - ) - hidden_states, residual = layer_norm_fn( - hidden_states, - self.norm1.weight, - self.norm1.bias, - residual=residual, - eps=self.norm1.eps, - dropout_p=self.dropout1.p if self.training else 0.0, - rowscale=rowscale1, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - is_rms_norm=isinstance(self.norm1, RMSNorm) - ) - if mixer_kwargs is None: - mixer_kwargs = {} - if mixer_subset is not None: - mixer_kwargs["mixer_subset"] = mixer_subset - hidden_states = self.mixer(hidden_states, **mixer_kwargs) - if mixer_subset is not None: - residual = residual[:, mixer_subset] - if not isinstance(self.mlp, nn.Identity): - if not self.fused_dropout_add_ln: - dropped = self.drop_path2(self.dropout2(hidden_states)) - residual = (dropped + residual) if residual is not None else dropped - hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - if self.drop_path2.p == 0 or not self.training: - rowscale2 = None - else: - rowscale2 = self.drop_path2( - torch.ones( - hidden_states.shape[:-1], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - ) - hidden_states, residual = layer_norm_fn( - hidden_states, - self.norm2.weight, - self.norm2.bias, - residual=residual, - eps=self.norm2.eps, - dropout_p=self.dropout2.p if self.training else 0.0, - rowscale=rowscale2, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - is_rms_norm=isinstance(self.norm2, RMSNorm) - ) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - else: - assert residual is None - mixer_out = self.mixer( - hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {}) - ) - if self.return_residual: # mixer out is actually a pair here - mixer_out, hidden_states = mixer_out - if not self.fused_dropout_add_ln: - hidden_states = self.norm1( - (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to( - dtype=self.norm1.weight.dtype - ) - ) - else: - if self.drop_path1.p == 0 or not self.training: - rowscale1 = None - else: - rowscale1 = self.drop_path1( - torch.ones( - mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype - ) - ) - hidden_states = layer_norm_fn( - mixer_out, - self.norm1.weight, - self.norm1.bias, - residual=hidden_states, - eps=self.norm1.eps, - dropout_p=self.dropout1.p if self.training else 0.0, - rowscale=rowscale1, - prenorm=False, - is_rms_norm=isinstance(self.norm1, RMSNorm) - ) - if not isinstance(self.mlp, nn.Identity): - mlp_out = self.mlp(hidden_states) - if self.return_residual: # mlp out is actually a pair here - mlp_out, hidden_states = mlp_out - if not self.fused_dropout_add_ln: - hidden_states = self.norm2( - (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to( - dtype=self.norm2.weight.dtype - ) - ) - else: - if self.drop_path2.p == 0 or not self.training: - rowscale2 = None - else: - rowscale2 = self.drop_path2( - torch.ones( - mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype - ) - ) - hidden_states = layer_norm_fn( - mlp_out, - self.norm2.weight, - self.norm2.bias, - residual=hidden_states, - eps=self.norm2.eps, - dropout_p=self.dropout2.p if self.training else 0.0, - rowscale=rowscale2, - prenorm=False, - is_rms_norm=isinstance(self.norm2, RMSNorm) - ) - return hidden_states - - -class ParallelBlock(nn.Module): - """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX, - and PaLM. - """ - - def __init__( - self, - dim, - mixer_cls=None, - mlp_cls=None, - norm_cls=nn.LayerNorm, - dropout_cls=nn.Dropout, - resid_dropout1=0.0, - resid_dropout2=0.0, - tied_norm=False, - fused_dropout_add_ln=False, - residual_in_fp32=False, - sequence_parallel=False, - mark_shared_params=False, - ): - """ - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA / MLP -> Dropout -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both - the hidden_states (output1 of the MHA / MLP) and the residual. - This is for performance reasons, as we can fuse the dropout, add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.tied_norm = tied_norm - self.fused_dropout_add_ln = fused_dropout_add_ln - self.residual_in_fp32 = residual_in_fp32 - if mixer_cls is None: - mixer_cls = partial(MHA, num_heads=dim // 64) - if mlp_cls is None: - mlp_cls = partial(Mlp, hidden_features=4 * dim) - self.mixer = mixer_cls(dim) - self.dropout1 = dropout_cls(resid_dropout1) - self.norm1 = norm_cls(dim) - self.mlp = mlp_cls(dim) - self.dropout2 = dropout_cls(resid_dropout2) - if not self.tied_norm: - self.norm2 = norm_cls(dim) - - if self.fused_dropout_add_ln: - assert layer_norm_fn is not None, "Triton is not installed" - assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( - self.dropout1, nn.Dropout - ) - - # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, - # then the input to each worker in the tensor parallel group will be different. - # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. - # For now this is not an issue because we always use sequence_parallel=True during training - # and only use sequence_parallel=False during inference. - - # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. - if sequence_parallel: - for p in self.norm1.parameters(): - p._sequence_parallel = True - if hasattr(self, "norm2"): - for p in self.norm2.parameters(): - p._sequence_parallel = True - # Mark the norm parameters as "shared_params" so that we sync their values at init. - if mark_shared_params: - for p in self.norm1.parameters(): - p._shared_params = True - if hasattr(self, "norm2"): - for p in self.norm2.parameters(): - p._shared_params = True - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - - def forward( - self, - hidden_states1: Tensor, - hidden_states2: Optional[Tensor] = None, - residual: Optional[Tensor] = None, - mixer_kwargs=None, - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states1: the output of the previous attention (mixer) or embedding layer. - hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1). - residual. - """ - # TODO: Ideally we should only do the allgather / allreduce once for - # the Linear to MLP & Attention - if not self.fused_dropout_add_ln: - dropped1 = self.dropout1(hidden_states1) - # For the very 1st block, we only want 1 dropout, not two different dropouts - if hidden_states2 is not None: - dropped2 = self.dropout2(hidden_states2) - residual = ( - (residual + dropped1 + dropped2) - if residual is not None - else dropped1 + dropped2 - ) - else: - residual = (residual + dropped1) if residual is not None else dropped1 - hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) - hidden_states2 = ( - self.norm2(residual.to(dtype=self.norm2.weight.dtype)) - if not self.tied_norm - else hidden_states1 - ) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - weight2, bias2 = ( - (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None) - ) - hidden_states1, *rest, residual = layer_norm_fn( - hidden_states1, - self.norm1.weight, - self.norm1.bias, - residual=residual, - x1=hidden_states2, - weight1=weight2, - bias1=bias2, - eps=self.norm1.eps, - dropout_p=self.dropout1.p if self.training else 0.0, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - is_rms_norm=isinstance(self.norm1, RMSNorm) - ) - if self.tied_norm: - hidden_states2 = hidden_states1 - else: - hidden_states2, = rest - if mixer_kwargs is None: - mixer_kwargs = {} - hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs) - hidden_states2 = self.mlp(hidden_states2) - return hidden_states1, hidden_states2, residual diff --git a/vllm/thirdparty_files/flash_attn/modules/embedding.py b/vllm/thirdparty_files/flash_attn/modules/embedding.py deleted file mode 100644 index 33587d09413d..000000000000 --- a/vllm/thirdparty_files/flash_attn/modules/embedding.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) 2022, Tri Dao. - -import torch -import torch.nn as nn -from einops import rearrange -from torch import Tensor - -from flash_attn.utils.distributed import all_reduce, reduce_scatter - - -class GPT2Embeddings(nn.Module): - def __init__( - self, - embed_dim, - vocab_size, - max_position_embeddings, - padding_idx=None, - word_embed_proj_dim=None, - device=None, - dtype=None, - ): - """ - If max_position_embeddings <= 0, there's no position embeddings - If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension - the project up to embed_dim - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - if word_embed_proj_dim is None: - self.word_embeddings = nn.Embedding( - vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs - ) - self.project_in = None - else: - self.word_embeddings = nn.Embedding( - vocab_size, word_embed_proj_dim, padding_idx=padding_idx, **factory_kwargs - ) - self.project_in = nn.Linear( - word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs - ) - self.max_position_embeddings = max_position_embeddings - if self.max_position_embeddings > 0: - self.position_embeddings = nn.Embedding( - max_position_embeddings, embed_dim, **factory_kwargs - ) - - def forward(self, input_ids, position_ids=None): - """ - input_ids: (batch, seqlen) - position_ids: (batch, seqlen) - """ - batch_size, seqlen = input_ids.shape - embeddings = self.word_embeddings(input_ids) - if self.project_in is not None: - embeddings = self.project_in(embeddings) - if self.max_position_embeddings > 0: - if position_ids is None: - position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) - position_embeddings = self.position_embeddings(position_ids) - embeddings = embeddings + position_embeddings - return embeddings - - -class BertEmbeddings(nn.Module): - def __init__( - self, - embed_dim, - vocab_size, - max_position_embeddings, - type_vocab_size, - padding_idx=None, - device=None, - dtype=None, - ): - """ - If max_position_embeddings <= 0, there's no position embeddings - If type_vocab_size <= 0, there's no token type embeddings - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.word_embeddings = nn.Embedding( - vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs - ) - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - if self.max_position_embeddings > 0: - self.position_embeddings = nn.Embedding( - max_position_embeddings, embed_dim, **factory_kwargs - ) - if self.type_vocab_size > 0: - self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs) - - def forward(self, input_ids, position_ids=None, token_type_ids=None): - """ - input_ids: (batch, seqlen) - position_ids: (batch, seqlen) - token_type_ids: (batch, seqlen) - """ - batch_size, seqlen = input_ids.shape - embeddings = self.word_embeddings(input_ids) - if self.max_position_embeddings > 0: - if position_ids is None: - position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) - position_embeddings = self.position_embeddings(position_ids) - embeddings = embeddings + position_embeddings - if self.type_vocab_size > 0: - if token_type_ids is None: - token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = embeddings + token_type_embeddings - return embeddings - - -class VocabParallelEmbedding(nn.Embedding): - def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs): - self.process_group = process_group - if process_group is not None: - world_size = torch.distributed.get_world_size(process_group) - if num_embeddings % world_size != 0: - raise ValueError( - f"num_embeddings ({num_embeddings}) must be divisible by " - f"world_size ({world_size})" - ) - if world_size > 1 and padding_idx is not None: - raise RuntimeError("ParallelEmbedding does not support padding_idx") - else: - world_size = 1 - super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs) - - def forward(self, input: Tensor) -> Tensor: - if self.process_group is None: - return super().forward(input) - else: - rank = torch.distributed.get_rank(self.process_group) - vocab_size = self.num_embeddings - vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size - # Create a mask of valid vocab ids (1 means it needs to be masked). - input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index) - input = input - vocab_start_index - input[input_ids_mask] = 0 - embeddings = super().forward(input) - embeddings[input_ids_mask] = 0.0 - return embeddings - - -class ColumnParallelEmbedding(nn.Embedding): - def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs): - self.process_group = process_group - if process_group is not None: - world_size = torch.distributed.get_world_size(process_group) - if embedding_dim % world_size != 0: - raise ValueError( - f"embedding_dim ({embedding_dim}) must be divisible by " - f"world_size ({world_size})" - ) - else: - world_size = 1 - super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs) - - -class ParallelGPT2Embeddings(nn.Module): - def __init__( - self, - embed_dim, - vocab_size, - max_position_embeddings, - process_group, - padding_idx=None, - sequence_parallel=True, - device=None, - dtype=None, - ): - """ - If max_position_embeddings <= 0, there's no position embeddings - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.process_group = process_group - self.sequence_parallel = sequence_parallel - self.word_embeddings = VocabParallelEmbedding( - vocab_size, - embed_dim, - padding_idx=padding_idx, - process_group=process_group, - **factory_kwargs, - ) - self.max_position_embeddings = max_position_embeddings - if self.max_position_embeddings > 0: - self.position_embeddings = ColumnParallelEmbedding( - max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs - ) - - def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False): - """ - input_ids: (batch, seqlen) - position_ids: (batch, seqlen) - """ - batch_size, seqlen = input_ids.shape - world_size = torch.distributed.get_world_size(self.process_group) - embeddings = self.word_embeddings(input_ids) - if self.max_position_embeddings > 0: - if position_ids is None: - position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) - position_embeddings = self.position_embeddings(position_ids) - if world_size <= 1: - embeddings = embeddings + position_embeddings - else: - partition_dim = self.position_embeddings.embedding_dim - rank = torch.distributed.get_rank(self.process_group) - embeddings[ - ..., rank * partition_dim : (rank + 1) * partition_dim - ] += position_embeddings - if combine_batch_seqlen_dim: - embeddings = rearrange(embeddings, "b s d -> (b s) d") - reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce - return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group) diff --git a/vllm/thirdparty_files/flash_attn/modules/mha.py b/vllm/thirdparty_files/flash_attn/modules/mha.py deleted file mode 100644 index 89c7680d5257..000000000000 --- a/vllm/thirdparty_files/flash_attn/modules/mha.py +++ /dev/null @@ -1,1016 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -import math -from functools import partial - -import torch -import torch.nn as nn -from einops import rearrange, repeat - -from flash_attn.utils.distributed import get_dim_for_local_rank - -try: - from flash_attn import ( - flash_attn_kvpacked_func, - flash_attn_qkvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, - flash_attn_with_kvcache, - ) -except ImportError: - flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None - flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None - flash_attn_with_kvcache = None - -try: - from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear -except ImportError: - FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None - -try: - from flash_attn.layers.rotary import RotaryEmbedding -except ImportError: - RotaryEmbedding = None - - -# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 -def get_alibi_slopes(nheads): - def get_slopes_power_of_2(nheads): - start = 2 ** (-(2 ** -(math.log2(nheads) - 3))) - ratio = start - return [start * ratio**i for i in range(nheads)] - - if math.log2(nheads).is_integer(): - return get_slopes_power_of_2(nheads) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(nheads)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2] - ) - - -class FlashSelfAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__( - self, - causal=False, - softmax_scale=None, - attention_dropout=0.0, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - ): - super().__init__() - assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed" - assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed" - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) - self.window_size = window_size - self.deterministic = deterministic - - def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): - """Implements the multihead softmax attention. - Arguments - --------- - qkv: The tensor containing the query, key, and value. - If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). - If cu_seqlens is not None and max_seqlen is not None, then qkv has shape - (total, 3, H, D), where total is the sum of the sequence lengths in the batch. - causal: if passed, will override self.causal - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into qkv. - max_seqlen: int. Maximum sequence length in the batch. - Returns: - -------- - out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, - else (B, S, H, D). - """ - assert qkv.dtype in [torch.float16, torch.bfloat16] - assert qkv.is_cuda - causal = self.causal if causal is None else causal - unpadded = cu_seqlens is not None - if unpadded: - assert cu_seqlens.dtype == torch.int32 - assert max_seqlen is not None - assert isinstance(max_seqlen, int) - return flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - self.drop.p if self.training else 0.0, - softmax_scale=self.softmax_scale, - causal=causal, - alibi_slopes=self.alibi_slopes, - window_size=self.window_size, - deterministic=self.deterministic, - ) - else: - return flash_attn_qkvpacked_func( - qkv, - self.drop.p if self.training else 0.0, - softmax_scale=self.softmax_scale, - causal=causal, - alibi_slopes=self.alibi_slopes, - window_size=self.window_size, - deterministic=self.deterministic, - ) - - -class FlashCrossAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__( - self, - causal=False, - softmax_scale=None, - attention_dropout=0.0, - alibi_slopes=None, - window_size=(-1, -1), - deterministic=False, - ): - super().__init__() - assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed" - assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed" - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) - self.window_size = window_size - self.deterministic = deterministic - - def forward( - self, - q, - kv, - causal=None, - cu_seqlens=None, - max_seqlen=None, - cu_seqlens_k=None, - max_seqlen_k=None, - ): - """Implements the multihead softmax attention. - Arguments - --------- - q: The tensor containing the query. (B, Sq, H, D) - kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) - causal: if passed, will override self.causal - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - max_seqlen: int. Maximum sequence length in the batch of q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_k: int. Maximum sequence length in the batch of k and v. - """ - assert q.dtype in [torch.float16, torch.bfloat16] - assert q.is_cuda and kv.is_cuda - causal = self.causal if causal is None else causal - unpadded = cu_seqlens is not None - if unpadded: - assert cu_seqlens.dtype == torch.int32 - assert max_seqlen is not None - assert isinstance(max_seqlen, int) - assert cu_seqlens_k is not None - assert cu_seqlens_k.dtype == torch.int32 - assert max_seqlen_k is not None - assert isinstance(max_seqlen, int) - return flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens, - cu_seqlens_k, - max_seqlen, - max_seqlen_k, - self.drop.p if self.training else 0.0, - softmax_scale=self.softmax_scale, - causal=causal, - alibi_slopes=self.alibi_slopes, - window_size=self.window_size, - deterministic=self.deterministic, - ) - else: - batch_size, seqlen_q = q.shape[0], q.shape[1] - seqlen_k = kv.shape[1] - assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] - return flash_attn_kvpacked_func( - q, - kv, - self.drop.p if self.training else 0.0, - causal=causal, - softmax_scale=self.softmax_scale, - alibi_slopes=self.alibi_slopes, - window_size=self.window_size, - deterministic=self.deterministic, - ) - - -class SelfAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - def forward(self, qkv, causal=None, key_padding_mask=None): - """Implements the multihead softmax attention. - Arguments - --------- - qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) - causal: if passed, will override self.causal - key_padding_mask: boolean mask to apply to the attention weights. True means to keep, - False means to mask out. (B, S) - """ - batch_size, seqlen = qkv.shape[0], qkv.shape[1] - causal = self.causal if causal is None else causal - q, k, v = qkv.unbind(dim=2) - softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) - if key_padding_mask is not None: - padding_mask = torch.full( - (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device - ) - padding_mask.masked_fill_(key_padding_mask, 0.0) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") - if causal: - # "triu_tril_cuda_template" not implemented for 'BFloat16' - # So we have to construct the mask in float - causal_mask = torch.triu( - torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1 - ) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + causal_mask.to(dtype=scores.dtype) - attention = torch.softmax(scores, dim=-1, dtype=v.dtype) - attention_drop = self.drop(attention) - output = torch.einsum("bhts,bshd->bthd", attention_drop, v) - return output - - -class CrossAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - def forward(self, q, kv, causal=None, key_padding_mask=None): - """Implements the multihead softmax attention. - Arguments - --------- - q: The tensor containing the query. (B, Sq, H, D) - kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) - causal: if passed, will override self.causal - key_padding_mask: boolean mask to apply to the attention weights. True means to keep, - False means to mask out. (B, Sk) - """ - batch_size, seqlen_q = q.shape[0], q.shape[1] - causal = self.causal if causal is None else causal - seqlen_k = kv.shape[1] - assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] - if kv.shape[3] != q.shape[2]: # MQA/GQA - kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) - k, v = kv.unbind(dim=2) - softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) - if key_padding_mask is not None: - padding_mask = torch.full( - (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device - ) - padding_mask.masked_fill_(key_padding_mask, 0.0) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") - if causal: - # causal mask needs to take into account the difference between seqlen_q and seqlen_k - row_idx = rearrange( - torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1" - ) - col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long) - sk = ( - seqlen_k - if key_padding_mask is None - else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") - ) - causal_mask = col_idx > row_idx + sk - seqlen_q - scores = scores.masked_fill(causal_mask, -10000.0) - attention = torch.softmax(scores, dim=-1, dtype=v.dtype) - attention_drop = self.drop(attention) - output = torch.einsum("bhts,bshd->bthd", attention_drop, v) - return output - - -class LinearResidual(nn.Linear): - """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.""" - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return super().forward(input), input - - -def _update_kv_cache(kv, inference_params, layer_idx): - """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" - # Pre-allocate memory for key-values for inference. - num_heads, head_dim = kv.shape[-2:] - if layer_idx not in inference_params.key_value_memory_dict: - kv_cache = torch.empty( - inference_params.max_batch_size, - inference_params.max_seqlen, - 2, - num_heads, - head_dim, - dtype=kv.dtype, - device=kv.device, - ) - inference_params.key_value_memory_dict[layer_idx] = kv_cache - else: - kv_cache = inference_params.key_value_memory_dict[layer_idx] - # Adjust key and value for inference - batch_start = inference_params.batch_size_offset - batch_end = batch_start + kv.shape[0] - sequence_start = inference_params.seqlen_offset - sequence_end = sequence_start + kv.shape[1] - assert batch_end <= kv_cache.shape[0] - assert sequence_end <= kv_cache.shape[1] - assert kv_cache is not None - kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - return kv_cache[batch_start:batch_end, :sequence_end, ...] - - -class MHA(nn.Module): - """Multi-head self-attention and cross-attention""" - - def __init__( - self, - embed_dim, - num_heads, - num_heads_kv=None, - cross_attn=False, - qkv_proj_bias=True, - out_proj_bias=True, - dropout=0.0, - softmax_scale=None, - causal=False, - layer_idx=None, - dwconv=False, - rotary_emb_dim=0, - rotary_emb_base=10000.0, - rotary_emb_scale_base=None, - rotary_emb_interleaved=False, - use_alibi=False, - window_size=(-1, -1), - fused_bias_fc=False, - use_flash_attn=False, - return_residual=False, - checkpointing=False, - device=None, - dtype=None, - ) -> None: - """ - num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. - return_residual: whether to return the input x along with the output. This is for - performance reason: for post-norm architecture, returning the input allows us - to fuse the backward of nn.Linear with the residual connection. - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.embed_dim = embed_dim - self.cross_attn = cross_attn - self.causal = causal - self.layer_idx = layer_idx - self.dwconv = dwconv - self.rotary_emb_dim = rotary_emb_dim - self.use_flash_attn = use_flash_attn - self.return_residual = return_residual - self.checkpointing = checkpointing - if use_alibi: - assert use_flash_attn, "ALiBi code path requires flash_attn" - alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device) - else: - alibi_slopes = None - if window_size != (-1, -1): - assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" - - self.num_heads = num_heads - self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads - assert ( - self.num_heads % self.num_heads_kv == 0 - ), "num_heads must be divisible by num_heads_kv" - assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" - self.head_dim = self.embed_dim // num_heads - qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) - kv_dim = 2 * self.head_dim * self.num_heads_kv - - if self.rotary_emb_dim > 0: - assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet" - assert RotaryEmbedding is not None, "rotary_emb is not installed" - self.rotary_emb = RotaryEmbedding( - self.rotary_emb_dim, - base=rotary_emb_base, - scale_base=rotary_emb_scale_base, - interleaved=rotary_emb_interleaved, - device=device, - ) - - if fused_bias_fc and FusedDense is None: - raise ImportError("fused_dense is not installed") - linear_cls = nn.Linear if not fused_bias_fc else FusedDense - linear_resid_cls = ( - LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True) - ) - wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls - inner_attn_cls = ( - partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) - if use_flash_attn - else SelfAttention - ) - inner_cross_attn_cls = ( - partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size) - if use_flash_attn - else CrossAttention - ) - if not self.cross_attn: - self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) - else: - self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) - self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) - if self.dwconv: - if self.num_heads_kv == self.num_heads: - self.dwconv_qkv = nn.Conv1d( - qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim - ) - else: - self.dwconv_q = nn.Conv1d( - embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim - ) - self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim) - self.inner_attn = inner_attn_cls( - causal=causal, - softmax_scale=softmax_scale, - attention_dropout=dropout, - ) - self.inner_cross_attn = inner_cross_attn_cls( - causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout - ) - self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): - dtype = self.out_proj.weight.dtype if dtype is None else dtype - device = self.out_proj.weight.device - return torch.empty( - batch_size, - max_seqlen, - 2, - self.num_heads_kv, - self.head_dim, - dtype=dtype, - device=device, - ) - - def _update_kv_cache(self, kv, inference_params): - """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" - assert not self.dwconv, "Generation does not support dwconv yet" - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - return _update_kv_cache(kv, inference_params, self.layer_idx) - - def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): - """ - Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. - q: (batch_size, seqlen_q, nheads, head_dim) - kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) - """ - assert inference_params is not None and inference_params.seqlen_offset > 0 - assert self.use_flash_attn - if self.rotary_emb_dim > 0: - assert self.rotary_emb.scale is None, "This code path does not support xPos" - self.rotary_emb._update_cos_sin_cache( - inference_params.max_seqlen, device=q.device, dtype=q.dtype - ) - rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached - else: - rotary_cos, rotary_sin = None, None - batch = q.shape[0] - kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] - cache_seqlens = ( - inference_params.lengths_per_sample[:batch] - if inference_params.lengths_per_sample is not None - else inference_params.seqlen_offset - ) - alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) - context = flash_attn_with_kvcache( - q, - kv_cache[:, :, 0], - kv_cache[:, :, 1], - kv[:, :, 0], - kv[:, :, 1], - rotary_cos=rotary_cos, - rotary_sin=rotary_sin, - cache_seqlens=cache_seqlens, - softmax_scale=self.inner_cross_attn.softmax_scale, - causal=self.inner_cross_attn.causal, - rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, - alibi_slopes=alibi_slopes, - ) - return context - - def _update_kvcache_attention(self, q, kv, inference_params): - """Write kv to inference_params, then do attention""" - if ( - inference_params.seqlen_offset == 0 - or flash_attn_with_kvcache is None - or not self.use_flash_attn - ): - # TODO: this only uses seqlen_offset and not lengths_per_sample. - kv = self._update_kv_cache(kv, inference_params) - return self.inner_cross_attn(q, kv) - else: - batch = q.shape[0] - kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] - cache_seqlens = ( - inference_params.lengths_per_sample[:batch] - if inference_params.lengths_per_sample is not None - else inference_params.seqlen_offset - ) - alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) - return flash_attn_with_kvcache( - q, - kv_cache[:, :, 0], - kv_cache[:, :, 1], - kv[:, :, 0], - kv[:, :, 1], - cache_seqlens=cache_seqlens, - softmax_scale=self.inner_cross_attn.softmax_scale, - causal=self.inner_cross_attn.causal, - alibi_slopes=alibi_slopes, - ) - - def forward( - self, - x, - x_kv=None, - key_padding_mask=None, - cu_seqlens=None, - max_seqlen=None, - mixer_subset=None, - inference_params=None, - **kwargs, - ): - """ - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if - cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total - is the is the sum of the sequence lengths in the batch. - x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into x. Only applicable when using - FlashAttention. - max_seqlen: int. Maximum sequence length in the batch. - key_padding_mask: boolean mask, True means to keep, False means to mask out. - (batch, seqlen). Only applicable when not using FlashAttention. - mixer_subset: for cross-attention only. If not None, will take a subset of x - before applying the query projection. Useful for e.g., ViT where we only care - about the CLS token in the last layer. - inference_params: for generation. Adapted from Megatron-LM (and Apex) - https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 - """ - if cu_seqlens is not None: - assert max_seqlen is not None - assert key_padding_mask is None - assert self.use_flash_attn - assert not self.dwconv - assert self.rotary_emb_dim == 0 - if key_padding_mask is not None: - assert cu_seqlens is None - assert max_seqlen is None - assert not self.use_flash_attn - if inference_params is not None: - assert key_padding_mask is None - assert cu_seqlens is None and max_seqlen is None - assert not self.dwconv - - kwargs = ( - {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs} - if self.use_flash_attn - else {"key_padding_mask": key_padding_mask, **kwargs} - ) - seqlen_offset = ( - 0 - if inference_params is None - else ( - inference_params.lengths_per_sample - if inference_params.lengths_per_sample is not None - else inference_params.seqlen_offset - ) - ) - rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None - batch, seqlen = x.shape[:2] - if not self.cross_attn and self.num_heads_kv == self.num_heads: - assert x_kv is None and mixer_subset is None - if not self.return_residual: - qkv = self.Wqkv(x) - else: - qkv, x = self.Wqkv(x) - if self.dwconv: - qkv = rearrange( - self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" - ).contiguous() - qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) - if ( - inference_params is None - or inference_params.seqlen_offset == 0 - or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) - or not self.use_flash_attn - ): - if self.rotary_emb_dim > 0: - qkv = self.rotary_emb( - qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen - ) - if inference_params is None: - if not self.checkpointing: - context = self.inner_attn(qkv, **kwargs) - else: - context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) - else: - context = self._update_kvcache_attention( - qkv[:, :, 0], qkv[:, :, 1:], inference_params - ) - else: - context = self._apply_rotary_update_kvcache_attention( - qkv[:, :, 0], qkv[:, :, 1:], inference_params - ) - else: - if self.cross_attn: - if not self.return_residual: - q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) - kv = self.Wkv(x_kv if x_kv is not None else x) - else: - if x_kv is not None: - kv, x_kv = self.Wkv(x_kv) - else: - kv, x = self.Wkv(x) - q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) - else: - assert self.num_heads_kv != self.num_heads - if not self.return_residual: - qkv = self.Wqkv(x) - else: - qkv, x = self.Wqkv(x) - q = qkv[..., : self.num_heads * self.head_dim] - kv = qkv[..., self.num_heads * self.head_dim :] - q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) - kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) - if self.dwconv: - q = rearrange( - self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d" - ).contiguous() - kv = rearrange( - self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" - ).contiguous() - if ( - inference_params is None - or inference_params.seqlen_offset == 0 - or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) - or not self.use_flash_attn - ): - if self.rotary_emb_dim > 0: - q, kv = self.rotary_emb( - q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen - ) - if inference_params is None: - if not self.checkpointing: - context = self.inner_cross_attn(q, kv, **kwargs) - else: - context = torch.utils.checkpoint.checkpoint( - self.inner_cross_attn, q, kv, **kwargs - ) - else: - context = self._update_kvcache_attention(q, kv, inference_params) - else: - context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) - out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) - return out if not self.return_residual else (out, x) - - -class ParallelMHA(nn.Module): - """Multi-head self-attention and cross-attention""" - - def __init__( - self, - embed_dim, - num_heads, - process_group, - num_heads_kv=None, - qkv_proj_bias=True, - out_proj_bias=True, - dropout=0.0, - softmax_scale=None, - causal=False, - layer_idx=None, - rotary_emb_dim=0, - rotary_emb_base=10000.0, - rotary_emb_scale_base=None, - rotary_emb_interleaved=False, - use_alibi=False, - window_size=(-1, -1), - use_flash_attn=False, - checkpointing=False, - sequence_parallel=True, - device=None, - dtype=None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.embed_dim = embed_dim - self.causal = causal - self.layer_idx = layer_idx - self.rotary_emb_dim = rotary_emb_dim - self.use_flash_attn = use_flash_attn - self.checkpointing = checkpointing - self.process_group = process_group - self.world_size = process_group.size() - self.local_rank = torch.distributed.get_rank(process_group) - - self.num_heads = num_heads - assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" - - self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads - assert ( - self.num_heads % self.num_heads_kv == 0 - ), "num_heads must be divisible by num_heads_kv" - - self.num_heads_per_rank = get_dim_for_local_rank( - self.num_heads, self.world_size, self.local_rank - ) - self.num_heads_kv_per_rank = get_dim_for_local_rank( - self.num_heads_kv, self.world_size, self.local_rank - ) - self.head_dim = self.embed_dim // num_heads - qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) - - if use_alibi: - assert use_flash_attn, "ALiBi code path requires flash_attn" - num_heads_local = math.ceil(self.num_heads / self.world_size) - alibi_slopes = torch.tensor( - get_alibi_slopes(num_heads)[ - self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local - ], - device=device, - ) - else: - alibi_slopes = None - if window_size != (-1, -1): - assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" - - if self.rotary_emb_dim > 0: - assert RotaryEmbedding is not None, "rotary_emb is not installed" - self.rotary_emb = RotaryEmbedding( - self.rotary_emb_dim, - base=rotary_emb_base, - scale_base=rotary_emb_scale_base, - interleaved=rotary_emb_interleaved, - device=device, - ) - - if ColumnParallelLinear is None or RowParallelLinear is None: - raise ImportError("fused_dense is not installed") - self.Wqkv = ColumnParallelLinear( - embed_dim, - qkv_dim, - process_group, - bias=qkv_proj_bias, - sequence_parallel=sequence_parallel, - multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2), - **factory_kwargs, - ) - inner_attn_cls = ( - partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) - if use_flash_attn - else SelfAttention - ) - inner_cross_attn_cls = ( - partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size) - if use_flash_attn - else CrossAttention - ) - self.inner_attn = inner_attn_cls( - causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout - ) - self.inner_cross_attn = inner_cross_attn_cls( - causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout - ) - self.out_proj = RowParallelLinear( - embed_dim, - embed_dim, - process_group, - bias=out_proj_bias, - sequence_parallel=sequence_parallel, - multiple_of=self.head_dim, - **factory_kwargs, - ) - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): - dtype = self.out_proj.weight.dtype if dtype is None else dtype - device = self.out_proj.weight.device - return torch.empty( - batch_size, - max_seqlen, - 2, - self.num_heads_kv_per_rank, - self.head_dim, - dtype=dtype, - device=device, - ) - - def _update_kv_cache(self, kv, inference_params): - """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - return _update_kv_cache(kv, inference_params, self.layer_idx) - - def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): - """ - Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. - q: (batch_size, seqlen_q, nheads, head_dim) - kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) - """ - assert inference_params is not None and inference_params.seqlen_offset > 0 - assert self.use_flash_attn - if self.rotary_emb_dim > 0: - assert self.rotary_emb.scale is None, "This code path does not support xPos" - self.rotary_emb._update_cos_sin_cache( - inference_params.max_seqlen, device=q.device, dtype=q.dtype - ) - rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached - else: - rotary_cos, rotary_sin = None, None - batch = q.shape[0] - kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] - cache_seqlens = ( - inference_params.lengths_per_sample[:batch] - if inference_params.lengths_per_sample is not None - else inference_params.seqlen_offset - ) - alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) - context = flash_attn_with_kvcache( - q, - kv_cache[:, :, 0], - kv_cache[:, :, 1], - kv[:, :, 0], - kv[:, :, 1], - rotary_cos=rotary_cos, - rotary_sin=rotary_sin, - cache_seqlens=cache_seqlens, - softmax_scale=self.inner_cross_attn.softmax_scale, - causal=self.inner_cross_attn.causal, - rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, - alibi_slopes=alibi_slopes, - ) - return context - - def _update_kvcache_attention(self, q, kv, inference_params): - """Write kv to inference_params, then do attention""" - if inference_params.seqlen_offset == 0 or not self.use_flash_attn: - # TODO: this only uses seqlen_offset and not lengths_per_sample. - kv = self._update_kv_cache(kv, inference_params) - return self.inner_cross_attn(q, kv) - else: - batch = q.shape[0] - kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] - cache_seqlens = ( - inference_params.lengths_per_sample[:batch] - if inference_params.lengths_per_sample is not None - else inference_params.seqlen_offset - ) - alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) - context = flash_attn_with_kvcache( - q, - kv_cache[:, :, 0], - kv_cache[:, :, 1], - kv[:, :, 0], - kv[:, :, 1], - cache_seqlens=cache_seqlens, - softmax_scale=self.inner_cross_attn.softmax_scale, - causal=self.inner_cross_attn.causal, - alibi_slopes=alibi_slopes, - ) - return context - - def forward(self, x, seqlen=None, inference_params=None, **kwargs): - """ - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. - If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we - split x during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - """ - qkv = self.Wqkv(x) - if seqlen is not None: - qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen) - seqlen_offset = ( - 0 - if inference_params is None - else ( - inference_params.lengths_per_sample - if inference_params.lengths_per_sample is not None - else inference_params.seqlen_offset - ) - ) - rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None - if self.num_heads_kv == self.num_heads: - qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) - if ( - inference_params is None - or inference_params.seqlen_offset == 0 - or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) - or not self.use_flash_attn - ): - if self.rotary_emb_dim > 0: - qkv = self.rotary_emb( - qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen - ) - if inference_params is None: - if not self.checkpointing: - context = self.inner_attn(qkv, **kwargs) - else: - context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) - else: - context = self._update_kvcache_attention( - qkv[:, :, 0], qkv[:, :, 1:], inference_params - ) - else: - context = self._apply_rotary_update_kvcache_attention( - qkv[:, :, 0], qkv[:, :, 1:], inference_params - ) - else: - q = rearrange( - qkv[..., : self.num_heads_per_rank * self.head_dim], - "... (h d) -> ... h d", - d=self.head_dim, - ) - kv = rearrange( - qkv[..., self.num_heads_per_rank * self.head_dim :], - "... (two hkv d) -> ... two hkv d", - two=2, - d=self.head_dim, - ) - if ( - inference_params is None - or inference_params.seqlen_offset == 0 - or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) - or not self.use_flash_attn - ): - if self.rotary_emb_dim > 0: - q, kv = self.rotary_emb( - q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen - ) - if inference_params is None: - if not self.checkpointing: - context = self.inner_cross_attn(q, kv, **kwargs) - else: - context = torch.utils.checkpoint.checkpoint( - self.inner_cross_attn, q, kv, **kwargs - ) - else: - context = self._update_kvcache_attention(q, kv, inference_params) - else: - context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) - context = rearrange(context, "b s h d -> b s (h d)") - if seqlen is not None: - context = rearrange(context, "b s d -> (b s) d") - out = self.out_proj(context) - return out diff --git a/vllm/thirdparty_files/flash_attn/modules/mlp.py b/vllm/thirdparty_files/flash_attn/modules/mlp.py deleted file mode 100644 index 23584d3098a2..000000000000 --- a/vllm/thirdparty_files/flash_attn/modules/mlp.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.distributed import ProcessGroup - - -try: - from flash_attn.ops.activations import swiglu -except ImportError: - swiglu = None - -try: - from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear -except ImportError: - ColumnParallelLinear, RowParallelLinear = None, None - -try: - from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP -except ImportError: - FusedMLP, ParallelFusedMLP = None, None - - -class Mlp(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - activation=F.gelu, - bias1=True, - bias2=True, - return_residual=False, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features if out_features is not None else in_features - hidden_features = hidden_features if hidden_features is not None else in_features * 4 - self.return_residual = return_residual - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) - self.activation = activation - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) - - def forward(self, x): - y = self.fc1(x) - y = self.activation(y) - y = self.fc2(y) - return y if not self.return_residual else (y, x) - - -class ParallelMLP(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - activation=F.gelu, - process_group: ProcessGroup = None, - sequence_parallel=True, - bias1=True, - bias2=True, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - assert ColumnParallelLinear is not None, "Need to install fused_dense" - assert RowParallelLinear is not None, "Need to install fused_dense" - out_features = out_features if out_features is not None else in_features - hidden_features = hidden_features if hidden_features is not None else in_features * 4 - self.fc1 = ColumnParallelLinear( - in_features, - hidden_features, - process_group, - bias=bias1, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - self.activation = activation - self.fc2 = RowParallelLinear( - hidden_features, - out_features, - process_group, - bias=bias2, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - - def forward(self, x): - y = self.fc1(x) - y = self.activation(y) - y = self.fc2(y) - return y - - -class GatedMlp(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - activation=F.sigmoid, - bias1=True, - bias2=True, - multiple_of=128, - return_residual=False, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features if out_features is not None else in_features - hidden_features = ( - hidden_features if hidden_features is not None else int(8 * in_features / 3) - ) - hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of - self.return_residual = return_residual - self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs) - self.activation = activation - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) - - def forward(self, x): - y = self.fc1(x) - if self.activation == F.sigmoid: # Special case for GLU - y = F.glu(y, dim=-1) - elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU - y, gate = y.chunk(2, dim=-1) - y = swiglu(gate, y) - else: - y, gate = y.chunk(2, dim=-1) - y = y * self.activation(gate) - y = self.fc2(y) - return y if not self.return_residual else (y, x) - - -class ParallelGatedMlp(nn.Module): - """Parallel GatedMlp""" - - def __init__( - self, - in_features, - process_group, - hidden_features=None, - out_features=None, - activation=F.sigmoid, - bias1=True, - bias2=True, - multiple_of=128, - sequence_parallel=True, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features if out_features is not None else in_features - hidden_features = ( - hidden_features if hidden_features is not None else int(8 * in_features / 3) - ) - hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of - if ColumnParallelLinear is None or RowParallelLinear is None: - raise ImportError("fused_dense is not installed") - self.fc1 = ColumnParallelLinear( - in_features, - 2 * hidden_features, - process_group, - bias=bias1, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - self.activation = activation - self.fc2 = RowParallelLinear( - hidden_features, - out_features, - process_group, - bias=bias2, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - - def forward(self, x): - y = self.fc1(x) - if self.activation == F.sigmoid: # Special case for GLU - y = F.glu(y, dim=-1) - else: - y, gate = y.chunk(2, dim=-1) - y = y * self.activation(gate) - y = self.fc2(y) - return y diff --git a/vllm/thirdparty_files/flash_attn/ops/__init__.py b/vllm/thirdparty_files/flash_attn/ops/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/thirdparty_files/flash_attn/ops/activations.py b/vllm/thirdparty_files/flash_attn/ops/activations.py deleted file mode 100644 index b00063b6bd49..000000000000 --- a/vllm/thirdparty_files/flash_attn/ops/activations.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -# 1/sqrt(2*pi)-> 0.3989423 -# 1/sqrt(2) -> 0.70710678 -# sqrt(2/pi) -> 0.79788456 - -# this function is tanh approximation of gelu -# actual gelu is: -# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) -@torch.jit.script -def bias_gelu(y, bias): - x = bias + y - return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) - - -# gradient of tanh approximation of gelu -# gradient of actual gelu is: -# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) -@torch.jit.script -def bias_gelu_back(g, y, bias): - """Assume that y has shape (B, D) and bias has shape (D)""" - x = bias + y - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( - 1 + tanh_out - ) - grad_y = ff * g - return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) - - -class GeLUFunction(torch.autograd.Function): - @staticmethod - # bias is an optional argument - def forward(ctx, input, bias): - ctx.save_for_backward(input, bias) - return bias_gelu(input, bias) - - @staticmethod - def backward(ctx, grad_output): - input, bias = ctx.saved_tensors - tmp = bias_gelu_back(grad_output, input, bias) - return tmp, tmp - - -bias_gelu_impl = GeLUFunction.apply - -# this function is tanh approximation of gelu -# actual gelu is: -# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) -@torch.jit.script -def gelu_fwd(x): - return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) - - -# gradient of tanh approximation of gelu -# gradient of actual gelu is: -# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) -@torch.jit.script -def gelu_bwd(g, x): - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( - 1 + tanh_out - ) - return (ff * g).to(dtype=x.dtype) - - -class FastGeLUFunction(torch.autograd.Function): - @staticmethod - # bias is an optional argument - def forward(ctx, input): - ctx.save_for_backward(input) - return gelu_fwd(input) - - @staticmethod - def backward(ctx, grad_output): - (input,) = ctx.saved_tensors - tmp = gelu_bwd(grad_output, input) - return tmp - - -fast_gelu_impl = FastGeLUFunction.apply - - -@torch.jit.script -def relu_bwd(g, x): - return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) - - -@torch.jit.script -def sqrelu_fwd(x): - r = F.relu(x) - return (r * r).to(dtype=x.dtype) - - -@torch.jit.script -def sqrelu_bwd(g, x): - return (2.0 * g * F.relu(x)).to(dtype=x.dtype) - - -swiglu_fwd_codestring = """ -template T swiglu_fwd(T x, T y) { - return float(x) * float(y) / (1.0f + ::exp(-float(x))); -} -""" -swiglu_bwd_codestring = """ -template T swiglu_bwd(T x, T y, T g, T& dx, T& dy) { - float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); - dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); - dy = float(x) * x_sigmoid * float(g); -} -""" -swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) -swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) - - -class SwiGLUFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, y): - ctx.save_for_backward(x, y) - return swiglu_fwd(x, y) - - @staticmethod - def backward(ctx, dout): - x, y = ctx.saved_tensors - return swiglu_bwd(x, y, dout) - -swiglu = SwiGLUFunction.apply diff --git a/vllm/thirdparty_files/flash_attn/ops/fused_dense.py b/vllm/thirdparty_files/flash_attn/ops/fused_dense.py deleted file mode 100644 index 1e45b8e60981..000000000000 --- a/vllm/thirdparty_files/flash_attn/ops/fused_dense.py +++ /dev/null @@ -1,688 +0,0 @@ -# Copyright (c) 2023, Tri Dao. -# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py -# We make it work with pytorch amp and with bfloat16. -# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py -from functools import partial -from typing import Optional - -# import fused_dense_cuda # from apex -import fused_dense_lib as fused_dense_cuda -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd -from torch.distributed import ProcessGroup - -from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd -from flash_attn.utils.distributed import ( - all_gather_raw, - all_reduce, - all_reduce_raw, - reduce_scatter, - reduce_scatter_raw, -) - - -class FusedDenseFunc(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True - ): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather_raw of x before doing the matmul. - """ - ctx.compute_weight_gradient = weight.requires_grad - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - - if torch.is_autocast_enabled(): - weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) - bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None - weight = weight.contiguous() - if process_group is not None and sequence_parallel: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *weight.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - output = F.linear(total_x, weight, bias) - if ctx.compute_weight_gradient: - ctx.save_for_backward(x, weight) - else: - ctx.save_for_backward(weight) - return output if not return_residual else (output, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - sequence_parallel = ctx.sequence_parallel - if ctx.compute_weight_gradient: - x, weight = ctx.saved_tensors - if process_group is not None and sequence_parallel: - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - else: - (weight,) = ctx.saved_tensors - total_x = None - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_output, weight.t()) - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - if process_group is not None: - reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw - grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) - else: - grad_input = None - if ctx.needs_input_grad[1]: - assert ctx.compute_weight_gradient - if process_group is not None and sequence_parallel: - handle_x.wait() - grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] - ) - else: - grad_weight = None - grad_bias = grad_output if ctx.needs_input_grad[2] else None - if process_group is not None and ctx.needs_input_grad[0]: - handle_grad_input.wait() - return grad_input, grad_weight, grad_bias, None, None, None - - -def fused_dense_func( - x: Tensor, - weight: Tensor, - bias: Optional[Tensor] = None, - return_residual: bool = False, - process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True, -): - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FusedDenseFunc.apply( - x, weight, bias, return_residual, process_group, sequence_parallel - ) - else: - assert process_group is None - out = F.linear(x, weight, bias) - return out if not return_residual else (out, x) - - -class FusedDense(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - return_residual: bool = False, - device=None, - dtype=None, - ) -> None: - super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) - self.return_residual = return_residual - - def forward(self, x, process_group=None): - """ - If process_group is not None, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - """ - return fused_dense_func( - x, - self.weight, - self.bias, - return_residual=self.return_residual, - process_group=process_group, - ) - - -class ColumnParallelLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - multiple_of=1, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - if out_features % multiple_of: - raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") - multiple = out_features // multiple_of - # We want to split @multiple across world_size, but it could be an uneven split - div = multiple // world_size - mod = multiple % world_size - # The first @mod ranks get @div + 1 copies, the rest get @div copies - local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) - super().__init__( - in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype - ) - self.process_group = process_group - self.sequence_parallel = sequence_parallel - - def forward(self, x): - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. - return fused_dense_func( - x, - self.weight, - self.bias, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - ) - - -class RowParallelLinear(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - multiple_of=1, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - rank = torch.distributed.get_rank(process_group) - if in_features % multiple_of: - raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") - multiple = in_features // multiple_of - # We want to split @multiple across world_size, but it could be an uneven split - div = multiple // world_size - mod = multiple % world_size - # The first @mod ranks get @div + 1 copies, the rest get @div copies - local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) - # Only rank 0 will have bias - super().__init__( - local_multiple * multiple_of, - out_features, - bias=bias and rank == 0, - device=device, - dtype=dtype, - ) - self.process_group = process_group - self.sequence_parallel = sequence_parallel - - def forward(self, x): - """ - We're doing Tensor Parallel with sequence parallelism: we do the matmul and then - a reduce_scatter of the result. - """ - out = fused_dense_func(x, self.weight, self.bias) - reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce - return reduce_fn(out, self.process_group) - - -class FusedMLPFunc(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, - x, - weight1, - bias1, - weight2, - bias2, - activation="gelu_approx", - save_pre_act=True, - return_residual=False, - checkpoint_lvl=0, - heuristic=0, - process_group=None, - sequence_parallel=True, - ): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather of x before doing the matmul. - If sequence_parallel=False, then the input is already gathered. - - checkpoint_lvl: - 0: no recomputation in the bwd - 1: recompute gelu_out / relu_out in the bwd - 2: recompute pre_act and gelu_out / relu_out in the bwd - """ - assert -1 <= heuristic <= 4 - assert activation in ["gelu_approx", "relu", "sqrelu"] - if activation == "sqrelu": - assert heuristic == -1 - if not save_pre_act: - checkpoint_lvl = 2 - assert checkpoint_lvl in [0, 1, 2] - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - ctx.checkpoint_lvl = checkpoint_lvl - ctx.activation = activation - ctx.heuristic = heuristic - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - else: - total_x = x - - if torch.is_autocast_enabled(): - dtype = torch.get_autocast_gpu_dtype() - weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]] - bias1 = bias1.to(dtype=dtype) if bias1 is not None else None - bias2 = bias2.to(dtype=dtype) if bias2 is not None else None - weight1 = weight1.contiguous() - bias1 = bias1.contiguous() if bias1 is not None else None - weight2 = weight2.contiguous() - bias2 = bias2.contiguous() if bias2 is not None else None - if process_group is not None and sequence_parallel: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - if heuristic == -1: - pre_act = F.linear(total_x, weight1, bias1) - activation_fn = ( - partial(F.gelu, approximate="tanh") - if activation == "gelu_approx" - else (sqrelu_fwd if activation == "sqrelu" else F.relu) - ) - with torch.jit.fuser("fuser2"): - output1 = activation_fn(pre_act) - # This is before adding bias1 - # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1) - # with torch.jit.fuser('fuser2'): - # output1 = bias_gelu(pre_act, bias1) - else: - is_gelu = activation == "gelu_approx" - output1, *rest = fused_dense_cuda.linear_act_forward( - total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic - ) - if save_pre_act: - pre_act = rest[0] - output2 = F.linear(output1, weight2, bias2) - if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): - # For RELU the pre_act is very small (just a bit-mask) so we just save it - ctx.save_for_backward(x, weight1, weight2, pre_act, output1) - elif checkpoint_lvl == 1: - ctx.save_for_backward(x, weight1, weight2, pre_act) - elif checkpoint_lvl == 2: - ctx.save_for_backward(x, weight1, weight2, bias1) - output2 = output2.reshape(*batch_shape, output2.shape[-1]) - return output2 if not return_residual else (output2, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - checkpoint_lvl = ctx.checkpoint_lvl - activation = ctx.activation - activation_fn = ( - partial(F.gelu, approximate="tanh") - if activation == "gelu_approx" - else (sqrelu_fwd if activation == "sqrelu" else F.relu) - ) - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - sequence_parallel = ctx.sequence_parallel - x, weight1, weight2, *rest = ctx.saved_tensors - if process_group is None or not sequence_parallel: - total_x = x - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - if checkpoint_lvl in [0, 1]: - if process_group is not None and sequence_parallel: - total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): - pre_act, output1 = rest - elif checkpoint_lvl == 1: - (pre_act,) = rest - with torch.jit.fuser("fuser2"): - output1 = activation_fn(pre_act) - elif checkpoint_lvl == 2: - (bias1,) = rest - if process_group is not None and sequence_parallel: - total_x, _ = all_gather_raw(x, process_group) - if ctx.heuristic == -1: - pre_act = F.linear(total_x, weight1, bias1) - with torch.jit.fuser("fuser2"): - output1 = activation_fn(pre_act) - else: - output1, pre_act = fused_dense_cuda.linear_act_forward( - total_x.reshape(batch_dim, total_x.shape[-1]), - weight1, - bias1, - activation == "gelu_approx", - True, - ctx.heuristic, - ) - - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - output1 = output1.reshape(batch_dim, output1.shape[-1]) - pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1]) - if ctx.needs_input_grad[3]: - grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad( - output1, grad_output, ctx.needs_input_grad[4] - ) - else: - grad_weight2 = None - grad_bias2 = grad_output if ctx.needs_input_grad[4] else None - if ctx.heuristic == -1: - # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act) - grad_output1 = F.linear(grad_output, weight2.t()) - activation_grad_fn = ( - gelu_bwd - if activation == "gelu_approx" - else (sqrelu_bwd if activation == "sqrelu" else relu_bwd) - ) - with torch.jit.fuser("fuser2"): - grad_pre_act = activation_grad_fn(grad_output1, pre_act) - else: - # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't - # just compute gelu/relu grad - grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad( - weight2, grad_output, pre_act, activation == "gelu_approx", ctx.heuristic - ) - if not ctx.needs_input_grad[2]: - grad_bias1 = None - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_pre_act, weight1.t()) - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_pre_act, weight1 - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - if process_group is not None: - reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw - grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) - else: - grad_input = None - if ctx.heuristic == -1: - if ctx.needs_input_grad[1]: - if process_group is not None and sequence_parallel and checkpoint_lvl != 2: - handle_x.wait() - grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), - grad_pre_act, - ctx.needs_input_grad[2], - ) - else: - grad_weight1 = None - grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None - else: - if ctx.needs_input_grad[1]: - if process_group is not None and sequence_parallel and checkpoint_lvl != 2: - handle_x.wait() - grad_weight1 = F.linear( - grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t() - ) - else: - grad_weight1 = None - if process_group is not None and ctx.needs_input_grad[0]: - handle_grad_input.wait() - return ( - grad_input, - grad_weight1, - grad_bias1, - grad_weight2, - grad_bias2, - None, - None, - None, - None, - None, - None, - None, - ) - - -def fused_mlp_func( - x: Tensor, - weight1: Tensor, - weight2: Tensor, - bias1: Optional[Tensor] = None, - bias2: Optional[Tensor] = None, - activation: str = "gelu_approx", - save_pre_act: bool = True, - return_residual: bool = False, - checkpoint_lvl: int = 0, - heuristic: int = 0, - process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True, -): - assert activation in ["gelu_approx", "relu", "sqrelu"] - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu) - dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == "relu" else 8) == 0) - if ( - x.is_cuda - and weight1.is_cuda - and weight2.is_cuda - and (bias1 is None or bias1.is_cuda) - and (bias2 is None or bias2.is_cuda) - and dtype_eligible - and dim_eligible - ): - return FusedMLPFunc.apply( - x, - weight1, - bias1, - weight2, - bias2, - activation, - save_pre_act, - return_residual, - checkpoint_lvl, - heuristic, - process_group, - sequence_parallel, - ) - else: - assert process_group is None - pre_act = F.linear(x, weight1, bias1) - activation_fn = ( - partial(F.gelu, approximate="tanh") - if activation == "gelu_approx" - else partial(F.relu, inplace=True) - ) - output1 = activation_fn(pre_act) - output2 = F.linear(output1, weight2, bias2) - return output2 if not return_residual else (output2, x) - - -class FusedMLP(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - bias1=True, - bias2=True, - activation="gelu_approx", - return_residual=False, - checkpoint_lvl=0, - heuristic="auto", - device=None, - dtype=None, - ): - """ - If process_group is not None, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul, gelu, then matmul. - Finally we do a reduce_scatter of the output. - - checkpoint_lvl (increasing lvl means slower but more memory saving): - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute pre_act and gelu_out in the bwd - heuristic: - -1: don't fuse gemm + gelu (separate kernel) - 0..4: use this heuristic for the algo section in the fused gemm + gelu - 'auto': heuristic will be picked automatically: - For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. - For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. - For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation - is slower than the unfused version. - return_residual: whether to return the input x along with the output. This is for - performance reason: for post-norm architecture, returning the input allows us - to fuse the backward of nn.Linear with the residual connection. - """ - assert checkpoint_lvl in [0, 1, 2] - assert activation in ["gelu_approx", "relu", "sqrelu"] - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features * 4 - self.activation = activation - self.return_residual = return_residual - self.checkpoint_lvl = checkpoint_lvl - self.heuristic = heuristic if activation != "sqrelu" else -1 - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) - - def forward(self, x, process_group=None): - dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() - if self.heuristic == "auto": - if self.activation == "gelu_approx": - if torch.cuda.get_device_capability("cuda") == (9, 0): - heuristic = -1 - else: - cuda_ver = tuple(map(int, torch.version.cuda.split("."))) - heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) - else: - heuristic = 0 - else: - heuristic = self.heuristic - out = fused_mlp_func( - x, - self.fc1.weight, - self.fc2.weight, - self.fc1.bias, - self.fc2.bias, - activation=self.activation, - save_pre_act=self.training, - return_residual=self.return_residual, - checkpoint_lvl=self.checkpoint_lvl, - heuristic=heuristic, - process_group=process_group, - ) - if self.return_residual: - out, x = out - if process_group is not None: - out = reduce_scatter(out, process_group) - return out if not self.return_residual else (out, x) - - -class ParallelFusedMLP(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - activation="gelu_approx", - process_group: ProcessGroup = None, - bias1=True, - bias2=True, - sequence_parallel=True, - checkpoint_lvl=0, - heuristic="auto", - device=None, - dtype=None, - ): - """ - process_group is required. We're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul, gelu, then matmul. - Finally we do a reduce_scatter of the output. - - checkpoint_lvl (increasing lvl means slower but more memory saving): - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute pre_act and gelu_out in the bwd - heuristic: - -1: don't fuse gemm + gelu (separate kernel) - 0..4: use this heuristic for the algo section in the fused gemm + gelu - 'auto': heuristic will be picked automatically: - For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. - For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. - """ - assert checkpoint_lvl in [0, 1, 2] - assert activation in ["gelu_approx", "relu", "sqrelu"] - assert process_group is not None - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features * 4 - self.activation = activation - self.process_group = process_group - self.sequence_parallel = sequence_parallel - self.checkpoint_lvl = checkpoint_lvl - self.heuristic = heuristic if activation != "sqrelu" else -1 - self.fc1 = ColumnParallelLinear( - in_features, hidden_features, process_group, bias=bias1, **factory_kwargs - ) - self.fc2 = RowParallelLinear( - hidden_features, out_features, process_group, bias=bias2, **factory_kwargs - ) - - def forward(self, x): - dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() - if self.heuristic == "auto": - if self.activation == "gelu_approx": - cuda_ver = tuple(map(int, torch.version.cuda.split("."))) - heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) - else: - heuristic = 0 - else: - heuristic = self.heuristic - out = fused_mlp_func( - x, - self.fc1.weight, - self.fc2.weight, - self.fc1.bias, - self.fc2.bias, - activation=self.activation, - save_pre_act=self.training, - checkpoint_lvl=self.checkpoint_lvl, - heuristic=heuristic, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - ) - reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce - return reduce_fn(out, self.process_group) diff --git a/vllm/thirdparty_files/flash_attn/ops/layer_norm.py b/vllm/thirdparty_files/flash_attn/ops/layer_norm.py deleted file mode 100644 index 4b6cd798fd02..000000000000 --- a/vllm/thirdparty_files/flash_attn/ops/layer_norm.py +++ /dev/null @@ -1,800 +0,0 @@ -# Copyright (c) 2022, Tri Dao. -# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py - -import dropout_layer_norm -import torch -from torch.nn import init - - -def maybe_align(x, alignment_in_bytes=16): - """Assume that x already has last dim divisible by alignment_in_bytes""" - # TD [2023-07-04] I'm not 100% sure that clone will align the memory - # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 - return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() - - -def _dropout_add_layer_norm_forward( - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma.numel() - x0mat = x0.view((-1, hidden_size)) - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - rowscale = rowscale.view(-1) if rowscale is not None else None - zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, - residualmat, - gamma, - beta, - rowscale, - colscale, - None, - None, - dropout_p, - epsilon, - 1.0, - 0, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma - - -def _dropout_add_layer_norm_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - dropout_p, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - x0 must not be None if we have colscale. - """ - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dzmat = dz.view(xmat.shape) - dxmat = dx.view(xmat.shape) if dx is not None else None - x0mat = x0.view((-1, hidden_size)) if x0 is not None else None - rowscale = rowscale.view(-1) if rowscale is not None else None - if colscale is not None: - assert x0 is not None, "x0 is required to compute the gradient of colscale" - dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( - dzmat, - dxmat, - xmat, - x0mat, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - None, - None, - dropout_p, - 1.0, - 0, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - if colscale is None: - return dx0mat, dresidualmat, dgamma, dbeta - else: - dcolscale = rest[0] - return dx0mat, dresidualmat, dgamma, dbeta, dcolscale - - -def _dropout_add_layer_norm_subset_forward( - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma.numel() - x0mat = x0.view((-1, hidden_size)) - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - x0_subset = x0_subset.view(-1) if x0_subset is not None else None - out_subset = out_subset.view(-1) if out_subset is not None else None - zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, - residualmat, - gamma, - beta, - None, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask is None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma - - -def _dropout_add_layer_norm_subset_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - colscale, - x0_subset, - out_subset, - dropout_p, - rowscale_const, - x0_numrows, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - x0 must not be None if we have colscale. - """ - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dzmat = dz.view(-1, hidden_size) - dxmat = dx.view(xmat.shape) if dx is not None else None - x0mat = x0.view((-1, hidden_size)) if x0 is not None else None - x0_subset = x0_subset.view(-1) if x0_subset is not None else None - out_subset = out_subset.view(-1) if out_subset is not None else None - if colscale is not None: - assert x0 is not None, "x0 is required to compute the gradient of colscale" - dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( - dzmat, - dxmat, - xmat, - x0mat, - dmask, - mu, - rsigma, - gamma, - None, - colscale, - x0_subset, - out_subset, - dropout_p, - rowscale_const, - x0_numrows, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - if colscale is None: - return dx0mat, dresidualmat, dgamma, dbeta - else: - dcolscale = rest[0] - return dx0mat, dresidualmat, dgamma, dbeta, dcolscale - - -def _dropout_add_layer_norm_parallel_residual_forward( - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32=False, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes""" - hidden_size = gamma0.numel() - x0mat = x0.view((-1, hidden_size)) - x1mat = x1.view((-1, hidden_size)) if x1 is not None else None - residualmat = residual.view((-1, hidden_size)) if residual is not None else None - ( - z0mat, - z1mat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( - x0mat, - x1mat, - residualmat, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - None, - residual_in_fp32, - is_rms_norm, - ) - # dmask0 and dmask1 are None if dropout_p == 0.0 - # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype - return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma - - -def _dropout_add_layer_norm_parallel_residual_backward( - dz0, - dz1, - dx, - x, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - is_rms_norm=False, -): - """Assume that arguments are contiguous and aligned to 16 bytes - dx == None means that it was a post-norm architecture - (x = drop(x0) + residual was not returned in the fwd). - """ - hidden_size = gamma0.numel() - xmat = x.view((-1, hidden_size)) - dz0mat = dz0.view(xmat.shape) - dz1mat = dz1.view(xmat.shape) if dz1 is not None else None - dxmat = dx.view(xmat.shape) if dx is not None else None - ( - dx0mat, - dx1mat, - dresidualmat, - dgamma0, - dbeta0, - dgamma1, - dbeta1, - *rest, - ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( - dz0mat, - dz1mat, - dxmat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - is_rms_norm, - ) - # dresidualmat is None if not has_residual - return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 - - -class DropoutAddLayerNormFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma = maybe_align(gamma.contiguous(), 16) - beta = maybe_align(beta.contiguous(), 16) if beta is not None else None - rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None - colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( - x0, - residual, - gamma, - beta, - rowscale, - colscale, - dropout_p, - epsilon, - residual_in_fp32, - is_rms_norm, - ) - # Only need to save x0 if we need to compute gradient wrt colscale - x0_saved = x0 if colscale is not None else None - ctx.save_for_backward( - xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale - ) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta is not None - if not return_dmask: - return ( - zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) - ) - else: - dmask = ( - dmask.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask) - return ( - (zmat.view(x0.shape), dmask) - if not prenorm - else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) - ) - - @staticmethod - def backward(ctx, dz, *args): - # assert dz.is_contiguous() - dz = maybe_align(dz.contiguous(), 16) # this happens! - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors - # x0 is None if colscale is None - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - rowscale, - colscale, - dropout_p, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(x.shape) - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - dcolscale = rest[0] if colscale is not None else None - return ( - dx0, - dresidual, - dgamma, - dbeta if ctx.has_beta else None, - None, - dcolscale, - None, - None, - None, - None, - None, - None, - ) - - -class DropoutAddLayerNormSubsetFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma = maybe_align(gamma.contiguous(), 16) - beta = maybe_align(beta.contiguous(), 16) if beta is not None else None - colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( - x0, - residual, - gamma, - beta, - colscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - is_rms_norm, - ) - # Only need to save x0 if we need to compute gradient wrt colscale - x0_saved = x0 if colscale is not None else None - x_shape = (-1, *x0.shape[1:]) - ctx.save_for_backward( - xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset - ) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.rowscale_const = rowscale_const - ctx.x0_numrows = x0.shape[:-1].numel() - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta is not None - z_shape = (-1, *x0.shape[1:]) - if not return_dmask: - return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) - else: - z = zmat.view(z_shape) - dmask = ( - dmask.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask) - return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) - - @staticmethod - def backward(ctx, dz, *args): - # assert dz.is_contiguous() - dz = maybe_align(dz.contiguous(), 16) # this happens! - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors - # x0 is None if colscale is None - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( - dz, - dx, - x, - x0, - dmask, - mu, - rsigma, - gamma, - colscale, - x0_subset, - out_subset, - dropout_p, - ctx.rowscale_const, - ctx.x0_numrows, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(-1, *x.shape[1:]) - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - dcolscale = rest[0] if colscale is not None else None - return ( - dx0, - dresidual, - dgamma, - dbeta if ctx.has_beta else None, - dcolscale, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32=False, - prenorm=False, - is_rms_norm=False, - return_dmask=False, - ): - x0 = maybe_align(x0.contiguous(), 16) - x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None - residual = maybe_align(residual.contiguous(), 16) if residual is not None else None - gamma0 = maybe_align(gamma0.contiguous(), 16) - beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None - gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None - beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None - ( - z0mat, - z1mat, - xmat, - dmask0, - dmask1, - mu, - rsigma, - ) = _dropout_add_layer_norm_parallel_residual_forward( - x0, - x1, - residual, - gamma0, - beta0, - gamma1, - beta1, - dropout_p, - epsilon, - residual_in_fp32, - is_rms_norm, - ) - ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) - ctx.prenorm = prenorm - ctx.dropout_p = dropout_p - ctx.has_x1 = x1 is not None - ctx.has_residual = residual is not None - ctx.is_rms_norm = is_rms_norm - ctx.has_beta = beta0 is not None - z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) - if not return_dmask: - return z if not prenorm else (*z, xmat.view(x0.shape)) - else: - dmask0 = ( - dmask0.view(x0.shape) - if dropout_p > 0.0 - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - dmask1 = ( - dmask1.view(x0.shape) - if dropout_p > 0.0 and x1 is not None - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) - ) - ctx.mark_non_differentiable(dmask0) - ctx.mark_non_differentiable(dmask1) - return ( - (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) - ) - - @staticmethod - def backward(ctx, dz0, dz1, *args): - dz0 = maybe_align(dz0.contiguous(), 16) # this happens! - dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None - dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None - x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors - dropout_p = ctx.dropout_p - has_x1 = ctx.has_x1 - has_residual = ctx.has_residual - ( - dx0mat, - dx1mat, - dresidualmat, - dgamma0, - dbeta0, - dgamma1, - dbeta1, - ) = _dropout_add_layer_norm_parallel_residual_backward( - dz0, - dz1, - dx, - x, - dmask0, - dmask1, - mu, - rsigma, - gamma0, - gamma1, - dropout_p, - has_x1, - has_residual, - ctx.is_rms_norm, - ) - dx0 = dx0mat.view(x.shape) - dx1 = dx1mat.view(x.shape) if dx1mat is not None else None - dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None - return ( - dx0, - dx1, - dresidual, - dgamma0, - dbeta0 if ctx.has_beta else None, - dgamma1, - dbeta1 if ctx.has_beta else None, - None, - None, - None, - None, - None, - None, - ) - - -def layer_norm(x, weight, bias, epsilon): - return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) - - -def dropout_add_layer_norm( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - rowscale=None, - layerscale=None, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormFn.apply( - x0, - residual, - weight, - bias, - rowscale, - layerscale, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -def dropout_add_layer_norm_subset( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - layerscale=None, - x0_subset=None, - out_subset=None, - rowscale_const=1.0, - out_numrows=0, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormSubsetFn.apply( - x0, - residual, - weight, - bias, - layerscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -def dropout_add_layer_norm_parallel_residual( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormParallelResidualFn.apply( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - False, - return_dropout_mask, - ) - - -class DropoutAddLayerNorm(torch.nn.Module): - def __init__( - self, - hidden_size, - prenorm=False, - p=0.0, - eps=1e-5, - residual_in_fp32=False, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.prenorm = prenorm - self.p = p - self.eps = eps - self.residual_in_fp32 = residual_in_fp32 - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - init.zeros_(self.bias) - - def forward(self, x0, residual=None): - return dropout_add_layer_norm( - x0, - residual, - self.weight, - self.bias, - self.p if self.training else 0.0, - self.eps, - prenorm=self.prenorm, - residual_in_fp32=self.residual_in_fp32, - ) diff --git a/vllm/thirdparty_files/flash_attn/ops/rms_norm.py b/vllm/thirdparty_files/flash_attn/ops/rms_norm.py deleted file mode 100644 index 068348d61290..000000000000 --- a/vllm/thirdparty_files/flash_attn/ops/rms_norm.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) 2022, Tri Dao. -# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py - -import torch -from torch.nn import init - -from flash_attn.ops.layer_norm import ( - DropoutAddLayerNormFn, - DropoutAddLayerNormParallelResidualFn, - DropoutAddLayerNormSubsetFn, -) - - -def rms_norm(x, weight, epsilon): - return DropoutAddLayerNormFn.apply( - x, None, weight, None, None, None, 0.0, epsilon, False, False, True - ) - - -def dropout_add_rms_norm( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - rowscale=None, - layerscale=None, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormFn.apply( - x0, - residual, - weight, - bias, - rowscale, - layerscale, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - True, - return_dropout_mask, - ) - - -def dropout_add_rms_norm_subset( - x0, - residual, - weight, - bias, - dropout_p, - epsilon, - layerscale=None, - x0_subset=None, - out_subset=None, - rowscale_const=1.0, - out_numrows=0, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormSubsetFn.apply( - x0, - residual, - weight, - bias, - layerscale, - x0_subset, - out_subset, - dropout_p, - epsilon, - rowscale_const, - out_numrows, - residual_in_fp32, - prenorm, - True, - return_dropout_mask, - ) - - -def dropout_add_rms_norm_parallel_residual( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - """residual_in_fp32 only has an effect if residual is None. - Otherwise residual dtype is residual.dtype. - """ - return DropoutAddLayerNormParallelResidualFn.apply( - x0, - x1, - residual, - weight0, - bias0, - weight1, - bias1, - dropout_p, - epsilon, - residual_in_fp32, - prenorm, - True, - return_dropout_mask, - ) - - -class RMSNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.eps = eps - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - - def forward(self, x): - return rms_norm(x, self.weight, self.eps) - - -class DropoutAddRMSNorm(torch.nn.Module): - def __init__( - self, - hidden_size, - prenorm=False, - p=0.0, - eps=1e-5, - residual_in_fp32=False, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.prenorm = prenorm - self.p = p - self.eps = eps - self.residual_in_fp32 = residual_in_fp32 - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - - def forward(self, x0, residual=None): - return dropout_add_rms_norm( - x0, - residual, - self.weight, - None, - self.p if self.training else 0.0, - self.eps, - prenorm=self.prenorm, - residual_in_fp32=self.residual_in_fp32, - ) diff --git a/vllm/thirdparty_files/flash_attn/ops/triton/__init__.py b/vllm/thirdparty_files/flash_attn/ops/triton/__init__.py deleted file mode 100644 index 8b137891791f..000000000000 --- a/vllm/thirdparty_files/flash_attn/ops/triton/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/vllm/thirdparty_files/flash_attn/ops/triton/cross_entropy.py b/vllm/thirdparty_files/flash_attn/ops/triton/cross_entropy.py deleted file mode 100644 index c8111ca54826..000000000000 --- a/vllm/thirdparty_files/flash_attn/ops/triton/cross_entropy.py +++ /dev/null @@ -1,320 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -from typing import Tuple, Optional, Union - -import torch - -from einops import rearrange - -import triton -import triton.language as tl - -# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for -# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent -# version of PyTorch. The following 2 lines are for backward compatibility with -# older PyTorch. -if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base - - -@triton.heuristics( - { - "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, - } -) -@triton.jit -def cross_entropy_fwd_kernel( - loss_ptr, # data ptrs - lse_ptr, - z_loss_ptr, - logits_ptr, - labels_ptr, - smoothing, - logit_scale, - lse_square_scale, - ignored_index, - total_classes, - class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes - n_cols, # shapes - n_rows, - logits_row_stride, # strides - BLOCK_SIZE: tl.constexpr, - HAS_SMOOTHING: tl.constexpr, - # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE - SPLIT: tl.constexpr, -): - row_idx = tl.program_id(0) - col_block_idx = tl.program_id(1) - logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) - col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - label_idx = tl.load(labels_ptr + row_idx) - logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( - tl.float32 - ) * logit_scale - max_logits = tl.max(logits, 0) - if HAS_SMOOTHING: - sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0) - lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits - tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse) - if label_idx == ignored_index: - loss = 0.0 - z_loss = 0.0 - else: - label_idx -= class_start_idx - if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min( - n_cols, (col_block_idx + 1) * BLOCK_SIZE - ): - logits_label = tl.load(logits_ptr + label_idx) * logit_scale - if HAS_SMOOTHING: - loss = ( - (lse if not SPLIT else 0.0) - - smoothing * sum_logits / total_classes - - (1 - smoothing) * logits_label - ) - else: - loss = (lse if not SPLIT else 0.0) - logits_label - else: - # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss - if HAS_SMOOTHING: - loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) - else: - loss = 0.0 - if not SPLIT: - z_loss = lse_square_scale * lse * lse - loss += z_loss - else: - z_loss = 0.0 - tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss) - if not SPLIT: - tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss) - - -@triton.heuristics( - { - "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, - } -) -@triton.jit -def cross_entropy_bwd_kernel( - dlogits_ptr, # data ptrs - dloss_ptr, - logits_ptr, - lse_ptr, - labels_ptr, - smoothing, - logit_scale, - lse_square_scale, - ignored_index, - total_classes, - class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes - n_cols, # shapes - logits_row_stride, # strides - dlogits_row_stride, - dloss_row_stride, - BLOCK_SIZE: tl.constexpr, - HAS_SMOOTHING: tl.constexpr, -): - row_idx = tl.program_id(0) - col_block_idx = tl.program_id(1) - logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) - dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) - col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - label_idx = tl.load(labels_ptr + row_idx) - if label_idx != ignored_index: - dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) - else: - dloss = 0.0 - logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( - tl.float32 - ) * logit_scale - lse = tl.load(lse_ptr + row_idx) - probs = tl.exp(logits - lse) - probs += 2.0 * lse_square_scale * lse * probs - label_idx -= class_start_idx - if HAS_SMOOTHING: - smooth_positive = 1.0 - smoothing - smooth_negative = smoothing / total_classes - probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative - else: - probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) - tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) - - -class CrossEntropyLoss(torch.autograd.Function): - - @staticmethod - def forward( - ctx, - logits, - labels, - smoothing=0.0, - logit_scale=1.0, - lse_square_scale=0.0, - ignored_index=-100, - inplace_backward=False, - process_group=None, - ): - n_rows, n_cols = logits.shape - assert labels.shape == (n_rows,) - world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) - total_classes = world_size * n_cols - rank = 0 if process_group is None else torch.distributed.get_rank(process_group) - class_start_idx = rank * n_cols - - if logits.stride(-1) != 1: - logits = logits.contiguous() - # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py - MAX_BLOCK_SIZE = 64 * 1024 - BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) - num_warps = ( - 4 - if BLOCK_SIZE < 2048 - else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) - ) - # We may split the lse computation across multiple blocks, then do a reduction - # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k) - # where having just one thread block processing more than 64k elements is slow. - split = world_size > 1 or n_cols > MAX_BLOCK_SIZE - n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE - loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,) - losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) - lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) - z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) - # Need this, otherwise Triton tries to launch from cuda:0 and we get - # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - with torch.cuda.device(logits.device.index): - cross_entropy_fwd_kernel[(n_rows, n_splits)]( - losses, # data ptrs - lse, - z_losses, - logits, - labels, - smoothing, - logit_scale, - lse_square_scale, - ignored_index, - total_classes, - class_start_idx, - n_cols, # shapes - n_rows, - logits.stride(0), # strides - BLOCK_SIZE=BLOCK_SIZE, # constants - num_warps=num_warps, - SPLIT=split, - ) - - if split: - # If there's no smoothing, if labels are in the vocab of this partition, losses contains - # - predicted logit, and 0 otherwise. - # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains - # -0.9 * predicted logit - 0.1 * sum logit / total_classes. - # For labels not in the vocab of this partition, losses contains - # -0.1 * sum logit / total_classes. - if n_splits > 1: - lse = torch.logsumexp(lse, dim=0) - losses = losses.sum(dim=0) - if world_size > 1: - lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) - torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) - handle_losses = torch.distributed.all_reduce( - losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True - ) - lse = torch.logsumexp(lse_allgather, dim=0) - handle_losses.wait() - # After the allreduce, if there's no smoothing, the total losses are - predicted_logit, - # we just have to add the (global) lse. - # If there's smoothing=0.1, the total losses are - # -0.9 * predicted_logit - 0.1 * sum logit / total_classes. - # Again, we just have to add the (global) lse. - losses += lse - if lse_square_scale != 0.0: - z_losses = lse_square_scale * lse.square() - z_losses.masked_fill_(labels == ignored_index, 0.0) - losses += z_losses - else: - z_losses = torch.zeros_like(losses) - losses.masked_fill_(labels == ignored_index, 0.0) - - ctx.save_for_backward(logits, lse, labels) - ctx.mark_non_differentiable(z_losses) - ctx.smoothing = smoothing - ctx.logit_scale = logit_scale - ctx.lse_square_scale = lse_square_scale - ctx.ignored_index = ignored_index - ctx.total_classes = total_classes - ctx.class_start_idx = class_start_idx - ctx.inplace_backward = inplace_backward - - return losses, z_losses - - @staticmethod - def backward(ctx, grad_losses, grad_z_losses): - del grad_z_losses # z_losses are only for logging. - - logits, lse, labels = ctx.saved_tensors - dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) - n_rows, n_cols = logits.shape - BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) - num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) - grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa - # Need this, otherwise Triton tries to launch from cuda:0 and we get - # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - with torch.cuda.device(logits.device.index): - cross_entropy_bwd_kernel[grid]( - dlogits, # data ptrs - grad_losses, - logits, - lse, - labels, - ctx.smoothing, - ctx.logit_scale, - ctx.lse_square_scale, - ctx.ignored_index, - ctx.total_classes, - ctx.class_start_idx, - n_cols, # shapes - logits.stride(0), # strides - dlogits.stride(0), - grad_losses.stride(0), - BLOCK_SIZE=BLOCK_SIZE, # constants - num_warps=num_warps, - ) - return dlogits, None, None, None, None, None, None, None, None - -def cross_entropy_loss( - logits: torch.Tensor, - labels: torch.Tensor, - label_smoothing: float = 0.0, - logit_scale: float = 1.0, - lse_square_scale: float = 0.0, - ignored_index=-100, - inplace_backward: bool = False, - process_group=None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Arguments: - logits: (batch, vocab_size) - labels: (batch,) - label_smoothing: float - logit_scale: float. Multiply logits by this scale before calculating the loss. - lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. - This is also referred to as "z-loss". - ignored_index: int. If labels == ignored_index, the loss is set to 0.0. - inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. - This saves memory. - process_group: if not None, we're doing Tensor Parallel: each process is responsible for - one part of the vocab. The loss will be aggregated across processes. - Returns: - losses: (batch,), float - z_losses: (batch,), float - """ - return CrossEntropyLoss.apply( - logits, - labels, - label_smoothing, - logit_scale, - lse_square_scale, - ignored_index, - inplace_backward, - process_group, - ) diff --git a/vllm/thirdparty_files/flash_attn/ops/triton/k_activations.py b/vllm/thirdparty_files/flash_attn/ops/triton/k_activations.py deleted file mode 100644 index efb83c358eb4..000000000000 --- a/vllm/thirdparty_files/flash_attn/ops/triton/k_activations.py +++ /dev/null @@ -1,162 +0,0 @@ -# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import math -from enum import Enum -from typing import Optional - -import triton -import triton.language as tl - -_sqrt2pi = math.sqrt(2.0 / math.pi) -_sqrt1_2 = math.sqrt(1.0 / 2) -_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) - - -class Activation(str, Enum): - SquaredReLU = "squared_relu" - GeLU = "gelu" - GeLUApprox = "gelu_approx" - LeakyReLU = "leaky_relu" - ReLU = "relu" - - -def get_triton_activation_kernel(activation: Optional[Activation]): - return ( - { - Activation.ReLU: relu, - Activation.LeakyReLU: leaky_relu, - Activation.GeLU: gelu, - Activation.GeLUApprox: gelu_approx, - Activation.SquaredReLU: squared_relu, - }[activation] - if activation - else None - ) - - -def get_triton_activation_bwd_kernel(activation: Optional[Activation]): - return ( - { - Activation.ReLU: relu_grad, - Activation.LeakyReLU: leaky_relu_grad, - Activation.GeLU: gelu_grad, - Activation.GeLUApprox: gelu_approx_grad, - Activation.SquaredReLU: squared_relu_grad, - }[activation] - if activation - else None - ) - - -@triton.jit -def tanh(x): - # Tanh is just a scaled sigmoid - return 2 * tl.sigmoid(2 * x) - 1 - - -@triton.jit -def cosh(x): - exp_x = tl.exp(x) - return (exp_x + 1.0 / exp_x) * 0.5 - - -# a Triton implementation of the most used activations -# See for instance http://arxiv.org/abs/1606.08415 for an overview - -# ReLU -@triton.jit -def relu(x): - """ - ReLU_ activation function - - .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html - """ - zero = 0.0 - return tl.where(x >= 0, x, zero.to(x.dtype)) - - -@triton.jit -def relu_grad(x): - # ReLU is different from other activations - # in that it does not require the input to retrospectively compute its gradient - # here the input is the downstream gradient, and we return the upstream gradient directly - zero = 0.0 - one = 1.0 - return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype)) - - -@triton.jit -def squared_relu(x): - """ - Squared ReLU activation, as proposed in the Primer_ paper. - - .. _Primer: https://arxiv.org/abs/2109.08668 - """ - x_ = relu(x) - return (x_ * x_).to(x.dtype) - - -@triton.jit -def squared_relu_grad(x): - return tl.where(x >= 0, 2.0 * x, 0.0) - - -# Leaky ReLU -@triton.jit -def leaky_relu(x): - """ - LeakyReLU_ activation - - .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html - """ - scale = 0.01 + 0.0 - scale = scale.to(x.dtype) - return tl.where(x >= 0, x, scale * x) - - -@triton.jit -def leaky_relu_grad(x): - min_grad = 0.01 - max_grad = 1 - - min_grad = min_grad.to(x.dtype) - max_grad = max_grad.to(x.dtype) - - return tl.where(x >= 0, max_grad, min_grad) - - -@triton.jit -def gelu(x): - """Gaussian Error Linear Unit (GELU)""" - return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) - - -@triton.jit -def gelu_grad(x): - cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) - pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization - return cdf + x * pdf - - -@triton.jit -def gelu_approx(x): - """ - GeLU_ activation - Gaussian error linear unit, with tanh approximation - - .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf - """ - return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x))) - - -@triton.jit -def gelu_approx_grad(x): - # CREDITS: Fast implementation proposed in - # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 - tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( - 1 + tanh_out - ) diff --git a/vllm/thirdparty_files/flash_attn/ops/triton/layer_norm.py b/vllm/thirdparty_files/flash_attn/ops/triton/layer_norm.py deleted file mode 100644 index c922906e4450..000000000000 --- a/vllm/thirdparty_files/flash_attn/ops/triton/layer_norm.py +++ /dev/null @@ -1,1086 +0,0 @@ -# Copyright (c) 2024, Tri Dao. -# Implement dropout + residual + layer_norm / rms_norm. - -# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html -# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. -# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. -# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. - -import math - -import torch -import torch.nn.functional as F -from torch.cuda.amp import custom_fwd, custom_bwd - -import triton -import triton.language as tl - - -def layer_norm_ref( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - dropout_mask=None, - dropout_mask1=None, - upcast=False, -): - dtype = x.dtype - if upcast: - x = x.float() - weight = weight.float() - bias = bias.float() if bias is not None else None - residual = residual.float() if residual is not None else residual - x1 = x1.float() if x1 is not None else None - weight1 = weight1.float() if weight1 is not None else None - bias1 = bias1.float() if bias1 is not None else None - if x1 is not None: - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - if rowscale is not None: - x = x * rowscale[..., None] - if dropout_p > 0.0: - if dropout_mask is not None: - x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) - else: - x = F.dropout(x, p=dropout_p) - if x1 is not None: - if dropout_mask1 is not None: - x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) - else: - x1 = F.dropout(x1, p=dropout_p) - if x1 is not None: - x = x + x1 - if residual is not None: - x = (x + residual).to(x.dtype) - out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( - dtype - ) - if weight1 is None: - return out if not prenorm else (out, x) - else: - out1 = F.layer_norm( - x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps - ).to(dtype) - return (out, out1) if not prenorm else (out, out1, x) - - -def rms_norm_ref( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - dropout_mask=None, - dropout_mask1=None, - upcast=False, -): - dtype = x.dtype - if upcast: - x = x.float() - weight = weight.float() - bias = bias.float() if bias is not None else None - residual = residual.float() if residual is not None else residual - x1 = x1.float() if x1 is not None else None - weight1 = weight1.float() if weight1 is not None else None - bias1 = bias1.float() if bias1 is not None else None - if x1 is not None: - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - if rowscale is not None: - x = x * rowscale[..., None] - if dropout_p > 0.0: - if dropout_mask is not None: - x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) - else: - x = F.dropout(x, p=dropout_p) - if x1 is not None: - if dropout_mask1 is not None: - x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) - else: - x1 = F.dropout(x1, p=dropout_p) - if x1 is not None: - x = x + x1 - if residual is not None: - x = (x + residual).to(x.dtype) - rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) - out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) - if weight1 is None: - return out if not prenorm else (out, x) - else: - out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( - dtype - ) - return (out, out1) if not prenorm else (out, out1, x) - - -@triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], - key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], -) -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) -@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) -@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) -@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) -@triton.jit -def _layer_norm_fwd_1pass_kernel( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - RESIDUAL, # pointer to the residual - X1, - W1, - B1, - Y1, - RESIDUAL_OUT, # pointer to the residual - ROWSCALE, - SEEDS, # Dropout seeds for each row - DROPOUT_MASK, - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_res_row, - stride_res_out_row, - stride_x1_row, - stride_y1_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - dropout_p, # Dropout probability - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_RESIDUAL: tl.constexpr, - STORE_RESIDUAL_OUT: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_DROPOUT: tl.constexpr, - STORE_DROPOUT_MASK: tl.constexpr, - HAS_ROWSCALE: tl.constexpr, - HAS_X1: tl.constexpr, - HAS_W1: tl.constexpr, - HAS_B1: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - X += row * stride_x_row - Y += row * stride_y_row - if HAS_RESIDUAL: - RESIDUAL += row * stride_res_row - if STORE_RESIDUAL_OUT: - RESIDUAL_OUT += row * stride_res_out_row - if HAS_X1: - X1 += row * stride_x1_row - if HAS_W1: - Y1 += row * stride_y1_row - # Compute mean and variance - cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + row).to(tl.float32) - x *= rowscale - if HAS_DROPOUT: - # Compute dropout mask - # 7 rounds is good enough, and reduces register pressure - keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) - if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) - if HAS_X1: - x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) - x1 *= rowscale - if HAS_DROPOUT: - # Compute dropout mask - # 7 rounds is good enough, and reduces register pressure - keep_mask = ( - tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - ) - x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) - if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) - x += x1 - if HAS_RESIDUAL: - residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) - x += residual - if STORE_RESIDUAL_OUT: - tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) - if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - else: - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) - # Normalize and apply linear transformation - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - if HAS_BIAS: - b = tl.load(B + cols, mask=mask).to(tl.float32) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - y = x_hat * w + b if HAS_BIAS else x_hat * w - # Write output - tl.store(Y + cols, y, mask=mask) - if HAS_W1: - w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) - if HAS_B1: - b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) - y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 - tl.store(Y1 + cols, y1, mask=mask) - - -def _layer_norm_fwd( - x, - weight, - bias, - eps, - residual=None, - x1=None, - weight1=None, - bias1=None, - dropout_p=0.0, - rowscale=None, - out_dtype=None, - residual_dtype=None, - is_rms_norm=False, - return_dropout_mask=False, -): - if residual is not None: - residual_dtype = residual.dtype - M, N = x.shape - assert x.stride(-1) == 1 - if residual is not None: - assert residual.stride(-1) == 1 - assert residual.shape == (M, N) - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N,) - if x1 is not None: - assert x1.shape == x.shape - assert rowscale is None - assert x1.stride(-1) == 1 - if weight1 is not None: - assert weight1.shape == (N,) - assert weight1.stride(-1) == 1 - if bias1 is not None: - assert bias1.shape == (N,) - assert bias1.stride(-1) == 1 - if rowscale is not None: - assert rowscale.is_contiguous() - assert rowscale.shape == (M,) - # allocate output - y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) - assert y.stride(-1) == 1 - if weight1 is not None: - y1 = torch.empty_like(y) - assert y1.stride(-1) == 1 - else: - y1 = None - if ( - residual is not None - or (residual_dtype is not None and residual_dtype != x.dtype) - or dropout_p > 0.0 - or rowscale is not None - or x1 is not None - ): - residual_out = torch.empty( - M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype - ) - assert residual_out.stride(-1) == 1 - else: - residual_out = None - mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None - rstd = torch.empty((M,), dtype=torch.float32, device=x.device) - if dropout_p > 0.0: - seeds = torch.randint( - 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 - ) - else: - seeds = None - if return_dropout_mask and dropout_p > 0.0: - dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) - else: - dropout_mask = None - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - with torch.cuda.device(x.device.index): - _layer_norm_fwd_1pass_kernel[(M,)]( - x, - y, - weight, - bias, - residual, - x1, - weight1, - bias1, - y1, - residual_out, - rowscale, - seeds, - dropout_mask, - mean, - rstd, - x.stride(0), - y.stride(0), - residual.stride(0) if residual is not None else 0, - residual_out.stride(0) if residual_out is not None else 0, - x1.stride(0) if x1 is not None else 0, - y1.stride(0) if y1 is not None else 0, - M, - N, - eps, - dropout_p, - is_rms_norm, - BLOCK_N, - residual is not None, - residual_out is not None, - bias is not None, - dropout_p > 0.0, - dropout_mask is not None, - rowscale is not None, - ) - # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 - if dropout_mask is not None and x1 is not None: - dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) - else: - dropout_mask1 = None - return ( - y, - y1, - mean, - rstd, - residual_out if residual_out is not None else x, - seeds, - dropout_mask, - dropout_mask1, - ) - - -@triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], - key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], -) -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) -# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) -@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) -@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) -@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) -@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) -@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) -@triton.jit -def _layer_norm_bwd_kernel( - X, # pointer to the input - W, # pointer to the weights - B, # pointer to the biases - Y, # pointer to the output to be recomputed - DY, # pointer to the output gradient - DX, # pointer to the input gradient - DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - DRESIDUAL, - W1, - DY1, - DX1, - DW1, - DB1, - DRESIDUAL_IN, - ROWSCALE, - SEEDS, - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_dy_row, - stride_dx_row, - stride_dres_row, - stride_dy1_row, - stride_dx1_row, - stride_dres_in_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - dropout_p, - rows_per_program, - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_DRESIDUAL: tl.constexpr, - STORE_DRESIDUAL: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_DROPOUT: tl.constexpr, - HAS_ROWSCALE: tl.constexpr, - HAS_DY1: tl.constexpr, - HAS_DX1: tl.constexpr, - HAS_B1: tl.constexpr, - RECOMPUTE_OUTPUT: tl.constexpr, -): - # Map the program id to the elements of X, DX, and DY it should compute. - row_block_id = tl.program_id(0) - row_start = row_block_id * rows_per_program - # Do not early exit if row_start >= M, because we need to write DW and DB - cols = tl.arange(0, BLOCK_N) - mask = cols < N - X += row_start * stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += row_start * stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += row_start * stride_dres_in_row - DY += row_start * stride_dy_row - DX += row_start * stride_dx_row - if HAS_DY1: - DY1 += row_start * stride_dy1_row - if HAS_DX1: - DX1 += row_start * stride_dx1_row - if RECOMPUTE_OUTPUT: - Y += row_start * stride_y_row - w = tl.load(W + cols, mask=mask).to(tl.float32) - if RECOMPUTE_OUTPUT and HAS_BIAS: - b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) - if HAS_DY1: - w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) - dw = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_BIAS: - db = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_DY1: - dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_B1: - db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) - row_end = min((row_block_id + 1) * rows_per_program, M) - for row in range(row_start, row_end): - # Load data to SRAM - x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) - dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) - if HAS_DY1: - dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) - if not IS_RMS_NORM: - mean = tl.load(Mean + row) - rstd = tl.load(Rstd + row) - # Compute dx - xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - xhat = tl.where(mask, xhat, 0.0) - if RECOMPUTE_OUTPUT: - y = xhat * w + b if HAS_BIAS else xhat * w - tl.store(Y + cols, y, mask=mask) - wdy = w * dy - dw += dy * xhat - if HAS_BIAS: - db += dy - if HAS_DY1: - wdy += w1 * dy1 - dw1 += dy1 * xhat - if HAS_B1: - db1 += dy1 - if not IS_RMS_NORM: - c1 = tl.sum(xhat * wdy, axis=0) / N - c2 = tl.sum(wdy, axis=0) / N - dx = (wdy - (xhat * c1 + c2)) * rstd - else: - c1 = tl.sum(xhat * wdy, axis=0) / N - dx = (wdy - xhat * c1) * rstd - if HAS_DRESIDUAL: - dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) - dx += dres - # Write dx - if STORE_DRESIDUAL: - tl.store(DRESIDUAL_IN + cols, dx, mask=mask) - if HAS_DX1: - if HAS_DROPOUT: - keep_mask = ( - tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - ) - dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) - else: - dx1 = dx - tl.store(DX1 + cols, dx1, mask=mask) - if HAS_DROPOUT: - keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p - dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) - if HAS_ROWSCALE: - rowscale = tl.load(ROWSCALE + row).to(tl.float32) - dx *= rowscale - tl.store(DX + cols, dx, mask=mask) - - X += stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += stride_dres_in_row - if RECOMPUTE_OUTPUT: - Y += stride_y_row - DY += stride_dy_row - DX += stride_dx_row - if HAS_DY1: - DY1 += stride_dy1_row - if HAS_DX1: - DX1 += stride_dx1_row - tl.store(DW + row_block_id * N + cols, dw, mask=mask) - if HAS_BIAS: - tl.store(DB + row_block_id * N + cols, db, mask=mask) - if HAS_DY1: - tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) - if HAS_B1: - tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) - - -def _layer_norm_bwd( - dy, - x, - weight, - bias, - eps, - mean, - rstd, - dresidual=None, - dy1=None, - weight1=None, - bias1=None, - seeds=None, - dropout_p=0.0, - rowscale=None, - has_residual=False, - has_x1=False, - is_rms_norm=False, - x_dtype=None, - recompute_output=False, -): - M, N = x.shape - assert x.stride(-1) == 1 - assert dy.stride(-1) == 1 - assert dy.shape == (M, N) - if dresidual is not None: - assert dresidual.stride(-1) == 1 - assert dresidual.shape == (M, N) - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N,) - if dy1 is not None: - assert weight1 is not None - assert dy1.shape == dy.shape - assert dy1.stride(-1) == 1 - if weight1 is not None: - assert weight1.shape == (N,) - assert weight1.stride(-1) == 1 - if bias1 is not None: - assert bias1.shape == (N,) - assert bias1.stride(-1) == 1 - if seeds is not None: - assert seeds.is_contiguous() - assert seeds.shape == (M if not has_x1 else M * 2,) - if rowscale is not None: - assert rowscale.is_contiguous() - assert rowscale.shape == (M,) - # allocate output - dx = ( - torch.empty_like(x) - if x_dtype is None - else torch.empty(M, N, dtype=x_dtype, device=x.device) - ) - dresidual_in = ( - torch.empty_like(x) - if has_residual - and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) - else None - ) - dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None - y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None - if recompute_output: - assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" - - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count - _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) - _db = ( - torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) - if bias is not None - else None - ) - _dw1 = torch.empty_like(_dw) if weight1 is not None else None - _db1 = torch.empty_like(_db) if bias1 is not None else None - rows_per_program = math.ceil(M / sm_count) - grid = (sm_count,) - with torch.cuda.device(x.device.index): - _layer_norm_bwd_kernel[grid]( - x, - weight, - bias, - y, - dy, - dx, - _dw, - _db, - dresidual, - weight1, - dy1, - dx1, - _dw1, - _db1, - dresidual_in, - rowscale, - seeds, - mean, - rstd, - x.stride(0), - 0 if not recompute_output else y.stride(0), - dy.stride(0), - dx.stride(0), - dresidual.stride(0) if dresidual is not None else 0, - dy1.stride(0) if dy1 is not None else 0, - dx1.stride(0) if dx1 is not None else 0, - dresidual_in.stride(0) if dresidual_in is not None else 0, - M, - N, - eps, - dropout_p, - rows_per_program, - is_rms_norm, - BLOCK_N, - dresidual is not None, - dresidual_in is not None, - bias is not None, - dropout_p > 0.0, - ) - dw = _dw.sum(0).to(weight.dtype) - db = _db.sum(0).to(bias.dtype) if bias is not None else None - dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None - db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None - # Don't need to compute dresidual_in separately in this case - if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: - dresidual_in = dx - if has_x1 and dropout_p == 0.0: - dx1 = dx - return ( - (dx, dw, db, dresidual_in, dx1, dw1, db1) - if not recompute_output - else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) - ) - - -class LayerNormFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - return_dropout_mask=False, - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = x.reshape(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() - if residual is not None: - assert residual.shape == x_shape_og - residual = residual.reshape(-1, residual.shape[-1]) - if residual.stride(-1) != 1: - residual = residual.contiguous() - if x1 is not None: - assert x1.shape == x_shape_og - assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - x1 = x1.reshape(-1, x1.shape[-1]) - if x1.stride(-1) != 1: - x1 = x1.contiguous() - weight = weight.contiguous() - if bias is not None: - bias = bias.contiguous() - if weight1 is not None: - weight1 = weight1.contiguous() - if bias1 is not None: - bias1 = bias1.contiguous() - if rowscale is not None: - rowscale = rowscale.reshape(-1).contiguous() - residual_dtype = ( - residual.dtype - if residual is not None - else (torch.float32 if residual_in_fp32 else None) - ) - y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( - x, - weight, - bias, - eps, - residual, - x1, - weight1, - bias1, - dropout_p=dropout_p, - rowscale=rowscale, - residual_dtype=residual_dtype, - is_rms_norm=is_rms_norm, - return_dropout_mask=return_dropout_mask, - ) - ctx.save_for_backward( - residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd - ) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.dropout_p = dropout_p - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.has_x1 = x1 is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - y = y.reshape(x_shape_og) - y1 = y1.reshape(x_shape_og) if y1 is not None else None - residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None - dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None - dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None - if not return_dropout_mask: - if weight1 is None: - return y if not prenorm else (y, residual_out) - else: - return (y, y1) if not prenorm else (y, y1, residual_out) - else: - if weight1 is None: - return ( - (y, dropout_mask, dropout_mask1) - if not prenorm - else (y, residual_out, dropout_mask, dropout_mask1) - ) - else: - return ( - (y, y1, dropout_mask, dropout_mask1) - if not prenorm - else (y, y1, residual_out, dropout_mask, dropout_mask1) - ) - - @staticmethod - def backward(ctx, dy, *args): - x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors - dy = dy.reshape(-1, dy.shape[-1]) - if dy.stride(-1) != 1: - dy = dy.contiguous() - assert dy.shape == x.shape - if weight1 is not None: - dy1, args = args[0], args[1:] - dy1 = dy1.reshape(-1, dy1.shape[-1]) - if dy1.stride(-1) != 1: - dy1 = dy1.contiguous() - assert dy1.shape == x.shape - else: - dy1 = None - if ctx.prenorm: - dresidual = args[0] - dresidual = dresidual.reshape(-1, dresidual.shape[-1]) - if dresidual.stride(-1) != 1: - dresidual = dresidual.contiguous() - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( - dy, - x, - weight, - bias, - ctx.eps, - mean, - rstd, - dresidual, - dy1, - weight1, - bias1, - seeds, - ctx.dropout_p, - rowscale, - ctx.has_residual, - ctx.has_x1, - ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - ) - return ( - dx.reshape(ctx.x_shape_og), - dw, - db, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, - dw1, - db1, - None, - None, - None, - None, - None, - None, - None, - ) - - -def layer_norm_fn( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - return_dropout_mask=False, -): - return LayerNormFn.apply( - x, - weight, - bias, - residual, - x1, - weight1, - bias1, - eps, - dropout_p, - rowscale, - prenorm, - residual_in_fp32, - is_rms_norm, - return_dropout_mask, - ) - - -def rms_norm_fn( - x, - weight, - bias, - residual=None, - x1=None, - weight1=None, - bias1=None, - eps=1e-6, - dropout_p=0.0, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - return_dropout_mask=False, -): - return LayerNormFn.apply( - x, - weight, - bias, - residual, - x1, - weight1, - bias1, - eps, - dropout_p, - rowscale, - prenorm, - residual_in_fp32, - True, - return_dropout_mask, - ) - - -class RMSNorm(torch.nn.Module): - - def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.eps = eps - if dropout_p > 0.0: - self.drop = torch.nn.Dropout(dropout_p) - else: - self.drop = None - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - torch.nn.init.ones_(self.weight) - - def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): - return rms_norm_fn( - x, - self.weight, - self.bias, - residual=residual, - eps=self.eps, - dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, - prenorm=prenorm, - residual_in_fp32=residual_in_fp32, - ) - - -class LayerNormLinearFn(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = x.reshape(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() - if residual is not None: - assert residual.shape == x_shape_og - residual = residual.reshape(-1, residual.shape[-1]) - if residual.stride(-1) != 1: - residual = residual.contiguous() - norm_weight = norm_weight.contiguous() - if norm_bias is not None: - norm_bias = norm_bias.contiguous() - residual_dtype = ( - residual.dtype - if residual is not None - else (torch.float32 if residual_in_fp32 else None) - ) - y, mean, rstd, residual_out = _layer_norm_fwd( - x, - norm_weight, - norm_bias, - eps, - residual, - out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), - residual_dtype=residual_dtype, - is_rms_norm=is_rms_norm, - ) - y = y.reshape(x_shape_og) - dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype - linear_weight = linear_weight.to(dtype) - linear_bias = linear_bias.to(dtype) if linear_bias is not None else None - out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) - # We don't store y, will be recomputed in the backward pass to save memory - ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - ctx.linear_bias_is_none = linear_bias is None - return out if not prenorm else (out, residual_out.reshape(x_shape_og)) - - @staticmethod - @custom_bwd - def backward(ctx, dout, *args): - x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors - dout = dout.reshape(-1, dout.shape[-1]) - dy = F.linear(dout, linear_weight.t()) - dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) - if dy.stride(-1) != 1: - dy = dy.contiguous() - assert dy.shape == x.shape - if ctx.prenorm: - dresidual = args[0] - dresidual = dresidual.reshape(-1, dresidual.shape[-1]) - if dresidual.stride(-1) != 1: - dresidual = dresidual.contiguous() - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( - dy, - x, - norm_weight, - norm_bias, - ctx.eps, - mean, - rstd, - dresidual, - ctx.has_residual, - ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - recompute_output=True, - ) - dlinear_weight = torch.einsum("bo,bi->oi", dout, y) - return ( - dx.reshape(ctx.x_shape_og), - dnorm_weight, - dnorm_bias, - dlinear_weight, - dlinear_bias, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - None, - None, - None, - None, - ) - - -def layer_norm_linear_fn( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, -): - return LayerNormLinearFn.apply( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual, - eps, - prenorm, - residual_in_fp32, - is_rms_norm, - ) diff --git a/vllm/thirdparty_files/flash_attn/ops/triton/linear.py b/vllm/thirdparty_files/flash_attn/ops/triton/linear.py deleted file mode 100644 index a8966dbc345a..000000000000 --- a/vllm/thirdparty_files/flash_attn/ops/triton/linear.py +++ /dev/null @@ -1,594 +0,0 @@ -# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py -# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py -from typing import Optional - -import torch -import triton -import triton.language as tl -from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time - -from flash_attn.ops.triton.k_activations import ( - gelu, - gelu_approx, - gelu_approx_grad, - gelu_grad, - squared_relu, - squared_relu_grad, -) - -# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications - - -def init_to_zero(name): - return lambda nargs: nargs[name].zero_() - - -def get_configs_io_bound(): - configs = [] - for num_stages in [2, 3, 4, 5, 6]: - for block_m in [16, 32]: - for block_k in [32, 64]: - for block_n in [32, 64, 128, 256]: - num_warps = 2 if block_n <= 64 else 4 - configs.append( - triton.Config( - { - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "SPLIT_K": 1, - }, - num_stages=num_stages, - num_warps=num_warps, - ) - ) - # split_k not used - # for split_k in [2, 4, 8, 16]: - # configs.append(triton.Config( - # {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - # num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) - return configs - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - # good for int8 - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - ] - + get_configs_io_bound(), - key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], - prune_configs_by={ - "early_config_prune": early_config_prune, - "perf_model": estimate_matmul_time, - "top_k": 10, - }, -) -@triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) -@triton.jit -def kernel_fwd( - C, # Pointers to matrices - ACT_INPUT, - A, - B, - bias, - # Matrix dimensions - M, - N, - K, - CACHE_KEY_M, - CACHE_KEY_N, - CACHE_KEY_K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. stride_am is how much to increase a_ptr - # by to get the element one row down (A has M rows) - stride_cm, - # stride_cn, # Assume that stride_cn == 1 - stride_am, - stride_ak, - stride_bn, - stride_bk, - # Meta-parameters - BLOCK_M: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - # split k not used, not performant with activation, kept because early_config_prune is expecting it - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - A_ROWMAJOR: tl.constexpr, - B_COLMAJOR: tl.constexpr, - BIAS: tl.constexpr, - SAVE_ACT_INPUT: tl.constexpr, - ACTIVATION: tl.constexpr, -): - - """ - Kernel for computing Out = activation(A x W + C) - - Input has shape (M, K) - - Weight has shape (K, N) - - Bias has shape (N,) - - Output has shape (M, N) - - ActInputs (optional) has shape (M, N) - 'ActInputs' optionally saves the A x W + C intermediate for backward computations - This kernel will consolidate over K - """ - - pid = tl.program_id(axis=0) - - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - # now compute the block that each program will go through - # rm (resp. rn) denotes a range of indices - # for rows (resp. col) of C - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - # trick to avoid masking on M and N axis - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - - if A_ROWMAJOR: - A = A + (ram[:, None] * stride_am + rk[None, :]) - else: - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - if B_COLMAJOR: - B = B + (rk[:, None] + rbn[None, :] * stride_bn) - else: - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.0) - b = tl.load(B, mask=rk[:, None] < k, other=0.0) - acc += tl.dot(a, b) - - if A_ROWMAJOR: - A += BLOCK_K - else: - A += BLOCK_K * stride_ak - if B_COLMAJOR: - B += BLOCK_K - else: - B += BLOCK_K * stride_bk - - # Putting bias after the matmul (instead of before) is faster, idk why - if BIAS: - bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) - acc += bias[None, :] - - # optional: save the activation inputs - if SAVE_ACT_INPUT: - # act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn - act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] - tl.store(act_in_ptrs, acc) - - # optional: fused activation (while the data is in shared memory) - if ACTIVATION == "gelu": - acc = gelu(acc) - elif ACTIVATION == "gelu_approx": - acc = gelu_approx(acc) - elif ACTIVATION == "squared_relu": - acc = squared_relu(acc) - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # write back result - # C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - C = C + rm[:, None] * stride_cm + rn[None, :] - mask = (rm < M)[:, None] & (rn < N)[None, :] - tl.store(C, acc) - - -def triton_linear_act( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: str = "id", - save_act_input: bool = False, -) -> torch.Tensor: - """ - Compute e = activation(x @ weight.T + bias). - This wrapper kicks the `kernel_fwd` Triton kernel - :param x: input tensor - :param weight: weight matrix - :param bias: an optional bias tensor - :param activation: Activation name. Needs to be a Triton kernel. - :param act_input: an optional tensor to save the activation inputs (for backward) - :return: result tensor - """ - # if torch.is_autocast_enabled(): - # dtype = torch.get_autocast_gpu_dtype() - # x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]] - - assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] - - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - x_reshaped = x.reshape(batch_dim, n) - - if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1: - x_reshaped = x_reshaped.contiguous() - if weight.stride(0) > 1 and weight.stride(1) > 1: - weight = weight.contiguous() - bias = bias.contiguous() if bias is not None else None - - assert ( - x.dtype == weight.dtype - ), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" - if bias is not None: - assert ( - x.dtype == bias.dtype - ), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" - assert ( - x_reshaped.shape[1] == weight.shape[1] - ), f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}" - - assert ( - bias is None or bias.shape[0] == weight.shape[0] - ), "Incompatible dimensions in between weight and bias" - - M, K = x_reshaped.shape - N, K = weight.shape - - output = torch.empty((M, N), device=x.device, dtype=x.dtype) - act_input = torch.empty_like(output) if save_act_input else None - - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa - - kernel_fwd[grid]( - output, - act_input, - x_reshaped, - weight, # data ptrs - bias if bias is not None else x, # auto skip bias if not present - M, # shapes - N, - K, - M // 32, # key for triton cache (limit number of compilations) - N // 32, - K // 32, - stride_cm=output.stride(0), # strides - # stride_cn=output.stride(1), - stride_am=x_reshaped.stride(0), - stride_ak=x_reshaped.stride(1), - stride_bk=weight.stride(1), - stride_bn=weight.stride(0), - BIAS=bias is not None, # optional fused bias - SAVE_ACT_INPUT=save_act_input, # optional save activation inputs - ACTIVATION=activation, # optional fused activation - A_ROWMAJOR=x_reshaped.stride(1) == 1, - B_COLMAJOR=weight.stride(1) == 1, - GROUP_M=8, # speed optimization: group the programs - ) - - if not save_act_input: - return output.reshape(*batch_shape, output.shape[-1]) - else: - return ( - output.reshape(*batch_shape, output.shape[-1]), - act_input.reshape(*batch_shape, act_input.shape[-1]), - ) - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - # good for int8 - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 - ), - ] - + get_configs_io_bound(), - key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], - prune_configs_by={ - "early_config_prune": early_config_prune, - "perf_model": estimate_matmul_time, - "top_k": 10, - }, -) -@triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) -@triton.jit -def kernel_bwd( - C, # Pointers to matrices - ACT_INPUT, - A, - B, - # Matrix dimensions - M, - N, - K, - CACHE_KEY_M, - CACHE_KEY_N, - CACHE_KEY_K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. stride_am is how much to increase a_ptr - # by to get the element one row down (A has M rows) - stride_cm, - # stride_cn, # Assume that stride_cn == 1 - stride_am, - stride_ak, - stride_bk, - stride_bn, - # Meta-parameters - BLOCK_M: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - # split k not used, not performant with activation, kept because early_config_prune is expecting it - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - ACTIVATION: tl.constexpr, -): - - """ - Kernel for computing Out = activation(A x W + C) - - Input has shape (M, K) - - Weight has shape (K, N) - - Output has shape (M, N) - - ActInputs (optional) has shape (M, N) - 'ActInputs' optionally saves the A x W + C intermediate for backward computations - This kernel will consolidate over K - """ - - pid = tl.program_id(axis=0) - - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - # now compute the block that each program will go through - # rm (resp. rn) denotes a range of indices - # for rows (resp. col) of C - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - # trick to avoid masking on M and N axis - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.0) - b = tl.load(B, mask=rk[:, None] < k, other=0.0) - acc += tl.dot(a, b) - - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - - # optional: fused activation (while the data is in shared memory) - if ACTIVATION != "id": - act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] - act_input = tl.load(act_in_ptrs).to(acc.dtype) - if ACTIVATION == "gelu": - acc *= gelu_grad(act_input) - elif ACTIVATION == "gelu_approx": - acc *= gelu_approx_grad(act_input) - elif ACTIVATION == "squared_relu": - acc *= squared_relu_grad(act_input) - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # write back result - C = C + rm[:, None] * stride_cm + rn[None, :] - mask = (rm < M)[:, None] & (rn < N)[None, :] - tl.store(C, acc, mask=mask) - - -def triton_dgrad_act( - grad_output: torch.Tensor, - weight: torch.Tensor, - activation: str = "id", - act_input: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - Compute e = activation(grad_output @ weight + bias). - This wrapper kicks the `kernel_fwd` Triton kernel - :param grad_output: input tensor - :param weight: weight matrix - :param activation: Activation name. Needs to be a Triton kernel. - :param act_input: an optional tensor to save the activation inputs (for backward) - :return: result tensor - """ - assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] - - batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1] - batch_dim = batch_shape.numel() - grad_output_reshaped = grad_output.reshape(batch_dim, n) - - if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: - grad_output_reshaped = grad_output_reshaped.contiguous() - if weight.stride(0) > 1 and weight.stride(1) > 1: - weight = weight.contiguous() - - assert ( - grad_output.dtype == weight.dtype - ), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}" - assert ( - grad_output_reshaped.shape[1] == weight.shape[0] - ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" - if activation != "id": - assert act_input is not None, f"act_input is required for activation {activation}" - - # M, N, K in bwd are different from M, N, K in fwd - M, K = grad_output_reshaped.shape - K, N = weight.shape - - grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype) - - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa - - kernel_bwd[grid]( - grad_input, - act_input, - grad_output_reshaped, - weight, # data ptrs - M, # shapes - N, - K, - M // 32, # key for triton cache (limit number of compilations) - N // 32, - K // 32, - stride_cm=grad_input.stride(0), # strides - # stride_cn=grad_input.stride(1), - stride_am=grad_output_reshaped.stride(0), - stride_ak=grad_output_reshaped.stride(1), - stride_bk=weight.stride(0), - stride_bn=weight.stride(1), - ACTIVATION=activation, # optional fused activation - GROUP_M=8, # speed optimization: group the programs - ) - - return grad_input.reshape(*batch_shape, grad_input.shape[-1]) diff --git a/vllm/thirdparty_files/flash_attn/ops/triton/mlp.py b/vllm/thirdparty_files/flash_attn/ops/triton/mlp.py deleted file mode 100644 index b795310f1c8a..000000000000 --- a/vllm/thirdparty_files/flash_attn/ops/triton/mlp.py +++ /dev/null @@ -1,149 +0,0 @@ -# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared -# to naive implementation. -import fused_dense_lib as fused_dense_cuda -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.cuda.amp import custom_bwd, custom_fwd - -from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd -from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act - - -class FusedDenseSqreluDenseFunc(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0): - """checkpoint_lvl: - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute act_input and gelu_out in the bwd - """ - if torch.is_autocast_enabled(): - dtype = torch.get_autocast_gpu_dtype() - x, weight1, bias1, weight2, bias2 = [ - a.to(dtype=dtype) for a in [x, weight1, bias1, weight2, bias2] - ] - is_bf16 = x.dtype == torch.bfloat16 - assert checkpoint_lvl in [0, 1, 2] - x = x.contiguous() - weight1 = weight1.contiguous() - bias1 = bias1.contiguous() - weight2 = weight2.contiguous() - bias2 = bias2.contiguous() - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - if is_bf16: - act_input = fused_dense_cuda.linear_bias_forward( - x.reshape(batch_dim, n), weight1, bias1 - ) - output1 = sqrelu_fwd(act_input) - else: - save_act_input = checkpoint_lvl != 2 - result = triton_linear_act( - x.reshape(batch_dim, n), - weight1, - bias1, - activation="squared_relu", - save_act_input=save_act_input, - ) - if save_act_input: - output1, act_input = result - else: - output1 = result - output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2) - ctx.checkpoint_lvl = checkpoint_lvl - if checkpoint_lvl == 0: - ctx.save_for_backward(x, weight1, bias1, weight2, act_input, output1) - elif checkpoint_lvl == 1: - ctx.save_for_backward(x, weight1, bias1, weight2, act_input) - elif checkpoint_lvl == 2: - ctx.save_for_backward(x, weight1, bias1, weight2) - return output2.reshape(*batch_shape, output2.shape[-1]) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - grad_output = grad_output.contiguous() - checkpoint_lvl = ctx.checkpoint_lvl - x, weight1, bias1, weight2, *rest = ctx.saved_tensors - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - is_bf16 = x.dtype == torch.bfloat16 - if checkpoint_lvl == 0: - act_input, output1 = rest - elif checkpoint_lvl == 1: - (act_input,) = rest - output1 = sqrelu_fwd(act_input) - elif checkpoint_lvl == 2: - if is_bf16: - act_input = fused_dense_cuda.linear_bias_forward( - x.reshape(batch_dim, n), weight1, bias1 - ) - output1 = sqrelu_fwd(act_input) - else: - output1, act_input = triton_linear_act( - x.reshape(batch_dim, n), - weight1, - bias1, - activation="squared_relu", - save_act_input=True, - ) - - if is_bf16: - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) - grad_output1 = grad_output @ weight2 - grad_act_input = sqrelu_bwd(grad_output1, act_input) - grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( - x.reshape(batch_dim, n), weight1, grad_act_input - ) - else: - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) - grad_act_input = triton_dgrad_act( - grad_output, weight2, activation="squared_relu", act_input=act_input - ) - grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( - x.reshape(batch_dim, n), weight1, grad_act_input - ) - return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None - - -fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply - - -class FusedDenseSqreluDense(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - bias1=True, - bias2=True, - checkpoint_lvl=0, - device=None, - dtype=None, - ): - """ - checkpoint_lvl (increasing lvl means slower but more memory saving): - 0: no recomputation in the bwd - 1: recompute gelu_out in the bwd - 2: recompute gelu_in and gelu_out in the bwd - """ - assert checkpoint_lvl in [0, 1, 2] - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features * 4 - assert bias1 == True, "DenseSqreluDense module without bias is currently not supported" - assert bias2 == True, "DenseSqreluDense module without bias is currently not supported" - self.checkpoint_lvl = checkpoint_lvl - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) - - def forward(self, x): - assert x.is_cuda - return fused_dense_sqrelu_dense_function( - x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias, self.checkpoint_lvl - ) diff --git a/vllm/thirdparty_files/flash_attn/ops/triton/rotary.py b/vllm/thirdparty_files/flash_attn/ops/triton/rotary.py deleted file mode 100644 index 8d2e09b0c8c9..000000000000 --- a/vllm/thirdparty_files/flash_attn/ops/triton/rotary.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -from typing import Optional, Union - -import torch - -import triton -import triton.language as tl - - -# @triton.autotune( -# configs=[ -# triton.Config({"BLOCK_M": 2}), -# triton.Config({"BLOCK_M": 4}), -# triton.Config({"BLOCK_M": 8}), -# triton.Config({"BLOCK_M": 16}), -# ], -# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], -# ) -@triton.jit -def rotary_kernel( - OUT, # Pointers to matrices - X, - COS, - SIN, - CU_SEQLENS, - SEQLEN_OFFSETS, # this could be int or a pointer - # Matrix dimensions - seqlen, - nheads, - rotary_dim, - seqlen_ro, - CACHE_KEY_SEQLEN, - # strides - stride_out_batch, - stride_out_seqlen, - stride_out_nheads, - stride_out_headdim, - stride_x_batch, - stride_x_seqlen, - stride_x_nheads, - stride_x_headdim, - # Meta-parameters - BLOCK_K: tl.constexpr, - IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, - IS_VARLEN: tl.constexpr, - INTERLEAVED: tl.constexpr, - CONJUGATE: tl.constexpr, - BLOCK_M: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) - rotary_dim_half = rotary_dim // 2 - - if not IS_VARLEN: - X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads - OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads - else: - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads - OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads - - if pid_m * BLOCK_M >= seqlen: - return - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - if not IS_SEQLEN_OFFSETS_TENSOR: - rm_cs = rm + SEQLEN_OFFSETS - else: - rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) - rk = tl.arange(0, BLOCK_K) - rk_half = tl.arange(0, BLOCK_K // 2) - - if not INTERLEAVED: - # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT - X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) - COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) - SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) - cos = tl.load( - COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0 - ).to(tl.float32) - sin = tl.load( - SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0 - ).to(tl.float32) - x0 = tl.load( - X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0 - ).to(tl.float32) - x1 = tl.load( - X + rotary_dim_half * stride_x_headdim, - mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), - other=0.0, - ).to(tl.float32) - if CONJUGATE: - sin = -sin - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - # write back result - OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) - tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) - tl.store( - OUT + rotary_dim_half * stride_out_headdim, - o1, - mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), - ) - else: - # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow. - # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...]. - # Loading x0 will be fast but x1 will be slow. - # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...]. - # Then we do the calculation and use tl.where to pick put the right outputs for the even - # and for the odd indices. - rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... - rk_repeat = tl.arange(0, BLOCK_K) // 2 - X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) - X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) - COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) - SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) - cos = tl.load( - COS, - mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), - other=1.0, - ).to(tl.float32) - sin = tl.load( - SIN, - mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), - other=0.0, - ).to(tl.float32) - x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to( - tl.float32 - ) - x1 = tl.load( - X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 - ).to(tl.float32) - if CONJUGATE: - sin = -sin - x0_cos = x0 * cos - x1_sin = x1 * sin - out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) - OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) - tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) - - -def apply_rotary( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - interleaved=False, - inplace=False, - conjugate=False, -) -> torch.Tensor: - """ - Arguments: - x: (batch, seqlen, nheads, headdim) if cu_seqlens is None - else (total_seqlen, nheads, headdim). - cos: (seqlen_ro, rotary_dim / 2) - sin: (seqlen_ro, rotary_dim / 2) - seqlen_offsets: integer or integer tensor of size (batch,) - cu_seqlens: (batch + 1,) or None - max_seqlen: int - Returns: - y: (batch, seqlen, nheads, headdim) - """ - is_varlen = cu_seqlens is not None - if not is_varlen: - batch, seqlen, nheads, headdim = x.shape - else: - assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" - total_seqlen, nheads, headdim = x.shape - batch_p_1 = cu_seqlens.shape[0] - batch = batch_p_1 - 1 - seqlen = max_seqlen - seqlen_ro, rotary_dim = cos.shape - assert sin.shape == cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim, "rotary_dim must be <= headdim" - assert headdim <= 256, "Only support headdim <= 256" - assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" - - assert ( - cos.dtype == sin.dtype - ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" - assert ( - x.dtype == cos.dtype - ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" - - cos, sin = cos.contiguous(), sin.contiguous() - if isinstance(seqlen_offsets, torch.Tensor): - assert seqlen_offsets.shape == (batch,) - assert seqlen_offsets.dtype in [torch.int32, torch.int64] - seqlen_offsets = seqlen_offsets.contiguous() - else: - assert seqlen_offsets + seqlen <= seqlen_ro - - output = torch.empty_like(x) if not inplace else x - if rotary_dim < headdim and not inplace: - output[..., rotary_dim:].copy_(x[..., rotary_dim:]) - - BLOCK_K = ( - 32 - if rotary_dim <= 32 - else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) - ) - grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa - BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) - - # Need this, otherwise Triton tries to launch from cuda:0 and we get - # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - with torch.cuda.device(x.device.index): - rotary_kernel[grid]( - output, # data ptrs - x, - cos, - sin, - cu_seqlens, - seqlen_offsets, - seqlen, # shapes - nheads, - rotary_dim, - seqlen_ro, - seqlen // 128, # key for triton cache (limit number of compilations) - output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 - output.stride(-3), # seqlen_stride or total_seqlen_stride - output.stride(-2), # nheads_stride - output.stride(-1), # headdim_stride - x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 - x.stride(-3), # seqlen stride or total_seqlen_stride - x.stride(-2), # nheads stride - x.stride(-1), # headdim stride - BLOCK_K, - isinstance(seqlen_offsets, torch.Tensor), - is_varlen, - interleaved, - conjugate, - BLOCK_M, - ) - return output diff --git a/vllm/thirdparty_files/flash_attn/utils/__init__.py b/vllm/thirdparty_files/flash_attn/utils/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/thirdparty_files/flash_attn/utils/benchmark.py b/vllm/thirdparty_files/flash_attn/utils/benchmark.py deleted file mode 100644 index 15b30405f209..000000000000 --- a/vllm/thirdparty_files/flash_attn/utils/benchmark.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright (c) 2023, Tri Dao. -""" Useful functions for writing test code. """ - -import torch -import torch.utils.benchmark as benchmark - - -def benchmark_forward( - fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs -): - """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" - if verbose: - print(desc, "- Forward pass") - - def amp_wrapper(*inputs, **kwinputs): - with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): - fn(*inputs, **kwinputs) - - t = benchmark.Timer( - stmt="fn_amp(*inputs, **kwinputs)", - globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, - num_threads=torch.get_num_threads(), - ) - m = t.timeit(repeats) - if verbose: - print(m) - return t, m - - -def benchmark_backward( - fn, - *inputs, - grad=None, - repeats=10, - desc="", - verbose=True, - amp=False, - amp_dtype=torch.float16, - **kwinputs, -): - """Use Pytorch Benchmark on the backward pass of an arbitrary function.""" - if verbose: - print(desc, "- Backward pass") - with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): - y = fn(*inputs, **kwinputs) - if type(y) is tuple: - y = y[0] - if grad is None: - grad = torch.randn_like(y) - else: - if grad.shape != y.shape: - raise RuntimeError("Grad shape does not match output shape") - - def f(*inputs, y, grad): - # Set .grad to None to avoid extra operation of gradient accumulation - for x in inputs: - if isinstance(x, torch.Tensor): - x.grad = None - y.backward(grad, retain_graph=True) - - t = benchmark.Timer( - stmt="f(*inputs, y=y, grad=grad)", - globals={"f": f, "inputs": inputs, "y": y, "grad": grad}, - num_threads=torch.get_num_threads(), - ) - m = t.timeit(repeats) - if verbose: - print(m) - return t, m - - -def benchmark_combined( - fn, - *inputs, - grad=None, - repeats=10, - desc="", - verbose=True, - amp=False, - amp_dtype=torch.float16, - **kwinputs, -): - """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" - if verbose: - print(desc, "- Forward + Backward pass") - with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): - y = fn(*inputs, **kwinputs) - if type(y) is tuple: - y = y[0] - if grad is None: - grad = torch.randn_like(y) - else: - if grad.shape != y.shape: - raise RuntimeError("Grad shape does not match output shape") - - def f(grad, *inputs, **kwinputs): - for x in inputs: - if isinstance(x, torch.Tensor): - x.grad = None - with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): - y = fn(*inputs, **kwinputs) - if type(y) is tuple: - y = y[0] - y.backward(grad, retain_graph=True) - - t = benchmark.Timer( - stmt="f(grad, *inputs, **kwinputs)", - globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs}, - num_threads=torch.get_num_threads(), - ) - m = t.timeit(repeats) - if verbose: - print(m) - return t, m - - -def benchmark_fwd_bwd( - fn, - *inputs, - grad=None, - repeats=10, - desc="", - verbose=True, - amp=False, - amp_dtype=torch.float16, - **kwinputs, -): - """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" - return ( - benchmark_forward( - fn, - *inputs, - repeats=repeats, - desc=desc, - verbose=verbose, - amp=amp, - amp_dtype=amp_dtype, - **kwinputs, - ), - benchmark_backward( - fn, - *inputs, - grad=grad, - repeats=repeats, - desc=desc, - verbose=verbose, - amp=amp, - amp_dtype=amp_dtype, - **kwinputs, - ), - ) - - -def benchmark_all( - fn, - *inputs, - grad=None, - repeats=10, - desc="", - verbose=True, - amp=False, - amp_dtype=torch.float16, - **kwinputs, -): - """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" - return ( - benchmark_forward( - fn, - *inputs, - repeats=repeats, - desc=desc, - verbose=verbose, - amp=amp, - amp_dtype=amp_dtype, - **kwinputs, - ), - benchmark_backward( - fn, - *inputs, - grad=grad, - repeats=repeats, - desc=desc, - verbose=verbose, - amp=amp, - amp_dtype=amp_dtype, - **kwinputs, - ), - benchmark_combined( - fn, - *inputs, - grad=grad, - repeats=repeats, - desc=desc, - verbose=verbose, - amp=amp, - amp_dtype=amp_dtype, - **kwinputs, - ), - ) - - -def pytorch_profiler( - fn, - *inputs, - trace_filename=None, - backward=False, - amp=False, - amp_dtype=torch.float16, - cpu=False, - verbose=True, - **kwinputs, -): - """Wrap benchmark functions in Pytorch profiler to see CUDA information.""" - if backward: - with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): - out = fn(*inputs, **kwinputs) - if type(out) is tuple: - out = out[0] - g = torch.randn_like(out) - for _ in range(30): # Warm up - if backward: - for x in inputs: - if isinstance(x, torch.Tensor): - x.grad = None - with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): - out = fn(*inputs, **kwinputs) - if type(out) is tuple: - out = out[0] - # Backward should be done outside autocast - if backward: - out.backward(g, retain_graph=True) - activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [ - torch.profiler.ProfilerActivity.CUDA - ] - with torch.profiler.profile( - activities=activities, - record_shapes=True, - # profile_memory=True, - with_stack=True, - ) as prof: - if backward: - for x in inputs: - if isinstance(x, torch.Tensor): - x.grad = None - with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): - out = fn(*inputs, **kwinputs) - if type(out) is tuple: - out = out[0] - if backward: - out.backward(g, retain_graph=True) - if verbose: - # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) - print(prof.key_averages().table(row_limit=50)) - if trace_filename is not None: - prof.export_chrome_trace(trace_filename) - - -def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - torch.cuda.synchronize() - fn(*inputs, **kwinputs) - torch.cuda.synchronize() - mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000) - if verbose: - print(f"{desc} max memory: {mem}GB") - torch.cuda.empty_cache() - return mem diff --git a/vllm/thirdparty_files/flash_attn/utils/distributed.py b/vllm/thirdparty_files/flash_attn/utils/distributed.py deleted file mode 100644 index 74c55279645c..000000000000 --- a/vllm/thirdparty_files/flash_attn/utils/distributed.py +++ /dev/null @@ -1,144 +0,0 @@ -from typing import Optional - -import torch -from torch import Tensor -from torch.distributed import ProcessGroup - -# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for -# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent -# version of PyTorch. The following 4 lines are for backward compatibility with -# older PyTorch. -if "all_gather_into_tensor" not in dir(torch.distributed): - torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base -if "reduce_scatter_tensor" not in dir(torch.distributed): - torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base - - -# Raw operation, does not support autograd, but does support async -def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): - world_size = torch.distributed.get_world_size(process_group) - output = torch.empty( - world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device - ) - handle = torch.distributed.all_gather_into_tensor( - output, input_.contiguous(), group=process_group, async_op=async_op - ) - return output, handle - - -# Raw operation, does not support autograd, but does support async -def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): - world_size = torch.distributed.get_world_size(process_group) - assert input_.shape[0] % world_size == 0 - output = torch.empty( - input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device - ) - handle = torch.distributed.reduce_scatter_tensor( - output, input_.contiguous(), group=process_group, async_op=async_op - ) - return output, handle - - -# Raw operation, does not support autograd, but does support async -def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): - input_ = input_.contiguous() - handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) - return input_, handle - - -class AllGatherFunc(torch.autograd.Function): - """Gather the input from sequence parallel region and concatenate.""" - - @staticmethod - def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: - ctx.process_group = process_group - output, _ = all_gather_raw(input_, process_group) - return output - - @staticmethod - def backward(ctx, grad_output: Tensor): - grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) - return grad_input, None - - -# Supports autograd, but does not support async -all_gather = AllGatherFunc.apply - - -class ReduceScatterFunc(torch.autograd.Function): - """Reduce scatter the input from the sequence parallel region and concatenate.""" - - @staticmethod - def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: - ctx.process_group = process_group - output, _ = reduce_scatter_raw(input_, process_group) - return output - - @staticmethod - def backward(ctx, grad_output: Tensor): - grad_input, _ = all_gather_raw(grad_output, ctx.process_group) - return grad_input, None - - -# Supports autograd, but does not support async -reduce_scatter = ReduceScatterFunc.apply - - -class AllReduceFunc(torch.autograd.Function): - """Gather the input from sequence parallel region and concatenate.""" - - @staticmethod - def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: - ctx.process_group = process_group - output, _ = all_reduce_raw(input_, process_group) - return output - - @staticmethod - def backward(ctx, grad_output: Tensor): - return grad_output, None - - -# Supports autograd, but does not support async -all_reduce = AllReduceFunc.apply - - -def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): - # We want to iterate over parameters with _shared_params=True in the same order, - # as different ranks might have different number of parameters (e.g., only rank 0 has bias). - pamams_shared = { - name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False) - } - for _, p in sorted(pamams_shared.items()): - with torch.no_grad(): - # Broadcast needs src to be global rank, not group rank - torch.distributed.broadcast( - p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group - ) - - -# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256 -def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): - # We want to iterate over parameters with _sequence_parallel=True in the same order, - # as different ranks might have different number of parameters (e.g., only rank 0 has bias). - params_seqparallel = { - name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False) - } - grads = [p.grad for _, p in sorted(params_seqparallel.items())] - if grads: - with torch.no_grad(): - coalesced = torch._utils._flatten_dense_tensors(grads) - torch.distributed.all_reduce(coalesced, group=process_group) - for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): - buf.copy_(synced) - - -def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: - """Get the dim for the local rank derived from splitting dim on world_size processes. - - The split may not be even across the world_size processes. - """ - multiple = dim // multiple_of - div = multiple // world_size - mod = multiple % world_size - local_multiple = div + int(local_rank < mod) - return local_multiple * multiple_of diff --git a/vllm/thirdparty_files/flash_attn/utils/generation.py b/vllm/thirdparty_files/flash_attn/utils/generation.py deleted file mode 100644 index d5d1139033aa..000000000000 --- a/vllm/thirdparty_files/flash_attn/utils/generation.py +++ /dev/null @@ -1,735 +0,0 @@ -# Copyright (c) 2023, Tri Dao. -# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 -import gc -import time -from collections import namedtuple -from dataclasses import dataclass, field -from functools import partial -from typing import Callable, Optional, Sequence, Union - -import torch -import torch.nn.functional as F -from einops import rearrange, repeat -from torch import Tensor -from torch.profiler import ProfilerActivity, profile, record_function -from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput - - -@dataclass -class InferenceParams: - """Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference.""" - - max_seqlen: int - max_batch_size: int - seqlen_offset: int = 0 - batch_size_offset: int = 0 - key_value_memory_dict: dict = field(default_factory=dict) - lengths_per_sample: Optional[Tensor] = None - - def reset(self, max_seqlen, max_batch_size): - self.max_seqlen = max_seqlen - self.max_batch_size = max_batch_size - self.seqlen_offset = 0 - if self.lengths_per_sample is not None: - self.lengths_per_sample.zero_() - - -# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py -# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 -def modify_logits_for_top_k_filtering(logits, top_k): - """Set the logits for none top-k values to -inf. Done in-place.""" - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits.masked_fill_(indices_to_remove, float("-Inf")) - - -# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py -# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 -def modify_logits_for_top_p_filtering(logits, top_p): - """Set the logits for none top-p values to -inf. Done in-place.""" - if top_p <= 0.0 or top_p >= 1.0: - return - # First sort and calculate cumulative sum of probabilities. - sorted_logits, sorted_indices = torch.sort(logits, descending=False) - cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) - # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs <= (1 - top_p) - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove - ) - logits.masked_fill_(indices_to_remove, float("-inf")) - - -def sample(logits, top_k=1, top_p=0.0, temperature=1.0): - """Sample from top-k logits. - Arguments: - logits: Tensor of shape (batch_size, vocab_size) - """ - if top_k == 1: # Short-circuit for greedy decoding - return logits.argmax(dim=-1) - else: - if top_p > 0.0: - assert top_p <= 1.0, "top-p should be in (0, 1]." - if top_k > 0: - top_k = min(top_k, logits.size(-1)) # Safety check - logits_top, indices = torch.topk(logits, top_k, dim=-1) - if temperature != 1.0: - logits_top /= temperature - modify_logits_for_top_p_filtering(logits_top, top_p) - return indices[ - torch.arange(indices.shape[0], device=indices.device), - torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), - ] - else: - # Clone so that when we modify for top_p we don't change the original logits - logits_top = logits / temperature if temperature != 1.0 else logits.clone() - modify_logits_for_top_p_filtering(logits_top, top_p) - return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( - dim=-1 - ) - - -@torch.inference_mode() -def decode( - input_ids, - model, - max_length, - top_k=1, - top_p=0.0, - temperature=1.0, - eos_token_id=None, - teacher_outputs=None, - vocab_size=None, - tensor_parallel=1, - cg=False, - enable_timing=False, -): - """Decoding, either greedy or with top-k or top-p sampling. - If top-k = 0, don't limit the number of candidates (pure sampling). - Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, - then top-p. - We assume that all sequences in the same batch have the same length. - - Arguments: - input_ids: (batch, seq_len) - max_length: int - teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the - logits, the next token is taken from the teacher_outputs. Useful for testing. - Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: - sequences: (batch, max_length) - scores: tuples of (batch, vocab_size) - """ - batch_size, seqlen_og = input_ids.shape - teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 - if cg: - if not hasattr(model, "_decoding_cache"): - model._decoding_cache = None - model._decoding_cache = update_graph_cache( - model, - model._decoding_cache, - batch_size, - seqlen_og, - max_length, - tensor_parallel=tensor_parallel, - ) - inference_params = model._decoding_cache.inference_params - inference_params.reset(max_length, batch_size) - else: - inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) - - def get_logits(input_ids, inference_params): - decoding = inference_params.seqlen_offset > 0 - if decoding: - position_ids = torch.full( - (batch_size, 1), - inference_params.seqlen_offset, - dtype=torch.long, - device=input_ids.device, - ) - else: - position_ids = None - if not cg or not decoding: - logits = model( - input_ids, - position_ids=position_ids, - inference_params=inference_params, - num_last_tokens=1, - ).logits.squeeze(dim=1) - else: - logits = model._decoding_cache.run( - input_ids, position_ids, inference_params.seqlen_offset - ).squeeze(dim=1) - return logits[..., :vocab_size] if vocab_size is not None else logits - - def sample_tokens(logits, inference_params): - if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: - token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) - else: - token = teacher_outputs[:, inference_params.seqlen_offset] - # return rearrange(token, "b -> b 1") - return token.unsqueeze(1) - - def should_stop(current_token, inference_params): - if inference_params.seqlen_offset == 0: - return False - if eos_token_id is not None and (current_token == eos_token_id).all(): - return True - if inference_params.seqlen_offset >= max_length - 1: - return True - return False - - start = torch.cuda.Event(enable_timing=enable_timing) - end = torch.cuda.Event(enable_timing=enable_timing) - - if enable_timing: - if tensor_parallel > 1: - torch.distributed.barrier() - start.record() - scores, sequences = [], [input_ids] - while not should_stop(sequences[-1], inference_params): - scores.append(get_logits(sequences[-1], inference_params)) - inference_params.seqlen_offset += sequences[-1].shape[1] - sequences.append(sample_tokens(scores[-1], inference_params)) - if enable_timing: - end.record() - if tensor_parallel > 1: - torch.distributed.barrier() - torch.cuda.synchronize() - print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") - output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput - return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) - - -def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0): - """Algorithm 1 from [1] - [1] Fast Inference from Transformers via Speculative Decoding - Yaniv Leviathan, Matan Kalman, Yossi Matias - https://arxiv.org/abs/2211.17192 - - Arguments: - logits: Tensor of shape (batch_size, seqlen + 1, vocab_size) - logits_draft: Tensor of shape (batch_size, seqlen, vocab_size) - tokens_draft: Tensor of shape (batch_size, seqlen) - Return: - tokens: Tensor of shape (batch_size, seqlen + 1) - num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1]. - For each sequence in the batch, the number of valid tokens that were sampled by - speculative sampling. - """ - batch, seqlen_p_1, vocab_size = logits.shape - seqlen = seqlen_p_1 - 1 - assert logits_draft.shape == (batch, seqlen, vocab_size) - assert tokens_draft.shape == (batch, seqlen) - assert tokens_draft.dtype in [torch.int64, torch.int32] - # TODO: if top_k = 1 we can simplify things and only work with indices - if top_p > 0.0: - assert top_p <= 1.0, "top-p should be in (0, 1]." - # Clone so that when we modify for top_p we don't change the original logits - logits = logits / temperature if temperature != 1.0 else logits.clone() - logits_draft = logits_draft / temperature if temperature != 1.0 else logits_draft.clone() - if top_k > 0: - top_k = min(top_k, logits.size(-1)) # Safety check - modify_logits_for_top_k_filtering(logits, top_k) - modify_logits_for_top_k_filtering(logits_draft, top_k) - modify_logits_for_top_p_filtering(logits, top_p) - modify_logits_for_top_p_filtering(logits_draft, top_p) - probs = torch.softmax(logits, dim=-1) - probs_draft = torch.softmax(logits_draft, dim=-1) - gather = lambda probs, tokens: rearrange( - probs.gather(dim=-1, index=rearrange(tokens, "... -> ... 1")), "... 1 -> ..." - ) - # (batch, seqlen) - accepted = torch.rand(batch, seqlen, device=probs.device) * gather( - probs_draft, tokens_draft - ) <= gather(probs[:, :-1], tokens_draft) - accepted_all = accepted.all(dim=-1) - # (batch,) - first_rejected_idx = torch.where(accepted_all, seqlen, accepted.int().argmin(dim=-1)) - probs_diff = torch.clamp(probs[:, :-1] - probs_draft, min=0.0) - # torch.multinomial can deal with unnormalized probabilities - # probs_diff /= probs_diff.sum(dim=-1, keepdim=True) - resample_probs = torch.cat([probs_diff, probs[:, -1:]], dim=1) - resample_probs = rearrange( - resample_probs.gather(dim=1, index=repeat(first_rejected_idx, "b -> b 1 d", d=vocab_size)), - "b 1 d -> b d", - ) - resample = torch.multinomial(resample_probs, num_samples=1).squeeze(dim=-1) # (batch,) - tokens = F.pad(tokens_draft, (0, 1)) - tokens[:, first_rejected_idx] = resample - return tokens, first_rejected_idx + 1 - - -@torch.inference_mode() -def decode_speculative( - input_ids, - model, - model_draft, - max_length, - speculative_lookahead=3, - top_k=1, - top_p=0.0, - temperature=1.0, - eos_token_id=None, - vocab_size=None, - tensor_parallel=1, - cg=False, - enable_timing=False, - debug=False, -): - """ - TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now. - - Speculative decoding, either greedy or with top-k or top-p sampling. - If top-k = 0, don't limit the number of candidates (pure sampling). - Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, - then top-p. - We assume that all sequences in the same batch have the same length. - - Arguments: - input_ids: (batch, seq_len) - max_length: int - Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: - sequences: (batch, max_length) - scores: tuples of (batch, vocab_size) - """ - batch_size, seqlen_og = input_ids.shape - assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1" - assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id" - if cg: - if not hasattr(model_draft, "_decoding_cache"): - model_draft._decoding_cache = None - model_draft._decoding_cache = update_graph_cache( - model_draft, - model_draft._decoding_cache, - batch_size, - seqlen_og, - max_length, - # draft model needs to process either 1 or 2 tokens at a time - decoding_seqlens=(1, 2), - tensor_parallel=tensor_parallel, - ) - inference_params_draft = model_draft._decoding_cache.inference_params - inference_params_draft.reset(max_length, batch_size) - if not hasattr(model, "_decoding_cache"): - model._decoding_cache = None - model._decoding_cache = update_graph_cache( - model, - model._decoding_cache, - batch_size, - seqlen_og, - max_length, - decoding_seqlens=range(1, speculative_lookahead + 2), - tensor_parallel=tensor_parallel, - ) - inference_params = model._decoding_cache.inference_params - inference_params.reset(max_length, batch_size) - else: - inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) - inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) - - def get_logits(input_ids, inference_params, model, num_last_tokens=1, cg=False): - decoding = inference_params.seqlen_offset > 0 - if decoding: - seqlen = input_ids.shape[1] - # if inference_params.lengths_per_sample is None: - # TODO: in the case of batched decoding where each sequence has a different length, - # we need to compute the position_ids for each sequence using lengths_per_sample - if True: - cache_seqlens = torch.full( - (input_ids.shape[0],), - inference_params.seqlen_offset, - dtype=torch.int32, - device=input_ids.device, - ) - else: - cache_seqlens = inference_params.lengths_per_sample - position_ids = cache_seqlens[:, None] + torch.arange( - seqlen, dtype=torch.long, device=input_ids.device - ) - else: - position_ids = None - if not cg or not decoding: - logits = model( - input_ids, - position_ids=position_ids, - inference_params=inference_params, - num_last_tokens=num_last_tokens, - ).logits - else: - # NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1]. - # This might not be compatible the num_last_tokens used here. - assert num_last_tokens <= input_ids.shape[1] - logits = model._decoding_cache.run( - input_ids, position_ids, inference_params.seqlen_offset - )[:, -num_last_tokens:] - return logits[..., :vocab_size] if vocab_size is not None else logits - - def sample_tokens(input_ids, get_logits_fn, inference_params, sample_fn, num_tokens=1): - """Sample `num_tokens` tokens from the model, given the previous logits. - Also return the logits of the sampled tokens. - Arguments: - input_ids: (batch, seqlen) - Return: - tokens: (batch, num_tokens) - scores: (batch, num_tokens), which contains @previous_logits and the logits of the next - (num_tokens - 1) tokens. The logits of the last token isn't computed. - """ - assert num_tokens >= 1 - sequences, scores = [input_ids], [] - for i in range(num_tokens): - scores.append(get_logits_fn(sequences[-1], inference_params)[:, -1]) - inference_params.seqlen_offset += sequences[-1].shape[1] - sequences.append(sample_fn(scores[-1]).unsqueeze(1)) - return torch.cat(sequences[1:], dim=1), torch.stack(scores, dim=1) - - sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature) - sample_fn = partial(sample, **sampling_kwargs) - get_logits_main = partial(get_logits, model=model, cg=cg) - get_logits_draft = partial(get_logits, model=model_draft, cg=cg) - sample_tokens_main = partial( - sample_tokens, - get_logits_fn=get_logits_main, - sample_fn=sample_fn, - inference_params=inference_params, - ) - sample_tokens_draft = partial( - sample_tokens, - get_logits_fn=get_logits_draft, - sample_fn=sample_fn, - inference_params=inference_params_draft, - ) - - if debug: - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained("gpt2") - if enable_timing: - if tensor_parallel > 1: - torch.distributed.barrier() - torch.cuda.synchronize() - start = time.time() - - sequences, scores = [input_ids], [] - num_main_model_calls = 0 - num_draft_tokens = 0 - num_accepted_tokens_history = [] - if seqlen_og >= max_length - 1: - # Don't do speculative sampling, just sample 1 token from the model - tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1) - sequences.append(tokens) - scores.append(scores_new) - else: - # Sample from draft model, which produces @n_spec_tokens, and @model - # will then use to produce between 1 and 1 + @n_spec_tokens tokens. - # We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length. - n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1) - tokens_draft, scores_draft = sample_tokens_draft(input_ids, num_tokens=n_spec_tokens) - num_draft_tokens += n_spec_tokens - if debug: - scores_draft_ref = model_draft( - torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 - ).logits - print((scores_draft - scores_draft_ref[:, :-1]).abs().max()) - - # Evaluate the draft tokens with the model - logits = get_logits_main( - torch.cat([input_ids, tokens_draft], dim=1), - inference_params, - num_last_tokens=n_spec_tokens + 1, - ) - num_main_model_calls += 1 - if debug: - logits_ref = model( - torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 - ).logits - print((logits - logits_ref).abs().max()) - # breakpoint() - tokens, num_generated_tokens = sample_speculative( - logits, scores_draft, tokens_draft, **sampling_kwargs - ) - num_accepted_tokens_history.append(num_generated_tokens - 1) - if debug: - print(tokens) - print(num_generated_tokens) - # breakpoint() - # TODO: we're using the fact that batch_size == 1 - # TODO: check eos_token_id - sequences.append(tokens[:1, : num_generated_tokens[0]]) - scores.append(logits[:1, : num_generated_tokens[0]]) - # Note that @model has not evaluated the last sampled token yet, so we'll need to pass - # that in the next time we call @model. - num_generated = num_generated_tokens[0].item() - inference_params.seqlen_offset = seqlen_og + num_generated - 1 - inference_params_draft.seqlen_offset = ( - inference_params.seqlen_offset - 1 - if num_generated > 1 - else inference_params.seqlen_offset - ) - if debug: - cur_ids = torch.cat([input_ids, sequences[-1]], dim=1) - scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits - print((scores[-1] - scores_ref[:, :-1]).abs().max()) - # breakpoint() - - while True: - # seqlen_offset is total length generated - 1 - if inference_params.seqlen_offset >= max_length - 1: - break - if inference_params.seqlen_offset >= max_length - 2: - # Don't do speculative sampling, just sample 1 token from the model - tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1) - sequences.append(tokens) - scores.append(scores_new) - break - # Sample from draft model - n_spec_tokens = min( - speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2 - ) - # If the main model accepts all the draft tokens, plus it samples one new token, - # then at the next iteration the draft model need to evaluate the logits of the last draft - # token and the logits of the newly sampled token. So here we pass in the last 2 tokens - # of sequences[-1]. - # This exception is when the main model rejects all the draft tokens, in which case we - # will only have 1 token to pass in. - tokens_draft, scores_draft = sample_tokens_draft( - sequences[-1][:, -2:], num_tokens=n_spec_tokens - ) - num_draft_tokens += n_spec_tokens - if debug: - scores_draft_ref = model_draft( - torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 - ).logits - print((scores_draft - scores_draft_ref[:, :-1]).abs().max()) - # breakpoint() - # Evaluate the draft tokens with the model - logits = get_logits_main( - torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1), - inference_params, - num_last_tokens=n_spec_tokens + 1, - ) # (batch, n_spec_tokens + 1, vocab_size) - num_main_model_calls += 1 - if debug: - logits_ref = model( - torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 - ).logits - print((logits - logits_ref).abs().max()) - # breakpoint() - tokens, num_generated_tokens = sample_speculative( - logits, scores_draft, tokens_draft, **sampling_kwargs - ) - num_accepted_tokens_history.append(num_generated_tokens - 1) - if debug: - print(tokens) - print(num_generated_tokens) - # breakpoint() - sequences.append(tokens[:1, : num_generated_tokens[0]]) - scores.append(logits[:1, : num_generated_tokens[0]]) - # We've evaluated 1 token from sequences[-1][:, -1:] above, plus - # num_generated_tokens[0].item() - 1 tokens from the draft model. - num_generated = num_generated_tokens[0].item() - inference_params.seqlen_offset += num_generated - inference_params_draft.seqlen_offset = ( - inference_params.seqlen_offset - 1 - if num_generated > 1 - else inference_params.seqlen_offset - ) - if debug: - cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1) - scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits - print((scores[-1] - scores_ref[:, :-1]).abs().max()) - # breakpoint() - - if enable_timing: - if tensor_parallel > 1: - torch.distributed.barrier() - torch.cuda.synchronize() - print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") - print(f"Number of calls to main model: {num_main_model_calls}") - print( - f"Acceptance rate: {torch.cat(num_accepted_tokens_history).sum().item() / num_draft_tokens * 100:.2f}%" - ) - sequences = torch.cat(sequences, dim=1) - scores = torch.cat(scores, dim=1) - if debug: - scores_ref = model(sequences).logits - print((scores - scores_ref[:, seqlen_og - 1 : -1]).abs().max()) - output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput - return output_cls(sequences=sequences, scores=scores) - - -class GenerationMixin: - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - raise NotImplementedError - - def generate( - self, - input_ids, - max_length, - top_k=1, - top_p=0.0, - temperature=1.0, - return_dict_in_generate=False, - output_scores=False, - **kwargs, - ): - output = decode( - input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs - ) - if not output_scores: - output.scores = None - return output if return_dict_in_generate else output.sequences - - -def allocate_inference_cache( - max_batch_size, - max_seqlen, - nheads, - headdim, - layers: Union[int, Sequence], - device, - dtype=torch.float16, -): - assert dtype in [torch.float16, torch.bfloat16, torch.float32] - kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim) - if isinstance(layers, int): - layers = range(layers) - return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers} - - -@dataclass -class DecodingCGCache: - max_batch_size: int = 0 - max_seqlen: int = 0 - device = None - dtype = None - callables: dict = field(default_factory=dict) - mempool = None - inference_params: Optional[InferenceParams] = None - run: Optional[Callable] = None - - -@torch.inference_mode() -def update_graph_cache( - model, - cache, - batch_size, - seqlen_og, - max_seqlen, - decoding_seqlens=(1,), - tensor_parallel=1, - dtype=None, - n_warmups=2, -): - if cache is None: - cache = DecodingCGCache() - param_example = next(iter(model.parameters())) - device = param_example.device - if dtype is None: - dtype = param_example.dtype - if ( - (device, dtype) != (cache.device, cache.dtype) - or batch_size > cache.max_batch_size - or max_seqlen > cache.max_seqlen - ): # Invalidate the cache - cache.callables = {} - cache.mempool = None - cache.inference_params = None - gc.collect() - cache.device, cache.dtype = device, dtype - cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen - if hasattr(model, "allocate_inference_cache"): - inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) - else: - headdim = getattr( - model.config, - "head_dim", - model.config.hidden_size // model.config.num_attention_heads, - ) - inf_cache = allocate_inference_cache( - batch_size, - max_seqlen, - model.config.num_attention_heads // tensor_parallel, - headdim, - model.config.num_hidden_layers, - device, - dtype, - ) - lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) - cache.inference_params = InferenceParams( - max_seqlen=max_seqlen, - max_batch_size=batch_size, - seqlen_offset=seqlen_og, - key_value_memory_dict=inf_cache, - lengths_per_sample=lengths_per_sample, - ) - cache.mempool = torch.cuda.graphs.graph_pool_handle() - for decoding_seqlen in decoding_seqlens: - if (batch_size, decoding_seqlen) not in cache.callables: - cache.callables[batch_size, decoding_seqlen] = capture_graph( - model, - cache.inference_params, - batch_size, - max_seqlen, - decoding_seqlen=decoding_seqlen, - mempool=cache.mempool, - n_warmups=n_warmups, - ) - - def dispatch(input_ids, position_ids, seqlen): - batch_size, decoding_seqlen = input_ids.shape[:2] - return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) - - cache.run = dispatch - cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing - return cache - - -def capture_graph( - model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 -): - device = next(iter(model.parameters())).device - input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) - position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) - seqlen_offset_og = inference_params.seqlen_offset - inference_params.seqlen_offset = max_seqlen - decoding_seqlen - inference_params.lengths_per_sample[:] = inference_params.seqlen_offset - - # Warmup before capture - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - for _ in range(n_warmups): - logits = model( - input_ids, - position_ids=position_ids, - inference_params=inference_params, - num_last_tokens=decoding_seqlen, - ).logits - s.synchronize() - # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, - # which requires that graph launch and non-captured launch to not overlap (I think, - # that's how I interpret the documentation). I'm not sure if this is required. - if torch.distributed.is_initialized(): - torch.distributed.barrier() - torch.cuda.current_stream().wait_stream(s) - # Captures the graph - # To allow capture, automatically sets a side stream as the current stream in the context - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, pool=mempool): - logits = model( - input_ids, - position_ids=position_ids, - inference_params=inference_params, - num_last_tokens=decoding_seqlen, - ).logits - - def run(new_input_ids, new_position_ids, seqlen): - inference_params.lengths_per_sample[:] = seqlen - input_ids.copy_(new_input_ids) - position_ids.copy_(new_position_ids) - graph.replay() - return logits.clone() - - inference_params.seqlen_offset = seqlen_offset_og - return run diff --git a/vllm/thirdparty_files/flash_attn/utils/pretrained.py b/vllm/thirdparty_files/flash_attn/utils/pretrained.py deleted file mode 100644 index 40e76bd26923..000000000000 --- a/vllm/thirdparty_files/flash_attn/utils/pretrained.py +++ /dev/null @@ -1,79 +0,0 @@ -import os -from functools import partial - -import torch -from safetensors.torch import load_file as safe_load_file -from transformers.utils import ( - SAFE_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_NAME, - WEIGHTS_INDEX_NAME, - WEIGHTS_NAME, -) -from transformers.utils.hub import cached_file, get_checkpoint_shard_files - - -def state_dict_from_pretrained(model_name, device=None, dtype=None): - # If not fp32, then we don't want to load directly to the GPU - mapped_device = "cpu" if dtype not in [torch.float32, None] else device - is_sharded = False - load_safe = False - resolved_archive_file = None - - weights_path = os.path.join(model_name, WEIGHTS_NAME) - weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME) - safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME) - safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME) - - if os.path.isfile(weights_path): - resolved_archive_file = cached_file( - model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False - ) - elif os.path.isfile(weights_index_path): - resolved_archive_file = cached_file( - model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False - ) - is_sharded = True - elif os.path.isfile(safe_weights_path): - resolved_archive_file = cached_file( - model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False - ) - load_safe = True - elif os.path.isfile(safe_weights_index_path): - resolved_archive_file = cached_file( - model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False - ) - is_sharded = True - load_safe = True - else: # Try loading from HF hub instead of from local files - resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, - _raise_exceptions_for_missing_entries=False) - if resolved_archive_file is None: - resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, - _raise_exceptions_for_missing_entries=False) - if resolved_archive_file is not None: - is_sharded = True - - if resolved_archive_file is None: - raise EnvironmentError(f"Model name {model_name} was not found.") - - if load_safe: - loader = partial(safe_load_file, device=mapped_device) - else: - loader = partial(torch.load, map_location=mapped_device) - - if is_sharded: - # resolved_archive_file becomes a list of files that point to the different - # checkpoint shards in this case. - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( - model_name, resolved_archive_file - ) - state_dict = {} - for sharded_file in resolved_archive_file: - state_dict.update(loader(sharded_file)) - else: - state_dict = loader(resolved_archive_file) - # Convert dtype before moving to GPU to save memory - if dtype is not None: - state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} - state_dict = {k: v.to(device=device) for k, v in state_dict.items()} - return state_dict From 97bcb6f4c3bd854f625c379c5d23d62d06107b8e Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 20 Mar 2024 18:53:16 -0700 Subject: [PATCH 66/88] ip --- .buildkite/test-pipeline.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index cf84d8d981c0..6d052d0f7f4a 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -46,10 +46,6 @@ steps: commands: - pytest -v -s prefix_caching -- label: Chunked Prefill Test - commands: - - pytest -v -s chunked_prefill - - label: Samplers Test command: pytest -v -s samplers From df343509a5568afe5bd5894c1a46bf701bd2f0cc Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 20 Mar 2024 19:16:23 -0700 Subject: [PATCH 67/88] working except tests. --- benchmarks/benchmark_latency.py | 28 +----- .../kernels/benchmark_paged_attention.py | 1 - tests/chunked_prefill/test_correctness.py | 99 ------------------- tests/conftest.py | 2 - tests/core/test_scheduler.py | 27 +++-- tests/samplers/test_sampler.py | 43 -------- vllm/config.py | 4 - vllm/core/scheduler.py | 1 - vllm/engine/arg_utils.py | 4 +- vllm/engine/llm_engine.py | 2 - vllm/entrypoints/llm.py | 1 - vllm/utils.py | 1 - vllm/worker/model_runner.py | 22 ++--- vllm/worker/worker.py | 4 - 14 files changed, 23 insertions(+), 216 deletions(-) delete mode 100644 tests/chunked_prefill/test_correctness.py diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e9e2a883db83..432c2c288086 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -10,13 +10,6 @@ from vllm import LLM, SamplingParams -SAMPLE_PROMPTS = [ - "The president of the United States is", - "Hello, my name is", - "The capital of France is", - "The future of AI is", -] - def main(args: argparse.Namespace): print(args) @@ -35,7 +28,6 @@ def main(args: argparse.Namespace): device=args.device, block_size=args.block_size, max_chunked_prefill_len=args.max_chunked_prefill_len, - max_num_prompt_seqs=args.max_num_prompt_seqs, ray_workers_use_nsight=args.ray_workers_use_nsight, ) @@ -68,25 +60,16 @@ def run_to_completion(profile_dir: Optional[str] = None): print(p.key_averages()) else: start_time = time.perf_counter() - if args.use_sample: - batch = (SAMPLE_PROMPTS * - (args.batch_size // len(SAMPLE_PROMPTS) + - 1))[:args.batch_size] - outputs = llm.generate(prompts=batch, - sampling_params=sampling_params, - use_tqdm=False) - else: - outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids, - sampling_params=sampling_params, - use_tqdm=False) + outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids, + sampling_params=sampling_params, + use_tqdm=False) end_time = time.perf_counter() if args.verbose: for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print( - f"Prompt: {prompt!r}, Generated text: {generated_text!r}" - ) + print(f"Prompt: {prompt!r}, Generated text: " + f"{generated_text!r}") latency = end_time - start_time return latency @@ -182,7 +165,6 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help='print generated text') parser.add_argument('--max-chunked-prefill-len', type=int, default=-1) - parser.add_argument('--max-num-prompt-seqs', type=int, default=1000) parser.add_argument( "--ray-workers-use-nsight", action='store_true', diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index a39a0fb0b2bb..20261ccdce79 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -7,7 +7,6 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random from vllm._C import ops -from vllm.model_executor.layers.attention import flash_attn_with_kvcache_paged NUM_BLOCKS = 1024 PARTITION_SIZE = 512 diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py deleted file mode 100644 index 083d68eb5e78..000000000000 --- a/tests/chunked_prefill/test_correctness.py +++ /dev/null @@ -1,99 +0,0 @@ -import gc - -import pytest -import torch - -from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel - -MODELS = [ - "JackFram/llama-68m", - # "facebook/opt-125m", -] - -TEST_PROMPTS = [ - # pylint: disable=line-too-long - "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", - "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", - "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", - # Different between page attention and flash attention. - # "Describe the basic components of a neural network and how it can be trained.", - "Write a short story about a robot that dreams for the first time.", - "Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.", - "Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.", - "Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'", -] - - -# TODO(sang): Add chunked prefill parameters. -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("max_chunked_prefill_len", [16]) -@pytest.mark.parametrize("max_num_prompt_seqs", [1, 2, 100]) -@pytest.mark.parametrize("block_size", [32]) -@pytest.mark.parametrize("tensor_parallel_size", [1, 2]) -def test_models( - hf_runner, - vllm_runner, - model: str, - dtype: str, - max_tokens: int, - max_chunked_prefill_len: int, - max_num_prompt_seqs: int, - block_size: int, - tensor_parallel_size: int, -) -> None: - """ verify the flash attention has the same output - as page attention """ - if torch.cuda.device_count() < tensor_parallel_size: - pytest.skip( - f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}" - ) - print("loading page attention models..") - pg_model = vllm_runner(model, dtype=dtype) - expected_outputs = [] - - print("generating tokens...") - expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) - print("generating tokens finished") - - del pg_model - - for i in range(len(TEST_PROMPTS)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = expected_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") - - destroy_model_parallel() - gc.collect() - torch.cuda.empty_cache() - - flash_attn_output_by_batches = [] - flash_attn_model = vllm_runner(model, - dtype=dtype, - block_size=block_size, - flash_style=True, - tensor_parallel_size=tensor_parallel_size) - for i in range(10): - prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] - flash_attn_output_by_batches.append( - flash_attn_model.generate_greedy(prompts, max_tokens)) - - del flash_attn_model - destroy_model_parallel() - gc.collect() - torch.cuda.empty_cache() - - for flash_attn_outputs in flash_attn_output_by_batches: - for i in range(len(flash_attn_outputs)): - fa_output_ids, fa_output_str = flash_attn_outputs[i] - vllm_output_ids, vllm_output_str = expected_outputs[ - i % len(expected_outputs)] - print(vllm_output_str) - assert fa_output_ids == vllm_output_ids, ( - f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" - f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" - ) diff --git a/tests/conftest.py b/tests/conftest.py index e254aeb4b8df..4bc635bf4f9a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -168,7 +168,6 @@ def __init__( tensor_parallel_size: int = 1, block_size: int = 16, max_chunked_prefill_len: int = -1, - max_num_prompt_seqs: int = 1000, max_num_batched_tokens: int = 4096, **kwargs, ) -> None: @@ -182,7 +181,6 @@ def __init__( tensor_parallel_size=tensor_parallel_size, block_size=block_size, max_chunked_prefill_len=max_chunked_prefill_len, - max_num_prompt_seqs=max_num_prompt_seqs, max_num_batched_tokens=max_num_batched_tokens, **kwargs, ) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 1540f2fed9e8..482c9393041f 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -134,15 +134,12 @@ def test_scheduler_schedule_chunked_prefill(): num_seq_group = 2 max_model_len = 16 max_chunked_prefill_len = 2 - max_num_prompt_seqs = 1 scheduler_config = SchedulerConfig( 64, num_seq_group, max_model_len, - flash_style=True, - max_chunked_prefill_len=max_chunked_prefill_len, - max_num_prompt_seqs=max_num_prompt_seqs) - cache_config = CacheConfig(block_size, 1.0, 1) + max_chunked_prefill_len=max_chunked_prefill_len) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 scheduler = Scheduler(scheduler_config, cache_config, None) @@ -150,18 +147,15 @@ def test_scheduler_schedule_chunked_prefill(): # Add seq groups to scheduler. seq_groups: List[SequenceGroup] = [] for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - num_processed_token_ids=block_size - - 1) + _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) scheduler.add_seq_group(seq_group) seq_groups.append(seq_group) # Schedule chunk prefill. Only the first seq_group should be scheduled. seq_group_meta, out = scheduler.schedule() assert set(out.scheduled_seq_groups) == set(seq_groups[:1]) - seq_groups[0].get_num_unprefilled() == 2 - seq_groups[1].get_num_unprefilled() == 4 + assert seq_groups[0].get_num_unprefilled() == 2 + assert seq_groups[1].get_num_unprefilled() == 4 assert out.num_batched_tokens == 2 assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -170,11 +164,12 @@ def test_scheduler_schedule_chunked_prefill(): assert seq_group_meta[0].is_chunked_prefill assert seq_group_meta[0].is_prompt - # Schedule chunk prefill. Still Only the first seq_group should be scheduled. + # Schedule chunk prefill. Still Only the first seq_group should be + # scheduled. seq_group_meta, out = scheduler.schedule() assert set(out.scheduled_seq_groups) == set(seq_groups[:1]) - seq_groups[0].get_num_unprefilled() == 0 - seq_groups[1].get_num_unprefilled() == 4 + assert seq_groups[0].get_num_unprefilled() == 0 + assert seq_groups[1].get_num_unprefilled() == 4 assert out.num_batched_tokens == 2 assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -187,8 +182,8 @@ def test_scheduler_schedule_chunked_prefill(): # for chunk prefill, and the first seq_group should be select for decoding. seq_group_meta, out = scheduler.schedule() assert set(out.scheduled_seq_groups) == set(seq_groups) - seq_groups[0].get_num_unprefilled() == 0 - seq_groups[1].get_num_unprefilled() == 2 + assert seq_groups[0].get_num_unprefilled() == 0 + assert seq_groups[1].get_num_unprefilled() == 2 assert out.num_batched_tokens == 3 assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index dda100d2d69e..34894d85ecf1 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -285,49 +285,6 @@ def test_sampling(model_runner: ModelRunner): del model_runner -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_logits_processors(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - input_tensor, _, sampler, model_runner = _prepare_test(batch_size) - - # This sample logits processor gives infinite score to the i-th token, - # where i is the length of the input sequence. - # We therefore expect the output token sequence to be [0, 1, 2, ...] - def pick_ith(token_ids, logits): - logits[len(token_ids)] = float("inf") - return logits - - seq_group_metadata_list = [] - prompt_lens = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - is_chunked_prefill=False, - seq_data={0: SequenceData([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0, - logits_processors=[pick_ith]), - block_tables={0: [1]}, - )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) - for _, sequence_output in enumerate(sampler_output): - for idx, nth_output in enumerate(sequence_output.samples): - assert nth_output.output_token == idx - - del model_runner - - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_top_k_top_p(seed: int, device: str): diff --git a/vllm/config.py b/vllm/config.py index e32609549d4a..dda4bd924ba1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -539,8 +539,6 @@ class SchedulerConfig: requests. Longer requests will be chunked into multiple chunks. -1 means no chunking (disabled). This features is only supported for flash style attention. - max_num_prompt_seqs: The maximum number of prompt sequences that can be - processed in a single iteration. """ def __init__( @@ -549,7 +547,6 @@ def __init__( max_num_seqs: int, max_model_len: int, max_chunked_prefill_len: int = -1, - max_num_prompt_seqs: int = 1024, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -561,7 +558,6 @@ def __init__( self.max_model_len = max_model_len self.chunked_prefill_enabled = max_chunked_prefill_len != -1 self.max_chunked_prefill_len = max_chunked_prefill_len - self.max_num_prompt_seqs = max_num_prompt_seqs self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 823090430e6c..064585bd8bc6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -371,7 +371,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: block_tables[seq_id] = self.block_manager.get_block_table(seq) self.block_manager.access_all_blocks_in_seq(seq, now) - # SANG-TODO Update chunked prefill related info. seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=scheduler_outputs.prompt_run, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8dad83fe33ef..b5143d5920aa 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -52,7 +52,6 @@ class EngineArgs: device: str = 'auto' ray_workers_use_nsight: bool = False max_chunked_prefill_len: int = -1 - max_num_prompt_seqs: int = 256 def __post_init__(self): if self.tokenizer is None: @@ -357,8 +356,7 @@ def create_engine_configs( self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, - max_chunked_prefill_len=self.max_chunked_prefill_len, - max_num_prompt_seqs=self.max_num_prompt_seqs) + max_chunked_prefill_len=self.max_chunked_prefill_len) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 46f15c5e3e7e..f9c21c1c65c2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -615,8 +615,6 @@ def step(self) -> List[RequestOutput]: >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - # print("SANG-TODO step seq_group_metadata_list length: ", - # len(seq_group_metadata_list)) if not scheduler_outputs.is_empty(): output = self.model_executor.execute_model( seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index efdf37030f61..1f463bdaaedc 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -146,7 +146,6 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - # print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/utils.py b/vllm/utils.py index 99a63b611259..8fa372b5f7f0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -322,7 +322,6 @@ def create_kv_caches_with_random( key_caches.append(key_cache) value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_caches = [] for _ in range(num_layers): value_cache = torch.empty(size=value_cache_shape, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index eb6dff99dc31..c26822586ae4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -58,7 +58,6 @@ def __init__( # FIXME(woosuk): This is a hack to make the tests work. Refactor this. self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None) - self.device_config = (device_config if device_config is not None else DeviceConfig()) self.device = self.device_config.device @@ -145,33 +144,31 @@ def _prepare_prompt( subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] num_chunked_prefill = 0 - # Whether or not if any seq_group has prefix cached. - # print("SANG-TODO # of requests (seq_group_metadata_list): ", - # len(seq_group_metadata_list)) + for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 seq_id = seq_ids[0] + computed_block_nums = seq_group_metadata.computed_block_nums if seq_group_metadata.is_chunked_prefill: num_chunked_prefill += 1 - # TODO(sang): Support it. + # TODO(sang): Both are the same thing and should be handled + # in the same way. if computed_block_nums is not None: raise RuntimeError( - "chunked prefill cannot be used with prefix caching now." - ) + "chunked prefill cannot be used with prefix caching " + "now.") seq_data = seq_group_metadata.seq_data[seq_id] prefill_start, prefill_end = seq_data.get_prefill_range() prompt_tokens = seq_data.get_token_ids()[prefill_start:prefill_end] prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) - computed_len = 0 # NOTE: This only works for oooooooxxx style attention. - computed_block_nums = seq_group_metadata.computed_block_nums if computed_block_nums is not None and len( computed_block_nums) > 0 and self.sliding_window is None: # Prefix is not supported with sliding_window @@ -189,9 +186,6 @@ def _prepare_prompt( input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - # NOTE(sang): prefill_end is always # of prompts if chunked - # prefill is not enabled. Prefix caching is not working with - # chunked prefill now. input_positions.extend( list(range(computed_len, computed_len + prefill_end))) @@ -558,15 +552,12 @@ def prepare_input_tensors( # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt - # SANG-TODO set num prompt tokens and generations? # Prepare input tensors. if is_prompt: - # print("SANG-TODO execute model prompt.") (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_prompt(seq_group_metadata_list) else: - # print("SANG-TODO execute model decode.") (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_decode(seq_group_metadata_list) @@ -800,7 +791,6 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: slot_mapping=slot_mapping[:batch_size], num_chunked_prefill=0, prompt_lens=None, - num_chunked_prefill=0, prompt_lens_tensor=None, num_prompt_tokens=0, num_generation_tokens=batch_size, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 49e0bec2056b..7e566095b275 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -113,7 +113,6 @@ def profile_num_available_blocks( gpu_memory_utilization: The fraction of the total GPU memory to use. cpu_swap_space: The size of the CPU swap space in bytes. """ - # print("SANG-TODO profile_num_available_blocks") # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -143,8 +142,6 @@ def profile_num_available_blocks( gc.collect() torch.cuda.empty_cache() - # print("SANG-TODO profile_num_available_blocks done") - return num_gpu_blocks, num_cpu_blocks def init_cache_engine(self, cache_config: CacheConfig) -> None: @@ -184,7 +181,6 @@ def execute_model( blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, ) -> Optional[SamplerOutput]: - # print("SANG-TODO execute model.") if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups = len(seq_group_metadata_list) From e70e03d1ccc403b53746e6e1b49f05fb9621b8bc Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 20 Mar 2024 19:17:00 -0700 Subject: [PATCH 68/88] . --- csrc/cache_kernels.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 76709202b469..7254010b8e3a 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -174,7 +174,6 @@ __global__ void reshape_and_cache_kernel( const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; - const int n = num_heads * head_size; for (int i = threadIdx.x; i < n; i += blockDim.x) { const int64_t src_key_idx = token_idx * key_stride + i; From f89f428e1bcfd9b2f7f641d30f5a7375055e8346 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 20 Mar 2024 21:49:27 -0700 Subject: [PATCH 69/88] ip --- benchmarks/benchmark_latency.py | 18 +++--------------- tests/conftest.py | 2 -- tests/core/test_scheduler.py | 1 + vllm/engine/llm_engine.py | 1 + vllm/worker/worker.py | 1 - 5 files changed, 5 insertions(+), 18 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 432c2c288086..3e668ab08877 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -60,16 +60,10 @@ def run_to_completion(profile_dir: Optional[str] = None): print(p.key_averages()) else: start_time = time.perf_counter() - outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids, - sampling_params=sampling_params, - use_tqdm=False) + llm.generate(prompt_token_ids=dummy_prompt_token_ids, + sampling_params=sampling_params, + use_tqdm=False) end_time = time.perf_counter() - if args.verbose: - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: " - f"{generated_text!r}") latency = end_time - start_time return latency @@ -158,12 +152,6 @@ def run_to_completion(profile_dir: Optional[str] = None): type=int, default=16, help='block size of key/value cache') - parser.add_argument('--use-sample', - action='store_true', - help='use sample input instead of dummy input') - parser.add_argument('--verbose', - action='store_true', - help='print generated text') parser.add_argument('--max-chunked-prefill-len', type=int, default=-1) parser.add_argument( "--ray-workers-use-nsight", diff --git a/tests/conftest.py b/tests/conftest.py index 4bc635bf4f9a..c682b80ac39d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -168,7 +168,6 @@ def __init__( tensor_parallel_size: int = 1, block_size: int = 16, max_chunked_prefill_len: int = -1, - max_num_batched_tokens: int = 4096, **kwargs, ) -> None: self.model = LLM( @@ -181,7 +180,6 @@ def __init__( tensor_parallel_size=tensor_parallel_size, block_size=block_size, max_chunked_prefill_len=max_chunked_prefill_len, - max_num_batched_tokens=max_num_batched_tokens, **kwargs, ) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 482c9393041f..892da752c514 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -152,6 +152,7 @@ def test_scheduler_schedule_chunked_prefill(): seq_groups.append(seq_group) # Schedule chunk prefill. Only the first seq_group should be scheduled. + breakpoint() seq_group_meta, out = scheduler.schedule() assert set(out.scheduled_seq_groups) == set(seq_groups[:1]) assert seq_groups[0].get_num_unprefilled() == 2 diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f9c21c1c65c2..2280481cca9c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -615,6 +615,7 @@ def step(self) -> List[RequestOutput]: >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + if not scheduler_outputs.is_empty(): output = self.model_executor.execute_model( seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7e566095b275..81beb5ce4d8d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -141,7 +141,6 @@ def profile_num_available_blocks( self.model_runner.remove_all_loras() gc.collect() torch.cuda.empty_cache() - return num_gpu_blocks, num_cpu_blocks def init_cache_engine(self, cache_config: CacheConfig) -> None: From bf02f8e04e35fefcf18cf575cb142212dffaf124 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 20 Mar 2024 21:54:36 -0700 Subject: [PATCH 70/88] done --- tests/core/test_scheduler.py | 68 ------------------------------------ 1 file changed, 68 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 892da752c514..57f6f381e2d5 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -129,74 +129,6 @@ def test_scheduler_schedule_preempt_abort(): assert scheduler.get_num_unfinished_seq_groups() == 1 -def test_scheduler_schedule_chunked_prefill(): - block_size = 4 - num_seq_group = 2 - max_model_len = 16 - max_chunked_prefill_len = 2 - scheduler_config = SchedulerConfig( - 64, - num_seq_group, - max_model_len, - max_chunked_prefill_len=max_chunked_prefill_len) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq groups to scheduler. - seq_groups: List[SequenceGroup] = [] - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) - scheduler.add_seq_group(seq_group) - seq_groups.append(seq_group) - - # Schedule chunk prefill. Only the first seq_group should be scheduled. - breakpoint() - seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(seq_groups[:1]) - assert seq_groups[0].get_num_unprefilled() == 2 - assert seq_groups[1].get_num_unprefilled() == 4 - assert out.num_batched_tokens == 2 - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 1 - assert seq_group_meta[0].request_id == "0" - assert seq_group_meta[0].is_chunked_prefill - assert seq_group_meta[0].is_prompt - - # Schedule chunk prefill. Still Only the first seq_group should be - # scheduled. - seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(seq_groups[:1]) - assert seq_groups[0].get_num_unprefilled() == 0 - assert seq_groups[1].get_num_unprefilled() == 4 - assert out.num_batched_tokens == 2 - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 1 - assert seq_group_meta[0].request_id == "0" - assert not seq_group_meta[0].is_chunked_prefill - assert seq_group_meta[0].is_prompt - - # Schedule chunk prefill. This time the second seq_group should be selected - # for chunk prefill, and the first seq_group should be select for decoding. - seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(seq_groups) - assert seq_groups[0].get_num_unprefilled() == 0 - assert seq_groups[1].get_num_unprefilled() == 2 - assert out.num_batched_tokens == 3 - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 2 - assert seq_group_meta[0].request_id == "1" - assert seq_group_meta[0].is_chunked_prefill - assert seq_group_meta[0].is_prompt - assert seq_group_meta[1].request_id == "0" - assert not seq_group_meta[1].is_chunked_prefill - assert not seq_group_meta[1].is_prompt - - def test_scheduler_max_seqs(): block_size = 4 num_seq_group = 4 From ad43095b4a842533092e22750fc7bf27a21b199f Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 20 Mar 2024 23:02:26 -0700 Subject: [PATCH 71/88] done --- tests/core/test_scheduler.py | 2 ++ tests/test_sequence.py | 14 ++------------ 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 57f6f381e2d5..66feeec90ad9 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -56,6 +56,8 @@ def test_scheduler_schedule_simple(): cache_config.num_gpu_blocks = 8 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. for i in range(num_seq_group): _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) scheduler.add_seq_group(seq_group) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index c948abe437a1..3f665b78606c 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,16 +1,6 @@ import pytest -from vllm.sequence import (SequenceData, Sequence, SequenceGroupOutput, - SamplerOutput, SequenceOutput) - - -@pytest.fixture(name="sequence") -def create_sequence(seq_len: int, block_size: int) -> Sequence: - return Sequence( - seq_id=0, - prompt="", - prompt_token_ids=list(range(seq_len)), - block_size=block_size, - ) +from vllm.sequence import (SequenceData, SequenceGroupOutput, SamplerOutput, + SequenceOutput) @pytest.fixture From 16b61965f84213d1ce407946e9e6efdd538ce84e Mon Sep 17 00:00:00 2001 From: sang Date: Fri, 22 Mar 2024 03:33:09 -0700 Subject: [PATCH 72/88] Addressed code review. --- .../kernels/benchmark_paged_attention.py | 5 +-- tests/test_sequence.py | 8 ++--- vllm/core/scheduler.py | 7 +++- vllm/engine/arg_utils.py | 5 --- vllm/sequence.py | 32 +++++++++++++------ vllm/worker/model_runner.py | 22 ++++++------- 6 files changed, 44 insertions(+), 35 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 20261ccdce79..d921dea1220e 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -168,10 +168,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: type=int, choices=[64, 80, 96, 112, 128, 256], default=128) - parser.add_argument("--block-size", - type=int, - choices=[16, 32, 256], - default=16) + parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--dtype", type=str, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 3f665b78606c..bfc7e1deac27 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -53,21 +53,21 @@ def test_sampler_output_eq(sample_outputs): def test_sequence_data_prefill(): seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4]) assert seq_data.get_prefill_range() == (0, 0) - assert seq_data.get_num_unprefilled() == 4 + assert seq_data.get_num_uncomputed_tokens() == 4 # advance by 2 assert seq_data.advance_prefill_range(2) == 2 - assert seq_data.get_num_unprefilled() == 2 + assert seq_data.get_num_uncomputed_tokens() == 2 assert seq_data.get_prefill_range() == (0, 2) # advance range by 3 even though there are only 2 unprefilled tokens assert seq_data.advance_prefill_range(3) == 2 - assert seq_data.get_num_unprefilled() == 0 + assert seq_data.get_num_uncomputed_tokens() == 0 assert seq_data.get_prefill_range() == (2, 4) # following advances should not change anything assert seq_data.advance_prefill_range(2) == 0 - assert seq_data.get_num_unprefilled() == 0 + assert seq_data.get_num_uncomputed_tokens() == 0 assert seq_data.get_prefill_range() == (4, 4) # append tokens and reset, simulating recompute diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 064585bd8bc6..499a3953a461 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -390,6 +390,10 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_manager.fork(parent_seq, child_seq) def free_seq(self, seq: Sequence) -> None: + """Free a sequence from a block table. + + Freed sequence can be used + """ self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: @@ -452,7 +456,8 @@ def _preempt_by_recompute( assert len(seqs) == 1 for seq in seqs: seq.status = SequenceStatus.WAITING - self.block_manager.free(seq) + self.free_seq(seq) + seq.on_recompute() # NOTE: For FCFS, we insert the preempted sequence group to the front # of the waiting queue. self.waiting.appendleft(seq_group) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b5143d5920aa..d7b458833a0b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -312,11 +312,6 @@ def add_cli_args( default=-1, help='max number of prefill tokens allowed in chunked prefill' ', -1 means no limit') - parser.add_argument( - '--max-num-prompt-seqs', - type=int, - default=1024, - help='max number of prompt sequences allowed in prefill') return parser @classmethod diff --git a/vllm/sequence.py b/vllm/sequence.py index a0216379b903..6662f735cef8 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -114,6 +114,13 @@ def append_token_id(self, token_id: int, logprob: float) -> None: self.output_token_ids.append(token_id) self.cumulative_logprob += logprob + def reset_prefill_range(self) -> None: + """Reset the prefill range. It is supposed to be called when a + sequence needs to be started from the beginning. + """ + self._prefill_start = 0 + self._prefill_end = 0 + def get_len(self) -> int: return len(self.output_token_ids) + len(self.prompt_token_ids) @@ -147,13 +154,11 @@ def get_prefill_range(self) -> Tuple[int, int]: """Returns the prefill range.""" return self._prefill_start, self._prefill_end - def get_num_unprefilled(self) -> int: - """Return the number of prefil tokens that are not completed. - - Note that we use prompt_len + output_len instead of - prompt_len here. This is because during recompute - we need to prefill for both prompt and output. - """ + def get_num_uncomputed_tokens(self) -> int: + """Return the number of prefil tokens that are not computed.""" + # we use `get_len()` which includes prompt_len + output_len instead + # of prompt_len here. This is because during recompute we need to + # prefill for both prompt and output. return self.get_len() - self._prefill_end def get_last_token_id(self) -> int: @@ -231,6 +236,10 @@ def hash_of_block(self, logical_idx: int) -> int: def num_hashed_tokens_of_block(self, logical_idx: int): return logical_idx * self.block_size + self.block_size + def on_recompute(self): + """Reset the sequence states for recomputation.""" + self.data.reset_prefill_range() + def _append_logical_block(self) -> None: block = LogicalTokenBlock( block_number=len(self.logical_token_blocks), @@ -445,9 +454,12 @@ def advance_prefill_range(self, size: int) -> int: for seq in self.seqs_dict.values() ][0] - def get_num_unprefilled(self) -> int: - # All sequences in the group should have the same prompt. - return list(self.seqs_dict.values())[0].data.get_num_unprefilled() + def get_num_uncomputed_tokens(self) -> int: + # All sequences in the group should have the same prompt, so the + # number of unfinished prefill tokens are the same across all + # sequences. + return list( + self.seqs_dict.values())[0].data.get_num_uncomputed_tokens() def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: return len(self.get_seqs(status)) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c26822586ae4..0a07b3b22efb 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -165,6 +165,10 @@ def _prepare_prompt( prefill_start, prefill_end = seq_data.get_prefill_range() prompt_tokens = seq_data.get_token_ids()[prefill_start:prefill_end] prompt_len = len(prompt_tokens) + # Right now, the prefill_end is always same as the length of + # prompt. However, once chunked prefill is introduced, this + # assumption can be changed. + assert prefill_end == seq_data.get_prompt_len() prompt_lens.append(prompt_len) computed_len = 0 @@ -175,19 +179,21 @@ def _prepare_prompt( computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) - context_len = computed_len else: prefix_block_tables.append([]) - context_len = 0 + computed_len = prefill_start + # Right now, prefill start is always 0. However, this + # assumption can be changed once chunked prefill is introduced. + assert computed_len == 0 + # actual prompt lens - context_lens.append(context_len) + context_lens.append(computed_len) subquery_lens.append(prompt_len - computed_len) input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend( - list(range(computed_len, computed_len + prefill_end))) + input_positions.extend(list(range(computed_len, prefill_end))) lora_id = seq_group_metadata.lora_int_id @@ -220,12 +226,6 @@ def _prepare_prompt( "sliding window attention") start_idx = max(0, prompt_len - self.sliding_window) - # If chunked prefill is enabled, computed_len is always 0. - # TODO(sang) This is hack. We should clean it up when - # supporting prefix cache + chunked prefill. - if computed_len == 0: - computed_len = prefill_start - for i in range(computed_len, prefill_end): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) From 5002e61908e0c5c2a758490ff2840d3bf770041b Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 25 Mar 2024 06:48:47 -0700 Subject: [PATCH 73/88] update --- vllm/attention/backends/flash_attn.py | 2 + vllm/attention/backends/xformers.py | 2 + vllm/engine/arg_utils.py | 11 +-- vllm/model_executor/input_metadata.py | 100 -------------------------- 4 files changed, 10 insertions(+), 105 deletions(-) delete mode 100644 vllm/model_executor/input_metadata.py diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ac33a917bb0a..b49501c50e7e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -71,6 +71,8 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): num_prompt_tokens: int # The number of generation tokens. Doesn't include padding. num_generation_tokens: int + # The number of chunked prefill sequences in the batch. + num_chunked_prefill: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b7eff2b598e1..8ca6b2b492c6 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -79,6 +79,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): num_prompt_tokens: int # The number of generation tokens. Doesn't include padding. num_generation_tokens: int + # The number of chunked prefill sequences in the batch. + num_chunked_prefill: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a3dc6be19267..522ce56a0a7b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -354,11 +354,12 @@ def create_engine_configs( self.tokenizer_pool_type, self.tokenizer_pool_extra_config, ), self.ray_workers_use_nsight) - scheduler_config = SchedulerConfig(self.max_num_batched_tokens, - self.max_num_seqs, - model_config.max_model_len, - delay_factor=self.scheduler_delay_factor, - max_chunked_prefill_len=self.max_chunked_prefill_len) + scheduler_config = SchedulerConfig( + self.max_num_batched_tokens, + self.max_num_seqs, + model_config.max_model_len, + delay_factor=self.scheduler_delay_factor, + max_chunked_prefill_len=self.max_chunked_prefill_len) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py deleted file mode 100644 index 2b087293cd6e..000000000000 --- a/vllm/model_executor/input_metadata.py +++ /dev/null @@ -1,100 +0,0 @@ -from dataclasses import dataclass, fields -from typing import Optional, List, Any, Dict - -import torch -from xformers.ops.fmha.attn_bias import AttentionBias - - -@dataclass -class InputMetadata: - """Metadata for input sequences. Used in PagedAttention. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor - # The number of chunked prefill sequences in the batch. - num_chunked_prefill: int - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int - """ - Definition of context_len, subquery_len, and seqlen. - |---------- N-1 iteration --------| - |---------------- N iteration ---------------------| - |- tokenA -|......................|-- newTokens ---| - |---------- context_len ----------| - |-------------------- seqlen ----------------------| - |- subquery_len -| - - WARNING: context_len has different definition depending on if it is - prefill vs decoding. When it is prefill, it doesn't include new - tokens. When it is for decoding, it includes a new token. - """ - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] - # Maximum context length in the batch. - max_context_len: Optional[int] - # FIXME: It is for flash attn. - # Maximum sequence length in the batch. - max_seq_len: Optional[int] - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] - # FIXME: It is for flash attn. - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - # (batch_size,). The length of context (tokens stored in KV cache) per - # sequence. WARNING: When it is a prefill request, it doesn't include new - # tokens. When it is for decoding, it includes a new token. - context_lens: Optional[torch.Tensor] - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - use_cuda_graph: bool - kv_cache_dtype: str - - def __post_init__(self): - # Set during the execution of the first attention op. - # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. - # will not appear in the __repr__ and __init__ - self.attn_bias: Optional[List[AttentionBias]] = None - - # Cuda graph is only used for decoding now. - if self.use_cuda_graph: - assert self.num_prompt_tokens == 0 - - def asdict_zerocopy(self) -> Dict[str, Any]: - """Similar to dataclasses.asdict, but avoids deepcopying.""" - # Note that if we add dataclasses as fields, they will need - # similar handling. - return { - field.name: getattr(self, field.name) - for field in fields(self) - } From 80f51ea97a65969b6d018ef323fbfce620677c44 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 25 Mar 2024 07:33:41 -0700 Subject: [PATCH 74/88] test fix --- tests/spec_decode/test_utils.py | 2 ++ tests/test_logits_processor.py | 1 + tests/worker/test_model_runner.py | 1 + vllm/spec_decode/batch_expansion.py | 1 + 4 files changed, 5 insertions(+) diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 19833ddb0615..b94a07571583 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -15,6 +15,7 @@ def test_get_all_seq_ids(): SequenceGroupMetadata( request_id=str(seq_id), is_prompt=True, + is_chunked_prefill=False, seq_data={ seq_id: MagicMock(), }, @@ -37,6 +38,7 @@ def fake_sequence_group_metadata(): SequenceGroupMetadata( request_id=str(i), is_prompt=True, + is_chunked_prefill=False, seq_data={ i: MagicMock(), }, diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index fe321520114f..198d84c599b3 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -70,6 +70,7 @@ def pick_ith(token_ids, logits): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, + is_chunked_prefill=False, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index d28f4c1b14fd..e75975dc3892 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -136,6 +136,7 @@ def test_prepare_decode_cuda_graph(batch_size): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, + is_chunked_prefill=False, seq_data={0: SequenceData(seq_data)}, sampling_params=SamplingParams(temperature=0), block_tables={0: [1]}, diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 0f698fa34601..ffdb9f5d1682 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -266,6 +266,7 @@ def _create_single_target_seq_group_metadata( return SequenceGroupMetadata( request_id=seq_group_metadata.request_id, is_prompt=seq_group_metadata.is_prompt, + is_chunked_prefill=False, seq_data={ target_seq_id: SequenceData( From fa7ba35bd72b6ef76aa79f70e5b8b88c46582aeb Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 25 Mar 2024 15:46:05 -0700 Subject: [PATCH 75/88] lint --- tests/test_sequence.py | 3 ++- vllm/sequence.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index bfc7e1deac27..2921baa7a54f 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,5 +1,6 @@ import pytest -from vllm.sequence import (SequenceData, SequenceGroupOutput, SamplerOutput, + +from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) diff --git a/vllm/sequence.py b/vllm/sequence.py index 91442e115aac..aa5de8474996 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -2,7 +2,8 @@ import copy import enum from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + from vllm.block import LogicalTokenBlock from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams From 51cf7f2c88f9114b01ca5f0e0eb1078be03d0903 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 25 Mar 2024 16:10:51 -0700 Subject: [PATCH 76/88] fix broken tests. --- tests/spec_decode/utils.py | 20 +++++++++++--------- tests/worker/test_model_runner.py | 5 +++-- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 43d9f01dcad7..6cee91db050f 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -161,22 +161,24 @@ def create_seq_group_metadata_from_prompts( for i, final_len in enumerate(final_seq_lens) } + seq_data_lst = [] + for prompt_token_ids, cont_token_ids in zip(prompts, continuations): + seq_data = SequenceData(prompt_token_ids=prompt_token_ids[:], + output_token_ids=cont_token_ids[:]) + seq_data.advance_prefill_range(len(prompt_token_ids)) + seq_data_lst.append(seq_data) + return [ SequenceGroupMetadata( request_id=str(i), is_prompt=len(cont_token_ids) == 0, is_chunked_prefill=False, - seq_data={ - i: - SequenceData( - prompt_token_ids=prompt_token_ids[:], - output_token_ids=cont_token_ids[:], - ), - }, + seq_data={i: seq_data}, sampling_params=SamplingParams(temperature=0.0, ), block_tables={i: block_allocations[i][:]}, - ) for i, (prompt_token_ids, - cont_token_ids) in enumerate(zip(prompts, continuations)) + ) for i, ( + prompt_token_ids, cont_token_ids, + seq_data) in enumerate(zip(prompts, continuations, seq_data_lst)) ] diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index e7ae90a9c731..30f4715112fd 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -18,16 +18,17 @@ def test_prepare_prompt(batch_size): # make sure all tokens fit into one block prompt_len = i % (model_runner.block_size - 1) + 1 prompt_lens.append(prompt_len) - seq_data = list(range(prompt_len)) + seq_data = SequenceData(list(range(prompt_len))) seq_group_metadata_list.append( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, is_chunked_prefill=False, - seq_data={0: SequenceData(seq_data)}, + seq_data={0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, )) + seq_data.advance_prefill_range(prompt_len) expected_selected_token_indices = [] selected_token_start_idx = 0 From cdee1c6ce5c2969a8426e435fd86f71463966b97 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 25 Mar 2024 17:05:03 -0700 Subject: [PATCH 77/88] . --- tests/samplers/test_sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 422120d63d3f..0b1baaf25771 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -269,6 +269,7 @@ def generate_test_case(): SequenceGroupMetadata( request_id="test_1", is_prompt=True, + is_chunked_prefill=False, seq_data={ next(seq_id_counter): create_sequence_data(), }, From 16e3a7dd0d2eebf74a46a88e5d2195deebc42c33 Mon Sep 17 00:00:00 2001 From: sang Date: Mon, 25 Mar 2024 18:40:45 -0700 Subject: [PATCH 78/88] done --- tests/samplers/test_sampler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 0b1baaf25771..d139d2128e8a 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -251,6 +251,7 @@ def generate_test_case(): SequenceGroupMetadata( request_id=f"test_{batch_size}", is_prompt=is_prompt, + is_chunked_prefill=False, seq_data=seq_data, sampling_params=sampling_params, block_tables={}, @@ -285,6 +286,7 @@ def generate_test_case(): SequenceGroupMetadata( request_id="test_1", is_prompt=True, + is_chunked_prefill=False, seq_data={ next(seq_id_counter): create_sequence_data(), }, @@ -300,6 +302,7 @@ def generate_test_case(): SequenceGroupMetadata( request_id="test_1", is_prompt=False, + is_chunked_prefill=False, seq_data={ next(seq_id_counter): create_sequence_data(num_generated=1), @@ -317,6 +320,7 @@ def generate_test_case(): SequenceGroupMetadata( request_id="test_1", is_prompt=False, + is_chunked_prefill=False, seq_data={ next(seq_id_counter): create_sequence_data(num_generated=1), @@ -330,6 +334,7 @@ def generate_test_case(): SequenceGroupMetadata( request_id="test_2", is_prompt=True, + is_chunked_prefill=False, seq_data={ next(seq_id_counter): create_sequence_data(), }, From e0d301c4956ff246c6b223f0e0aa99f9e6934888 Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 26 Mar 2024 18:05:27 -0700 Subject: [PATCH 79/88] remove num chunked prefill from seq group metadata --- benchmarks/benchmark_latency.py | 7 ++++++- tests/samplers/test_sampler.py | 9 --------- tests/spec_decode/test_utils.py | 2 -- tests/spec_decode/utils.py | 1 - tests/test_logits_processor.py | 1 - tests/worker/test_model_runner.py | 2 -- vllm/attention/backends/flash_attn.py | 2 -- vllm/attention/backends/xformers.py | 2 -- vllm/core/scheduler.py | 1 - vllm/sequence.py | 6 ------ vllm/spec_decode/batch_expansion.py | 1 - vllm/worker/model_runner.py | 18 +++++------------- 12 files changed, 11 insertions(+), 41 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 3e668ab08877..bd0c32ddc8c8 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -152,7 +152,12 @@ def run_to_completion(profile_dir: Optional[str] = None): type=int, default=16, help='block size of key/value cache') - parser.add_argument('--max-chunked-prefill-len', type=int, default=-1) + parser.add_argument( + '--max-chunked-prefill-len', + type=int, + default=-1, + help='max number of prefill tokens allowed in chunked prefill' + ', -1 means no limit') parser.add_argument( "--ray-workers-use-nsight", action='store_true', diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index d139d2128e8a..1626b7228207 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -56,7 +56,6 @@ def _do_sample( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - is_chunked_prefill=False, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, @@ -251,7 +250,6 @@ def generate_test_case(): SequenceGroupMetadata( request_id=f"test_{batch_size}", is_prompt=is_prompt, - is_chunked_prefill=False, seq_data=seq_data, sampling_params=sampling_params, block_tables={}, @@ -270,7 +268,6 @@ def generate_test_case(): SequenceGroupMetadata( request_id="test_1", is_prompt=True, - is_chunked_prefill=False, seq_data={ next(seq_id_counter): create_sequence_data(), }, @@ -286,7 +283,6 @@ def generate_test_case(): SequenceGroupMetadata( request_id="test_1", is_prompt=True, - is_chunked_prefill=False, seq_data={ next(seq_id_counter): create_sequence_data(), }, @@ -302,7 +298,6 @@ def generate_test_case(): SequenceGroupMetadata( request_id="test_1", is_prompt=False, - is_chunked_prefill=False, seq_data={ next(seq_id_counter): create_sequence_data(num_generated=1), @@ -320,7 +315,6 @@ def generate_test_case(): SequenceGroupMetadata( request_id="test_1", is_prompt=False, - is_chunked_prefill=False, seq_data={ next(seq_id_counter): create_sequence_data(num_generated=1), @@ -334,7 +328,6 @@ def generate_test_case(): SequenceGroupMetadata( request_id="test_2", is_prompt=True, - is_chunked_prefill=False, seq_data={ next(seq_id_counter): create_sequence_data(), }, @@ -453,7 +446,6 @@ def test_sampler_mixed(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - is_chunked_prefill=False, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, @@ -543,7 +535,6 @@ def test_sampler_top_k_top_p(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - is_chunked_prefill=False, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=SamplingParams( temperature=1, diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index c9cad5eac50c..6b6f35a1a1d0 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -15,7 +15,6 @@ def test_get_all_seq_ids(): SequenceGroupMetadata( request_id=str(seq_id), is_prompt=True, - is_chunked_prefill=False, seq_data={ seq_id: MagicMock(), }, @@ -38,7 +37,6 @@ def fake_sequence_group_metadata(): SequenceGroupMetadata( request_id=str(i), is_prompt=True, - is_chunked_prefill=False, seq_data={ i: MagicMock(), }, diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 6cee91db050f..bf8adcd19053 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -172,7 +172,6 @@ def create_seq_group_metadata_from_prompts( SequenceGroupMetadata( request_id=str(i), is_prompt=len(cont_token_ids) == 0, - is_chunked_prefill=False, seq_data={i: seq_data}, sampling_params=SamplingParams(temperature=0.0, ), block_tables={i: block_allocations[i][:]}, diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 198d84c599b3..fe321520114f 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -70,7 +70,6 @@ def pick_ith(token_ids, logits): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - is_chunked_prefill=False, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 30f4715112fd..0dd74e145709 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -23,7 +23,6 @@ def test_prepare_prompt(batch_size): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - is_chunked_prefill=False, seq_data={0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, @@ -137,7 +136,6 @@ def test_prepare_decode_cuda_graph(batch_size): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, - is_chunked_prefill=False, seq_data={0: SequenceData(seq_data)}, sampling_params=SamplingParams(temperature=0), block_tables={0: [1]}, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 2a86751baf3c..e50d52377b8e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -72,8 +72,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): num_prompt_tokens: int # The number of generation tokens. Doesn't include padding. num_generation_tokens: int - # The number of chunked prefill sequences in the batch. - num_chunked_prefill: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 2a7311d75010..fcd903ddf5f5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -80,8 +80,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): num_prompt_tokens: int # The number of generation tokens. Doesn't include padding. num_generation_tokens: int - # The number of chunked prefill sequences in the batch. - num_chunked_prefill: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ca8ef039ceec..c173f349d75a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -382,7 +382,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=scheduler_outputs.prompt_run, - is_chunked_prefill=False, seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, diff --git a/vllm/sequence.py b/vllm/sequence.py index aa5de8474996..4ca4c87fbf9b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -530,10 +530,6 @@ class SequenceGroupMetadata: Args: request_id: The ID of the request. is_prompt: Whether the request is at prompt stage. - is_chunked_prefill: Whether the request is at chunked prefill stage. - If a prefill request is chunked, the first ~ n-1th chunks are - chunked prefill requests. - Note that chunked_prefill is also a prompt stage. seq_data: The sequence data. (Seq id -> sequence data) sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block @@ -547,7 +543,6 @@ def __init__( self, request_id: str, is_prompt: bool, - is_chunked_prefill: bool, seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], @@ -558,7 +553,6 @@ def __init__( ) -> None: self.request_id = request_id self.is_prompt = is_prompt - self.is_chunked_prefill = is_chunked_prefill self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index ceee0d64b9ee..e0b75837e8a3 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -265,7 +265,6 @@ def _create_single_target_seq_group_metadata( return SequenceGroupMetadata( request_id=seq_group_metadata.request_id, is_prompt=seq_group_metadata.is_prompt, - is_chunked_prefill=False, seq_data={ target_seq_id: SequenceData( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a241afa09c6f..f88393db0041 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -149,7 +149,6 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] - num_chunked_prefill = 0 multi_modal_input_list: List[torch.Tensor] = [] for seq_group_metadata in seq_group_metadata_list: @@ -159,14 +158,11 @@ def _prepare_prompt( seq_id = seq_ids[0] computed_block_nums = seq_group_metadata.computed_block_nums - if seq_group_metadata.is_chunked_prefill: - num_chunked_prefill += 1 - # TODO(sang): Both are the same thing and should be handled - # in the same way. - if computed_block_nums is not None: - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") + if (self.scheduler_config.chunked_prefill_enabled + and computed_block_nums is not None): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") seq_data = seq_group_metadata.seq_data[seq_id] prefill_start, prefill_end = seq_data.get_prefill_range() @@ -317,7 +313,6 @@ def _prepare_prompt( slot_mapping=slot_mapping, prompt_lens=prompt_lens, prompt_lens_tensor=prompt_lens_tensor, - num_chunked_prefill=num_chunked_prefill, num_prompt_tokens=num_prompt_tokens, num_generation_tokens=0, max_subquery_len=max_subquery_len, @@ -447,7 +442,6 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, - num_chunked_prefill=0, prompt_lens=None, prompt_lens_tensor=None, num_prompt_tokens=0, @@ -728,7 +722,6 @@ def profile_run(self) -> None: seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - is_chunked_prefill=False, seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, @@ -830,7 +823,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping[:batch_size], - num_chunked_prefill=0, prompt_lens=None, prompt_lens_tensor=None, num_prompt_tokens=0, From 5e0f87e7ddc758324ae82aafa63d6a90f517cd88 Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 26 Mar 2024 19:37:18 -0700 Subject: [PATCH 80/88] change apis --- tests/test_sequence.py | 2 +- vllm/core/scheduler.py | 27 +++++++++-- vllm/engine/llm_engine.py | 15 ++++-- vllm/sequence.py | 92 ++++++++++++++++++++----------------- vllm/worker/model_runner.py | 7 ++- 5 files changed, 90 insertions(+), 53 deletions(-) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 2921baa7a54f..7a5e92ea7567 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -55,7 +55,7 @@ def test_sequence_data_prefill(): seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4]) assert seq_data.get_prefill_range() == (0, 0) assert seq_data.get_num_uncomputed_tokens() == 4 - + # SANG-TODO Fix. # advance by 2 assert seq_data.advance_prefill_range(2) == 2 assert seq_data.get_num_uncomputed_tokens() == 2 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c173f349d75a..382b2c58d75c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -27,11 +27,18 @@ class PreemptionMode(enum.Enum): RECOMPUTE = enum.auto() +class ScheduledSequenceGroup: + + def __init__(self, seq_group: SequenceGroup, chunk_size: int): + self.seq_group = seq_group + self.chunk_size = chunk_size + + class SchedulerOutputs: def __init__( self, - scheduled_seq_groups: Iterable[SequenceGroup], + scheduled_seq_groups: Iterable[ScheduledSequenceGroup], prompt_run: bool, num_batched_tokens: int, blocks_to_swap_in: Dict[int, int], @@ -246,10 +253,11 @@ def _schedule(self) -> SchedulerOutputs: curr_loras.add(lora_int_id) self.waiting.popleft() self._allocate(seq_group) - seq_group.advance_prefill_range(num_prompt_tokens) + # seq_group.advance_prefill_range(num_prompt_tokens) self.running.append(seq_group) num_curr_seqs += num_new_seqs - scheduled.append(seq_group) + scheduled.append( + ScheduledSequenceGroup(seq_group, num_prompt_tokens)) self.waiting.extendleft(leftover_waiting_sequences) @@ -348,7 +356,10 @@ def _schedule(self) -> SchedulerOutputs: for seq_group in self.running) scheduler_outputs = SchedulerOutputs( - scheduled_seq_groups=self.running, + scheduled_seq_groups=[ + ScheduledSequenceGroup(running_group, 1) + for running_group in self.running + ], prompt_run=False, num_batched_tokens=num_batched_tokens, blocks_to_swap_in=blocks_to_swap_in, @@ -367,17 +378,22 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] - for seq_group in scheduler_outputs.scheduled_seq_groups: + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + seq_group = scheduled_seq_group.seq_group + chunk_size = scheduled_seq_group.chunk_size + seq_group.maybe_set_first_scheduled_time(now) seq_data: Dict[int, SequenceData] = {} block_tables: Dict[int, List[int]] = {} + token_chunk_sizes: Dict[int, int] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id seq_data[seq_id] = seq.data block_tables[seq_id] = self.block_manager.get_block_table(seq) self.block_manager.access_all_blocks_in_seq(seq, now) + token_chunk_sizes[seq_id] = chunk_size seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, @@ -385,6 +401,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + token_chunk_sizes=token_chunk_sizes, lora_request=seq_group.lora_request, computed_block_nums=self.block_manager. get_common_computed_block_ids(seq_group), diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 144829739f68..0146fcc942aa 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -556,10 +556,14 @@ def _process_model_outputs( # If prefix caching is enabled, mark all blocks in the sequence groups # as completed so that future requests don't attempt to recompute them if self.cache_config.enable_prefix_caching: - for seq_group in scheduled_seq_groups: - self.scheduler.mark_blocks_as_computed(seq_group) - - for seq_group, outputs in zip(scheduled_seq_groups, output): + for scheduled_seq_group in scheduled_seq_groups: + self.scheduler.mark_blocks_as_computed( + scheduled_seq_group.seq_group) + + for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output): + seq_group = scheduled_seq_group.seq_group + seq_group.record_num_computed_tokens( + scheduled_seq_group.chunk_size) self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. @@ -567,7 +571,8 @@ def _process_model_outputs( # Create the outputs. request_outputs: List[RequestOutput] = [] - for seq_group in scheduled_seq_groups: + for scheduled_seq_group in scheduled_seq_groups: + seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) diff --git a/vllm/sequence.py b/vllm/sequence.py index 4ca4c87fbf9b..ae925188b620 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -2,7 +2,7 @@ import copy import enum from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union from vllm.block import LogicalTokenBlock from vllm.lora.request import LoRARequest @@ -115,20 +115,12 @@ def __init__( self.prompt_token_ids = prompt_token_ids self.output_token_ids = output_token_ids self.cumulative_logprob = 0.0 - self._prefill_start: int = 0 - self._prefill_end: int = 0 + self._num_computed_tokens = 0 def append_token_id(self, token_id: int, logprob: float) -> None: self.output_token_ids.append(token_id) self.cumulative_logprob += logprob - def reset_prefill_range(self) -> None: - """Reset the prefill range. It is supposed to be called when a - sequence needs to be started from the beginning. - """ - self._prefill_start = 0 - self._prefill_end = 0 - def get_len(self) -> int: return len(self.output_token_ids) + len(self.prompt_token_ids) @@ -141,26 +133,37 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.prompt_token_ids + self.output_token_ids - def advance_prefill_range(self, size: int) -> int: - """Advance the prefill range by the specified amount + def get_num_computed_tokens(self) -> int: + """Return the number of prefill tokens that are already computed.""" + return self._num_computed_tokens - Args: - size: The amount to advance the prefill range. - Returns: - The actual number of advanced tokens. + def record_num_computed_tokens(self, num_computed_tokens) -> int: + """Record how many tokens have computed.""" + self._num_computed_tokens = num_computed_tokens + + def reset_num_computed_tokens(self) -> None: + """Reset the number of computed tokens from this sequence. It is + supposed to be called when a sequence needs to be started from + the beginning again (e.g., sequence is preempted). """ - self._prefill_start = self._prefill_end - # The increased range could be larger than the seq length. - # Clamp it to the seq length. - # Note that we use prompt_len + output_len instead of - # prompt_len here. This is because during recompute - # we need to prefill for both prompt and output. - self._prefill_end = min(self._prefill_end + size, self.get_len()) - return self._prefill_end - self._prefill_start - - def get_prefill_range(self) -> Tuple[int, int]: - """Returns the prefill range.""" - return self._prefill_start, self._prefill_end + self._num_computed_tokens = 0 + + # def advance_prefill_range(self, size: int) -> int: + # """Advance the prefill range by the specified amount + + # Args: + # size: The amount to advance the prefill range. + # Returns: + # The actual number of advanced tokens. + # """ + # self._prefill_start = self._prefill_end + # # The increased range could be larger than the seq length. + # # Clamp it to the seq length. + # # Note that we use prompt_len + output_len instead of + # # prompt_len here. This is because during recompute + # # we need to prefill for both prompt and output. + # self._prefill_end = min(self._prefill_end + size, self.get_len()) + # return self._prefill_end - self._prefill_start def get_num_uncomputed_tokens(self) -> int: """Return the number of prefil tokens that are not computed.""" @@ -246,7 +249,7 @@ def num_hashed_tokens_of_block(self, logical_idx: int): def on_recompute(self): """Reset the sequence states for recomputation.""" - self.data.reset_prefill_range() + self.data.reset_num_computed_tokens() def _append_logical_block(self) -> None: block = LogicalTokenBlock( @@ -470,19 +473,23 @@ def get_unfinished_seqs(self) -> List[Sequence]: def get_finished_seqs(self) -> List[Sequence]: return [seq for seq in self.seqs_dict.values() if seq.is_finished()] - def advance_prefill_range(self, size: int) -> int: - """Advance the prefill range by the specified amount. + # def advance_prefill_range(self, size: int) -> int: + # """Advance the prefill range by the specified amount. - Args: - size: The amount to advance the prefill range. - Returns: - The actual number of advanced tokens. - """ - # All sequences in the group should have the same prompt. - return [ - seq.data.advance_prefill_range(size) - for seq in self.seqs_dict.values() - ][0] + # Args: + # size: The amount to advance the prefill range. + # Returns: + # The actual number of advanced tokens. + # """ + # # All sequences in the group should have the same prompt. + # return [ + # seq.data.advance_prefill_range(size) + # for seq in self.seqs_dict.values() + # ][0] + + def record_num_computed_tokens(self, num_computed_tokens): + for seq in self.seqs_dict.values(): + seq.data.record_num_computed_tokens(num_computed_tokens) def get_num_uncomputed_tokens(self) -> int: # All sequences in the group should have the same prompt, so the @@ -537,6 +544,7 @@ class SequenceGroupMetadata: state: Internal state tied to this sequence group. lora_request: LoRA request. multi_modal_data: Multi modal data. + token_chunk_sizes: seq_id -> token chunk size to run a model. """ def __init__( @@ -546,6 +554,7 @@ def __init__( seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], + token_chunk_sizes: Dict[int, int], lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, @@ -560,6 +569,7 @@ def __init__( self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state + self.token_chunk_sizes = token_chunk_sizes @property def lora_int_id(self) -> int: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f88393db0041..ec93db64785a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -156,6 +156,7 @@ def _prepare_prompt( seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 seq_id = seq_ids[0] + token_chunk_sizes = seq_group_metadata.token_chunk_sizes computed_block_nums = seq_group_metadata.computed_block_nums if (self.scheduler_config.chunked_prefill_enabled @@ -164,8 +165,11 @@ def _prepare_prompt( "chunked prefill cannot be used with prefix caching " "now.") + chunk_size = token_chunk_sizes[seq_id] seq_data = seq_group_metadata.seq_data[seq_id] - prefill_start, prefill_end = seq_data.get_prefill_range() + prefill_start = seq_data.get_num_computed_tokens() + prefill_end = min(seq_data.get_prompt_len(), + prefill_start + chunk_size) prompt_tokens = seq_data.get_token_ids()[prefill_start:prefill_end] prompt_len = len(prompt_tokens) # Right now, the prefill_end is always same as the length of @@ -725,6 +729,7 @@ def profile_run(self) -> None: seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, + token_chunk_sizes={group_id: seq_data.get_len()}, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, multi_modal_data=fake_multi_modal_input, From 6e7264841b7e4f81ccf479c6af05b06a5f357873 Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 26 Mar 2024 20:16:04 -0700 Subject: [PATCH 81/88] cleaned --- tests/spec_decode/utils.py | 1 - tests/test_sequence.py | 23 +++++++---------- tests/worker/test_model_runner.py | 1 - vllm/core/scheduler.py | 16 +++++++++++- vllm/engine/llm_engine.py | 3 +-- vllm/sequence.py | 43 ++++--------------------------- vllm/worker/model_runner.py | 1 - 7 files changed, 31 insertions(+), 57 deletions(-) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index bf8adcd19053..f533fce14afe 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -165,7 +165,6 @@ def create_seq_group_metadata_from_prompts( for prompt_token_ids, cont_token_ids in zip(prompts, continuations): seq_data = SequenceData(prompt_token_ids=prompt_token_ids[:], output_token_ids=cont_token_ids[:]) - seq_data.advance_prefill_range(len(prompt_token_ids)) seq_data_lst.append(seq_data) return [ diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 7a5e92ea7567..544ea3ad4307 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -53,23 +53,20 @@ def test_sampler_output_eq(sample_outputs): def test_sequence_data_prefill(): seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4]) - assert seq_data.get_prefill_range() == (0, 0) assert seq_data.get_num_uncomputed_tokens() == 4 - # SANG-TODO Fix. + assert seq_data.get_num_computed_tokens() == 0 # advance by 2 - assert seq_data.advance_prefill_range(2) == 2 + seq_data.add_num_computed_tokens(2) assert seq_data.get_num_uncomputed_tokens() == 2 - assert seq_data.get_prefill_range() == (0, 2) + assert seq_data.get_num_computed_tokens() == 2 - # advance range by 3 even though there are only 2 unprefilled tokens - assert seq_data.advance_prefill_range(3) == 2 - assert seq_data.get_num_uncomputed_tokens() == 0 - assert seq_data.get_prefill_range() == (2, 4) - - # following advances should not change anything - assert seq_data.advance_prefill_range(2) == 0 - assert seq_data.get_num_uncomputed_tokens() == 0 - assert seq_data.get_prefill_range() == (4, 4) + # advance by 1 + seq_data.add_num_computed_tokens(1) + assert seq_data.get_num_uncomputed_tokens() == 1 + assert seq_data.get_num_computed_tokens() == 3 # append tokens and reset, simulating recompute seq_data.append_token_id(1, logprob=0.0) + seq_data.reset_num_computed_tokens() + assert seq_data.get_num_uncomputed_tokens() == 5 + assert seq_data.get_num_computed_tokens() == 0 diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 0dd74e145709..7fb09f2dab1f 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -27,7 +27,6 @@ def test_prepare_prompt(batch_size): sampling_params=SamplingParams(temperature=0), block_tables=block_tables, )) - seq_data.advance_prefill_range(prompt_len) expected_selected_token_indices = [] selected_token_start_idx = 0 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 382b2c58d75c..46e04d936553 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -46,6 +46,21 @@ def __init__( blocks_to_copy: Dict[int, List[int]], ignored_seq_groups: List[SequenceGroup], ) -> None: + """A list of sequence groups to be scheduled as a single batch. + + Args: + scheduled_seq_groups: A tuple of scheduled sequence group and its + chunk size. + prompt_run: True if all sequence groups are in prefill phase. + If False, all sequence groups are in decoding phase. + num_batched_tokens: Total number of batched tokens. + blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block + number. + blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block + number. + blocks_to_copy: Blocks to copy. Source to a list of dest blocks. + ignored_seq_groups: Sequence groups that are going to be ignored. + """ self.scheduled_seq_groups = scheduled_seq_groups self.prompt_run = prompt_run self.num_batched_tokens = num_batched_tokens @@ -253,7 +268,6 @@ def _schedule(self) -> SchedulerOutputs: curr_loras.add(lora_int_id) self.waiting.popleft() self._allocate(seq_group) - # seq_group.advance_prefill_range(num_prompt_tokens) self.running.append(seq_group) num_curr_seqs += num_new_seqs scheduled.append( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0146fcc942aa..145ffc3f8c82 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -562,8 +562,7 @@ def _process_model_outputs( for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output): seq_group = scheduled_seq_group.seq_group - seq_group.record_num_computed_tokens( - scheduled_seq_group.chunk_size) + seq_group.add_num_computed_tokens(scheduled_seq_group.chunk_size) self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. diff --git a/vllm/sequence.py b/vllm/sequence.py index ae925188b620..4b9a9546dbc9 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -100,8 +100,6 @@ class SequenceData: prompt_token_ids: The token IDs of the prompt. output_token_ids: The token IDs of the output. cumulative_logprob: The cumulative log probability of the output. - _prefill_start: The start index of the prefill. - _prefill_end: The end index of the prefill. """ def __init__( @@ -137,9 +135,9 @@ def get_num_computed_tokens(self) -> int: """Return the number of prefill tokens that are already computed.""" return self._num_computed_tokens - def record_num_computed_tokens(self, num_computed_tokens) -> int: + def add_num_computed_tokens(self, num_computed_tokens_delta) -> int: """Record how many tokens have computed.""" - self._num_computed_tokens = num_computed_tokens + self._num_computed_tokens += num_computed_tokens_delta def reset_num_computed_tokens(self) -> None: """Reset the number of computed tokens from this sequence. It is @@ -148,29 +146,12 @@ def reset_num_computed_tokens(self) -> None: """ self._num_computed_tokens = 0 - # def advance_prefill_range(self, size: int) -> int: - # """Advance the prefill range by the specified amount - - # Args: - # size: The amount to advance the prefill range. - # Returns: - # The actual number of advanced tokens. - # """ - # self._prefill_start = self._prefill_end - # # The increased range could be larger than the seq length. - # # Clamp it to the seq length. - # # Note that we use prompt_len + output_len instead of - # # prompt_len here. This is because during recompute - # # we need to prefill for both prompt and output. - # self._prefill_end = min(self._prefill_end + size, self.get_len()) - # return self._prefill_end - self._prefill_start - def get_num_uncomputed_tokens(self) -> int: """Return the number of prefil tokens that are not computed.""" # we use `get_len()` which includes prompt_len + output_len instead # of prompt_len here. This is because during recompute we need to # prefill for both prompt and output. - return self.get_len() - self._prefill_end + return self.get_len() - self.get_num_computed_tokens() def get_last_token_id(self) -> int: if not self.output_token_ids: @@ -473,23 +454,9 @@ def get_unfinished_seqs(self) -> List[Sequence]: def get_finished_seqs(self) -> List[Sequence]: return [seq for seq in self.seqs_dict.values() if seq.is_finished()] - # def advance_prefill_range(self, size: int) -> int: - # """Advance the prefill range by the specified amount. - - # Args: - # size: The amount to advance the prefill range. - # Returns: - # The actual number of advanced tokens. - # """ - # # All sequences in the group should have the same prompt. - # return [ - # seq.data.advance_prefill_range(size) - # for seq in self.seqs_dict.values() - # ][0] - - def record_num_computed_tokens(self, num_computed_tokens): + def add_num_computed_tokens(self, num_computed_tokens_delta): for seq in self.seqs_dict.values(): - seq.data.record_num_computed_tokens(num_computed_tokens) + seq.data.add_num_computed_tokens(num_computed_tokens_delta) def get_num_uncomputed_tokens(self) -> int: # All sequences in the group should have the same prompt, so the diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ec93db64785a..99b7ff61a01b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -722,7 +722,6 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) seq_data, fake_multi_modal_input = _prepare_fake_inputs( seq_len, self.vision_language_config) - seq_data.advance_prefill_range(seq_len) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, From 4f869be49e09a1bbce2609f61ba20adae661dc2c Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 26 Mar 2024 23:33:01 -0700 Subject: [PATCH 82/88] now working --- tests/basic_correctness/test_basic_correctness.py | 5 +++-- vllm/core/scheduler.py | 8 ++++---- vllm/worker/model_runner.py | 3 ++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index da0176306b4e..f9bf85053e10 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -6,14 +6,15 @@ MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", + # "meta-llama/Llama-2-7b-hf", ] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [False, True]) +# @pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("enforce_eager", [False]) def test_models( hf_runner, vllm_runner, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 46e04d936553..040c5a6f1105 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -81,13 +81,13 @@ def is_empty(self) -> bool: and not self.blocks_to_swap_out and not self.blocks_to_copy) def _sort_by_lora_ids(self) -> bool: - self.scheduled_seq_groups = sorted(self.scheduled_seq_groups, - key=lambda g: - (g.lora_int_id, g.request_id)) + self.scheduled_seq_groups = sorted( + self.scheduled_seq_groups, + key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) @property def lora_requests(self) -> Set[LoRARequest]: - return {g.lora_request for g in self.scheduled_seq_groups} + return {g.seq_group.lora_request for g in self.scheduled_seq_groups} class Scheduler: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 99b7ff61a01b..d03baf4ee280 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -839,7 +839,8 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size], use_cuda_graph=True, - kv_cache_dtype=self.kv_cache_dtype) + kv_cache_dtype=self.kv_cache_dtype, + ) if self.lora_config: lora_mapping = LoRAMapping( From 4f63c5740bc9351870efc1aaa797361ca9955816 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 27 Mar 2024 03:20:50 -0700 Subject: [PATCH 83/88] update with new apis --- benchmarks/benchmark_latency.py | 12 ++--- .../test_basic_correctness.py | 5 +-- tests/conftest.py | 4 +- tests/core/test_scheduler.py | 22 ++++++---- tests/samplers/test_sampler.py | 29 ++++++------ tests/test_sequence.py | 4 +- tests/worker/test_model_runner.py | 35 ++++++++------- vllm/config.py | 11 ++--- vllm/core/scheduler.py | 44 ++++++++++++------- vllm/engine/arg_utils.py | 14 +++--- vllm/engine/llm_engine.py | 4 +- vllm/sequence.py | 29 ++++++++---- vllm/spec_decode/batch_expansion.py | 10 ++--- vllm/worker/model_runner.py | 19 ++++---- 14 files changed, 134 insertions(+), 108 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index bd0c32ddc8c8..caccdd4215b4 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -27,7 +27,7 @@ def main(args: argparse.Namespace): kv_cache_dtype=args.kv_cache_dtype, device=args.device, block_size=args.block_size, - max_chunked_prefill_len=args.max_chunked_prefill_len, + enable_chunked_prefill=args.enable_chunked_prefill, ray_workers_use_nsight=args.ray_workers_use_nsight, ) @@ -153,11 +153,11 @@ def run_to_completion(profile_dir: Optional[str] = None): default=16, help='block size of key/value cache') parser.add_argument( - '--max-chunked-prefill-len', - type=int, - default=-1, - help='max number of prefill tokens allowed in chunked prefill' - ', -1 means no limit') + '--enable-chunked-prefill', + type=bool, + default=False, + help='If True, the prefill requests can be chunked based on the ' + 'max_num_batched_tokens') parser.add_argument( "--ray-workers-use-nsight", action='store_true', diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index f9bf85053e10..da0176306b4e 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -6,15 +6,14 @@ MODELS = [ "facebook/opt-125m", - # "meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-2-7b-hf", ] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) -# @pytest.mark.parametrize("enforce_eager", [False, True]) -@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("enforce_eager", [False, True]) def test_models( hf_runner, vllm_runner, diff --git a/tests/conftest.py b/tests/conftest.py index 409b9c0c7336..cb823893c140 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -257,7 +257,7 @@ def __init__( disable_log_stats: bool = True, tensor_parallel_size: int = 1, block_size: int = 16, - max_chunked_prefill_len: int = -1, + enable_chunked_prefill: bool = False, **kwargs, ) -> None: self.model = LLM( @@ -269,7 +269,7 @@ def __init__( disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, block_size=block_size, - max_chunked_prefill_len=max_chunked_prefill_len, + enable_chunked_prefill=enable_chunked_prefill, **kwargs, ) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 58fd031c1c62..f56be2020dfe 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,6 +10,12 @@ from .utils import create_dummy_prompt +def get_sequence_groups(scheduler_output): + return [ + seq_group for seq_group, _ in scheduler_output.scheduled_seq_groups + ] + + def test_scheduler_add_seq_group(): block_size = 4 scheduler_config = SchedulerConfig(100, 64, 1) @@ -68,7 +74,7 @@ def test_scheduler_schedule_simple(): # Schedule seq groups prompts. num_tokens = block_size * num_seq_group seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(running) + assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_tokens assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -76,7 +82,7 @@ def test_scheduler_schedule_simple(): # Schedule seq groups generation. seq_group_meta, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set(running) + assert set(get_sequence_groups(out)) == set(running) assert out.num_batched_tokens == num_seq_group assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -100,7 +106,7 @@ def test_scheduler_schedule_preempt_abort(): # Schedule seq groups prompts. seq_group_meta, out = scheduler.schedule() - assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] + assert get_sequence_groups(out) == [seq_group_a, seq_group_b] assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -115,7 +121,7 @@ def test_scheduler_schedule_preempt_abort(): # Schedule seq groups generation and preempt seq group b. seq_group_meta, out = scheduler.schedule() - assert out.scheduled_seq_groups == [seq_group_a] + assert get_sequence_groups(out) == [seq_group_a] assert out.num_batched_tokens == 1 assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -125,7 +131,7 @@ def test_scheduler_schedule_preempt_abort(): # Abort seq group a. Re-schedule seq group b prompt with recomputation. scheduler.abort_seq_group("1") seq_group_meta, out = scheduler.schedule() - assert out.scheduled_seq_groups == [seq_group_b] + assert get_sequence_groups(out) == [seq_group_b] assert out.num_batched_tokens == 5 # 4 prompt + 1 generation. assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) @@ -155,11 +161,11 @@ def test_scheduler_max_seqs(): # Schedule seq groups prompts. _, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) + assert set(get_sequence_groups(out)) == set([all_seq_groups[0]]) # Schedule seq groups generation. _, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) + assert set(get_sequence_groups(out)) == set([all_seq_groups[0]]) # Append 2 more seq group scheduler.add_seq_group(all_seq_groups[1]) @@ -169,7 +175,7 @@ def test_scheduler_max_seqs(): # Only 1 seq group should be scheduled since max_seq_group is 2 # and one is prompting. _, out = scheduler.schedule() - assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]]) + assert set(get_sequence_groups(out)) == set([all_seq_groups[1]]) def test_scheduler_delay_factor(): diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 1626b7228207..a1f442654b96 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -241,7 +241,8 @@ def generate_test_case(): for _ in range(num_seqs): num_input = random.randint(1, 100) num_generated = random.randint(1, 100) if not is_prompt else 0 - seq_data[next(seq_id_counter)] = create_sequence_data( + seq_id = next(seq_id_counter) + seq_data[seq_id] = create_sequence_data( num_input=num_input, num_generated=num_generated) seq_group_penalization.append(num_generated < min_tokens) @@ -262,6 +263,8 @@ def generate_test_case(): } # define some explicit test cases for edge case behavior + seq_id = next(seq_id_counter) + seq_data = create_sequence_data() prompt_without_penalization = { "expected_penalization": [False], "seq_group_metadata_list": [ @@ -269,7 +272,7 @@ def generate_test_case(): request_id="test_1", is_prompt=True, seq_data={ - next(seq_id_counter): create_sequence_data(), + seq_id: seq_data, }, sampling_params=create_sampling_params(0), block_tables={}, @@ -277,15 +280,15 @@ def generate_test_case(): ] } + seq_id = next(seq_id_counter) + seq_data = create_sequence_data() prompt_with_penalization = { "expected_penalization": [True], "seq_group_metadata_list": [ SequenceGroupMetadata( request_id="test_1", is_prompt=True, - seq_data={ - next(seq_id_counter): create_sequence_data(), - }, + seq_data={seq_id, seq_data}, sampling_params=create_sampling_params(1), block_tables={}, ), @@ -309,18 +312,18 @@ def generate_test_case(): } stop_token_ids = [42, 99, 42, 0] # intentional duplication + decoding_seq_data = { + next(seq_id_counter): create_sequence_data(num_generated=1), + next(seq_id_counter): create_sequence_data(num_generated=100), + }, + prefill_seq_data = {next(seq_id_counter): create_sequence_data()} simple_combination = { "expected_penalization": [True, False, False], "seq_group_metadata_list": [ SequenceGroupMetadata( request_id="test_1", is_prompt=False, - seq_data={ - next(seq_id_counter): - create_sequence_data(num_generated=1), - next(seq_id_counter): - create_sequence_data(num_generated=100), - }, + seq_data=decoding_seq_data, sampling_params=create_sampling_params( 2, stop_token_ids=stop_token_ids), block_tables={}, @@ -328,9 +331,7 @@ def generate_test_case(): SequenceGroupMetadata( request_id="test_2", is_prompt=True, - seq_data={ - next(seq_id_counter): create_sequence_data(), - }, + seq_data=prefill_seq_data, sampling_params=create_sampling_params( 0, stop_token_ids=stop_token_ids), block_tables={}, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 544ea3ad4307..39674e469343 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -56,12 +56,12 @@ def test_sequence_data_prefill(): assert seq_data.get_num_uncomputed_tokens() == 4 assert seq_data.get_num_computed_tokens() == 0 # advance by 2 - seq_data.add_num_computed_tokens(2) + seq_data.update_num_computed_prefill_tokens(2) assert seq_data.get_num_uncomputed_tokens() == 2 assert seq_data.get_num_computed_tokens() == 2 # advance by 1 - seq_data.add_num_computed_tokens(1) + seq_data.update_num_computed_prefill_tokens(1) assert seq_data.get_num_uncomputed_tokens() == 1 assert seq_data.get_num_computed_tokens() == 3 diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 7fb09f2dab1f..5b6f001f62fa 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -19,14 +19,15 @@ def test_prepare_prompt(batch_size): prompt_len = i % (model_runner.block_size - 1) + 1 prompt_lens.append(prompt_len) seq_data = SequenceData(list(range(prompt_len))) - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - )) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + ) + assert seq_group_metadata.token_chunk_size == seq_data.get_len() + seq_group_metadata_list.append(seq_group_metadata) expected_selected_token_indices = [] selected_token_start_idx = 0 @@ -131,14 +132,16 @@ def test_prepare_decode_cuda_graph(batch_size): prompt_len = i % (model_runner.block_size - 1) + 1 prompt_lens.append(prompt_len) seq_data = list(range(prompt_len)) - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={0: SequenceData(seq_data)}, - sampling_params=SamplingParams(temperature=0), - block_tables={0: [1]}, - )) + seq_data = SequenceData(seq_data) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) input_tokens, input_positions, attn_metadata, _, _, _ = ( model_runner._prepare_decode(seq_group_metadata_list)) diff --git a/vllm/config.py b/vllm/config.py index 747e3f470e86..8561b615389d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -523,10 +523,8 @@ class SchedulerConfig: and generated text). delay_factor: Apply a delay (of delay factor multiplied by previous prompt latency) before scheduling next prompt. - max_chunked_prefill_len: The maximum length of tokens for prefill - requests. Longer requests will be chunked into multiple chunks. - -1 means no chunking (disabled). This features is only supported - for flash style attention. + enable_chunked_prefill: If True, prefill requests can be chunked based + on the remaining max_num_batched_tokens. """ def __init__( @@ -535,7 +533,7 @@ def __init__( max_num_seqs: int, max_model_len: int, delay_factor: float = 0.0, - max_chunked_prefill_len: int = -1, + enable_chunked_prefill: bool = False, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -546,8 +544,7 @@ def __init__( self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.delay_factor = delay_factor - self.chunked_prefill_enabled = max_chunked_prefill_len != -1 - self.max_chunked_prefill_len = max_chunked_prefill_len + self.chunked_prefill_enabled = enable_chunked_prefill self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 040c5a6f1105..5fe5ee63c0bc 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,6 +1,6 @@ import enum import time -from collections import deque +from collections import deque, namedtuple from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig @@ -27,11 +27,11 @@ class PreemptionMode(enum.Enum): RECOMPUTE = enum.auto() -class ScheduledSequenceGroup: - - def __init__(self, seq_group: SequenceGroup, chunk_size: int): - self.seq_group = seq_group - self.chunk_size = chunk_size +# seq_group: SequenceGroup to schedule. +# token_chunk_size: The number of prefill tokens to be processed in the next +# step. +ScheduledSequenceGroup = namedtuple("ScheduledSequenceGroup", + ["seq_group", "token_chunk_size"]) class SchedulerOutputs: @@ -50,7 +50,7 @@ def __init__( Args: scheduled_seq_groups: A tuple of scheduled sequence group and its - chunk size. + token chunk size. prompt_run: True if all sequence groups are in prefill phase. If False, all sequence groups are in decoding phase. num_batched_tokens: Total number of batched tokens. @@ -61,15 +61,24 @@ def __init__( blocks_to_copy: Blocks to copy. Source to a list of dest blocks. ignored_seq_groups: Sequence groups that are going to be ignored. """ + # A tuple of scheduled sequence group and its chunk size. self.scheduled_seq_groups = scheduled_seq_groups + # True if all sequence groups are in prefill phase. If False, all + # sequence groups are in decoding phase. self.prompt_run = prompt_run + # Total number of batched tokens. self.num_batched_tokens = num_batched_tokens + # Blocks to swap in. Dict of CPU -> GPU block number. self.blocks_to_swap_in = blocks_to_swap_in + # Blocks to swap out. Dict of GPU -> CPU block number. self.blocks_to_swap_out = blocks_to_swap_out + # Blocks to copy. Source to a list of dest blocks. self.blocks_to_copy = blocks_to_copy + # Sequence groups that are going to be ignored. + self.ignored_seq_groups = ignored_seq_groups + # Swap in and swap out should never happen at the same time. assert not (blocks_to_swap_in and blocks_to_swap_out) - self.ignored_seq_groups = ignored_seq_groups self.num_loras = len(self.lora_requests) if self.num_loras > 0: @@ -215,6 +224,8 @@ def _schedule(self) -> SchedulerOutputs: assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") + # get_len includes output tokens if the request has been + # preempted. num_prompt_tokens = waiting_seqs[0].get_len() if num_prompt_tokens > self.prompt_limit: logger.warning( @@ -271,8 +282,8 @@ def _schedule(self) -> SchedulerOutputs: self.running.append(seq_group) num_curr_seqs += num_new_seqs scheduled.append( - ScheduledSequenceGroup(seq_group, num_prompt_tokens)) - + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=num_prompt_tokens)) self.waiting.extendleft(leftover_waiting_sequences) if scheduled or ignored_seq_groups: @@ -371,7 +382,8 @@ def _schedule(self) -> SchedulerOutputs: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=[ - ScheduledSequenceGroup(running_group, 1) + ScheduledSequenceGroup(seq_group=running_group, + token_chunk_size=1) for running_group in self.running ], prompt_run=False, @@ -393,21 +405,19 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: - seq_group = scheduled_seq_group.seq_group - chunk_size = scheduled_seq_group.chunk_size - + seq_group, token_chunk_size = scheduled_seq_group seq_group.maybe_set_first_scheduled_time(now) + # seq_id -> SequenceData seq_data: Dict[int, SequenceData] = {} + # seq_id -> physical block numbers block_tables: Dict[int, List[int]] = {} - token_chunk_sizes: Dict[int, int] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id seq_data[seq_id] = seq.data block_tables[seq_id] = self.block_manager.get_block_table(seq) self.block_manager.access_all_blocks_in_seq(seq, now) - token_chunk_sizes[seq_id] = chunk_size seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, @@ -415,7 +425,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, - token_chunk_sizes=token_chunk_sizes, + token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=self.block_manager. get_common_computed_block_ids(seq_group), diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f20b3295d321..b52d407381fc 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -58,7 +58,7 @@ class EngineArgs: image_input_shape: Optional[str] = None image_feature_size: Optional[int] = None scheduler_delay_factor: float = 0.0 - max_chunked_prefill_len: int = -1 + enable_chunked_prefill: bool = False def __post_init__(self): if self.tokenizer is None: @@ -345,11 +345,11 @@ def add_cli_args( help='Apply a delay (of delay factor multiplied by previous' 'prompt latency) before scheduling next prompt.') parser.add_argument( - '--max-chunked-prefill-len', - type=int, - default=-1, - help='max number of prefill tokens allowed in chunked prefill' - ', -1 means no limit') + '--enable-chunked-prefill', + type=bool, + default=False, + help='If True, the prefill requests can be chunked based on the ' + 'max_num_batched_tokens') return parser @classmethod @@ -392,7 +392,7 @@ def create_engine_configs( self.max_num_seqs, model_config.max_model_len, delay_factor=self.scheduler_delay_factor, - max_chunked_prefill_len=self.max_chunked_prefill_len) + enable_chunked_prefill=self.enable_chunked_prefill) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 145ffc3f8c82..35d91369a22e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -561,8 +561,8 @@ def _process_model_outputs( scheduled_seq_group.seq_group) for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output): - seq_group = scheduled_seq_group.seq_group - seq_group.add_num_computed_tokens(scheduled_seq_group.chunk_size) + seq_group, token_chunk_size = scheduled_seq_group + seq_group.update_num_computed_tokens(token_chunk_size) self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. diff --git a/vllm/sequence.py b/vllm/sequence.py index 4b9a9546dbc9..4f9e455a30f2 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -135,9 +135,9 @@ def get_num_computed_tokens(self) -> int: """Return the number of prefill tokens that are already computed.""" return self._num_computed_tokens - def add_num_computed_tokens(self, num_computed_tokens_delta) -> int: - """Record how many tokens have computed.""" - self._num_computed_tokens += num_computed_tokens_delta + def update_num_computed_tokens(self, num_new_computed_tokens: int) -> int: + """Update number of tokens computed so far.""" + self._num_computed_tokens += num_new_computed_tokens def reset_num_computed_tokens(self) -> None: """Reset the number of computed tokens from this sequence. It is @@ -454,9 +454,10 @@ def get_unfinished_seqs(self) -> List[Sequence]: def get_finished_seqs(self) -> List[Sequence]: return [seq for seq in self.seqs_dict.values() if seq.is_finished()] - def add_num_computed_tokens(self, num_computed_tokens_delta): + def update_num_computed_tokens(self, num_new_computed_tokens: int): + """Update number of tokens computed so far.""" for seq in self.seqs_dict.values(): - seq.data.add_num_computed_tokens(num_computed_tokens_delta) + seq.data.update_num_computed_tokens(num_new_computed_tokens) def get_num_uncomputed_tokens(self) -> int: # All sequences in the group should have the same prompt, so the @@ -508,10 +509,11 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + token_chunk_size: The number of tokens to be processed. None if + chunking is not required. state: Internal state tied to this sequence group. lora_request: LoRA request. multi_modal_data: Multi modal data. - token_chunk_sizes: seq_id -> token chunk size to run a model. """ def __init__( @@ -521,7 +523,7 @@ def __init__( seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], - token_chunk_sizes: Dict[int, int], + token_chunk_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, @@ -536,12 +538,23 @@ def __init__( self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state - self.token_chunk_sizes = token_chunk_sizes + self._token_chunk_size = token_chunk_size + + if self._token_chunk_size is None: + if is_prompt: + self._token_chunk_size = list(seq_data.values())[0].get_len() + else: + self._token_chunk_size = 1 @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def token_chunk_size(self) -> int: + """Return the number of tokens to be processed (chunk size).""" + return self._token_chunk_size + class SequenceOutput: """The model output associated with a sequence. diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index e0b75837e8a3..a8e826684a36 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -261,16 +261,16 @@ def _create_single_target_seq_group_metadata( seq_data = seq_group_metadata.seq_data[seq_id] prompt_token_ids = seq_data.get_prompt_token_ids() new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] + seq_data = SequenceData( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ) return SequenceGroupMetadata( request_id=seq_group_metadata.request_id, is_prompt=seq_group_metadata.is_prompt, seq_data={ - target_seq_id: - SequenceData( - prompt_token_ids=prompt_token_ids, - output_token_ids=new_output_token_ids, - ), + target_seq_id: seq_data, }, sampling_params=seq_group_metadata.sampling_params, block_tables={ diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d03baf4ee280..530e83933919 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -156,28 +156,27 @@ def _prepare_prompt( seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 seq_id = seq_ids[0] - token_chunk_sizes = seq_group_metadata.token_chunk_sizes computed_block_nums = seq_group_metadata.computed_block_nums - if (self.scheduler_config.chunked_prefill_enabled + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled and computed_block_nums is not None): raise RuntimeError( "chunked prefill cannot be used with prefix caching " "now.") - chunk_size = token_chunk_sizes[seq_id] + token_chunk_size = seq_group_metadata.token_chunk_size seq_data = seq_group_metadata.seq_data[seq_id] - prefill_start = seq_data.get_num_computed_tokens() + computed_len = seq_data.get_num_computed_tokens() prefill_end = min(seq_data.get_prompt_len(), - prefill_start + chunk_size) - prompt_tokens = seq_data.get_token_ids()[prefill_start:prefill_end] + computed_len + token_chunk_size) + prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] prompt_len = len(prompt_tokens) # Right now, the prefill_end is always same as the length of - # prompt. However, once chunked prefill is introduced, this + # sequence. However, once chunked prefill is introduced, this # assumption can be changed. - assert prefill_end == seq_data.get_prompt_len() + assert prefill_end == seq_data.get_len() prompt_lens.append(prompt_len) - computed_len = 0 # NOTE: This only works for oooooooxxx style attention. if computed_block_nums is not None and len( @@ -188,7 +187,6 @@ def _prepare_prompt( prefix_block_tables.append(computed_block_nums) else: prefix_block_tables.append([]) - computed_len = prefill_start # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. assert computed_len == 0 @@ -728,7 +726,6 @@ def profile_run(self) -> None: seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, - token_chunk_sizes={group_id: seq_data.get_len()}, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, multi_modal_data=fake_multi_modal_input, From 5c3abf40af1e3921e37f27421da0244c52947c03 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 27 Mar 2024 03:29:43 -0700 Subject: [PATCH 84/88] working! --- tests/samplers/test_sampler.py | 29 ++++++++++++++--------------- tests/spec_decode/utils.py | 19 +++++++++---------- vllm/config.py | 3 +-- vllm/core/scheduler.py | 24 +++++++++++------------- vllm/spec_decode/batch_expansion.py | 10 +++++----- vllm/worker/model_runner.py | 2 ++ 6 files changed, 42 insertions(+), 45 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index a1f442654b96..1626b7228207 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -241,8 +241,7 @@ def generate_test_case(): for _ in range(num_seqs): num_input = random.randint(1, 100) num_generated = random.randint(1, 100) if not is_prompt else 0 - seq_id = next(seq_id_counter) - seq_data[seq_id] = create_sequence_data( + seq_data[next(seq_id_counter)] = create_sequence_data( num_input=num_input, num_generated=num_generated) seq_group_penalization.append(num_generated < min_tokens) @@ -263,8 +262,6 @@ def generate_test_case(): } # define some explicit test cases for edge case behavior - seq_id = next(seq_id_counter) - seq_data = create_sequence_data() prompt_without_penalization = { "expected_penalization": [False], "seq_group_metadata_list": [ @@ -272,7 +269,7 @@ def generate_test_case(): request_id="test_1", is_prompt=True, seq_data={ - seq_id: seq_data, + next(seq_id_counter): create_sequence_data(), }, sampling_params=create_sampling_params(0), block_tables={}, @@ -280,15 +277,15 @@ def generate_test_case(): ] } - seq_id = next(seq_id_counter) - seq_data = create_sequence_data() prompt_with_penalization = { "expected_penalization": [True], "seq_group_metadata_list": [ SequenceGroupMetadata( request_id="test_1", is_prompt=True, - seq_data={seq_id, seq_data}, + seq_data={ + next(seq_id_counter): create_sequence_data(), + }, sampling_params=create_sampling_params(1), block_tables={}, ), @@ -312,18 +309,18 @@ def generate_test_case(): } stop_token_ids = [42, 99, 42, 0] # intentional duplication - decoding_seq_data = { - next(seq_id_counter): create_sequence_data(num_generated=1), - next(seq_id_counter): create_sequence_data(num_generated=100), - }, - prefill_seq_data = {next(seq_id_counter): create_sequence_data()} simple_combination = { "expected_penalization": [True, False, False], "seq_group_metadata_list": [ SequenceGroupMetadata( request_id="test_1", is_prompt=False, - seq_data=decoding_seq_data, + seq_data={ + next(seq_id_counter): + create_sequence_data(num_generated=1), + next(seq_id_counter): + create_sequence_data(num_generated=100), + }, sampling_params=create_sampling_params( 2, stop_token_ids=stop_token_ids), block_tables={}, @@ -331,7 +328,9 @@ def generate_test_case(): SequenceGroupMetadata( request_id="test_2", is_prompt=True, - seq_data=prefill_seq_data, + seq_data={ + next(seq_id_counter): create_sequence_data(), + }, sampling_params=create_sampling_params( 0, stop_token_ids=stop_token_ids), block_tables={}, diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index f533fce14afe..0cd9a4b1d581 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -161,22 +161,21 @@ def create_seq_group_metadata_from_prompts( for i, final_len in enumerate(final_seq_lens) } - seq_data_lst = [] - for prompt_token_ids, cont_token_ids in zip(prompts, continuations): - seq_data = SequenceData(prompt_token_ids=prompt_token_ids[:], - output_token_ids=cont_token_ids[:]) - seq_data_lst.append(seq_data) - return [ SequenceGroupMetadata( request_id=str(i), is_prompt=len(cont_token_ids) == 0, - seq_data={i: seq_data}, + seq_data={ + i: + SequenceData( + prompt_token_ids=prompt_token_ids[:], + output_token_ids=cont_token_ids[:], + ), + }, sampling_params=SamplingParams(temperature=0.0, ), block_tables={i: block_allocations[i][:]}, - ) for i, ( - prompt_token_ids, cont_token_ids, - seq_data) in enumerate(zip(prompts, continuations, seq_data_lst)) + ) for i, (prompt_token_ids, + cont_token_ids) in enumerate(zip(prompts, continuations)) ] diff --git a/vllm/config.py b/vllm/config.py index 8561b615389d..309ae09e09e5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -548,8 +548,7 @@ def __init__( self._verify_args() def _verify_args(self) -> None: - if self.max_num_batched_tokens < self.max_model_len and \ - not self.chunked_prefill_enabled: + if self.max_num_batched_tokens < self.max_model_len: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " f"smaller than max_model_len ({self.max_model_len}). " diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5fe5ee63c0bc..472843032824 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -226,11 +226,11 @@ def _schedule(self) -> SchedulerOutputs: "sequence.") # get_len includes output tokens if the request has been # preempted. - num_prompt_tokens = waiting_seqs[0].get_len() - if num_prompt_tokens > self.prompt_limit: + num_prefill_tokens = waiting_seqs[0].get_len() + if num_prefill_tokens > self.prompt_limit: logger.warning( - f"Input prompt ({num_prompt_tokens} tokens) is too long" - f" and exceeds limit of {self.prompt_limit}") + f"Input prompt ({num_prefill_tokens} tokens) is too " + f"long and exceeds limit of {self.prompt_limit}") for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -243,8 +243,8 @@ def _schedule(self) -> SchedulerOutputs: break elif can_allocate == AllocStatus.NEVER: logger.warning( - f"Input prompt ({num_prompt_tokens} tokens) is too long" - f" and exceeds the capacity of block_manager") + f"Input prompt ({num_prefill_tokens} tokens) is too " + f"long and exceeds the capacity of block_manager") for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -263,7 +263,7 @@ def _schedule(self) -> SchedulerOutputs: continue # If the number of batched tokens exceeds the limit, stop. - num_batched_tokens += num_prompt_tokens + num_batched_tokens += num_prefill_tokens if (num_batched_tokens > self.scheduler_config.max_num_batched_tokens): break @@ -282,8 +282,9 @@ def _schedule(self) -> SchedulerOutputs: self.running.append(seq_group) num_curr_seqs += num_new_seqs scheduled.append( - ScheduledSequenceGroup(seq_group=seq_group, - token_chunk_size=num_prompt_tokens)) + ScheduledSequenceGroup( + seq_group=seq_group, + token_chunk_size=num_prefill_tokens)) self.waiting.extendleft(leftover_waiting_sequences) if scheduled or ignored_seq_groups: @@ -444,10 +445,7 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_manager.fork(parent_seq, child_seq) def free_seq(self, seq: Sequence) -> None: - """Free a sequence from a block table. - - Freed sequence can be used - """ + """Free a sequence from a block table.""" self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index a8e826684a36..e0b75837e8a3 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -261,16 +261,16 @@ def _create_single_target_seq_group_metadata( seq_data = seq_group_metadata.seq_data[seq_id] prompt_token_ids = seq_data.get_prompt_token_ids() new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] - seq_data = SequenceData( - prompt_token_ids=prompt_token_ids, - output_token_ids=new_output_token_ids, - ) return SequenceGroupMetadata( request_id=seq_group_metadata.request_id, is_prompt=seq_group_metadata.is_prompt, seq_data={ - target_seq_id: seq_data, + target_seq_id: + SequenceData( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ), }, sampling_params=seq_group_metadata.sampling_params, block_tables={ diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 530e83933919..fe908335a0c5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -170,6 +170,7 @@ def _prepare_prompt( computed_len = seq_data.get_num_computed_tokens() prefill_end = min(seq_data.get_prompt_len(), computed_len + token_chunk_size) + # TODO(sang): Rename it after chunked prefill is introduced. prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] prompt_len = len(prompt_tokens) # Right now, the prefill_end is always same as the length of @@ -348,6 +349,7 @@ def _prepare_decode( for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 seq_ids = list(seq_group_metadata.seq_data.keys()) lora_id = seq_group_metadata.lora_int_id From 66f3fcf1978d44045f7e8af26e7d34f11e1c259c Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 27 Mar 2024 06:49:59 -0700 Subject: [PATCH 85/88] fixed --- tests/test_sequence.py | 4 ++-- vllm/engine/llm_engine.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 39674e469343..1dec928158b1 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -56,12 +56,12 @@ def test_sequence_data_prefill(): assert seq_data.get_num_uncomputed_tokens() == 4 assert seq_data.get_num_computed_tokens() == 0 # advance by 2 - seq_data.update_num_computed_prefill_tokens(2) + seq_data.update_num_computed_tokens(2) assert seq_data.get_num_uncomputed_tokens() == 2 assert seq_data.get_num_computed_tokens() == 2 # advance by 1 - seq_data.update_num_computed_prefill_tokens(1) + seq_data.update_num_computed_tokens(1) assert seq_data.get_num_uncomputed_tokens() == 1 assert seq_data.get_num_computed_tokens() == 3 diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 35d91369a22e..03907f81492b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -686,17 +686,20 @@ def _get_stats(self, # Number of Tokens. if prompt_run: num_prompt_tokens = sum( - len(seq_group.prompt_token_ids) - for seq_group in scheduler_outputs.scheduled_seq_groups) + len(scheduled_seq_group.seq_group.prompt_token_ids) + for scheduled_seq_group in + scheduler_outputs.scheduled_seq_groups) num_generation_tokens = sum( - seq_group.num_seqs() - for seq_group in scheduler_outputs.scheduled_seq_groups) + scheduled_seq_group.seq_group.num_seqs() + for scheduled_seq_group in + scheduler_outputs.scheduled_seq_groups) else: num_generation_tokens = scheduler_outputs.num_batched_tokens # Latency Timings. time_last_iters = [] - for seq_group in scheduler_outputs.scheduled_seq_groups: + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + seq_group, _ = scheduled_seq_group # Time since last token. # (n.b. updates seq_group.metrics.last_token_time) time_last_iters.append(seq_group.get_last_latency(now)) From 9d4b65c898aa8e2aa1e74665f9d072d6e6d531e5 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 27 Mar 2024 21:42:07 -0700 Subject: [PATCH 86/88] Addressed code review. --- benchmarks/benchmark_latency.py | 3 ++- tests/core/test_scheduler.py | 4 +--- vllm/core/scheduler.py | 34 ++++++++++++++++++++------------- vllm/engine/arg_utils.py | 3 ++- vllm/engine/llm_engine.py | 5 +++-- vllm/sequence.py | 3 ++- 6 files changed, 31 insertions(+), 21 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index c74b68ab2f11..da02493b17fd 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -27,7 +27,8 @@ def main(args: argparse.Namespace): device=args.device, ray_workers_use_nsight=args.ray_workers_use_nsight, enable_chunked_prefill=args.enable_chunked_prefill, - download_dir=args.download_dir) + download_dir=args.download_dir, + block_size=args.block_size) sampling_params = SamplingParams( n=args.n, diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index f56be2020dfe..8d36abdc0c91 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -11,9 +11,7 @@ def get_sequence_groups(scheduler_output): - return [ - seq_group for seq_group, _ in scheduler_output.scheduled_seq_groups - ] + return [s for s, _ in scheduler_output.scheduled_seq_groups] def test_scheduler_add_seq_group(): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 472843032824..503d3f0158a4 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,6 +1,7 @@ import enum import time -from collections import deque, namedtuple +from collections import deque +from dataclasses import dataclass from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig @@ -30,8 +31,14 @@ class PreemptionMode(enum.Enum): # seq_group: SequenceGroup to schedule. # token_chunk_size: The number of prefill tokens to be processed in the next # step. -ScheduledSequenceGroup = namedtuple("ScheduledSequenceGroup", - ["seq_group", "token_chunk_size"]) +@dataclass +class ScheduledSequenceGroup: + # A sequence group that's scheduled. + seq_group: SequenceGroup + # The total chunk size (number of tokens) to process for next iteration. + # 1 for decoding. Same as prompt tokens for prefill, but if prefill is + # chunked, it can be smaller than that. + token_chunk_size: int class SchedulerOutputs: @@ -62,25 +69,25 @@ def __init__( ignored_seq_groups: Sequence groups that are going to be ignored. """ # A tuple of scheduled sequence group and its chunk size. - self.scheduled_seq_groups = scheduled_seq_groups + self.scheduled_seq_groups: ScheduledSequenceGroup = scheduled_seq_groups # True if all sequence groups are in prefill phase. If False, all # sequence groups are in decoding phase. - self.prompt_run = prompt_run + self.prompt_run: bool = prompt_run # Total number of batched tokens. - self.num_batched_tokens = num_batched_tokens + self.num_batched_tokens: int = num_batched_tokens # Blocks to swap in. Dict of CPU -> GPU block number. - self.blocks_to_swap_in = blocks_to_swap_in + self.blocks_to_swap_in: Dict[int, int] = blocks_to_swap_in # Blocks to swap out. Dict of GPU -> CPU block number. - self.blocks_to_swap_out = blocks_to_swap_out + self.blocks_to_swap_out: Dict[int, int] = blocks_to_swap_out # Blocks to copy. Source to a list of dest blocks. - self.blocks_to_copy = blocks_to_copy + self.blocks_to_copy: Dict[int, List[int]] = blocks_to_copy # Sequence groups that are going to be ignored. - self.ignored_seq_groups = ignored_seq_groups + self.ignored_seq_groups: List[SequenceGroup] = ignored_seq_groups # Swap in and swap out should never happen at the same time. assert not (blocks_to_swap_in and blocks_to_swap_out) - self.num_loras = len(self.lora_requests) + self.num_loras: int = len(self.lora_requests) if self.num_loras > 0: self._sort_by_lora_ids() @@ -406,7 +413,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: - seq_group, token_chunk_size = scheduled_seq_group + seq_group = scheduled_seq_group.seq_group + token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) # seq_id -> SequenceData @@ -509,7 +517,7 @@ def _preempt_by_recompute( for seq in seqs: seq.status = SequenceStatus.WAITING self.free_seq(seq) - seq.on_recompute() + seq.reset_state_for_recompute() # NOTE: For FCFS, we insert the preempted sequence group to the front # of the waiting queue. self.waiting.appendleft(seq_group) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b52d407381fc..3609fdd0433a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -392,7 +392,8 @@ def create_engine_configs( self.max_num_seqs, model_config.max_model_len, delay_factor=self.scheduler_delay_factor, - enable_chunked_prefill=self.enable_chunked_prefill) + enable_chunked_prefill=self.enable_chunked_prefill, + ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 65dcbb68a71b..babda8c37cb1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -561,7 +561,8 @@ def _process_model_outputs( scheduled_seq_group.seq_group) for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output): - seq_group, token_chunk_size = scheduled_seq_group + seq_group = scheduled_seq_group.seq_group + token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.update_num_computed_tokens(token_chunk_size) self._process_sequence_group_outputs(seq_group, outputs) @@ -699,7 +700,7 @@ def _get_stats(self, # Latency Timings. time_last_iters = [] for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: - seq_group, _ = scheduled_seq_group + seq_group = scheduled_seq_group.seq_group # Time since last token. # (n.b. updates seq_group.metrics.last_token_time) time_last_iters.append(seq_group.get_last_latency(now)) diff --git a/vllm/sequence.py b/vllm/sequence.py index 05022998ab19..3063b4a674cf 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -113,6 +113,7 @@ def __init__( self.prompt_token_ids = prompt_token_ids self.output_token_ids = output_token_ids self.cumulative_logprob = 0.0 + # The number of tokens that are computed (that run against the model). self._num_computed_tokens = 0 def append_token_id(self, token_id: int, logprob: float) -> None: @@ -229,7 +230,7 @@ def hash_of_block(self, logical_idx: int) -> int: def num_hashed_tokens_of_block(self, logical_idx: int): return logical_idx * self.block_size + self.block_size - def on_recompute(self): + def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" self.data.reset_num_computed_tokens() From 9bdb9dce411a58cee63f5c0f1db11604450f5976 Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 28 Mar 2024 07:24:55 -0700 Subject: [PATCH 87/88] fix tests. --- tests/core/test_scheduler.py | 2 +- vllm/core/scheduler.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 54ca1cb4ab7e..88c2c37f4fb3 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -11,7 +11,7 @@ def get_sequence_groups(scheduler_output): - return [s for s, _ in scheduler_output.scheduled_seq_groups] + return [s.seq_group for s in scheduler_output.scheduled_seq_groups] def test_scheduler_add_seq_group(): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index bbe7d858adb4..04e8056aab54 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -460,8 +460,9 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # batch will have been computed before the next scheduling invocation. # This is because the engine assumes that a failure in model execution # will crash the vLLM instance / will not retry. - for seq_group in scheduler_outputs.scheduled_seq_groups: - self.block_manager.mark_blocks_as_computed(seq_group) + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + self.block_manager.mark_blocks_as_computed( + scheduled_seq_group.seq_group) return seq_group_metadata_list, scheduler_outputs From 88126a9ac1587739bbf6412cd9963563817433cb Mon Sep 17 00:00:00 2001 From: sang Date: Thu, 28 Mar 2024 08:19:24 -0700 Subject: [PATCH 88/88] fixed a bug --- vllm/worker/model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c932f6f2218e..31fa52476af1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -168,7 +168,9 @@ def _prepare_prompt( token_chunk_size = seq_group_metadata.token_chunk_size seq_data = seq_group_metadata.seq_data[seq_id] computed_len = seq_data.get_num_computed_tokens() - prefill_end = min(seq_data.get_prompt_len(), + # We should use get_len here because in case of preemption + # it contains output tokens. + prefill_end = min(seq_data.get_len(), computed_len + token_chunk_size) # TODO(sang): Rename it after chunked prefill is introduced. prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]