Skip to content

Commit bb8c8f7

Browse files
SolitaryThinkerLeiWang1999
authored andcommitted
[Bugfix] fix flashinfer cudagraph capture for PP (vllm-project#6708)
Signed-off-by: LeiWang1999 <[email protected]>
1 parent c1523f3 commit bb8c8f7

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

tests/distributed/test_pipeline_parallel.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,27 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
6161
tp_args.append("--enforce-eager")
6262

6363
compare_two_settings(MODEL_NAME, pp_args, tp_args)
64+
65+
66+
@pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [
67+
(2, "JackFram/llama-160m"),
68+
])
69+
@pytest.mark.parametrize("ATTN_BACKEND", [
70+
"FLASH_ATTN",
71+
"FLASHINFER",
72+
])
73+
def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND):
74+
cudagraph_args = [
75+
# use half precision for speed and memory savings in CI environment
76+
"--dtype",
77+
"float16",
78+
"--pipeline-parallel-size",
79+
str(PP_SIZE),
80+
"--distributed-executor-backend",
81+
"ray",
82+
]
83+
os.environ["VLLM_ATTENTION_BACKEND"] = ATTN_BACKEND
84+
85+
eager_args = cudagraph_args + ["--enforce-eager"]
86+
87+
compare_two_settings(MODEL_NAME, eager_args, cudagraph_args)

vllm/worker/model_runner.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,9 +1040,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
10401040
self.parallel_config.pipeline_parallel_size):
10411041
for batch_size in reversed(batch_size_capture_list):
10421042
if self.attn_backend.get_name() == "flashinfer":
1043-
indptr_buffer = indptr_buffer[:batch_size + 1]
1044-
last_page_len_buffer = last_page_len_buffer[:
1045-
batch_size]
1043+
_indptr_buffer = indptr_buffer[:batch_size + 1]
1044+
_last_page_len_buffer = last_page_len_buffer[:
1045+
batch_size]
10461046

10471047
num_qo_heads = (
10481048
self.model_config.get_num_attention_heads(
@@ -1055,8 +1055,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
10551055
use_tensor_cores = False
10561056
decode_wrapper = \
10571057
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
1058-
decode_workspace_buffer, indptr_buffer,
1059-
indices_buffer, last_page_len_buffer, "NHD",
1058+
decode_workspace_buffer, _indptr_buffer,
1059+
indices_buffer, _last_page_len_buffer, "NHD",
10601060
use_tensor_cores)
10611061
kv_cache_dtype = get_kv_cache_torch_dtype(
10621062
self.kv_cache_dtype, self.model_config.dtype)
@@ -1131,10 +1131,10 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
11311131
self.model, self.attn_backend.get_name())
11321132

11331133
if self.attn_backend.get_name() == "flashinfer":
1134-
graph_runner.flashinfer_indptr_buffer = indptr_buffer
1134+
graph_runner.flashinfer_indptr_buffer = _indptr_buffer
11351135
graph_runner.flashinfer_indices_buffer = indices_buffer
11361136
graph_runner.flashinfer_last_page_len_buffer = \
1137-
last_page_len_buffer
1137+
_last_page_len_buffer
11381138
graph_runner.flashinfer_decode_workspace_buffer = \
11391139
decode_workspace_buffer
11401140
graph_runner.flashinfer_decode_wrapper = \

0 commit comments

Comments
 (0)