Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
06fe872
[1/n] Support efficient reshape caching.
rkooo567 Feb 28, 2024
9a0b6be
[2/n] support flash attention kernel
rkooo567 Feb 28, 2024
6947167
oss flash attention works
rkooo567 Feb 28, 2024
4769a26
in progress
rkooo567 Feb 28, 2024
963db44
flash attn enabled.
rkooo567 Feb 29, 2024
2b9c36b
ip
rkooo567 Feb 29, 2024
2c1bb6c
support every model
rkooo567 Feb 29, 2024
2bb5e62
Fixed broken tests.
rkooo567 Feb 29, 2024
78bb887
ip
rkooo567 Feb 29, 2024
74ac900
seems to work.
rkooo567 Mar 1, 2024
71bdada
.
rkooo567 Mar 1, 2024
d4c3b5d
ip?
rkooo567 Mar 1, 2024
baef7c6
block tables updated correctly
rkooo567 Mar 1, 2024
a12ec68
hopefully tests pass
rkooo567 Mar 1, 2024
0d8785f
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 3, 2024
08c8541
.
rkooo567 Mar 3, 2024
3bac9af
ip
rkooo567 Mar 3, 2024
31aa920
ip
rkooo567 Mar 4, 2024
2049b35
.
rkooo567 Mar 4, 2024
ef679d7
.
rkooo567 Mar 4, 2024
71bda97
.
rkooo567 Mar 4, 2024
4e00e7f
done?
rkooo567 Mar 4, 2024
7fd70f2
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 5, 2024
9177d54
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 6, 2024
c0384a4
Refactor 2d query to 1d query
rkooo567 Mar 6, 2024
6032edf
.,
rkooo567 Mar 6, 2024
c1ab0b0
done
rkooo567 Mar 6, 2024
f48dc72
Addressed code review.
rkooo567 Mar 7, 2024
769b2b4
working
rkooo567 Mar 7, 2024
4a20f4a
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f7347b8
working
rkooo567 Mar 7, 2024
d931725
Merge branch 'main' into 1dquery
rkooo567 Mar 7, 2024
f91d73e
fix lora
rkooo567 Mar 8, 2024
f7d79da
fixed
rkooo567 Mar 8, 2024
851c018
Merge branch 'main' into 1dquery
rkooo567 Mar 8, 2024
406f1d4
fix
rkooo567 Mar 8, 2024
9442e8f
Merge branch 'main' into chunked-prefill-3
rkooo567 Mar 8, 2024
3da31eb
Merge branch '1dquery' into chunked-prefill-3
rkooo567 Mar 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ steps:
commands:
- pytest -v -s prefix_caching

- label: Chunked Prefill Test
commands:
- pytest -v -s chunked_prefill

- label: Samplers Test
command: pytest -v -s samplers --forked

Expand Down
43 changes: 40 additions & 3 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@

from vllm import LLM, SamplingParams

SAMPLE_PROMPTS = [
"The president of the United States is",
"Hello, my name is",
"The capital of France is",
"The future of AI is",
]


def main(args: argparse.Namespace):
print(args)
Expand All @@ -26,6 +33,8 @@ def main(args: argparse.Namespace):
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
device=args.device,
block_size=args.block_size,
flash_style=args.flash_style,
ray_workers_use_nsight=args.ray_workers_use_nsight,
)

Expand Down Expand Up @@ -58,10 +67,25 @@ def run_to_completion(profile_dir: Optional[str] = None):
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
if args.use_sample:
batch = (SAMPLE_PROMPTS *
(args.batch_size // len(SAMPLE_PROMPTS) +
1))[:args.batch_size]
outputs = llm.generate(prompts=batch,
sampling_params=sampling_params,
use_tqdm=False)
else:
outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
if args.verbose:
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"Prompt: {prompt!r}, Generated text: {generated_text!r}"
)
latency = end_time - start_time
return latency

Expand Down Expand Up @@ -146,6 +170,19 @@ def run_to_completion(profile_dir: Optional[str] = None):
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument('--flash-style',
action='store_true',
help='enable flash attention')
parser.add_argument('--block-size',
type=int,
default=16,
help='block size of key/value cache')
parser.add_argument('--use-sample',
action='store_true',
help='use sample input instead of dummy input')
parser.add_argument('--verbose',
action='store_true',
help='print generated text')
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',
Expand Down
37 changes: 27 additions & 10 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
from vllm._C import ops
from vllm.model_executor.layers.attention import flash_attn_with_kvcache_paged

NUM_BLOCKS = 1024
PARTITION_SIZE = 512
Expand Down Expand Up @@ -64,14 +65,17 @@ def main(
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)

# Create the KV cache.
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
block_size,
1,
num_kv_heads,
head_size,
kv_cache_dtype,
dtype,
device=device)
flash_style = version == "flash"
key_caches, value_caches = create_kv_caches_with_random(
NUM_BLOCKS,
block_size,
1,
num_kv_heads,
head_size,
kv_cache_dtype,
dtype,
device=device,
flash_style=flash_style)
key_cache, value_cache = key_caches[0], value_caches[0]

# Prepare for the paged attention kernel.
Expand Down Expand Up @@ -131,6 +135,16 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
alibi_slopes,
kv_cache_dtype,
)
elif version == "flash":
flash_attn_with_kvcache_paged(
query.view(num_seqs, 1, num_query_heads, head_size),
key_cache,
value_cache,
scale,
block_tables,
context_lens,
alibi_slopes=alibi_slopes,
)
else:
raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize()
Expand Down Expand Up @@ -158,7 +172,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
description="Benchmark the paged attention kernel.")
parser.add_argument("--version",
type=str,
choices=["v1", "v2"],
choices=["v1", "v2", "flash"],
default="v2")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--context-len", type=int, default=4096)
Expand All @@ -168,7 +182,10 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
type=int,
choices=[64, 80, 96, 112, 128, 256],
default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--block-size",
type=int,
choices=[16, 32, 256],
default=16)
parser.add_argument("--use-alibi", action="store_true")
parser.add_argument("--dtype",
type=str,
Expand Down
7 changes: 7 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ void reshape_and_cache(
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);

void reshape_and_cache_flash(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping);

// Just for unittest
void convert_fp8_e5m2(
torch::Tensor& src_cache,
Expand Down
87 changes: 87 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ __global__ void reshape_and_cache_kernel(
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;


const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
Expand Down Expand Up @@ -269,6 +270,92 @@ void reshape_and_cache(

namespace vllm {

// flash-attention style cache funciton where key/value caches
// has the same shape of [num_blocks, block_size, num_heads, head_size]
template<typename scalar_t>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, head_size]
scalar_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;

const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;

const int head_idx = i / head_size;
const int head_offset = i % head_size;

// Target idx for kv are the same.
const int64_t tgt_idx = block_idx * block_size * num_heads * head_size
+ block_offset * num_heads * head_size
+ head_idx * head_size
+ head_offset;

scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
// TODO(sang): Support ENABLE_FP8_E5M2.
key_cache[tgt_idx] = tgt_key;
value_cache[tgt_idx] = tgt_value;
}
}

} // namespace vllm

// TODO(sang): Support kv_cache_dtype
void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping)
{
int num_tokens_padded = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(1);

int key_stride = key.stride(0);
int value_stride = value.stride(0);

dim3 grid(num_tokens_padded);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
"reshape_and_cache_flash_kernel",
[&] {
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(),
key_stride,
value_stride,
num_heads,
head_size,
block_size);
});
}

namespace vllm {

template<typename Tout, typename Tin>
__global__ void convert_fp8_e5m2_kernel(
const Tin* __restrict__ src_cache,
Expand Down
4 changes: 4 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"reshape_and_cache_flash",
&reshape_and_cache_flash,
"Reshape the flash-style key and value tensors and cache them.");
cache_ops.def(
"convert_fp8_e5m2",
&convert_fp8_e5m2,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ pynvml == 11.5.0
triton >= 2.1.0
outlines >= 0.0.27
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
flash-attn >= 2.5.0 # Required for chunked prefill.
86 changes: 86 additions & 0 deletions tests/chunked_prefill/test_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import gc

import pytest
import torch

from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel

MODELS = [
"JackFram/llama-68m",
# "facebook/opt-125m",
]

TEST_PROMPTS = [
# pylint: disable=line-too-long
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
# Different between page attention and flash attention.
# "Describe the basic components of a neural network and how it can be trained.",
"Write a short story about a robot that dreams for the first time.",
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
]


# TODO(sang): Add chunked prefill parameters.
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("block_size", [32])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_models(
vllm_runner,
model: str,
dtype: str,
max_tokens: int,
block_size: int,
tensor_parallel_size: int,
) -> None:
""" verify the flash attention has the same output
as page attention """
if torch.cuda.device_count() < tensor_parallel_size:
pytest.skip(
f"{torch.cuda.device_count()=} is smaller than {tensor_parallel_size=}"
)
print("loading page attention models..")
pg_model = vllm_runner(model, dtype=dtype)
expected_outputs = []

print("generating tokens...")
expected_outputs.extend(pg_model.generate_greedy(TEST_PROMPTS, max_tokens))
print("generating tokens finished")

del pg_model

destroy_model_parallel()
gc.collect()
torch.cuda.empty_cache()

flash_attn_output_by_batches = []
flash_attn_model = vllm_runner(model,
dtype=dtype,
block_size=block_size,
flash_style=True,
tensor_parallel_size=tensor_parallel_size)
for i in range(10):
prompts = [TEST_PROMPTS[j % len(TEST_PROMPTS)] for j in range(i)]
flash_attn_output_by_batches.append(
flash_attn_model.generate_greedy(prompts, max_tokens))

del flash_attn_model
destroy_model_parallel()
gc.collect()
torch.cuda.empty_cache()

for flash_attn_outputs in flash_attn_output_by_batches:
for i in range(len(flash_attn_outputs)):
fa_output_ids, fa_output_str = flash_attn_outputs[i]
vllm_output_ids, vllm_output_str = expected_outputs[
i % len(expected_outputs)]
print(vllm_output_str)
assert fa_output_ids == vllm_output_ids, (
f"Test{i}:\flash ids: {fa_output_ids}\nvLLM ids: {vllm_output_ids}"
f"Test{i}:\nflash output: {fa_output_str!r}\nvLLM output: {vllm_output_str!r}"
)
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def __init__(
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 32,
flash_style: bool = False,
**kwargs,
) -> None:
self.model = LLM(
Expand All @@ -175,6 +177,8 @@ def __init__(
swap_space=0,
disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size,
flash_style=flash_style,
block_size=block_size,
**kwargs,
)

Expand Down
Loading