diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 15f971b66e3b..fc5d0b2092e3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -46,6 +46,10 @@ steps: commands: - pytest -v -s prefix_caching +- label: Chunked Prefill Test + commands: + - pytest -v -s chunked_prefill + - label: Samplers Test command: pytest -v -s samplers --forked diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 2fdc08c5c26d..a75ca427506f 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -10,6 +10,13 @@ from vllm import LLM, SamplingParams +SAMPLE_PROMPTS = [ + "The president of the United States is", + "Hello, my name is", + "The capital of France is", + "The future of AI is", +] + def main(args: argparse.Namespace): print(args) @@ -26,6 +33,8 @@ def main(args: argparse.Namespace): enforce_eager=args.enforce_eager, kv_cache_dtype=args.kv_cache_dtype, device=args.device, + block_size=args.block_size, + flash_style=args.flash_style, ray_workers_use_nsight=args.ray_workers_use_nsight, ) @@ -58,10 +67,25 @@ def run_to_completion(profile_dir: Optional[str] = None): print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(prompt_token_ids=dummy_prompt_token_ids, - sampling_params=sampling_params, - use_tqdm=False) + if args.use_sample: + batch = (SAMPLE_PROMPTS * + (args.batch_size // len(SAMPLE_PROMPTS) + + 1))[:args.batch_size] + outputs = llm.generate(prompts=batch, + sampling_params=sampling_params, + use_tqdm=False) + else: + outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids, + sampling_params=sampling_params, + use_tqdm=False) end_time = time.perf_counter() + if args.verbose: + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print( + f"Prompt: {prompt!r}, Generated text: {generated_text!r}" + ) latency = end_time - start_time return latency @@ -146,6 +170,19 @@ def run_to_completion(profile_dir: Optional[str] = None): default="cuda", choices=["cuda"], help='device type for vLLM execution, supporting CUDA only currently.') + parser.add_argument('--flash-style', + action='store_true', + help='enable flash attention') + parser.add_argument('--block-size', + type=int, + default=16, + help='block size of key/value cache') + parser.add_argument('--use-sample', + action='store_true', + help='use sample input instead of dummy input') + parser.add_argument('--verbose', + action='store_true', + help='print generated text') parser.add_argument( "--ray-workers-use-nsight", action='store_true', diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index d921dea1220e..7b3c6405e259 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -7,6 +7,7 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random from vllm._C import ops +from vllm.model_executor.layers.attention import flash_attn_with_kvcache_paged NUM_BLOCKS = 1024 PARTITION_SIZE = 512 @@ -64,14 +65,17 @@ def main( block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) # Create the KV cache. - key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, - block_size, - 1, - num_kv_heads, - head_size, - kv_cache_dtype, - dtype, - device=device) + flash_style = version == "flash" + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + flash_style=flash_style) key_cache, value_cache = key_caches[0], value_caches[0] # Prepare for the paged attention kernel. @@ -131,6 +135,16 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: alibi_slopes, kv_cache_dtype, ) + elif version == "flash": + flash_attn_with_kvcache_paged( + query.view(num_seqs, 1, num_query_heads, head_size), + key_cache, + value_cache, + scale, + block_tables, + context_lens, + alibi_slopes=alibi_slopes, + ) else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() @@ -158,7 +172,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: description="Benchmark the paged attention kernel.") parser.add_argument("--version", type=str, - choices=["v1", "v2"], + choices=["v1", "v2", "flash"], default="v2") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--context-len", type=int, default=4096) @@ -168,7 +182,10 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: type=int, choices=[64, 80, 96, 112, 128, 256], default=128) - parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--block-size", + type=int, + choices=[16, 32, 256], + default=16) parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--dtype", type=str, diff --git a/csrc/cache.h b/csrc/cache.h index 765e231abd26..1bca7e4e39a9 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -23,6 +23,13 @@ void reshape_and_cache( torch::Tensor& slot_mapping, const std::string& kv_cache_dtype); +void reshape_and_cache_flash( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping); + // Just for unittest void convert_fp8_e5m2( torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 7254010b8e3a..06da455117b0 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -174,6 +174,7 @@ __global__ void reshape_and_cache_kernel( const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size; for (int i = threadIdx.x; i < n; i += blockDim.x) { const int64_t src_key_idx = token_idx * key_stride + i; @@ -269,6 +270,92 @@ void reshape_and_cache( namespace vllm { +// flash-attention style cache funciton where key/value caches +// has the same shape of [num_blocks, block_size, num_heads, head_size] +template +__global__ void reshape_and_cache_flash_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, head_size] + scalar_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + + // Target idx for kv are the same. + const int64_t tgt_idx = block_idx * block_size * num_heads * head_size + + block_offset * num_heads * head_size + + head_idx * head_size + + head_offset; + + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + // TODO(sang): Support ENABLE_FP8_E5M2. + key_cache[tgt_idx] = tgt_key; + value_cache[tgt_idx] = tgt_value; + } +} + +} // namespace vllm + +// TODO(sang): Support kv_cache_dtype +void reshape_and_cache_flash( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping) +{ + int num_tokens_padded = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(1); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + dim3 grid(num_tokens_padded); + dim3 block(std::min(num_heads * head_size, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), + "reshape_and_cache_flash_kernel", + [&] { + vllm::reshape_and_cache_flash_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), + key_stride, + value_stride, + num_heads, + head_size, + block_size); + }); +} + +namespace vllm { + template __global__ void convert_fp8_e5m2_kernel( const Tin* __restrict__ src_cache, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 4b6ade756639..a9ff2e8f5830 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -81,6 +81,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); + cache_ops.def( + "reshape_and_cache_flash", + &reshape_and_cache_flash, + "Reshape the flash-style key and value tensors and cache them."); cache_ops.def( "convert_fp8_e5m2", &convert_fp8_e5m2, diff --git a/requirements.txt b/requirements.txt index 05ec2e804e13..7aff2658f9d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ pynvml == 11.5.0 triton >= 2.1.0 outlines >= 0.0.27 cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. +flash-attn >= 2.5.0 # Required for chunked prefill. diff --git a/tests/chunked_prefill/test_correctness.py b/tests/chunked_prefill/test_correctness.py new file mode 100644 index 000000000000..6b38bb97b41d --- /dev/null +++ b/tests/chunked_prefill/test_correctness.py @@ -0,0 +1,86 @@ +import gc + +import pytest +import torch + +from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel + +MODELS = [ + "JackFram/llama-68m", + # "facebook/opt-125m", +] + +TEST_PROMPTS = [ + # pylint: disable=line-too-long + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", + "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", + "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", + # Different between page attention and flash attention. + # "Describe the basic components of a neural network and how it can be trained.", + "Write a short story about a robot that dreams for the first time.", + "Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.", + "Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.", + "Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'", +] + + +# TODO(sang): Add chunked prefill parameters. +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("block_size", [32]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_models( + vllm_runner, + model: str, + dtype: str, + max_tokens: int, + block_size: int, + tensor_parallel_size: int, +) -> None: + """ verify the flash attention has the same output + as page attention """ + if torch.cuda.device_count() < tensor_parallel_size: + pytest.skip( + f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}" + ) + print("loading page attention models..") + pg_model = vllm_runner(model, dtype=dtype) + expected_outputs = [] + + print("generating tokens...") + expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens)) + print("generating tokens finished") + + del pg_model + + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + + flash_attn_output_by_batches = [] + flash_attn_model = vllm_runner(model, + dtype=dtype, + block_size=block_size, + flash_style=True, + tensor_parallel_size=tensor_parallel_size) + for i in range(10): + prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)] + flash_attn_output_by_batches.append( + flash_attn_model.generate_greedy(prompts, max_tokens)) + + del flash_attn_model + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + + for flash_attn_outputs in flash_attn_output_by_batches: + for i in range(len(flash_attn_outputs)): + fa_output_ids, fa_output_str = flash_attn_outputs[i] + vllm_output_ids, vllm_output_str = expected_outputs[ + i % len(expected_outputs)] + print(vllm_output_str) + assert fa_output_ids == vllm_output_ids, ( + f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}" + f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}" + ) diff --git a/tests/conftest.py b/tests/conftest.py index 6eb8159837d5..f23f2117b4b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -165,6 +165,8 @@ def __init__( dtype: str = "half", disable_log_stats: bool = True, tensor_parallel_size: int = 1, + block_size: int = 32, + flash_style: bool = False, **kwargs, ) -> None: self.model = LLM( @@ -175,6 +177,8 @@ def __init__( swap_space=0, disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, + flash_style=flash_style, + block_size=block_size, **kwargs, ) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index ebfeb8ba0481..397101fa8610 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -10,7 +10,7 @@ def test_scheduler_add_seq_group(): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1, 256) + scheduler_config = SchedulerConfig(100, 64, 1) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -26,7 +26,7 @@ def test_scheduler_add_seq_group(): def test_scheduler_abort_seq_group(): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1, 256) + scheduler_config = SchedulerConfig(100, 64, 1) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -50,7 +50,7 @@ def test_scheduler_schedule_simple(): block_size = 4 num_seq_group = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len, 256) + scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -64,10 +64,10 @@ def test_scheduler_schedule_simple(): running.append(seq_group) # Schedule seq groups prompts. + num_tokens = block_size * num_seq_group seq_group_meta, out = scheduler.schedule() assert set(out.scheduled_seq_groups) == set(running) - assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs( - )[0].get_len() + assert out.num_batched_tokens == num_tokens assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == num_seq_group @@ -84,7 +84,7 @@ def test_scheduler_schedule_simple(): def test_scheduler_schedule_preempt_abort(): block_size = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, 2, max_model_len, 256) + scheduler_config = SchedulerConfig(64, 2, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 @@ -99,7 +99,7 @@ def test_scheduler_schedule_preempt_abort(): # Schedule seq groups prompts. seq_group_meta, out = scheduler.schedule() assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] - assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2 + assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == 2 @@ -124,7 +124,7 @@ def test_scheduler_schedule_preempt_abort(): scheduler.abort_seq_group("1") seq_group_meta, out = scheduler.schedule() assert out.scheduled_seq_groups == [seq_group_b] - assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len() + assert out.num_batched_tokens == 5 # 4 prompt + 1 generation. assert (not out.blocks_to_copy and not out.blocks_to_swap_in and not out.blocks_to_swap_out) assert len(seq_group_meta) == 1 @@ -136,7 +136,7 @@ def test_scheduler_max_seqs(): num_seq_group = 4 max_seq_group = 2 max_model_len = 16 - scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len, 256) + scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index d8dc74bc7b00..d2de4105b7f3 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -224,3 +224,73 @@ def test_swap_blocks( dist_key_caches[0][dst].cpu()) assert torch.allclose(src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_reshape_and_cache_flash( + kv_cache_factory, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Create a random slot mapping. + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device='cuda') + _, key, value = qkv.unbind(dim=1) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(num_blocks, + block_size, + 1, + num_heads, + head_size, + dtype, + seed, + flash_style=True) + assert len(key_caches) == 1 and len(value_caches) == 1 + key_cache, value_cache = key_caches[0], value_caches[0] + + # Clone the KV caches. + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Call the reshape_and_cache kernel. + cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping) + + # Run the reference implementation. + block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies[i] + block_offset = block_offsets[i] + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py new file mode 100644 index 000000000000..66a668cb7dd5 --- /dev/null +++ b/tests/kernels/test_flash_attention.py @@ -0,0 +1,633 @@ +import random +from typing import List, Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.attention import ( + flash_single_query_cached_kv_attention, + flash_multi_query_cached_kv_attention_varlen, +) +from vllm.utils import get_max_shared_memory_bytes + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +NUM_BLOCKS = 128 # Arbitrary values for testing +PARTITION_SIZE = 512 + +DTYPES = [torch.half, torch.bfloat16] +NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing +NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing +NUM_HEADS_SMALL = NUM_HEADS +HEAD_SIZES = [32, 64, 128, 256] +BLOCK_SIZES = [32, 64, 256] +USE_ALIBI = [False, True] +SEEDS = [0] +PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] + + +def pad_attention_inputs( + pad_config: Tuple[int, int], + block_size: int, + query: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Pad the attention inputs to the specified batch size and context length. + """ + pad_batch_size, pad_max_context_len = pad_config + if pad_batch_size == 0: + return query, block_tables, context_lens, max_context_len + target_batch_size = ( + (query.shape[0] - 1) % pad_batch_size + 1) * pad_batch_size + target_block_size = pad_max_context_len // block_size + 1 + padded_query = F.pad(query, + (0, 0, 0, 0, 0, target_batch_size - query.shape[0])) + padded_block_table = F.pad(block_tables, + (0, target_block_size - block_tables.shape[1], + 0, target_batch_size - block_tables.shape[0])) + padded_context_lens = F.pad(context_lens, + (0, target_batch_size - context_lens.shape[0])) + return padded_query, padded_block_table, padded_context_lens, pad_max_context_len + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], + flash_style: bool = False, +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[-2] + head_size = value_cache.shape[-1] + block_size = value_cache.shape[-3] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + context_lens = context_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + context_len = int(context_lens[i]) + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + if flash_style: + k = key_cache[block_number, block_offset, :, :] + else: + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + keys.append(k) + + if flash_style: + v = value_cache[block_number, block_offset, :, :] + else: + v = value_cache[block_number, :, :, block_offset] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(context_len, device="cuda").int() + alibi_bias = (position_ids - context_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", [False]) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("pad_config", [(0, 0)]) +@torch.inference_mode() +def test_flash_paged_attention( + kv_cache_factory, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + seed: int, + pad_config: Tuple[int, int], +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, + dtype=torch.float, + device="cuda") + + max_seq_len = MAX_SEQ_LEN if not pad_config[0] else (pad_config[1] - 1000) + context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] + context_lens[-1] = max_seq_len + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + dtype, + seed, + flash_style=True) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Call the paged attention kernel. + output = torch.empty_like(query) + + padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ + pad_attention_inputs(pad_config, block_size, query, + block_tables, context_lens, max_context_len) + + flash_single_query_cached_kv_attention( + output, + padded_query.view(num_seqs, 1, num_query_heads, head_size), + key_cache, + value_cache, + scale, + padded_block_table, + padded_context_lens, + alibi_slopes, + ) + + # Run the reference implementation. + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + context_lens, + scale, + alibi_slopes, + flash_style=True, + ) + + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +def ref_multi_query_kv_attention( + cu_seq_lens: List[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + ref_outputs = [] + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask. + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype, device="cuda") + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +def ref_multi_query_kv_attention_padded( + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + cu_seq_lens: List[int], + context_lens: List[int], + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + block_size = value_cache.shape[-3] + ref_outputs = [] + + for i in range(num_seqs): + q_start_idx = cu_seq_lens[i] + q_end_idx = cu_seq_lens[i + 1] + seq_len = q_end_idx - q_start_idx + + context_len = context_lens[i] + + block_table = block_tables[i] + keys = [] + values = [] + + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, block_offset, :, :] + keys.append(k) + + v = value_cache[block_number, block_offset, :, :] + values.append(v) + + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + q = query[q_start_idx:q_end_idx, :, :] + k = keys[:context_len, :, :] + v = values[:context_len, :, :] + + assert seq_len <= context_len + + # pad q if seq_len is less than context_len + # this is for correct calculation of attention. + if seq_len < context_len: + indices = [i % seq_len for i in range(context_len - seq_len)] + q_left_pad = q[indices, :, :] + q = torch.cat([q_left_pad, q], dim=0) + + # Create attention mask. + attn_mask = torch.triu(torch.ones(context_len, + context_len, + dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype, device="cuda") + + ref_output = ref_masked_attention( + q, + k, + v, + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output[-seq_len:, :, :]) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +def is_a100(): + return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 + + +if not is_a100(): + NUM_HEADS_SMALL = [(16, 16), (16, 8)] + MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS_SMALL) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("version", ["flash"]) +@pytest.mark.parametrize("chunked_prefill", [False, True]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@torch.inference_mode() +def test_multi_query_kv_attention( + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + version: str, + seed: int, + chunked_prefill: bool, + block_size: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. + # As the xformers library is already tested with its own tests, we can use + # a smaller MAX_SEQ_LEN here. + max_len = min(MAX_SEQ_LEN, 4096) + + seq_lens = [random.randint(1, max_len // 2) for i in range(num_seqs)] + max_seq_len = max(seq_lens) + + if chunked_prefill: + # context length will be different from seq_len if chunked_prefill is + # true. + context_lens = random.sample(range(max_seq_len, max_len), num_seqs) + else: + context_lens = seq_lens + max_context_len = max(context_lens) + + num_tokens = sum(seq_lens) + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + + cu_context_lens = [0] + for context_len in context_lens: + cu_context_lens.append(cu_context_lens[-1] + context_len) + + print(f"cu_seq_lens={cu_seq_lens}, cu_context_lens={cu_context_lens}") + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + num_queries_per_kv = num_query_heads // num_kv_heads + + value_cache = torch.empty(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + key_cache = torch.empty(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + query = torch.empty(num_tokens, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + value_cache.uniform_(-scale, scale) + key_cache.uniform_(-scale, scale) + query.uniform_(-scale, scale) + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + output = torch.empty_like(query) + + if version == "flash": + flash_multi_query_cached_kv_attention_varlen( + output, + query, + key_cache, + value_cache, + scale, + block_tables, + torch.cuda.IntTensor(cu_seq_lens), + torch.cuda.IntTensor(cu_context_lens), + block_size, + max_seq_len, + max_context_len, + None, + ) + else: + raise AssertionError(f"{version=} is not supported") + + ref_output = ref_multi_query_kv_attention_padded( + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + cu_seq_lens, + context_lens, + scale, + dtype, + ) + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +# @pytest.mark.parametrize("num_heads", [40]) +# @pytest.mark.parametrize("head_size", [64]) +# @pytest.mark.parametrize("block_size", [32]) +# @pytest.mark.parametrize("num_blocks", [128]) +# @pytest.mark.parametrize("dtype", [torch.half]) +# @pytest.mark.parametrize("seed", SEEDS) +# @torch.inference_mode() +# def test_e2e( +# kv_cache_factory, +# num_heads: int, +# head_size: int, +# block_size: int, +# num_blocks: int, +# dtype: torch.dtype, +# seed: int, +# ) -> None: +# from vllm._C import cache_ops +# batch_size = 2 +# seqlen = 29 +# num_tokens = batch_size * seqlen +# random.seed(seed) +# torch.random.manual_seed(seed) +# torch.cuda.manual_seed(seed) + +# block_tables = [] +# for _ in range(batch_size): +# block_table = [ +# random.randint(0, NUM_BLOCKS - 1) +# ] +# block_tables.append(block_table) +# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") +# # Create a random slot mapping. +# slot_mapping = [] +# for i in range(0, 29): +# block_number = int(block_tables[i // block_size]) +# block_offset = i % block_size +# slot = block_number * block_size + block_offset +# slot_mapping.append(slot) +# # for _ in range(23, 29): +# # slot_mapping.append(-1) +# for i in range(0, 29): +# block_number = int(block_tables[1]) +# block_offset = i % block_size +# slot = block_number * block_size + block_offset +# slot_mapping.append(slot) +# # slot_mapping = random.sample(range(num_slots), num_tokens) +# slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + +# qkv = torch.randn(num_tokens, +# 3, +# num_heads, +# head_size, +# dtype=dtype, +# device='cuda') +# query, key, value = qkv.unbind(dim=1) +# # query[23:29] = 0 +# # key[23:29] = 0 +# # value[23:29] = 0 +# # Create the KV caches. + +# key_caches, value_caches = kv_cache_factory(num_blocks, +# block_size, +# 1, +# num_heads, +# head_size, +# dtype, +# seed, +# flash_style=True) +# assert len(key_caches) == 1 and len(value_caches) == 1 +# key_cache, value_cache = key_caches[0], value_caches[0] +# # Call the reshape_and_cache kernel. +# cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, +# slot_mapping) + +# cloned_key_cache = key_cache.clone() +# cloned_value_cache = value_cache.clone() + +# # Run the reference implementation. +# block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') +# block_indicies = block_indicies.cpu().tolist() +# block_offsets = slot_mapping % block_size +# block_offsets = block_offsets.cpu().tolist() +# for i in range(num_tokens): +# block_idx = block_indicies[i] +# block_offset = block_offsets[i] +# print("block_idx", block_idx) +# print("block_offset", block_offset) +# if block_idx != -1: +# cloned_key_cache[block_idx, block_offset, :, :] = key[i] +# cloned_value_cache[block_idx, block_offset, :, :] = value[i] +# assert torch.allclose(key_cache, cloned_key_cache) +# assert torch.allclose(value_cache, cloned_value_cache) + +# for i in range(58): +# print(i) +# block_idx = block_indicies[i] +# block_offset = block_offsets[i] +# torch.allclose(key[i], key_cache[block_idx][block_offset]) + +# from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask +# from xformers import ops as xops +# seqlen = query.shape[1] +# attn_bias = BlockDiagonalCausalMask.from_seqlens( +# [seqlen] * batch_size) +# scale = float(1.0 / (head_size**0.5)) +# output = xops.memory_efficient_attention_forward( +# query.unsqueeze(0), +# key.unsqueeze(0), +# value.unsqueeze(0), +# attn_bias=None, +# p=0.0, +# scale=scale, +# ) + +# num_tokens, num_heads, head_size = query.shape +# from flash_attn import flash_attn_func +# output1 = flash_single_query_cached_kv_attention( +# None, +# query.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# key_cache, +# value_cache, +# scale, +# block_tables, +# torch.tensor([23, 29], dtype=torch.int, device='cuda'), +# alibi_slopes=None, +# ) +# output2 = flash_attn_func( +# # query.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# # key.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# # value.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# query.unsqueeze(0), +# key.unsqueeze(0), +# value.unsqueeze(0), +# # key_cache, +# # value_cache, +# softmax_scale=scale, +# # block_tables, +# # torch.tensor([23, 29], dtype=torch.int, device='cuda'), +# # alibi_slopes=None, +# causal=True, +# ) +# output3 = flash_attn_func( +# query.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# key.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# value.view(batch_size, num_tokens // batch_size, num_heads, head_size), +# # key_cache, +# # value_cache, +# softmax_scale=scale, +# # block_tables, +# # torch.tensor([23, 29], dtype=torch.int, device='cuda'), +# # alibi_slopes=None, +# causal=True, +# ) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index e881cd1ec375..1cfaffd5007f 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -113,7 +113,6 @@ def test_contexted_kv_attention( # to V_cache[num_blocks, num_kv_heads, head_size, block_size] v_cache = v_cache.view(-1, block_size, num_kv_heads, head_size).permute(0, 2, 3, 1).contiguous() - # Warm up the Triton kernel by calling it once before actually measuring generation time context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, b_start_loc, b_seq_len, b_ctx_len, max_input_len) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 31a7c716afbf..e4538de35169 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -25,7 +25,7 @@ def test_worker_apply_lora(sql_lora_files): revision=None, ), parallel_config=ParallelConfig(1, 1, False), - scheduler_config=SchedulerConfig(32, 32, 32, 256), + scheduler_config=SchedulerConfig(32, 32, 32), device_config=DeviceConfig("cuda"), local_rank=0, rank=0, diff --git a/tests/models/test_models.py b/tests/models/test_models.py index fb567e837d28..c268b6fd4868 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -6,26 +6,27 @@ MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", - "mistralai/Mistral-7B-v0.1", - "Deci/DeciLM-7b", - "tiiuae/falcon-7b", - "gpt2", - "bigcode/tiny_starcoder_py", - "EleutherAI/gpt-j-6b", - "EleutherAI/pythia-70m", - "bigscience/bloom-560m", - "mosaicml/mpt-7b", - "microsoft/phi-2", - "stabilityai/stablelm-3b-4e1t", - "allenai/OLMo-1B", - "bigcode/starcoder2-3b", + # "meta-llama/Llama-2-7b-hf", + # "mistralai/Mistral-7B-v0.1", + # "Deci/DeciLM-7b", + # "tiiuae/falcon-7b", + # "gpt2", + # "bigcode/tiny_starcoder_py", + # "EleutherAI/gpt-j-6b", + # "EleutherAI/pythia-70m", + # "bigscience/bloom-560m", + # "mosaicml/mpt-7b", + # "microsoft/phi-2", + # "stabilityai/stablelm-3b-4e1t", + # "allenai/OLMo-1B", + # "bigcode/starcoder2-3b", ] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("enforce_eager", [False, True]) def test_models( hf_runner, vllm_runner, @@ -33,12 +34,13 @@ def test_models( model: str, dtype: str, max_tokens: int, + enforce_eager: bool, ) -> None: hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype) + vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/prompts/example.txt b/tests/prompts/example.txt index e1b97bc6eee7..cef4d1d76873 100644 --- a/tests/prompts/example.txt +++ b/tests/prompts/example.txt @@ -5,4 +5,4 @@ Describe the basic components of a neural network and how it can be trained. Write a short story about a robot that dreams for the first time. Analyze the impact of the COVID-19 pandemic on global economic structures and future business models. Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies. -Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.' +Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.' \ No newline at end of file diff --git a/tests/worker/spec_decode/test_multi_step_worker.py b/tests/worker/spec_decode/test_multi_step_worker.py index ea5480290357..15a0546813eb 100644 --- a/tests/worker/spec_decode/test_multi_step_worker.py +++ b/tests/worker/spec_decode/test_multi_step_worker.py @@ -90,8 +90,8 @@ def test_same_output_for_single_step(): num_gpu_blocks, seed, ) - multi_step_worker.model_runner = worker.model_runner - multi_step_worker.cache_engine = worker.cache_engine + # multi_step_worker.model_runner = worker.model_runner + # multi_step_worker.cache_engine = worker.cache_engine num_steps = 1 @@ -258,4 +258,4 @@ def test_same_output_for_multi_step(): for multi_step_logprobs, single_step_logprobs in zip( multi_step_output_logprobs, single_step_output_logprobs): assert_logprobs_dict_allclose(multi_step_logprobs, - single_step_logprobs) + single_step_logprobs) \ No newline at end of file diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index f44895a728c7..55a078230b46 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -2,7 +2,14 @@ import torch from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import ModelRunner +from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT +from vllm.config import ModelConfig + + +# Make sure the result is aligned. +def round_up_to_next_multiple_of_batch_size(n): + batch_size = _BATCH_SIZE_ALIGNMENT + return ((n + 7) // batch_size) * batch_size def test_prepare_prompt(): @@ -12,6 +19,7 @@ def test_prepare_prompt(): batch_size = random.randint(1, 256) prompt_lens = [] seq_group_metadata_list = [] + block_tables = {0: [1]} for i in range(batch_size): # make sure all tokens fit into one block prompt_len = i % (model_runner.block_size - 1) + 1 @@ -23,24 +31,75 @@ def test_prepare_prompt(): is_prompt=True, seq_data={0: SequenceData(seq_data)}, sampling_params=SamplingParams(temperature=0), - block_tables={0: [1]}, + block_tables=block_tables, )) expected_selected_token_indices = [] selected_token_start_idx = 0 - max_seq_len = max(prompt_lens) for prompt_len in prompt_lens: expected_selected_token_indices.append(selected_token_start_idx + prompt_len - 1) - selected_token_start_idx += max_seq_len - input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = ( + selected_token_start_idx += prompt_len + input_tokens, input_positions, input_metadata, return_prompt_lens, _, _, _, _ = ( model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens + + # Verify input metadata is correct for prompts. + device = model_runner.device + assert input_metadata.is_prompt is True + assert torch.allclose(input_metadata.prompt_lens, + torch.tensor(prompt_lens, device=device)) + assert input_metadata.num_prompt_tokens == sum(prompt_lens) + assert input_metadata.num_generation_tokens == 0 + assert input_metadata.max_seq_len == max(prompt_lens) + # build start_loc + start_idx = 0 + start_loc = [start_idx] + # start_loc is padded. + for prompt_len in prompt_lens: + start_idx += prompt_len + start_loc.append(start_idx) + assert torch.allclose( + input_metadata.start_loc, + torch.tensor(start_loc, dtype=torch.long, device=device)) + assert input_metadata.max_context_len is None + # TODO(sang): The current definition of context_lens is the + # number of k/v that are already cached (before this run). + # It is inconsistent with decoding. + assert torch.allclose( + input_metadata.context_lens, + torch.zeros(input_metadata.context_lens.shape[0], + dtype=torch.int, + device=device)) + + expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.allclose(input_metadata.block_tables, expected) + # Cuda graph should not be used for prerill. + assert input_metadata.use_cuda_graph is False + assert input_metadata.kv_cache_dtype == "auto" + assert input_metadata.num_valid_tokens == round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)) + + assert input_tokens.shape == (round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)), ) + assert input_positions.shape == (round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)), ) + torch.testing.assert_close(input_tokens, input_positions) + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - assert input_tokens.shape == (batch_size, max_seq_len) - assert input_positions.shape == (batch_size, max_seq_len) + assert input_tokens.shape == (round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)), ) + assert input_positions.shape == (round_up_to_next_multiple_of_batch_size( + sum(prompt_lens)), ) + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) torch.testing.assert_close(input_tokens, input_positions) actual = sampling_metadata.selected_token_indices @@ -48,3 +107,93 @@ def test_prepare_prompt(): device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) + + +def test_prepare_decode_cuda_graph(): + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=False, + ) + model_runner = ModelRunner(model_config, None, None, None, None) + model_runner.set_block_size(16) + + # Make sure the result is aligned. + def round_up_to_next_multiple_of_batch_size(n): + batch_size = _BATCH_SIZE_ALIGNMENT + return ((n + 7) // batch_size) * batch_size + + batch_size = random.randint(1, 256) + prompt_lens = [] + seq_group_metadata_list = [] + for i in range(batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_lens.append(prompt_len) + seq_data = list(range(prompt_len)) + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: SequenceData(seq_data)}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + )) + + input_tokens, input_positions, input_metadata, _, _, _ = ( + model_runner._prepare_decode(seq_group_metadata_list)) + + # Verify input metadata is correct for prompts. + device = model_runner.device + assert input_metadata.is_prompt is False + assert input_metadata.prompt_lens is None + assert input_metadata.num_prompt_tokens == 0 + assert input_metadata.num_generation_tokens == ( + round_up_to_next_multiple_of_batch_size(len(seq_group_metadata_list))) + assert input_metadata.max_seq_len is None + assert input_metadata.start_loc is None + assert input_metadata.max_context_len == max(prompt_lens) + assert torch.allclose( + input_metadata.context_lens[:len(prompt_lens)], + torch.tensor(prompt_lens, dtype=torch.int, device=device)) + + # block table's first index corresponds to each batch, meaning in + # decoding it is each token. + assert input_metadata.block_tables.shape[0] == len(input_tokens) + # Block table's second dim correspondsd to each token's block number. + # It is padded up to + assert input_metadata.block_tables.shape[1] == ( + model_runner.get_max_block_per_batch()) + # Cuda graph should not be used for prerill. + assert input_metadata.use_cuda_graph is True + assert input_metadata.kv_cache_dtype == "auto" + assert input_metadata.num_valid_tokens == ( + round_up_to_next_multiple_of_batch_size(len(seq_group_metadata_list))) + + assert input_tokens.shape == (round_up_to_next_multiple_of_batch_size( + len(seq_group_metadata_list)), ) + assert input_positions.shape == (round_up_to_next_multiple_of_batch_size( + len(seq_group_metadata_list)), ) + torch.testing.assert_close(input_tokens, input_positions) + + # Verify Sampling + expected_selected_token_indices = [] + selected_token_start_idx = 0 + for prompt_len in prompt_lens: + expected_selected_token_indices.append(selected_token_start_idx) + selected_token_start_idx += 1 + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens) + actual = sampling_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) diff --git a/vllm/config.py b/vllm/config.py index ef9a920f29c2..292ba5222eef 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -60,6 +60,7 @@ class ModelConfig: max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. + flash_style: Enable flash style page attention. """ def __init__( @@ -79,7 +80,7 @@ def __init__( quantization: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, - max_logprobs: int = 5, + flash_style: bool = False, ) -> None: self.model = model self.tokenizer = tokenizer @@ -95,6 +96,7 @@ def __init__( self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture self.max_logprobs = max_logprobs + self.flash_style = flash_style if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, @@ -306,6 +308,7 @@ def __init__( cache_dtype: str, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, + flash_style: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -313,6 +316,7 @@ def __init__( self.cache_dtype = cache_dtype self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching + self.flash_style = flash_style self._verify_args() self._verify_cache_dtype() @@ -330,6 +334,15 @@ def _verify_args(self) -> None: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + if self.flash_style: + logger.info("Flash attention enabled.") + if self.block_size > 32: + # Flash style attention only supports block size >=256 for now. + # https://github.com/Dao-AILab/flash-attention/pull/824 will fix it. + raise ValueError( + "Flash style attention only supports block size <= 32. Got" + f"{self.block_size }") + def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass @@ -456,7 +469,6 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). - max_paddings: Maximum number of paddings to be added to a batch. """ def __init__( @@ -464,7 +476,6 @@ def __init__( max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, - max_paddings: int, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -474,7 +485,6 @@ def __init__( self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len - self.max_paddings = max_paddings self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c96c6d62ef19..3ec5216a0f18 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -173,12 +173,12 @@ def _schedule(self) -> SchedulerOutputs: curr_loras = set( seq_group.lora_int_id for seq_group in self.running) if self.lora_enabled else None - seq_lens: List[int] = [] # Optimization: We do not sort the waiting queue since the preempted # sequence groups are added to the front and the new sequence groups # are added to the back. leftover_waiting_sequences = deque() + num_batched_tokens = 0 while self.waiting: seq_group = self.waiting[0] waiting_seqs = seq_group.get_seqs( @@ -223,8 +223,7 @@ def _schedule(self) -> SchedulerOutputs: continue # If the number of batched tokens exceeds the limit, stop. - new_seq_lens = seq_lens + [num_prompt_tokens] - num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) + num_batched_tokens += num_prompt_tokens if (num_batched_tokens > self.scheduler_config.max_num_batched_tokens): break @@ -236,11 +235,6 @@ def _schedule(self) -> SchedulerOutputs: self.scheduler_config.max_num_seqs): break - num_paddings = num_batched_tokens - sum(new_seq_lens) - if num_paddings > self.scheduler_config.max_paddings: - break - seq_lens = new_seq_lens - if lora_int_id > 0: curr_loras.add(lora_int_id) self.waiting.popleft() @@ -255,8 +249,7 @@ def _schedule(self) -> SchedulerOutputs: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, prompt_run=True, - num_batched_tokens=len(seq_lens) * - max(seq_lens) if seq_lens else 0, + num_batched_tokens=num_batched_tokens, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c3dccdd5bb50..030d525e75e9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -30,7 +30,6 @@ class EngineArgs: gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 - max_paddings: int = 256 max_logprobs: int = 5 # OpenAI default value disable_log_stats: bool = False revision: Optional[str] = None @@ -46,6 +45,7 @@ class EngineArgs: lora_extra_vocab_size: int = 256 lora_dtype = 'auto' max_cpu_loras: Optional[int] = None + flash_style: bool = False device: str = 'auto' ray_workers_use_nsight: bool = False @@ -209,10 +209,6 @@ def add_cli_args( type=int, default=EngineArgs.max_num_seqs, help='maximum number of sequences per iteration') - parser.add_argument('--max-paddings', - type=int, - default=EngineArgs.max_paddings, - help='maximum number of paddings in a batch') parser.add_argument( '--max-logprobs', type=int, @@ -287,6 +283,9 @@ def add_cli_args( default=EngineArgs.device, choices=["auto", "cuda", "neuron"], help='Device type for vLLM execution.') + parser.add_argument('--flash-style', + action='store_true', + help='use flash attention.') return parser @classmethod @@ -308,12 +307,14 @@ def create_engine_configs( self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.enforce_eager, self.max_context_len_to_capture, - self.max_logprobs) + self.max_logprobs, + self.flash_style) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, model_config.get_sliding_window(), - self.enable_prefix_caching) + self.enable_prefix_caching, + self.flash_style) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, @@ -322,8 +323,7 @@ def create_engine_configs( self.ray_workers_use_nsight) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, - model_config.max_model_len, - self.max_paddings) + model_config.max_model_len) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8484014c9a13..91a7c38bd889 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -798,7 +798,6 @@ def _process_model_outputs( # Log stats. if self.log_stats: self.stat_logger.log(self._get_stats(scheduler_outputs)) - return request_outputs def step(self) -> List[RequestOutput]: @@ -853,6 +852,8 @@ def step(self) -> List[RequestOutput]: >>> break """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + # print("SANG-TODO step seq_group_metadata_list length: ", + # len(seq_group_metadata_list)) if not scheduler_outputs.is_empty(): # Execute the model. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1f463bdaaedc..efdf37030f61 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -146,6 +146,7 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ + # print("SANG-TODO generate: ", prompts, prompt_token_ids) if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index f0a88ac8e27f..f3589d04a977 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -6,13 +6,29 @@ class InputMetadata: """Metadata for input sequences. Used in PagedAttention. + NOTE: Any python object stored here is not updated when it is + cuda-graph replays. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + Args: prompt_lens: Lengths of prompts. - slot_mapping: The address to write the new KV to of each token. + slot_mapping: The index of each token mapped into a physical block + in block tables. E.g., if block_size is 32, 35 means it is in + the block number 1, 3rd index. + num_prompt_tokens: The number of tokens in the prompts. This might + include padding. + num_generation_tokens: The number of tokens in the generation sequences. + This might include padding. max_context_len: The maximum context length. context_lens: the length of attention context for each sequence. + I.e., the number of tokens that have attended so far. block_tables: The block tables. (Seq id -> list of physical block) kv_cache_dtype: Data type to store kv cache. + num_prompt_tokens: The number of tokens in the prompts. This might + include padding. + num_generation_tokens: The number of tokens in the generation sequences. + This might include padding. """ def __init__( @@ -20,6 +36,8 @@ def __init__( is_prompt: bool, slot_mapping: torch.Tensor, prompt_lens: Optional[torch.Tensor], + num_prompt_tokens: int, + num_generation_tokens: int, max_seq_len: Optional[int], start_loc: Optional[torch.Tensor], max_context_len: Optional[int], @@ -27,28 +45,85 @@ def __init__( block_tables: Optional[torch.Tensor], use_cuda_graph: bool, kv_cache_dtype: str, + flash_style: bool, ) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens + self.num_prompt_tokens = num_prompt_tokens + self.num_generation_tokens = num_generation_tokens self.max_seq_len = max_seq_len self.start_loc = start_loc self.max_context_len = max_context_len self.slot_mapping = slot_mapping + # Index: The batched sequence's index. + # Value: The length of attention context. + # NOTE(sang): When it is prefill/decoding, + # the definition is different. For prefill, + # it means the the length of KV that are cached + # excluding the new KVs. In decoding, this + # includes a new KV. self.context_lens = context_lens self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph self.kv_cache_dtype = kv_cache_dtype + self.flash_style = flash_style # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. self.attn_bias = None + # Number of valid tokens. It includes paddings. + # See attention.py for precise definition. + self.num_valid_tokens = slot_mapping.shape[0] + + # SANG-TODO + # # Prompt related metadata + # # This value might include padding if CudaGraph is enabled. + # self.num_prompts = len(prompt_lens) + # # This value is the source of truth. + # self.num_prompts_tensor = torch.cuda.IntTensor([self.num_prompts]) + # # This value might include padding if CudaGraph is enabled. + # self.num_prompt_tokens = num_prompt_tokens + # self.prompt_lens_tensor = torch.cuda.IntTensor(self.prompt_lens) + # self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 + + # # Cumulative prompt lengths for each prompt in the input + # # tensor. + # self.cum_prompt_query_lens = torch.zeros( + # self.num_prompts + 1, + # device=self.prompt_lens_tensor.device, + # dtype=torch.int32) + # # Cumulative context lengths. + # self.cum_prompt_context_lens = torch.zeros( + # self.num_prompts + 1, + # device=self.prompt_lens_tensor.device, + # dtype=torch.int32) + + # torch.cumsum(self.prompt_lens_tensor, + # dim=0, + # dtype=self.cum_prompt_query_lens.dtype, + # out=self.cum_prompt_query_lens[1:]) + + # # TODO: this will be different once we support chunked prefills. + # self.cum_prompt_context_lens = self.cum_prompt_query_lens + # self.max_context_len = max(self.max_context_len, self.max_prompt_len) + + # # Generation related metadata + # # This value might include padding if CudaGraph is enabled. + # self.num_generation_tokens = num_generation_tokens + # # This is the source of truth for the number of generation tokens. + # self.num_generation_tokens_tensor = torch.cuda.IntTensor( + # [num_generation_tokens]) def __repr__(self) -> str: return ("InputMetadata(" f"is_prompt={self.is_prompt}, " f"max_context_len={self.max_context_len}, " + f"num_generation_tokens={self.num_generation_tokens}, " + f"num_prompt_tokens={self.num_prompt_tokens}, " f"slot_mapping={self.slot_mapping}, " f"context_lens={self.context_lens}, " f"block_tables={self.block_tables}, " f"use_cuda_graph={self.use_cuda_graph}, " - f"kv_cache_dtype={self.kv_cache_dtype})") + f"flash_style={self.flash_style} " + f"kv_cache_dtype={self.kv_cache_dtype}) " + f"num_valid_tokens={self.num_valid_tokens}") diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 5a3a7b2dbaee..6d829395883a 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -20,8 +20,8 @@ class SiluAndMul(nn.Module): The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. Shapes: - x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) - return: (batch_size, seq_len, d) or (num_tokens, d) + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) """ def _forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 830e82e10f7a..58f1bbd99aad 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -5,7 +5,7 @@ import torch.nn as nn from vllm.model_executor.input_metadata import InputMetadata -from vllm.utils import is_hip +# from vllm.utils import is_hip class Attention(nn.Module): @@ -13,11 +13,27 @@ class Attention(nn.Module): This class takes query, key, and value tensors as input. The input tensors can either contain prompt tokens or generation tokens. + + + If the input tensors contain prompt tokens, the layout is as follows: + |<---------------------- num_valid_tokens ---------------------->| + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->| + + Otherwise, the layout is as follows: + |<------------------ num_valid_tokens ------------------->| + |<------- num_generation_tokens (M) ------->| + |<--generation_0-->|...|<--generation_M-1-->|<--padding-->| + + The prompts might have different lengths, while the generation tokens always + have length 1. The paddings are appended to make the input length a multiple + of 8, which is desirable for Tensor Cores. + The class does the following: 1. Store the input key and value tensors in the KV cache. 2. Perform (multi-head/multi-query/grouped-query) attention. - 3. Return the output tensor. + 3. Output a flattened 1D tensor. """ def __init__( @@ -30,8 +46,9 @@ def __init__( sliding_window: Optional[int] = None, ) -> None: super().__init__() - if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and - torch.get_default_dtype() in (torch.float16, torch.bfloat16)): + # if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and + # torch.get_default_dtype() in (torch.float16, torch.bfloat16)): + if False: # Ampere or later NVIDIA GPUs. # NOTE(woosuk): FlashAttention does not support FP32. from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index 512f4e49c7eb..20a3a072bb3b 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -53,18 +53,18 @@ def forward( """Forward pass with FlashAttention and PagedAttention. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] input_metadata: metadata for the inputs. Returns: - shape = [batch_size, seq_len, num_heads * head_size] + shape = [num_tokens, num_heads * head_size] """ - batch_size, seq_len, hidden_size = query.shape + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -74,27 +74,25 @@ def forward( # If key_cache and value_cache are not provided, the new key and value # vectors will not be cached. This happens during the initial memory # profiling run. - if key_cache is not None and value_cache is not None: - PagedAttentionImpl.reshape_and_cache(key, value, key_cache, - value_cache, input_metadata) + PagedAttentionImpl.reshape_and_cache(key, value, key_cache, + value_cache, input_metadata) if input_metadata.is_prompt: # Prompt run. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # normal attention - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) - output = flash_attn_func( - query, - key, - value, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) + output = torch.empty_like(query) + query = query.unflatten(0, (num_tokens, )) + key = key.unflatten(0, (num_tokens, )) + value = value.unflatten(0, (num_tokens, )) + output = flash_attn_func(query, + key, + value, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes) else: # prefix-enabled attention output = PagedAttentionImpl.forward_prefix( @@ -104,8 +102,6 @@ def forward( key_cache, value_cache, input_metadata, - self.num_heads, - self.num_kv_heads, self.alibi_slopes, ) else: @@ -121,4 +117,4 @@ def forward( ) # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + return output.view(num_tokens, hidden_size) diff --git a/vllm/model_executor/layers/attention/backends/xformers.py b/vllm/model_executor/layers/attention/backends/xformers.py index bad2a648b670..bbe77134ce12 100644 --- a/vllm/model_executor/layers/attention/backends/xformers.py +++ b/vllm/model_executor/layers/attention/backends/xformers.py @@ -11,10 +11,10 @@ from vllm.model_executor.layers.attention.ops.paged_attn import ( PagedAttentionImpl) from vllm.utils import is_hip +from vllm._C import cache_ops class XFormersBackend: - def __init__( self, num_heads: int, @@ -43,6 +43,8 @@ def __init__( self.use_ref_attention = _check_use_ref_attention() + # def _update_cache(self): + def forward( self, query: torch.Tensor, @@ -55,33 +57,41 @@ def forward( """Forward pass with xFormers and PagedAttention. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] + query: shape = [num_tokens num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, - block_size, x] + block_size, x]. None if it is a profiling run. value_cache: shape = [num_blocks, num_kv_heads, head_size, - block_size] + block_size]. None if it is a profiling run. input_metadata: metadata for the inputs. Returns: - shape = [batch_size, seq_len, num_heads * head_size] + shape = [num_tokens, num_heads * head_size] """ - batch_size, seq_len, hidden_size = query.shape + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) + output = torch.empty_like(query) # Reshape the keys and values and store them in the cache. # If key_cache and value_cache are not provided, the new key and value # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: - PagedAttentionImpl.reshape_and_cache(key, value, key_cache, + if input_metadata.flash_style: + # print("SANG-TODO reshape cache flash.") + cache_ops.reshape_and_cache_flash( + key, value, key_cache, value_cache, + input_metadata.slot_mapping.flatten()) + else: + PagedAttentionImpl.reshape_and_cache(key, value, key_cache, value_cache, input_metadata) if input_metadata.is_prompt: # Prompt run. + # Unless there's a prefix, context lens is all 0 for prefill. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): # normal attention @@ -109,15 +119,15 @@ def forward( if input_metadata.attn_bias is None: if self.alibi_slopes is None: attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size) + input_metadata.prompt_lens.tolist()) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) input_metadata.attn_bias = attn_bias else: input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, batch_size, - seq_len, query.dtype) + self.alibi_slopes, self.num_kv_heads, query.dtype, + input_metadata) if self.use_ref_attention: output = _ref_masked_attention( @@ -133,7 +143,7 @@ def forward( # with input tensor's size and stride (at least one # dimension spans across two contiguous subspaces). # Use reshape instead. - return output.reshape(batch_size, seq_len, hidden_size) + return output.reshape(num_tokens, hidden_size) # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. @@ -142,9 +152,9 @@ def forward( key = key.unsqueeze(0) value = value.unsqueeze(0) else: - query = query.unflatten(0, (batch_size, seq_len)) - key = key.unflatten(0, (batch_size, seq_len)) - value = value.unflatten(0, (batch_size, seq_len)) + query = query.unflatten(0, (num_tokens)) + key = key.unflatten(0, (num_tokens)) + value = value.unflatten(0, (num_tokens)) out = xops.memory_efficient_attention_forward( query, @@ -157,6 +167,117 @@ def forward( (is_hip()) else None, ) output = out.view_as(query) +# # if key_cache is not None and value_cache is not None: +# # output2 = flash_single_query_cached_kv_attention( +# # None, +# # query.view(batch_size, seq_len, self.num_heads, self.head_size), +# # key_cache, +# # value_cache, +# # self.scale, +# # input_metadata.block_tables, +# # input_metadata.context_lens, +# # alibi_slopes=self.alibi_slopes, +# # ) +# # output3 = flash_attn_func( +# # query.view(batch_size, seq_len, self.num_heads, self.head_size), +# # key.view(batch_size, seq_len, self.num_heads, self.head_size), +# # value.view(batch_size, seq_len, self.num_heads, self.head_size), +# # softmax_scale=self.scale, +# # causal=True, +# # ).view_as(query) +# # output4 = flash_attn_func( +# # query, +# # key, +# # value, +# # softmax_scale=self.scale, +# # causal=True, +# # ).view_as(query) +# # if key_cache is not None and value_cache is not None: +# # output2 = flash_attn_with_kvcache_paged( +# # query.view(batch_size, seq_len, self.num_heads, +# # self.head_size), +# # key_cache, +# # value_cache, +# # self.scale, +# # input_metadata.block_tables, +# # input_metadata.context_lens, +# # self.alibi_slopes, +# # ) +# else: +# if input_metadata.flash_style: +# output = flash_single_query_cached_kv_attention( +# None, +# query.view(batch_size, seq_len, self.num_heads, +# self.head_size), +# key_cache, +# value_cache, +# self.scale, +# input_metadata.block_tables, +# input_metadata.context_lens, +# alibi_slopes=self.alibi_slopes, +# ) +# # output = flash_multi_query_cached_kv_attention_varlen( +# # None, +# # query, +# # key_cache, +# # value_cache, +# # self.scale, +# # self.block_tables, +# # input_metadata.start_loc, +# # input_metadata.start_loc, +# # max_query_len, +# # max_context_len, +# # alibi_slopes=self.alibi_slopes, +# # ) +# else: +# # print("SANG-TODO context attention") +# # prefix-enabled attention +# output = torch.empty_like(query) +# context_attention_fwd( +# query, +# key, +# value, +# output, +# key_cache, +# value_cache, +# input_metadata. +# block_tables, # [BS, max_block_per_request] +# input_metadata.start_loc, +# input_metadata.prompt_lens, +# input_metadata.context_lens, +# input_metadata.max_seq_len, +# getattr(self, "alibi_slopes", None), +# ) + +# else: +# # Decoding run. +# if input_metadata.flash_style: +# # output = flash_attn_with_kvcache_paged( +# # query.view(batch_size, seq_len, self.num_heads, +# # self.head_size), key_cache, value_cache, +# # self.scale, input_metadata.block_tables, +# # input_metadata.context_lens, self.alibi_slopes) +# output = flash_single_query_cached_kv_attention( +# None, +# query.view(batch_size, 1, self.num_heads, self.head_size), +# key_cache, +# value_cache, +# self.scale, +# input_metadata.block_tables, +# input_metadata.context_lens, +# alibi_slopes=self.alibi_slopes, +# ) +# else: +# output = _paged_attention( +# query, +# key_cache, +# value_cache, +# input_metadata, +# self.num_kv_heads, +# self.scale, +# self.alibi_slopes, +# ) +# ======= else: # prefix-enabled attention @@ -182,41 +303,39 @@ def forward( ) # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) + return output.view(-1, self.num_heads * self.head_size) def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, - batch_size: int, - seq_len: int, dtype: torch.dtype, + input_metadata: InputMetadata, ) -> LowerTriangularMaskWithTensorBias: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - # When using custom attention bias, xformers requires the bias to - # be sliced from a tensor whose length is a multiple of 8. - padded_len = (seq_len + 7) // 8 * 8 - num_heads = alibi_slopes.shape[0] - bias = torch.empty( - batch_size, - num_heads, - seq_len, - padded_len, - device=alibi_slopes.device, - dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) - if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) - attn_bias = LowerTriangularMaskWithTensorBias(bias) - return attn_bias + for prompt_len in input_metadata.prompt_lens: + bias = torch.arange(prompt_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + padded_len = (prompt_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + prompt_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :prompt_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + attn_bias = LowerTriangularMaskWithTensorBias(bias) + return attn_bias def _check_use_ref_attention() -> bool: @@ -253,3 +372,209 @@ def _ref_masked_attention( attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) out = torch.einsum("hqk,khd->qhd", attn_weights, value) return out + + +# OSS version. +# def flash_attn_with_kvcache_paged( +# query: torch.Tensor, +# key_cache: torch.Tensor, +# value_cache: torch.Tensor, +# scale: float, +# block_tables: torch.Tensor, +# context_lens: torch.Tensor, +# alibi_slopes: Optional[torch.Tensor] = None, +# ) -> torch.Tensor: +# """Similar to vLLM's page attention, calculates a single token's attention +# based on key/value caches. The main difference is this uses flash attention +# style key-value caches. + +# Arguments: +# See https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py +# for other arguments. + +# Returns: +# output: [num_tokens, num_heads, head_size] +# """ +# block_size = value_cache.shape[1] +# assert block_size % 256 == 0, ("only support block_size divisible by 256. " +# f"Current block size: {block_size}") +# _, _, num_heads, head_size = query.shape +# out = flash_attn_with_kvcache( +# query, +# key_cache, +# value_cache, +# # Inplace update is slow. We don't use it. +# # We assume kvcache is already updated before +# # calling this API. +# None, # key +# None, # value +# cache_seqlens=context_lens, +# block_table=block_tables, +# softmax_scale=scale, +# causal=True, +# alibi_slopes=alibi_slopes, +# num_splits=0, +# ) + +# # num_tokens == batch_size * seqlen +# return out.view(-1, num_heads, head_size) + + +def flash_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + actual_batch_size: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Similar to vLLM's page attention, calculates a single token's attention + based on key/value caches. The main difference is this uses flash attention + style key-value caches. + + Arguments: + output: [num_padded_tokens, num_heads, head_size], output tensor + to write. if None an new output tensor will be created. + query: [batch_size, num_padded_tokens, num_heads, head_size], query tensor. + key_cache: [num_blocks, block_size, num_heads, head_size], key cache. + value_cache: [num_blocks, block_size, num_heads, head_size], value + cache. + scale: attention scale. + block_tables: [num_padded_tokens, max_context_len / block_size], + block tables. + context_lens: [num_padded_tokens], context lengths. + block_size: block size. + alibi_slopes: unused. + actual_batch_size: [1] actual batch size. + + Returns: + output: [num_padded_tokens, num_heads, head_size] + """ + block_size = value_cache.shape[1] + assert block_size >= 32, "only support block_size >= 32 for flash attention" + # TODO: support alibi_slopes + assert alibi_slopes is None, "doesn't support alibi_slopes" + batch_size, seqlen, num_heads, head_size = query.shape + assert seqlen == 1, ( + "Single query attention can be only used for decoding phase.") + num_tokens = batch_size * seqlen + out = flash_attn_with_page_attention( + query, + key_cache, + value_cache, + block_tables, + None, # key + None, # value + None, # cos + None, # sin + context_lens, + None, # cache_batch_idx + softmax_scale=scale, + causal=True, + window_size=(-1, -1), + rotary_interleaved=False, + num_splits=0, + actual_batch_size=actual_batch_size, + ) + if output is not None: + # in case that output is padded, only copy the valid part. + output[:num_tokens].copy_(out.view(num_tokens, num_heads, head_size)) + return out.view(num_tokens, num_heads, head_size) + + +def flash_multi_query_cached_kv_attention_varlen( + output: Optional[torch.Tensor], + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + scale: float, + block_tables: torch.Tensor, + cum_seqlens_q: torch.Tensor, + cum_context_len: torch.Tensor, + max_query_len: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + actual_batch_size: Optional[torch.Tensor] = None, +): + """Efficient multi-query paged attention based on flash attention. + It calculates attentions of list of sequences packed in a single batch, + indexed by cum_seqlens_q where the seq_i's index is + [cum_seqlens_q[i], cum_seqlensq[i+1]]. + Similarlly, the length of context is stored in cum_seqlens_k with similar + fashions. + It also supports calculating attention incrementally, where context length + is longer than sequence length. + + Arguments: + output: [num_padded_tokens, num_heads, head_size], output tensor to + write to. if None an new output tensor will be created. + + prefill -> always 1 batch + query, head_size, head_dim + + varlen -> provides cumulative lengths of queries + + decoding -> always 1 batch + 1 * number_of_batch, head_size, head_dim + + query: [num_padded_tokens, num_heads, head_size], query tensor. + key_cache: [num_blocks, block_size, num_heads, head_size], key cache. + value_cache: [num_blocks, block_size, num_heads, head_size], + value cache. + scale: attention scale. + block_tables: [num_padded_tokens, max_context_len / block_size], + block tables. + - these two are the same if no chunked prefill (for prefill) + - If you do chunked prefill, it may be different + - when? see attention mask setting + - Each iteration, it can have smaller # of queries than (because it is chunked) + tokens we should attend to (context len). + - cum_seqlens_q: actual query length (actual prompt length). + - cum_context_len: context len that needs to be attended (subquery). + cum_seqlens_q: [padded_batch_size + 1], cumulative sequence lengths + of query. + cum_context_len: [padded_batch_size + 1], cumulative lengths + of context. + block_size: block size. + max_query_len: max query length. + max_context_len: max context length. + alibi_slopes: unused. + actual_batch_size: [1] actual batch size. + + Returns: + output: [num_padded_tokens, num_heads, head_size] + """ + block_size = value_cache.shape[1] + assert block_size >= 32, "only support block_size >= 32 for flash attention" + # TODO: support alibi_slopes + assert alibi_slopes is None, "doesn't support alibi_slopes" + + num_tokens, _, _ = query.shape + out = flash_attn_varlen_with_page_attention( + query, + key_cache, + value_cache, + block_tables, + cum_seqlens_q, + cum_context_len, + max_query_len, + max_context_len, + None, # key + None, # value + None, # cos_cache + None, # sin_cache + None, # cache_batch_idx + softmax_scale=scale, + causal=True, + window_size=(-1, -1), + rotary_interleaved=False, + num_splits=0, + actual_batch_size=actual_batch_size, + ) + + if output is not None: + output[:num_tokens].copy_(out) + return out diff --git a/vllm/model_executor/layers/attention/ops/prefix_prefill.py b/vllm/model_executor/layers/attention/ops/prefix_prefill.py index 70f09224f1cf..c6054de2b718 100644 --- a/vllm/model_executor/layers/attention/ops/prefix_prefill.py +++ b/vllm/model_executor/layers/attention/ops/prefix_prefill.py @@ -696,7 +696,7 @@ def context_attention_fwd(q, ) return - _fwd_kernel[grid]( + _fwd_kernel_flash_attn_v2[grid]( q, k, v, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 320cb443524c..bb5aa281568f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -125,7 +125,6 @@ def _prune_hidden_states( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) return hidden_states.index_select(0, sampling_metadata.selected_token_indices) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index cb64d80c8147..5d4a460074c9 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -29,7 +29,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: architectures = ["QuantMixtralForCausalLM"] for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) + model_cls = ModelRegistry.load_model_cls(arch, model_config) if model_cls is not None: return model_cls raise ValueError( diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 75c2ae1e9f48..1e9e7b76659d 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -5,6 +5,7 @@ from vllm.logger import init_logger from vllm.utils import is_hip, is_neuron +from vllm.config import ModelConfig logger = init_logger(__name__) @@ -69,7 +70,8 @@ class ModelRegistry: @staticmethod - def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + def load_model_cls(model_arch: str, + model_config: ModelConfig) -> Optional[Type[nn.Module]]: if model_arch not in _MODELS: return None if is_hip(): @@ -92,7 +94,9 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: module_name = _NEURON_SUPPORTED_MODELS[model_arch] module = importlib.import_module( f"vllm.model_executor.models.{module_name}") - return getattr(module, model_cls_name, None) + model_cls = getattr(module, model_cls_name, None) + + return model_cls @staticmethod def get_supported_archs() -> List[str]: diff --git a/vllm/utils.py b/vllm/utils.py index 5b94067cec77..2d99c93df187 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -254,6 +254,7 @@ def create_kv_caches_with_random( model_dtype: Optional[Union[str, torch.dtype]] = None, seed: Optional[int] = 0, device: Optional[str] = "cuda", + flash_style: bool = False, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -280,7 +281,11 @@ def create_kv_caches_with_random( scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + if flash_style: + key_cache_shape = (num_blocks, block_size, num_heads, head_size) + else: + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, + x) key_caches = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, @@ -295,7 +300,11 @@ def create_kv_caches_with_random( f"Does not support key cache of type {cache_dtype}") key_caches.append(key_cache) - value_cache_shape = (num_blocks, num_heads, head_size, block_size) + if flash_style: + value_cache_shape = (num_blocks, block_size, num_heads, head_size) + else: + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches = [] for _ in range(num_layers): value_cache = torch.empty(size=value_cache_shape, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 880299783935..2ca359e6261d 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -60,19 +60,33 @@ def __init__( def get_key_block_shape(self) -> Tuple[int, int, int, int]: element_size = torch.tensor([], dtype=self.dtype).element_size() x = 16 // element_size - return ( - self.num_heads, - self.head_size // x, - self.block_size, - x, - ) + if self.cache_config.flash_style: + return ( + self.block_size, + self.num_heads, + self.head_size, + ) + else: + return ( + self.num_heads, + self.head_size // x, + self.block_size, + x, + ) def get_value_block_shape(self) -> Tuple[int, int, int]: - return ( - self.num_heads, - self.head_size, - self.block_size, - ) + if self.cache_config.flash_style: + return ( + self.block_size, + self.num_heads, + self.head_size, + ) + else: + return ( + self.num_heads, + self.head_size, + self.block_size, + ) def allocate_gpu_cache(self) -> List[KVCache]: gpu_cache: List[KVCache] = [] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9023b0c59b3f..def58835298b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -28,9 +28,12 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] _PAD_SLOT_ID = -1 LORA_WARMUP_RANK = 8 -# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. +_BATCH_SIZE_ALIGNMENT = 8 +# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) +] class ModelRunner: @@ -55,6 +58,9 @@ def __init__( # FIXME(woosuk): This is a hack to make the tests work. Refactor this. self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None) + self.flash_style = (self.model_config.flash_style + if model_config is not None else False) + self.device_config = (device_config if device_config is not None else DeviceConfig()) self.device = self.device_config.device @@ -110,8 +116,7 @@ def load_model(self) -> None: ), "Model does not have embedding_padding_modules" self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens + - self.scheduler_config.max_paddings, vocab_size, + self.scheduler_config.max_num_batched_tokens, vocab_size, self.lora_config, self.device, self.model.embedding_modules, self.model.embedding_padding_modules) self.model = self.lora_manager.create_lora_manager(self.model) @@ -119,10 +124,13 @@ def load_model(self) -> None: def set_block_size(self, block_size: int) -> None: self.block_size = block_size - max_num_blocks = (self.max_context_len_to_capture + block_size - - 1) // block_size self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + dtype=np.int32) + + def get_max_block_per_batch(self): + block_size = self.block_size + return (self.max_context_len_to_capture + block_size - 1) // block_size def _prepare_prompt( self, @@ -130,9 +138,9 @@ def _prepare_prompt( ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], List[int], List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() @@ -141,6 +149,8 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] + # print("SANG-TODO # of requests (seq_group_metadata_list): ", + # len(seq_group_metadata_list)) for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -151,6 +161,7 @@ def _prepare_prompt( prompt_tokens = seq_data.get_token_ids() prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) + computed_len = 0 # NOTE: This only works for oooooooxxx style attention. @@ -161,16 +172,18 @@ def _prepare_prompt( computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) + context_len = computed_len else: prefix_block_tables.append([]) + context_len = 0 # actual prompt lens - context_lens.append(computed_len) + context_lens.append(context_len) subquery_lens.append(prompt_len - computed_len) - input_tokens.append(prompt_tokens) + input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.append( + input_positions.extend( list(range(computed_len, computed_len + len(prompt_tokens)))) lora_id = seq_group_metadata.lora_int_id @@ -178,7 +191,7 @@ def _prepare_prompt( if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping.append([lora_id] * (prompt_len - computed_len)) + lora_index_mapping += [lora_id] * (prompt_len - computed_len) lora_prompt_mapping.extend( [lora_id] * (prompt_len - computed_len @@ -187,11 +200,10 @@ def _prepare_prompt( if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. - slot_mapping.append([_PAD_SLOT_ID] * prompt_len) + slot_mapping.extend([_PAD_SLOT_ID] * prompt_len) continue # Compute the slot mapping. - slot_mapping.append([]) block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, prompt_len - sliding_window). @@ -206,35 +218,35 @@ def _prepare_prompt( start_idx = max(0, prompt_len - self.sliding_window) for i in range(computed_len, prompt_len): if i < start_idx: - slot_mapping[-1].append(_PAD_SLOT_ID) + slot_mapping.append(_PAD_SLOT_ID) continue block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping[-1].append(slot) + slot_mapping.append(slot) max_prompt_len = max(subquery_lens) + num_prompt_tokens = len(input_tokens) assert max_prompt_len > 0 - input_tokens = _make_tensor_with_pad(input_tokens, - max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - input_positions = _make_tensor_with_pad(input_positions, - max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) - lora_index_mapping = [ - _pad_to_max(mapping, max_prompt_len, pad=0) - for mapping in lora_index_mapping - ] + + # Pad tokens to better utilize tensor cores although + # cuda graph is not enabled. + input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, + pad=0, + dtype=torch.long, + device=self.device) + input_positions = _make_tensor_with_pad_for_alignment( + input_positions, pad=0, dtype=torch.long, device=self.device) + slot_mapping = _make_tensor_with_pad_for_alignment(slot_mapping, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) + lora_index_mapping = _pad_to_alignment(lora_index_mapping, + _get_graph_batch_size( + len(lora_index_mapping)), + pad=0) + context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -247,19 +259,27 @@ def _prepare_prompt( dtype=torch.int, device=self.device, ) - start_loc_tensor = torch.arange(0, - len(prompt_lens) * max_prompt_len, - max_prompt_len, - dtype=torch.long, - device=self.device) prompt_lens_tensor = torch.tensor(prompt_lens, dtype=torch.long, device=self.device) + # Cumulative index of each prompt. [prompt_lens + 1] + # [0, 0+1th, 0+1th+2nd, ...] + start_loc_tensor = torch.zeros(prompt_lens_tensor.shape[0] + 1, + dtype=torch.long, + device=self.device) + + torch.cumsum(prompt_lens_tensor, + dim=0, + dtype=start_loc_tensor.dtype, + out=start_loc_tensor[1:]) + input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping, prompt_lens=prompt_lens_tensor, + num_prompt_tokens=num_prompt_tokens, + num_generation_tokens=0, max_seq_len=max_prompt_len, start_loc=start_loc_tensor, max_context_len=None, @@ -267,6 +287,7 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, + flash_style=self.flash_style, ) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, @@ -278,9 +299,9 @@ def _prepare_decode( ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] context_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] @@ -299,11 +320,11 @@ def _prepare_decode( for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) + input_tokens.append(generation_token) seq_len = seq_data.get_len() position = seq_len - 1 - input_positions.append([position]) + input_positions.append(position) context_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) @@ -313,8 +334,8 @@ def _prepare_decode( block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping.append([slot]) - lora_index_mapping.append([lora_id]) + slot_mapping.append(slot) + lora_index_mapping.append(lora_id) lora_prompt_mapping.append(lora_id) if self.sliding_window is not None: @@ -323,6 +344,9 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + # vLLM uses cuda graph only for decoding requests. + # See `capture_model` API for more details. + # For decoding requests, batch_size == input_tokens. batch_size = len(input_tokens) max_context_len = max(context_lens) use_captured_graph = ( @@ -335,32 +359,36 @@ def _prepare_decode( graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size for _ in range(graph_batch_size - batch_size): - input_tokens.append([]) - input_positions.append([]) - slot_mapping.append([]) - context_lens.append(1) + input_tokens.append(0) + input_positions.append(0) + slot_mapping.append(_PAD_SLOT_ID) + context_lens.append(0) block_tables.append([]) batch_size = graph_batch_size - input_tokens = _make_tensor_with_pad(input_tokens, - max_len=1, - pad=0, - dtype=torch.long, - device=self.device) - input_positions = _make_tensor_with_pad(input_positions, - max_len=1, - pad=0, - dtype=torch.long, - device=self.device) - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_len=1, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) + # Pad tokens to better utilize tensor cores although + # cuda graph is not enabled. + input_tokens = _make_tensor_with_pad_for_alignment(input_tokens, + pad=0, + dtype=torch.long, + device=self.device) + input_positions = _make_tensor_with_pad_for_alignment( + input_positions, pad=0, dtype=torch.long, device=self.device) + slot_mapping = _make_tensor_with_pad_for_alignment(slot_mapping, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) + if use_captured_graph: + # When using cuda-graph all these tensors should be + # padded. + assert context_lens.shape[0] == input_tokens.shape[0] + assert context_lens.shape[0] == input_positions.shape[0] + assert context_lens.shape[0] == slot_mapping.shape[0] + if use_captured_graph: # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -380,14 +408,17 @@ def _prepare_decode( device=self.device, ) - lora_index_mapping = [ - _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping - ] + lora_index_mapping = _pad_to_alignment(lora_index_mapping, + _get_graph_batch_size( + len(lora_index_mapping)), + pad=0) input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping, prompt_lens=None, + num_prompt_tokens=0, + num_generation_tokens=len(input_tokens), max_seq_len=None, start_loc=None, max_context_len=max_context_len, @@ -395,6 +426,7 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, + flash_style=self.flash_style, ) return (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) @@ -413,7 +445,6 @@ def _prepare_sample( categorized_sample_indices_start_idx = 0 pin_memory = not self.in_wsl and not self.device_config.is_neuron - max_subquery_len = max(subquery_lens) if subquery_lens else 1 for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params @@ -438,7 +469,7 @@ def _prepare_sample( selected_token_start_idx + subquery_len - 1)) selected_token_indices.append(selected_token_start_idx + subquery_len - 1) - selected_token_start_idx += max_subquery_len + selected_token_start_idx += subquery_len if sampling_params.seed is not None: seq_group_metadata.state.generator = torch.Generator( @@ -494,12 +525,15 @@ def prepare_input_tensors( # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt + # SANG-TODO set num prompt tokens and generations? # Prepare input tensors. if is_prompt: + # print("SANG-TODO execute model prompt.") (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_prompt(seq_group_metadata_list) else: + # print("SANG-TODO execute model decode.") (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_decode(seq_group_metadata_list) @@ -510,11 +544,8 @@ def prepare_input_tensors( subquery_lens) if self.lora_config: - flat_lora_index_mapping = [ - item for sublist in lora_index_mapping for item in sublist - ] lora_mapping = LoRAMapping( - flat_lora_index_mapping, + lora_index_mapping, lora_prompt_mapping, ) else: @@ -527,6 +558,8 @@ def prepare_input_tensors( "is_prompt": input_metadata.is_prompt, "slot_mapping": input_metadata.slot_mapping, "prompt_lens": input_metadata.prompt_lens, + "num_prompt_tokens": input_metadata.num_prompt_tokens, + "num_generation_tokens": input_metadata.num_generation_tokens, "max_seq_len": input_metadata.max_seq_len, "start_loc": input_metadata.start_loc, "max_context_len": input_metadata.max_context_len, @@ -550,6 +583,8 @@ def prepare_input_tensors( is_prompt=metadata_dict["is_prompt"], slot_mapping=metadata_dict["slot_mapping"], prompt_lens=metadata_dict["prompt_lens"], + num_prompt_tokens=metadata_dict["num_prompt_tokens"], + num_generation_tokens=metadata_dict["num_generation_tokens"], max_seq_len=metadata_dict["max_seq_len"], start_loc=metadata_dict["start_loc"], max_context_len=metadata_dict["max_context_len"], @@ -557,7 +592,7 @@ def prepare_input_tensors( block_tables=metadata_dict["block_tables"], use_cuda_graph=metadata_dict["use_cuda_graph"], kv_cache_dtype=metadata_dict["kv_cache_dtype"], - ) + flash_style=self.flash_style, sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -687,6 +722,18 @@ def list_loras(self) -> Set[int]: @torch.inference_mode() def capture_model(self, kv_caches: List[KVCache]) -> None: + """Cuda graph capture a model. + + Note that cuda graph performance gain is negligible if number + of batched tokens are less than 200. And since Cuda graph + requires fixed sized tensors, supporting large/variable batch + size requires high GPU memory overhead. Thus, vLLM only captures + decoding requests. Mixed batch (chunked prefill + decoding) or + prefill requests are not captured. + + this API assumes the shape is N batches of tokens flattened to 1D + tensor, where is token's seqlen is 1. + """ # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never # deleted before the CUDA graphs. self.cupy_nccl_backend = cupy_utils.get_nccl_backend() @@ -705,12 +752,11 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() - input_positions = torch.zeros(max_batch_size, 1, - dtype=torch.long).cuda() - slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() + input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() + input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() + slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) - context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + context_lens = torch.zeros(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() graph_batch_size = _get_graph_batch_size( @@ -734,6 +780,8 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: is_prompt=False, slot_mapping=slot_mapping[:batch_size], prompt_lens=None, + num_prompt_tokens=0, + num_generation_tokens=0, max_seq_len=None, start_loc=None, max_context_len=self.max_context_len_to_capture, @@ -741,7 +789,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: block_tables=block_tables[:batch_size], use_cuda_graph=True, kv_cache_dtype=self.kv_cache_dtype, - ) + flash_style=self.flash_style) if self.lora_config: lora_mapping = LoRAMapping( @@ -849,7 +897,6 @@ def forward( non_blocking=True) self.input_buffers["block_tables"].copy_(input_metadata.block_tables, non_blocking=True) - # Run the graph. self.graph.replay() @@ -869,11 +916,31 @@ def _maybe_cupy_nccl(): yield +def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]: + return x + [pad] * ((-len(x)) % multiple_of) + + def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: assert len(x) <= max_len return x + [pad] * (max_len - len(x)) +def _make_tensor_with_pad_for_alignment( + x: List[int], + pad: int, + dtype: torch.dtype, + device: Optional[Union[str, torch.device]], +) -> torch.Tensor: + """Create a tensor of a given list x with padding. + It adds paddings to align with graph batch size. See + _get_graph_batch_size for more details. + """ + batch_size = len(x) + batch_size = _get_graph_batch_size(batch_size) + padded_x = _pad_to_alignment(x, batch_size, pad) + return torch.tensor(padded_x, dtype=dtype, device=device) + + def _make_tensor_with_pad( x: List[List[int]], max_len: int, @@ -886,12 +953,18 @@ def _make_tensor_with_pad( def _get_graph_batch_size(batch_size: int) -> int: + """Returns the padded batch size given actual batch size. + + Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, + 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... + """ if batch_size <= 2: return batch_size elif batch_size <= 4: return 4 else: - return (batch_size + 7) // 8 * 8 + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) def _async_h2d( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 157e8c45836b..c49ceb92b0e3 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -20,6 +20,8 @@ from vllm.worker.model_runner import ModelRunner from vllm.lora.request import LoRARequest +MAX_INT_32 = 2**31 - 1 + class Worker: """A worker class that executes (a partition of) the model on a GPU. @@ -114,6 +116,7 @@ def profile_num_available_blocks( gpu_memory_utilization: The fraction of the total GPU memory to use. cpu_swap_space: The size of the CPU swap space in bytes. """ + # print("SANG-TODO profile_num_available_blocks") # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -142,6 +145,16 @@ def profile_num_available_blocks( self.model_runner.remove_all_loras() gc.collect() torch.cuda.empty_cache() + + # Flash attention only allows max(int32) slots. + # TODO: Remove this once we support int64 in flash-attention. + if self.model_config.flash_style: + num_gpu_blocks = min(num_gpu_blocks, + MAX_INT_32 // cache_block_size) + num_cpu_blocks = min(num_cpu_blocks, + MAX_INT_32 // cache_block_size) + # print("SANG-TODO profile_num_available_blocks done") + return num_gpu_blocks, num_cpu_blocks def init_cache_engine(self, cache_config: CacheConfig) -> None: @@ -193,6 +206,7 @@ def execute_model( blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, ) -> Optional[SamplerOutput]: + # print("SANG-TODO execute model.") if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups = len(seq_group_metadata_list)