Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions csrc/cpu/cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

#include "cpu_types.hpp"

#if defined(__x86_64__)
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
#else
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
#endif

namespace {
template <typename scalar_t>
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
Expand Down Expand Up @@ -95,13 +101,12 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
}

const int element_num_per_block = key_caches[0][0].numel();
VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
});
DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
});
}

void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
Expand All @@ -118,16 +123,15 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
int key_stride = key.stride(0);
int value_stride = value.stride(0);

VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
reshape_and_cache_cpu_impl<scalar_t>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
value_stride, num_heads, head_size, block_size, x);
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
});
DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
reshape_and_cache_cpu_impl<scalar_t>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
num_heads, head_size, block_size, x);
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
});
}

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
Expand Down
9 changes: 9 additions & 0 deletions csrc/cpu/cpu_types_x86.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,18 @@ namespace vec_op {
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)

#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)

#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__))

#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/getting_started/installation/cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ vLLM CPU backend supports the following vLLM features:
- Model Quantization (`INT8 W8A8, AWQ, GPTQ`)
- Chunked-prefill
- Prefix-caching
- FP8-E5M2 KV-Caching (TODO)
- FP8-E5M2 KV cache

## Related runtime environment variables

Expand Down
4 changes: 2 additions & 2 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def test_with_prefix_caching(


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("dtype", ["bfloat16", "half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False])
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_models_cpu(
@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("chunk_size", [30, 32])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("dtype", ["bfloat16", "half"])
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
def test_with_prefix_caching_cpu(
Expand Down
61 changes: 61 additions & 0 deletions tests/models/decoder_only/language/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from tests.kernels.utils import override_backend_env_variable
from tests.quantization.utils import is_quant_method_supported
from vllm.platforms import current_platform

from ...utils import check_logprobs_close

Expand Down Expand Up @@ -93,3 +94,63 @@ def test_models(
name_0="fp16_kv_cache",
name_1="fp8_kv_cache",
)


@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(),
reason="test for the CPU backend.")
@pytest.mark.parametrize(
"kv_cache_dtype,base_model,test_model",
[
# Test BF16 checkpoint w. fp8_e5m2 kv-cache.
("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct"),
])
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
@pytest.mark.parametrize("max_tokens", [4])
# Due to low-precision numerical divergence, this test is too sensitive for
# the async postprocessor
@pytest.mark.parametrize("disable_async_output_proc", [True])
def test_cpu_models(
vllm_runner,
example_prompts,
kv_cache_dtype: str,
base_model: str,
test_model: str,
max_tokens: int,
disable_async_output_proc: bool,
) -> None:
"""
Only checks log probs match to cover the discrepancy in
numerical sensitive kernels.
"""

MAX_MODEL_LEN = 1024
NUM_LOG_PROBS = 8

with vllm_runner(
base_model,
max_model_len=MAX_MODEL_LEN,
dtype="bfloat16",
kv_cache_dtype="auto",
disable_async_output_proc=disable_async_output_proc,
) as vllm_model:
baseline_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS)

with vllm_runner(
test_model,
max_model_len=MAX_MODEL_LEN,
dtype="bfloat16",
kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
) as vllm_model:
test_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS)

check_logprobs_close(
outputs_0_lst=baseline_outputs,
outputs_1_lst=test_outputs,
name_0="bf16_kv_cache",
name_1="fp8_kv_cache",
)
9 changes: 5 additions & 4 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
is_quantized_kv_cache)
# yapf: enable
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.logger import init_logger
from vllm.utils import make_tensor_with_pad
Expand Down Expand Up @@ -431,10 +431,11 @@ def __init__(
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
if is_quantized_kv_cache(kv_cache_dtype):

if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex:
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
"Torch SDPA backend FP8 KV cache requires "
"intel_extension_for_pytorch support.")
self.attn_type = attn_type

def forward(
Expand Down
23 changes: 12 additions & 11 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,25 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
if not model_config.enforce_eager:
logger.warning(
"CUDA graph is not supported on CPU, fallback to the eager "
"mode.")
model_config.enforce_eager = True

cache_config = vllm_config.cache_config

if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

if cache_config.cache_dtype == "fp8_e4m3":
cache_config.cache_dtype = "fp8_e5m2"
Comment on lines +77 to +78
Copy link
Member

@Isotr0py Isotr0py Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that chunked_prefill/prefix_caching doesn't work with fp8 kv_cache, perhaps add a check with better error message here?

[rank0]:   File "/data/develop-projects/github-repos/vllm-cpu/vllm/attention/backends/torch_sdpa.py", line 544, in forward
[rank0]:     ipex_modules.PagedAttention.flash_attn_varlen_func(
[rank0]:   File "/data/develop-projects/github-repos/vllm-cpu/.venv/lib/python3.12/site-packages/intel_extension_for_pytorch/llm/modules/mha_fusion.py", line 622, in flash_attn_varlen_func
[rank0]:     ).flash_attn_varlen_func(
[rank0]:       ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/develop-projects/github-repos/vllm-cpu/.venv/lib/python3.12/site-packages/intel_extension_for_pytorch/transformers/models/cpu/fusions/mha_fusion.py", line 415, in flash_attn_varlen_func
[rank0]:     torch.ops.torch_ipex.flash_attn_varlen_func(
[rank0]:   File "/data/develop-projects/github-repos/vllm-cpu/.venv/lib/python3.12/site-packages/torch/_ops.py", line 1123, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: expected scalar type BFloat16 but found Float8_e5m2
Processed prompts:   0%|                                                                                         | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

logger.warning(
"CPU backend doesn't support fp8_e4m3 KV cache type, "
"cast to fp8_e5m2.")

if (cache_config.cache_dtype != "auto"
and model_config.dtype == torch.half):
logger.warning("FP8 KV cache on the CPU backend only does not"
" support fp16 for now, cast to bf16.")
model_config.dtype = torch.bfloat16

kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE

if kv_cache_space >= 0:
Expand All @@ -85,14 +94,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
f" {kv_cache_space}, expect a positive integer value.")

scheduler_config = vllm_config.scheduler_config
if ((scheduler_config.chunked_prefill_enabled
or cache_config.enable_prefix_caching)
and model_config.dtype == torch.half):
logger.warning("Chunked-prefill on the CPU backend only does not"
" support fp16 for now, cast to bf16.")
model_config.dtype = torch.bfloat16

parallel_config = vllm_config.parallel_config
if (parallel_config.distributed_executor_backend is not None
and parallel_config.distributed_executor_backend != "mp"):
Expand Down
5 changes: 4 additions & 1 deletion vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,

if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
self.dtype = torch.float8_e5m2
else:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
raise NotImplementedError(f"Unsupported KV cache type "
f"{cache_config.cache_dtype}.")

# Get attention backend.
self.attn_backend = get_attn_backend(
Expand Down