Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions examples/offline_inference_audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ def run_ultravox(question: str, audio_count: int):
tokenize=False,
add_generation_prompt=True)

llm = LLM(model=model_name,
enforce_eager=True,
enable_chunked_prefill=False,
max_model_len=8192,
limit_mm_per_prompt={"audio": audio_count})
llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand Down
2 changes: 2 additions & 0 deletions tests/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ def make_test_metadata(
return attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
Expand Down Expand Up @@ -914,6 +915,7 @@ def make_test_metadata(
return attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
Expand Down
91 changes: 74 additions & 17 deletions tests/models/decoder_only/audio_language/test_ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import numpy as np
import pytest
import pytest_asyncio
from transformers import AutoModel, AutoTokenizer, BatchEncoding

from tests.utils import RemoteOpenAIServer
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

Expand All @@ -17,6 +19,13 @@
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
HF_PLACEHOLDER = "<|audio|>"

CHUNKED_PREFILL_KWARGS = {
"enable_chunked_prefill": True,
"max_num_seqs": 2,
# Use a very small limit to exercise chunked prefill.
"max_num_batched_tokens": 16
}


@pytest.fixture(scope="session")
def audio_assets():
Expand All @@ -30,6 +39,26 @@ def audio(request):
return AudioAsset(request.param)


@pytest.fixture(params=({}, CHUNKED_PREFILL_KWARGS))
def server(request, audio_assets):
args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
] + [
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client


def _get_prompt(audio_count, question, placeholder):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
placeholder = f"{placeholder}\n" * audio_count
Expand Down Expand Up @@ -68,8 +97,7 @@ def run_test(
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
**kwargs,
):
"""Inference result should be the same between hf and vllm."""
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
Expand All @@ -79,11 +107,8 @@ def run_test(
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
with vllm_runner(model, dtype=dtype, enforce_eager=True,
**kwargs) as vllm_model:
vllm_outputs_per_audio = [
vllm_model.generate_greedy_logprobs([vllm_prompt],
max_tokens,
Expand Down Expand Up @@ -135,18 +160,16 @@ def run_multi_audio_test(
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
**kwargs,
):
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
limit_mm_per_prompt={
"audio":
max((len(audio) for _, audio in prompts_and_audios))
}) as vllm_model:
},
**kwargs) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
[prompt for prompt, _ in prompts_and_audios],
max_tokens,
Expand All @@ -162,8 +185,9 @@ def run_multi_audio_test(
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
num_logprobs: int, vllm_kwargs: dict) -> None:

vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
Expand All @@ -175,17 +199,18 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
**vllm_kwargs,
)


@pytest.mark.core_model
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
max_tokens: int, num_logprobs: int,
vllm_kwargs: dict) -> None:

vllm_prompt = _get_prompt(len(audio_assets),
"Describe each of the audios above.",
Expand All @@ -198,5 +223,37 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
**vllm_kwargs,
)


@pytest.mark.asyncio
async def test_online_inference(client, audio_assets):
"""Exercises online inference with/without chunked prefill enabled."""

messages = [{
"role":
"user",
"content": [
*[{
"type": "audio_url",
"audio_url": {
"url": audio.url
}
} for audio in audio_assets],
{
"type":
"text",
"text":
f"What's happening in these {len(audio_assets)} audio clips?"
},
],
}]

chat_completion = await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
max_tokens=10)

assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
14 changes: 7 additions & 7 deletions tests/multimodal/test_processor_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import pytest
import torch

from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs
from vllm.inputs.registry import InputRegistry
from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
InputRegistry, token_inputs)
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData

Expand Down Expand Up @@ -56,7 +56,7 @@ def custom_dummy_data_factory(self,
num_crops=DEFAULT_NUM_CROPS):
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
return seq_data, None
return DummyData(seq_data, None)

with patch(
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
Expand Down Expand Up @@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == expected_seq_count
assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count


@pytest.mark.parametrize(
Expand All @@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS


### Test overrides for the max token count per multimodal instance
Expand Down
57 changes: 45 additions & 12 deletions tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,50 @@ def test_repeat_and_pad_placeholder_tokens(model):
tokenizer = AutoTokenizer.from_pretrained(model)

test_cases = [
("<image>", 2, "<image><image>", [32000, 32000]),
("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]),
("<image><image>", [3, 2], "<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000]),
("Image:<image>Image:<image>!", [3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]),
("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]),
]

for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases:
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
(
"<image>",
2,
"<image><image>",
[32000, 32000],
[{ "offset": 0, "length": 2 }],
),
(
"<image><image>",
2,
"<image><image><image>",
[32000, 32000, 32000],
[{ "offset": 0, "length": 2 }]),
(
"<image><image>",
[3, 2],
"<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000],
[{ "offset": 0, "length": 3 }, { "offset": 3, "length": 2 }],
),
(
"Image:<image>Image:<image>!",
[3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[{ "offset": 2, "length": 3 }, { "offset": 7, "length": 2 }],
),
(
"<image>",
[3, 2],
"<image><image><image>",
[32000, 32000, 32000],
[{ "offset": 0, "length": 3 }],
),
] # yapf: disable

for (
prompt,
repeat_count,
expected_prompt,
expected_token_ids,
expected_ranges,
) in test_cases:
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer=tokenizer,
prompt=prompt,
prompt_token_ids=tokenizer.encode(prompt,
Expand All @@ -113,3 +145,4 @@ def test_repeat_and_pad_placeholder_tokens(model):
)
assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids
assert ranges == expected_ranges
3 changes: 3 additions & 0 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_model_runner_input():
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
)
model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
Expand Down Expand Up @@ -124,6 +125,7 @@ def test_embedding_model_runner_input():
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
Expand Down Expand Up @@ -174,6 +176,7 @@ def test_multi_step_model_runner_input():
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
)
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
Expand Down
11 changes: 11 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import torch

from vllm.multimodal import MultiModalPlaceholderMap

if TYPE_CHECKING:
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
Expand Down Expand Up @@ -108,6 +110,15 @@ class AttentionMetadata:
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor

# The index maps that relate multi-modal embeddings to the corresponding
# placeholders.
#
# N.B. These aren't really related to attention and don't belong on this
# type -- this is just a temporary solution to make them available to
# `model_executable`.
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]

@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
Expand Down
3 changes: 3 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def prefill_metadata(
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
Expand Down Expand Up @@ -243,6 +245,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
Expand Down
Loading