Skip to content

Commit 074cd45

Browse files
Isotr0pyshreyankg
authored andcommitted
[LMM] Implement merged multimodal processor for whisper (vllm-project#13278)
1 parent f54f2d4 commit 074cd45

File tree

4 files changed

+150
-83
lines changed

4 files changed

+150
-83
lines changed

tests/models/multimodal/processing/test_common.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ def _test_processing_correctness(
8383
}
8484

8585
tokenizer_encode_kwargs = {}
86-
if model_config.hf_config.model_type == "mllama":
87-
# For Mllama, tokenizer will always add bos_token at the beginning of
88-
# prompt by default, causing hf_processor outputs incorrect token ids.
89-
# So we need use `add_special_tokens=False` here to leave bos_token
90-
# to be added by the processor.
86+
if model_config.hf_config.model_type in ("mllama", "whisper"):
87+
# For some encoder-decoder models, tokenizer will always add bos_token
88+
# at the beginning of prompt by default, causing hf_processor outputs
89+
# incorrect token ids. So we need use `add_special_tokens=False` here
90+
# to leave bos_token to be added by the processor.
9191
tokenizer_encode_kwargs = {"add_special_tokens": False}
9292

9393
for batch_idx in range(num_batches):
@@ -173,6 +173,7 @@ def _test_processing_correctness(
173173
"Qwen/Qwen2.5-VL-3B-Instruct",
174174
"Qwen/Qwen2-Audio-7B-Instruct",
175175
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
176+
"openai/whisper-large-v3",
176177
])
177178
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
178179
@pytest.mark.parametrize("num_batches", [32])

vllm/model_executor/models/whisper.py

Lines changed: 132 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
55
Union)
66

7-
import numpy as np
87
import torch
98
from torch import nn
9+
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
10+
WhisperProcessor)
1011
from transformers.models.whisper.modeling_whisper import sinusoids
1112

1213
from vllm.attention import Attention, AttentionMetadata, AttentionType
1314
from vllm.config import CacheConfig, VllmConfig
1415
from vllm.distributed import get_tensor_model_parallel_world_size
15-
from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext
1616
from vllm.logger import init_logger
1717
from vllm.model_executor.layers.activation import get_act_fn
1818
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -25,11 +25,14 @@
2525
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
2626
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2727
from vllm.model_executor.sampling_metadata import SamplingMetadata
28-
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
29-
NestedTensors)
30-
from vllm.multimodal.audio import resample_audio
31-
from vllm.sequence import SequenceData
32-
from vllm.transformers_utils.processor import cached_processor_from_config
28+
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
29+
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
30+
from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
31+
MultiModalDataParser)
32+
from vllm.multimodal.processing import (BaseProcessingInfo,
33+
EncDecMultiModalProcessor,
34+
PromptReplacement)
35+
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
3336

3437
from .interfaces import SupportsMultiModal, SupportsTranscription
3538
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
@@ -571,72 +574,126 @@ def load_weights(self, weights: Iterable[Tuple[str,
571574
return loaded_params
572575

573576

574-
def get_max_whisper_audio_tokens(ctx: InputContext) -> int:
575-
return ctx.model_config.hf_config.max_source_positions
576-
577-
578-
def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int,
579-
mm_counts: Mapping[str, int]):
580-
assert mm_counts["audio"] == 1
581-
num_tokens = get_max_whisper_audio_tokens(ctx)
582-
processor = cached_processor_from_config(ctx.model_config)
583-
chunk_length = processor.feature_extractor.chunk_length
584-
sampling_rate = processor.feature_extractor.sampling_rate
585-
num_samples = chunk_length * sampling_rate
586-
return DummyData(
587-
SequenceData.from_prompt_token_counts((0, num_tokens)),
588-
{"audio": [(np.zeros(num_samples), sampling_rate)]},
589-
)
590-
591-
592-
def input_processor_for_whisper(ctx: InputContext, inputs):
593-
multi_modal_data = inputs["encoder"]["multi_modal_data"]
594-
if isinstance(multi_modal_data["audio"], list):
595-
assert len(multi_modal_data["audio"]) == 1
596-
multi_modal_data["audio"] = multi_modal_data["audio"][0]
597-
# Resample and process audio
598-
audio, orig_sr = multi_modal_data["audio"]
599-
processor = cached_processor_from_config(ctx.model_config)
600-
target_sr = processor.feature_extractor.sampling_rate
601-
audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr)
602-
multi_modal_data["audio"] = (audio, target_sr)
603-
# Pre-allocate placeholder tokens in encoder sequence
604-
num_tokens = get_max_whisper_audio_tokens(ctx)
605-
inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens
606-
return inputs
607-
608-
609-
def input_mapper_for_whisper(
610-
ctx: InputContext,
611-
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
612-
) -> MultiModalKwargs:
613-
if not isinstance(multi_modal_data, list):
614-
multi_modal_data = [multi_modal_data]
615-
616-
assert len(multi_modal_data) == 1
617-
618-
if len(multi_modal_data) == 0:
619-
return MultiModalKwargs()
620-
621-
processor = cached_processor_from_config(ctx.model_config)
622-
sampling_rate = processor.feature_extractor.sampling_rate
623-
624-
audios = [audio for audio, _ in multi_modal_data]
625-
626-
kwargs = processor(audios,
627-
sampling_rate=sampling_rate,
628-
return_tensors="pt")
629-
kwargs["input_features"] = kwargs["input_features"].squeeze(0).to(
630-
ctx.model_config.dtype)
631-
632-
return MultiModalKwargs(kwargs)
633-
634-
635-
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper)
636-
@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper)
637-
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper)
638-
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
639-
"audio", get_max_whisper_audio_tokens)
577+
class WhisperProcessingInfo(BaseProcessingInfo):
578+
579+
def get_hf_config(self) -> WhisperConfig:
580+
return self.ctx.get_hf_config(WhisperConfig)
581+
582+
def get_hf_processor(self,
583+
sampling_rate: Optional[int] = None
584+
) -> WhisperProcessor:
585+
return self.ctx.get_hf_processor(WhisperProcessor)
586+
587+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
588+
return {"audio": 1}
589+
590+
def get_feature_extractor(self) -> WhisperFeatureExtractor:
591+
hf_processor = self.get_hf_processor()
592+
feature_extractor = hf_processor.feature_extractor # type: ignore
593+
assert isinstance(feature_extractor, WhisperFeatureExtractor)
594+
return feature_extractor
595+
596+
def get_max_audio_tokens(self) -> int:
597+
return self.get_hf_config().max_source_positions
598+
599+
def get_mm_max_tokens_per_item(
600+
self,
601+
seq_len: int,
602+
mm_counts: Mapping[str, int],
603+
) -> Mapping[str, int]:
604+
return {"audio": self.get_max_audio_tokens()}
605+
606+
607+
class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
608+
609+
def get_dummy_processor_inputs(
610+
self,
611+
seq_len: int,
612+
mm_counts: Mapping[str, int],
613+
) -> ProcessorInputs:
614+
feature_extractor = self.info.get_feature_extractor()
615+
616+
sampling_rate = feature_extractor.sampling_rate
617+
audio_len = feature_extractor.chunk_length * sampling_rate
618+
num_audios = mm_counts.get("audio", 0)
619+
620+
mm_data = {
621+
"audio":
622+
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
623+
}
624+
625+
return ProcessorInputs(
626+
prompt_text="<|startoftranscript|>" * num_audios,
627+
mm_data=mm_data,
628+
)
629+
630+
631+
class WhisperMultiModalProcessor(
632+
EncDecMultiModalProcessor[WhisperProcessingInfo]):
633+
634+
def _get_data_parser(self) -> MultiModalDataParser:
635+
feature_extractor = self.info.get_feature_extractor()
636+
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
637+
638+
def create_encoder_prompt(
639+
self,
640+
prompt: Union[str, list[int]],
641+
mm_data: MultiModalDataDict,
642+
) -> Union[str, list[int]]:
643+
# Strictly speaking, whisper encoder only accept audio features.
644+
# We create a dummy encoder prompt here which will be padded to
645+
# num_audio_tokens. So that we can create dummy data from this
646+
# for encoder profiling.
647+
return [0]
648+
649+
def _call_hf_processor(
650+
self,
651+
prompt: str,
652+
mm_data: Mapping[str, object],
653+
mm_kwargs: Mapping[str, object],
654+
) -> BatchFeature:
655+
if mm_data:
656+
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
657+
mm_data = dict(audio=mm_data.pop("audios"))
658+
mm_kwargs = dict(
659+
**mm_kwargs,
660+
sampling_rate=feature_extractor.sampling_rate,
661+
)
662+
processed_outputs = super()._call_hf_processor(
663+
prompt=prompt,
664+
mm_data=mm_data,
665+
mm_kwargs=mm_kwargs,
666+
)
667+
if "labels" in processed_outputs:
668+
processed_outputs["input_ids"] = processed_outputs.pop("labels")
669+
return processed_outputs
670+
671+
def _get_mm_fields_config(
672+
self,
673+
hf_inputs: BatchFeature,
674+
hf_processor_mm_kwargs: Mapping[str, object],
675+
) -> Mapping[str, MultiModalFieldConfig]:
676+
return dict(input_features=MultiModalFieldConfig.batched("audio"))
677+
678+
def _get_prompt_replacements(
679+
self,
680+
mm_items: MultiModalDataItems,
681+
hf_processor_mm_kwargs: Mapping[str, object],
682+
out_mm_kwargs: MultiModalKwargs,
683+
) -> list[PromptReplacement]:
684+
num_tokens = self.info.get_max_audio_tokens()
685+
return [
686+
PromptReplacement(
687+
modality="audio",
688+
target=[0],
689+
replacement=[0] * num_tokens,
690+
)
691+
]
692+
693+
694+
@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor,
695+
info=WhisperProcessingInfo,
696+
dummy_inputs=WhisperDummyInputsBuilder)
640697
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
641698
SupportsMultiModal):
642699
packed_modules_mapping = {
@@ -724,7 +781,8 @@ def _parse_and_validate_audio_input(
724781
if not isinstance(input_features, (torch.Tensor, list)):
725782
raise ValueError("Incorrect type of audio features. "
726783
f"Got type: {type(input_features)}")
727-
input_features = [feat.to(self.dtype) for feat in input_features]
784+
input_features = torch.cat(
785+
[feat.to(self.dtype) for feat in input_features])
728786

729787
return WhisperAudioInputs(input_features=input_features)
730788

vllm/multimodal/processing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1297,7 +1297,10 @@ def create_encoder_prompt(
12971297
prompt: Union[str, list[int]],
12981298
mm_data: MultiModalDataDict,
12991299
) -> Union[str, list[int]]:
1300-
"""Create input prompt for the encoder."""
1300+
"""
1301+
Create input prompt for the encoder. HF processor will be applied on
1302+
this prompt during profiling and generation.
1303+
"""
13011304
raise NotImplementedError
13021305

13031306
def apply(

vllm/multimodal/profiling.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,12 @@ def get_dummy_data(
166166
f"({set(mm_max_tokens_per_item.keys())})")
167167

168168
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
169-
prompt_token_ids = mm_inputs["prompt_token_ids"]
170169
placeholders_by_modality = mm_inputs["mm_placeholders"]
170+
# For encoder-decoder models, use encoder prompt token ids instead of
171+
# decoder prompt to construct dummy seq_data for encoder profiling.
172+
prompt_token_ids = (
173+
mm_inputs["prompt_token_ids"] if not is_encoder_data else
174+
mm_inputs["encoder_prompt_token_ids"]) # type: ignore
171175

172176
total_placeholders_by_modality = {
173177
modality: sum(item["length"] for item in placeholders)
@@ -188,7 +192,7 @@ def get_dummy_data(
188192

189193
# V0 does not support chunked prefill.
190194
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
191-
if total_len > seq_len:
195+
if total_len > seq_len and not is_encoder_data:
192196
logger.warning(
193197
"The context length (%d) of the model is too short "
194198
"to hold the multi-modal embeddings in the worst case "
@@ -201,7 +205,8 @@ def get_dummy_data(
201205
total_placeholders_by_modality)
202206

203207
return DummyData(
204-
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
208+
seq_data=SequenceData.from_prompt_token_counts(
209+
(0, max(seq_len, total_len))),
205210
multi_modal_data=None,
206211
multi_modal_placeholders=None,
207212
)

0 commit comments

Comments
 (0)