Skip to content

Commit 622f8ab

Browse files
authored
[Bugfix] bugfix and add model test for flashinfer fp8 kv cache. (#8013)
1 parent 1248e85 commit 622f8ab

File tree

2 files changed

+109
-5
lines changed

2 files changed

+109
-5
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# flake8: noqa
2+
"""Tests fp8 models against ground truth generation
3+
This verifies the flashinfer backend with fp8
4+
quantization and fp8 KV Cache without scaling
5+
factors Note: these tests will only pass on H100 GPU.
6+
"""
7+
import os
8+
from typing import List
9+
10+
import pytest
11+
from transformers import AutoTokenizer
12+
13+
from tests.quantization.utils import is_quant_method_supported
14+
from vllm import LLM, SamplingParams
15+
16+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
17+
18+
MAX_MODEL_LEN = 1024
19+
20+
MODELS = [
21+
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
22+
]
23+
24+
EXPECTED_STRS_MAP = {
25+
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": {
26+
"auto": [
27+
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
28+
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
29+
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
30+
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
31+
'In the sterile, metallic halls of the robotics lab, a peculiar phenomenon occurred. Zeta-5',
32+
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
33+
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
34+
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, mushi o',
35+
],
36+
"fp8": [
37+
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
38+
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
39+
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
40+
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
41+
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
42+
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
43+
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
44+
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
45+
]
46+
}
47+
}
48+
49+
50+
# This test compares against golden strings for exact match since
51+
# there is no baseline implementation to compare against
52+
# and is unstable w.r.t specifics of the fp8 implementation or
53+
# the hardware being run on.
54+
# No assert to prevent it from breaking the build
55+
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
56+
reason="fp8 is not supported on this GPU type.")
57+
@pytest.mark.parametrize("model_name", MODELS)
58+
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
59+
@pytest.mark.parametrize("backend", ["XFORMERS", "FLASHINFER"])
60+
def test_models(example_prompts, model_name, kv_cache_dtype, backend) -> None:
61+
# Note that the golden strings may not work for FLASHINFER Backend.
62+
# The intention is to test the path
63+
os.environ["VLLM_ATTENTION_BACKEND"] = backend
64+
model = LLM(model=model_name,
65+
max_model_len=MAX_MODEL_LEN,
66+
trust_remote_code=True,
67+
quantization="fp8",
68+
kv_cache_dtype=kv_cache_dtype)
69+
70+
tokenizer = AutoTokenizer.from_pretrained(model_name)
71+
formatted_prompts = [
72+
tokenizer.apply_chat_template([{
73+
"role": "user",
74+
"content": prompt
75+
}],
76+
tokenize=False,
77+
add_generation_prompt=True)
78+
for prompt in example_prompts
79+
]
80+
81+
params = SamplingParams(max_tokens=20, temperature=0)
82+
generations: List[str] = []
83+
# Note: these need to be run 1 at a time due to numerical precision,
84+
# since the expected strs were generated this way.
85+
for prompt in formatted_prompts:
86+
outputs = model.generate(prompt, params)
87+
generations.append(outputs[0].outputs[0].text)
88+
del model
89+
90+
print(f"Testing: {model_name} with kv_cache_dtype: {kv_cache_dtype}")
91+
expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
92+
for i in range(len(example_prompts)):
93+
generated_str = generations[i]
94+
expected_str = expected_strs[i]
95+
print(f"generated_str\n: {generated_str}")
96+
print(f"expected_str\n: {expected_str}")

vllm/attention/backends/flashinfer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,13 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
186186
self._graph_decode_workspace_buffer, _indptr_buffer,
187187
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
188188
use_tensor_cores)
189+
if self.runner.kv_cache_dtype.startswith("fp8"):
190+
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
191+
self.runner.kv_cache_dtype)
192+
else:
193+
kv_cache_dtype = get_kv_cache_torch_dtype(
194+
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
189195

190-
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
191-
self.runner.kv_cache_dtype)
192196
paged_kv_indptr_tensor_host = torch.arange(0,
193197
batch_size + 1,
194198
dtype=torch.int32)
@@ -349,7 +353,7 @@ def begin_forward(self):
349353
self.page_size,
350354
# Disable flashinfer's pos encoding and use vllm's rope.
351355
pos_encoding_mode="NONE",
352-
)
356+
data_type=self.data_type)
353357

354358
def asdict_zerocopy(self,
355359
skip_fields: Optional[Set[str]] = None
@@ -586,8 +590,12 @@ def build(self, seq_lens: List[int], query_lens: List[int],
586590
paged_kv_indptr_tensor = None
587591
paged_kv_last_page_len_tensor = None
588592

589-
kv_cache_dtype = get_kv_cache_torch_dtype(
590-
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
593+
if self.runner.kv_cache_dtype.startswith("fp8"):
594+
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
595+
self.runner.kv_cache_dtype)
596+
else:
597+
kv_cache_dtype = get_kv_cache_torch_dtype(
598+
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
591599

592600
return FlashInferMetadata(
593601
num_prefills=self.num_prefills,

0 commit comments

Comments
 (0)