Skip to content

Commit 5053bc7

Browse files
tomeras91jimpang
authored andcommitted
[Bugfix] Mamba cache Cuda Graph padding (vllm-project#6214)
1 parent 91b5255 commit 5053bc7

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

tests/models/test_jamba.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22

3+
from vllm.worker.model_runner import _get_graph_batch_size
4+
35
MODELS = ["ai21labs/Jamba-tiny-random"]
46

57

@@ -32,6 +34,32 @@ def test_models(
3234
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
3335

3436

37+
@pytest.mark.parametrize("model", MODELS)
38+
@pytest.mark.parametrize("dtype", ["bfloat16"])
39+
@pytest.mark.parametrize("max_tokens", [20])
40+
def test_mamba_cache_cg_padding(
41+
vllm_runner,
42+
example_prompts,
43+
model: str,
44+
dtype: str,
45+
max_tokens: int,
46+
) -> None:
47+
# This test is for verifying that mamba cache is padded to CG captured
48+
# batch size. If it's not, a torch RuntimeError will be raised because
49+
# tensor dimensions aren't compatible
50+
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
51+
example_prompts.append(example_prompts[0])
52+
53+
try:
54+
with vllm_runner(model, dtype=dtype) as vllm_model:
55+
vllm_model.generate_greedy(example_prompts, max_tokens)
56+
except RuntimeError:
57+
pytest.fail(
58+
"Couldn't run batch size which is not equal to a Cuda Graph "
59+
"captured batch size. "
60+
"Could be related to mamba cache not padded correctly")
61+
62+
3563
@pytest.mark.parametrize("model", MODELS)
3664
@pytest.mark.parametrize("dtype", ["float"])
3765
def test_state_cleanup(

vllm/model_executor/models/jamba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -788,12 +788,12 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
788788
key in kwargs
789789
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
790790
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
791-
batch_size = len(request_ids_to_seq_ids)
791+
cg_batch_size = input_buffers['input_ids'].shape[0]
792792
(
793793
current_mamba_cache,
794794
indices,
795795
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
796-
batch_size)
796+
cg_batch_size)
797797
self.current_indices = indices
798798
finished_requests_ids = kwargs["finished_requests_ids"]
799799
self._release_mamba_cache(finished_requests_ids)

0 commit comments

Comments
 (0)