Skip to content

Commit 92a8bbb

Browse files
wip flashinfer mla backend
Signed-off-by: Lucas Wilkinson <[email protected]> wip flash-infer Signed-off-by: Lucas Wilkinson <[email protected]> wip debugging Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent e7ef74e commit 92a8bbb

File tree

7 files changed

+808
-112
lines changed

7 files changed

+808
-112
lines changed

examples/deepseek-chat.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from vllm import LLM, SamplingParams
4+
5+
llm = LLM(
6+
model="deepseek-ai/DeepSeek-V2-Lite",
7+
trust_remote_code=True,
8+
)
9+
sampling_params = SamplingParams(temperature=0.5)
10+
11+
12+
def print_outputs(outputs):
13+
for output in outputs:
14+
prompt = output.prompt
15+
generated_text = output.outputs[0].text
16+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
17+
print("-" * 80)
18+
19+
20+
print("=" * 80)
21+
22+
# In this script, we demonstrate how to pass input to the chat method:
23+
24+
conversation = [
25+
{
26+
"role": "system",
27+
"content": "You are a helpful assistant"
28+
},
29+
{
30+
"role": "user",
31+
"content": "Hello"
32+
},
33+
{
34+
"role": "assistant",
35+
"content": "Hello! How can I assist you today?"
36+
},
37+
{
38+
"role": "user",
39+
"content": "Write an essay about the importance of higher education.",
40+
},
41+
]
42+
outputs = llm.chat(conversation,
43+
sampling_params=sampling_params,
44+
use_tqdm=False)
45+
print_outputs(outputs)
46+
47+
# You can run batch inference with llm.chat API
48+
conversation = [
49+
{
50+
"role": "system",
51+
"content": "You are a helpful assistant"
52+
},
53+
{
54+
"role": "user",
55+
"content": "Hello"
56+
},
57+
{
58+
"role": "assistant",
59+
"content": "Hello! How can I assist you today?"
60+
},
61+
{
62+
"role": "user",
63+
"content": "Write an essay about the importance of higher education.",
64+
},
65+
]
66+
conversations = [conversation for _ in range(10)]
67+
68+
# We turn on tqdm progress bar to verify it's indeed running batch inference
69+
outputs = llm.chat(messages=conversations,
70+
sampling_params=sampling_params,
71+
use_tqdm=True)
72+
print_outputs(outputs)
73+
74+
# A chat template can be optionally supplied.
75+
# If not, the model will use its default chat template.
76+
77+
# with open('template_falcon_180b.jinja', "r") as f:
78+
# chat_template = f.read()
79+
80+
# outputs = llm.chat(
81+
# conversations,
82+
# sampling_params=sampling_params,
83+
# use_tqdm=False,
84+
# chat_template=chat_template,
85+
# )

vllm/attention/backends/flashinfer.py

Lines changed: 7 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,13 @@
3232
AttentionMetadata,
3333
AttentionMetadataBuilder,
3434
AttentionState, AttentionType)
35-
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
35+
from vllm.attention.backends.utils import (PAD_SLOT_ID, PerLayerParameters,
36+
compute_slot_mapping,
3637
compute_slot_mapping_start_idx,
38+
infer_global_hyperparameters,
3739
is_block_tables_empty)
38-
from vllm.attention.layer import Attention
3940
from vllm.attention.ops.paged_attn import PagedAttention
40-
from vllm.config import VllmConfig, get_current_vllm_config
41+
from vllm.config import get_current_vllm_config
4142
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
4243
make_tensor_with_pad)
4344

@@ -106,72 +107,6 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
106107
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
107108

108109

109-
@dataclass
110-
class PerLayerParameters:
111-
"""
112-
Currently, FlashInfer backend only support models in which all layers share
113-
the same values for the following hyperparameters.
114-
"""
115-
116-
window_left: int
117-
logits_soft_cap: Optional[float]
118-
sm_scale: float
119-
120-
121-
def get_per_layer_parameters(
122-
vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]:
123-
"""
124-
Scan all attention layers and determine some hyperparameters
125-
to use during `plan`.
126-
"""
127-
128-
layers = vllm_config.compilation_config.static_forward_context
129-
per_layer_params: Dict[str, PerLayerParameters] = {}
130-
131-
for key, layer in layers.items():
132-
assert isinstance(layer, Attention)
133-
134-
impl = layer.impl
135-
assert isinstance(impl, FlashInferImpl)
136-
137-
# Infer hyperparameters from the attention layer
138-
window_size = impl.sliding_window
139-
window_left = window_size[0] if window_size is not None else -1
140-
logits_soft_cap = impl.logits_soft_cap
141-
sm_scale = impl.scale
142-
143-
per_layer_params[key] = PerLayerParameters(window_left,
144-
logits_soft_cap, sm_scale)
145-
146-
return per_layer_params
147-
148-
149-
def infer_global_hyperparameters(
150-
per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters:
151-
"""
152-
Currently, FlashInfer backend only support models in which all layers share
153-
the same values for the following hyperparameters:
154-
- `window_left`
155-
- `logits_soft_cap`
156-
- `sm_scale`
157-
158-
So this function asserts that all layers share the same values for these
159-
hyperparameters and returns the global values.
160-
"""
161-
162-
assert len(per_layer_params) > 0, "No attention layers found in the model."
163-
164-
param_sets = list(per_layer_params.values())
165-
global_params = param_sets[0]
166-
for params in param_sets:
167-
assert params == global_params, (
168-
"FlashInfer backend currently only supports models in which all "
169-
"layers share the same values for the following hyperparameters: "
170-
"`window_left`, `logits_soft_cap`, `sm_scale`.")
171-
172-
return global_params
173-
174-
175110
class FlashInferState(AttentionState):
176111

177112
def __init__(self, runner):
@@ -293,8 +228,8 @@ def graph_capture_get_metadata_for_batch(
293228
batch_size + 1,
294229
dtype=torch.int32)
295230

296-
global_params = infer_global_hyperparameters(
297-
get_per_layer_parameters(self.vllm_config))
231+
global_params = infer_global_hyperparameters(self.vllm_config,
232+
FlashInferImpl)
298233

299234
attn_metadata = self.runner.attn_backend.make_metadata(
300235
num_prefills=0,
@@ -652,7 +587,7 @@ def prepare(self):
652587
# - `logits_soft_cap`
653588
# - `sm_scale`
654589
inferred_params = infer_global_hyperparameters(
655-
get_per_layer_parameters(self.vllm_config))
590+
self.vllm_config, FlashInferImpl)
656591
self.global_hyperparameters = inferred_params
657592
self.window_left = inferred_params.window_left
658593
self.logits_soft_cap = inferred_params.logits_soft_cap

0 commit comments

Comments
 (0)