Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ void reshape_and_cache(
const std::string& kv_cache_dtype,
const float kv_scale);

void reshape_and_cache_flash(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);

// Just for unittest
void convert_fp8(
torch::Tensor& src_cache,
Expand Down
80 changes: 80 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel(
}
}

template<typename scalar_t>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride,
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_value_idx = block_idx * block_stride
+ block_offset * num_heads * head_size
+ head_idx * head_size
+ head_offset;
k_cache[tgt_value_idx] = key[src_key_idx];
v_cache[tgt_value_idx] = value[src_value_idx];
}
}
} // namespace vllm

#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
Expand Down Expand Up @@ -275,6 +310,51 @@ void reshape_and_cache(
}
}

void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype)
{
// FIXME: only support auto datatype, does not support fp8
if (kv_cache_dtype != "auto") {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = k_cache.size(1);

int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_stride = k_cache.stride(0);
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
"reshape_and_cache_flash",
[&] {
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
k_cache.data_ptr<scalar_t>(),
v_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(),
block_stride,
key_stride,
value_stride,
num_heads,
head_size,
block_size);
});
}

namespace vllm {

template<typename Tout, typename Tin>
Expand Down
4 changes: 4 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"reshape_and_cache_flash",
&reshape_and_cache_flash,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"convert_fp8",
&convert_fp8,
Expand Down
12 changes: 11 additions & 1 deletion tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import os

import pytest

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"


@pytest.mark.parametrize("model", MODELS)
Expand All @@ -23,11 +26,18 @@ def test_models(
max_tokens: int,
enforce_eager: bool,
) -> None:
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
pytest.skip("Skipping non-eager test for FlashInferBackend.")

hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model

vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager)
vllm_model = vllm_runner(model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

Expand Down
14 changes: 9 additions & 5 deletions tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MODELS = [
os.environ["TEST_DIST_MODEL"],
]
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"


@pytest.mark.skipif(torch.cuda.device_count() < 2,
Expand All @@ -33,16 +34,19 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
enforce_eager = False
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(can do in future PR): we should integrate this with #4548

if backend_by_env_var == "FLASHINFER":
enforce_eager = True

hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model

vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
)
vllm_model = vllm_runner(model,
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=enforce_eager)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

Expand Down
8 changes: 7 additions & 1 deletion tests/kernels/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import pytest

from vllm.utils import create_kv_caches_with_random
from vllm.utils import (create_kv_caches_with_random,
create_kv_caches_with_random_flash)


@pytest.fixture()
def kv_cache_factory():
return create_kv_caches_with_random


@pytest.fixture()
def kv_cache_factory_flashinfer():
return create_kv_caches_with_random_flash
77 changes: 77 additions & 0 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from vllm import _custom_ops as ops
from vllm._C import cache_ops
from vllm.utils import is_hip

COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
Expand Down Expand Up @@ -191,6 +192,82 @@ def test_reshape_and_cache(
assert torch.allclose(value_cache, cloned_value_cache)


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_reshape_and_cache_flash(
kv_cache_factory_flashinfer,
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
kv_cache_dtype: str,
) -> None:
if kv_cache_dtype == "fp8":
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda')

qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device=device)
_, key, value = qkv.unbind(dim=1)

# Create the KV caches.
key_caches, value_caches = kv_cache_factory_flashinfer(
num_blocks,
block_size,
1,
num_heads,
head_size,
kv_cache_dtype,
dtype,
)
key_cache, value_cache = key_caches[0], value_caches[0]

# Clone the KV caches.
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()

# Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)

# Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
block_indicies = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist()
for i in range(num_tokens):
block_idx = block_indicies[i]
block_offset = block_offsets[i]
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]

assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)


@pytest.mark.parametrize("direction", COPYING_DIRECTION)
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
Expand Down
12 changes: 12 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,18 @@ def reshape_and_cache(
slot_mapping, kv_cache_dtype, kv_scale)


def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
) -> None:
vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)


def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
block_mapping: torch.Tensor) -> None:
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
Expand Down
13 changes: 9 additions & 4 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
TypeVar)

import torch

Expand All @@ -15,7 +16,7 @@ def get_impl_cls() -> Type["AttentionImpl"]:

@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage":
raise NotImplementedError

@staticmethod
Expand Down Expand Up @@ -50,13 +51,17 @@ def copy_blocks(
class AttentionMetadataPerStage:
"""Attention metadata for a specific stage. I.e., prefill or decode."""

def asdict_zerocopy(self) -> Dict[str, Any]:
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if skip_fields is None:
skip_fields = set()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self)
for field in fields(self) if field.name not in skip_fields
}


Expand Down
Loading