Skip to content

Commit 1df44c3

Browse files
wip flash-infer
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent a6c0438 commit 1df44c3

File tree

2 files changed

+100
-4
lines changed

2 files changed

+100
-4
lines changed

examples/deepseek-chat.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from vllm import LLM, SamplingParams
4+
5+
llm = LLM(
6+
model="deepseek-ai/DeepSeek-V2-Lite",
7+
trust_remote_code=True,
8+
)
9+
sampling_params = SamplingParams(temperature=0.5)
10+
11+
12+
def print_outputs(outputs):
13+
for output in outputs:
14+
prompt = output.prompt
15+
generated_text = output.outputs[0].text
16+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
17+
print("-" * 80)
18+
19+
20+
print("=" * 80)
21+
22+
# In this script, we demonstrate how to pass input to the chat method:
23+
24+
conversation = [
25+
{
26+
"role": "system",
27+
"content": "You are a helpful assistant"
28+
},
29+
{
30+
"role": "user",
31+
"content": "Hello"
32+
},
33+
{
34+
"role": "assistant",
35+
"content": "Hello! How can I assist you today?"
36+
},
37+
{
38+
"role": "user",
39+
"content": "Write an essay about the importance of higher education.",
40+
},
41+
]
42+
outputs = llm.chat(conversation,
43+
sampling_params=sampling_params,
44+
use_tqdm=False)
45+
print_outputs(outputs)
46+
47+
# You can run batch inference with llm.chat API
48+
conversation = [
49+
{
50+
"role": "system",
51+
"content": "You are a helpful assistant"
52+
},
53+
{
54+
"role": "user",
55+
"content": "Hello"
56+
},
57+
{
58+
"role": "assistant",
59+
"content": "Hello! How can I assist you today?"
60+
},
61+
{
62+
"role": "user",
63+
"content": "Write an essay about the importance of higher education.",
64+
},
65+
]
66+
conversations = [conversation for _ in range(10)]
67+
68+
# We turn on tqdm progress bar to verify it's indeed running batch inference
69+
outputs = llm.chat(messages=conversations,
70+
sampling_params=sampling_params,
71+
use_tqdm=True)
72+
print_outputs(outputs)
73+
74+
# A chat template can be optionally supplied.
75+
# If not, the model will use its default chat template.
76+
77+
# with open('template_falcon_180b.jinja', "r") as f:
78+
# chat_template = f.read()
79+
80+
# outputs = llm.chat(
81+
# conversations,
82+
# sampling_params=sampling_params,
83+
# use_tqdm=False,
84+
# chat_template=chat_template,
85+
# )

vllm/attention/backends/flashinfer_mla.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import copy
34
from contextlib import contextmanager
45
from dataclasses import asdict, dataclass
56
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type
@@ -125,7 +126,7 @@ def graph_clone(self, batch_size: int):
125126
assert self._is_graph_capturing
126127
state = self.__class__(self.runner)
127128
state._workspace_buffer = self._graph_decode_workspace_buffer
128-
state._decode_wrapper = self._graph_decode_wrapper
129+
state._decode_wrapper = copy.copy(self._graph_decode_wrapper)
129130
return state
130131

131132
def graph_capture_get_metadata_for_batch(
@@ -197,10 +198,12 @@ def begin_forward(self, model_input):
197198
# In case of multistep chunked-prefill, there might be prefill requests
198199
# scheduled while CUDA graph mode is enabled. We don't run graph in that
199200
# case.
201+
print("begin_forward", model_input.input_tokens.shape[0])
200202
if use_cuda_graph and is_decode:
201203
batch_size = model_input.input_tokens.shape[0]
202204
state = (self.runner.graph_runners[model_input.virtual_engine]
203205
[batch_size].attn_state)
206+
print("choosing decode_wrapper", batch_size)
204207
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
205208
model_input.attn_metadata.begin_forward()
206209

@@ -421,9 +424,17 @@ def build(self, seq_lens: List[int], query_lens: List[int],
421424
self.paged_kv_indptr.extend([self.paged_kv_indptr[-1]] *
422425
cuda_graph_pad_size)
423426
self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
424-
query_start_loc_host = torch.functional.F.pad(
425-
query_start_loc_host, (cuda_graph_pad_size + 1, ),
426-
value=query_start_loc_host[-1].item())
427+
428+
print(cuda_graph_pad_size + 1 - query_start_loc_host.shape[0],
429+
cuda_graph_pad_size + 1, query_start_loc_host.shape[0])
430+
if cuda_graph_pad_size + 1 > query_start_loc_host.shape[0]:
431+
query_start_loc_host = torch.cat(
432+
(query_start_loc_host,
433+
torch.full((cuda_graph_pad_size + 1 -
434+
query_start_loc_host.shape[0], ),
435+
fill_value=query_start_loc_host[-1].item(),
436+
dtype=torch.int32,
437+
device="cpu")))
427438

428439
if len(self.paged_kv_indptr) > 0:
429440
# extend to the maximum number of blocks as returned by the

0 commit comments

Comments
 (0)