Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _test_processing_correctness(
}

tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type == "mllama":
if model_config.hf_config.model_type in ("mllama", "whisper"):
# For Mllama, tokenizer will always add bos_token at the beginning of
# prompt by default, causing hf_processor outputs incorrect token ids.
# So we need use `add_special_tokens=False` here to leave bos_token
Expand Down Expand Up @@ -173,6 +173,7 @@ def _test_processing_correctness(
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
Expand Down
202 changes: 128 additions & 74 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)

import numpy as np
import torch
from torch import nn
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
WhisperProcessor)
from transformers.models.whisper.modeling_whisper import sinusoids

from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand All @@ -25,11 +25,14 @@
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.audio import resample_audio
from vllm.sequence import SequenceData
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs

from .interfaces import SupportsMultiModal, SupportsTranscription
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
Expand Down Expand Up @@ -571,72 +574,122 @@ def load_weights(self, weights: Iterable[Tuple[str,
return loaded_params


def get_max_whisper_audio_tokens(ctx: InputContext) -> int:
return ctx.model_config.hf_config.max_source_positions


def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
assert mm_counts["audio"] == 1
num_tokens = get_max_whisper_audio_tokens(ctx)
processor = cached_processor_from_config(ctx.model_config)
chunk_length = processor.feature_extractor.chunk_length
sampling_rate = processor.feature_extractor.sampling_rate
num_samples = chunk_length * sampling_rate
return DummyData(
SequenceData.from_prompt_token_counts((0, num_tokens)),
{"audio": [(np.zeros(num_samples), sampling_rate)]},
)


def input_processor_for_whisper(ctx: InputContext, inputs):
multi_modal_data = inputs["encoder"]["multi_modal_data"]
if isinstance(multi_modal_data["audio"], list):
assert len(multi_modal_data["audio"]) == 1
multi_modal_data["audio"] = multi_modal_data["audio"][0]
# Resample and process audio
audio, orig_sr = multi_modal_data["audio"]
processor = cached_processor_from_config(ctx.model_config)
target_sr = processor.feature_extractor.sampling_rate
audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr)
multi_modal_data["audio"] = (audio, target_sr)
# Pre-allocate placeholder tokens in encoder sequence
num_tokens = get_max_whisper_audio_tokens(ctx)
inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens
return inputs


def input_mapper_for_whisper(
ctx: InputContext,
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
) -> MultiModalKwargs:
if not isinstance(multi_modal_data, list):
multi_modal_data = [multi_modal_data]

assert len(multi_modal_data) == 1

if len(multi_modal_data) == 0:
return MultiModalKwargs()

processor = cached_processor_from_config(ctx.model_config)
sampling_rate = processor.feature_extractor.sampling_rate

audios = [audio for audio, _ in multi_modal_data]

kwargs = processor(audios,
sampling_rate=sampling_rate,
return_tensors="pt")
kwargs["input_features"] = kwargs["input_features"].squeeze(0).to(
ctx.model_config.dtype)

return MultiModalKwargs(kwargs)


@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper)
@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper)
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_whisper_audio_tokens)
class WhisperProcessingInfo(BaseProcessingInfo):

def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig)

def get_hf_processor(self,
sampling_rate: Optional[int] = None
) -> WhisperProcessor:
return self.ctx.get_hf_processor(WhisperProcessor)

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": 1}

def get_feature_extractor(self) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor()
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor

def get_max_audio_tokens(self) -> int:
return self.get_hf_config().max_source_positions

def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"audio": self.get_max_audio_tokens()}


class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):

def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self.info.get_feature_extractor()

sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)

mm_data = {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}

return ProcessorInputs(
prompt_text="<|startoftranscript|>" * num_audios,
mm_data=mm_data,
)


class WhisperMultiModalProcessor(
EncDecMultiModalProcessor[WhisperProcessingInfo]):

def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)

def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
return [0]

def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_data = dict(audio=mm_data.pop("audios"))
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
if "labels" in processed_outputs:
processed_outputs["input_ids"] = processed_outputs.pop("labels")
return processed_outputs

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(input_features=MultiModalFieldConfig.batched("audio"))

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
num_tokens = self.info.get_max_audio_tokens()
return [
PromptReplacement(
modality="audio",
target=[0],
replacement=[0] * num_tokens,
)
]


@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor,
info=WhisperProcessingInfo,
dummy_inputs=WhisperDummyInputsBuilder)
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
SupportsMultiModal):
packed_modules_mapping = {
Expand Down Expand Up @@ -724,7 +777,8 @@ def _parse_and_validate_audio_input(
if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(input_features)}")
input_features = [feat.to(self.dtype) for feat in input_features]
input_features = torch.cat(
[feat.to(self.dtype) for feat in input_features])

return WhisperAudioInputs(input_features=input_features)

Expand Down
9 changes: 6 additions & 3 deletions vllm/multimodal/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def get_dummy_data(
f"({set(mm_max_tokens_per_item.keys())})")

mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
prompt_token_ids = (
mm_inputs["prompt_token_ids"] if not is_encoder_data else
mm_inputs["encoder_prompt_token_ids"]) # type: ignore
placeholders_by_modality = mm_inputs["mm_placeholders"]

total_placeholders_by_modality = {
Expand All @@ -188,7 +190,7 @@ def get_dummy_data(

# V0 does not support chunked prefill.
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
if total_len > seq_len:
if total_len > seq_len and not is_encoder_data:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
Expand All @@ -201,7 +203,8 @@ def get_dummy_data(
total_placeholders_by_modality)

return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
seq_data=SequenceData.from_prompt_token_counts(
(0, max(seq_len, total_len))),
multi_modal_data=None,
multi_modal_placeholders=None,
)
Expand Down