|
4 | 4 | from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, |
5 | 5 | Union) |
6 | 6 |
|
7 | | -import numpy as np |
8 | 7 | import torch |
9 | 8 | from torch import nn |
| 9 | +from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor, |
| 10 | + WhisperProcessor) |
10 | 11 | from transformers.models.whisper.modeling_whisper import sinusoids |
11 | 12 |
|
12 | 13 | from vllm.attention import Attention, AttentionMetadata, AttentionType |
13 | 14 | from vllm.config import CacheConfig, VllmConfig |
14 | 15 | from vllm.distributed import get_tensor_model_parallel_world_size |
15 | | -from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext |
16 | 16 | from vllm.logger import init_logger |
17 | 17 | from vllm.model_executor.layers.activation import get_act_fn |
18 | 18 | from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
|
25 | 25 | from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead |
26 | 26 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
27 | 27 | from vllm.model_executor.sampling_metadata import SamplingMetadata |
28 | | -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, |
29 | | - NestedTensors) |
30 | | -from vllm.multimodal.audio import resample_audio |
31 | | -from vllm.sequence import SequenceData |
32 | | -from vllm.transformers_utils.processor import cached_processor_from_config |
| 28 | +from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors |
| 29 | +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs |
| 30 | +from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems, |
| 31 | + MultiModalDataParser) |
| 32 | +from vllm.multimodal.processing import (BaseProcessingInfo, |
| 33 | + EncDecMultiModalProcessor, |
| 34 | + PromptReplacement) |
| 35 | +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs |
33 | 36 |
|
34 | 37 | from .interfaces import SupportsMultiModal, SupportsTranscription |
35 | 38 | from .utils import AutoWeightsLoader, WeightsMapper, make_layers |
@@ -571,72 +574,126 @@ def load_weights(self, weights: Iterable[Tuple[str, |
571 | 574 | return loaded_params |
572 | 575 |
|
573 | 576 |
|
574 | | -def get_max_whisper_audio_tokens(ctx: InputContext) -> int: |
575 | | - return ctx.model_config.hf_config.max_source_positions |
576 | | - |
577 | | - |
578 | | -def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, |
579 | | - mm_counts: Mapping[str, int]): |
580 | | - assert mm_counts["audio"] == 1 |
581 | | - num_tokens = get_max_whisper_audio_tokens(ctx) |
582 | | - processor = cached_processor_from_config(ctx.model_config) |
583 | | - chunk_length = processor.feature_extractor.chunk_length |
584 | | - sampling_rate = processor.feature_extractor.sampling_rate |
585 | | - num_samples = chunk_length * sampling_rate |
586 | | - return DummyData( |
587 | | - SequenceData.from_prompt_token_counts((0, num_tokens)), |
588 | | - {"audio": [(np.zeros(num_samples), sampling_rate)]}, |
589 | | - ) |
590 | | - |
591 | | - |
592 | | -def input_processor_for_whisper(ctx: InputContext, inputs): |
593 | | - multi_modal_data = inputs["encoder"]["multi_modal_data"] |
594 | | - if isinstance(multi_modal_data["audio"], list): |
595 | | - assert len(multi_modal_data["audio"]) == 1 |
596 | | - multi_modal_data["audio"] = multi_modal_data["audio"][0] |
597 | | - # Resample and process audio |
598 | | - audio, orig_sr = multi_modal_data["audio"] |
599 | | - processor = cached_processor_from_config(ctx.model_config) |
600 | | - target_sr = processor.feature_extractor.sampling_rate |
601 | | - audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) |
602 | | - multi_modal_data["audio"] = (audio, target_sr) |
603 | | - # Pre-allocate placeholder tokens in encoder sequence |
604 | | - num_tokens = get_max_whisper_audio_tokens(ctx) |
605 | | - inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens |
606 | | - return inputs |
607 | | - |
608 | | - |
609 | | -def input_mapper_for_whisper( |
610 | | - ctx: InputContext, |
611 | | - multi_modal_data: Union[np.ndarray, List[np.ndarray]], |
612 | | -) -> MultiModalKwargs: |
613 | | - if not isinstance(multi_modal_data, list): |
614 | | - multi_modal_data = [multi_modal_data] |
615 | | - |
616 | | - assert len(multi_modal_data) == 1 |
617 | | - |
618 | | - if len(multi_modal_data) == 0: |
619 | | - return MultiModalKwargs() |
620 | | - |
621 | | - processor = cached_processor_from_config(ctx.model_config) |
622 | | - sampling_rate = processor.feature_extractor.sampling_rate |
623 | | - |
624 | | - audios = [audio for audio, _ in multi_modal_data] |
625 | | - |
626 | | - kwargs = processor(audios, |
627 | | - sampling_rate=sampling_rate, |
628 | | - return_tensors="pt") |
629 | | - kwargs["input_features"] = kwargs["input_features"].squeeze(0).to( |
630 | | - ctx.model_config.dtype) |
631 | | - |
632 | | - return MultiModalKwargs(kwargs) |
633 | | - |
634 | | - |
635 | | -@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper) |
636 | | -@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper) |
637 | | -@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) |
638 | | -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( |
639 | | - "audio", get_max_whisper_audio_tokens) |
| 577 | +class WhisperProcessingInfo(BaseProcessingInfo): |
| 578 | + |
| 579 | + def get_hf_config(self) -> WhisperConfig: |
| 580 | + return self.ctx.get_hf_config(WhisperConfig) |
| 581 | + |
| 582 | + def get_hf_processor(self, |
| 583 | + sampling_rate: Optional[int] = None |
| 584 | + ) -> WhisperProcessor: |
| 585 | + return self.ctx.get_hf_processor(WhisperProcessor) |
| 586 | + |
| 587 | + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: |
| 588 | + return {"audio": 1} |
| 589 | + |
| 590 | + def get_feature_extractor(self) -> WhisperFeatureExtractor: |
| 591 | + hf_processor = self.get_hf_processor() |
| 592 | + feature_extractor = hf_processor.feature_extractor # type: ignore |
| 593 | + assert isinstance(feature_extractor, WhisperFeatureExtractor) |
| 594 | + return feature_extractor |
| 595 | + |
| 596 | + def get_max_audio_tokens(self) -> int: |
| 597 | + return self.get_hf_config().max_source_positions |
| 598 | + |
| 599 | + def get_mm_max_tokens_per_item( |
| 600 | + self, |
| 601 | + seq_len: int, |
| 602 | + mm_counts: Mapping[str, int], |
| 603 | + ) -> Mapping[str, int]: |
| 604 | + return {"audio": self.get_max_audio_tokens()} |
| 605 | + |
| 606 | + |
| 607 | +class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): |
| 608 | + |
| 609 | + def get_dummy_processor_inputs( |
| 610 | + self, |
| 611 | + seq_len: int, |
| 612 | + mm_counts: Mapping[str, int], |
| 613 | + ) -> ProcessorInputs: |
| 614 | + feature_extractor = self.info.get_feature_extractor() |
| 615 | + |
| 616 | + sampling_rate = feature_extractor.sampling_rate |
| 617 | + audio_len = feature_extractor.chunk_length * sampling_rate |
| 618 | + num_audios = mm_counts.get("audio", 0) |
| 619 | + |
| 620 | + mm_data = { |
| 621 | + "audio": |
| 622 | + self._get_dummy_audios(length=audio_len, num_audios=num_audios) |
| 623 | + } |
| 624 | + |
| 625 | + return ProcessorInputs( |
| 626 | + prompt_text="<|startoftranscript|>" * num_audios, |
| 627 | + mm_data=mm_data, |
| 628 | + ) |
| 629 | + |
| 630 | + |
| 631 | +class WhisperMultiModalProcessor( |
| 632 | + EncDecMultiModalProcessor[WhisperProcessingInfo]): |
| 633 | + |
| 634 | + def _get_data_parser(self) -> MultiModalDataParser: |
| 635 | + feature_extractor = self.info.get_feature_extractor() |
| 636 | + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) |
| 637 | + |
| 638 | + def create_encoder_prompt( |
| 639 | + self, |
| 640 | + prompt: Union[str, list[int]], |
| 641 | + mm_data: MultiModalDataDict, |
| 642 | + ) -> Union[str, list[int]]: |
| 643 | + # Strictly speaking, whisper encoder only accept audio features. |
| 644 | + # We create a dummy encoder prompt here which will be padded to |
| 645 | + # num_audio_tokens. So that we can create dummy data from this |
| 646 | + # for encoder profiling. |
| 647 | + return [0] |
| 648 | + |
| 649 | + def _call_hf_processor( |
| 650 | + self, |
| 651 | + prompt: str, |
| 652 | + mm_data: Mapping[str, object], |
| 653 | + mm_kwargs: Mapping[str, object], |
| 654 | + ) -> BatchFeature: |
| 655 | + if mm_data: |
| 656 | + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) |
| 657 | + mm_data = dict(audio=mm_data.pop("audios")) |
| 658 | + mm_kwargs = dict( |
| 659 | + **mm_kwargs, |
| 660 | + sampling_rate=feature_extractor.sampling_rate, |
| 661 | + ) |
| 662 | + processed_outputs = super()._call_hf_processor( |
| 663 | + prompt=prompt, |
| 664 | + mm_data=mm_data, |
| 665 | + mm_kwargs=mm_kwargs, |
| 666 | + ) |
| 667 | + if "labels" in processed_outputs: |
| 668 | + processed_outputs["input_ids"] = processed_outputs.pop("labels") |
| 669 | + return processed_outputs |
| 670 | + |
| 671 | + def _get_mm_fields_config( |
| 672 | + self, |
| 673 | + hf_inputs: BatchFeature, |
| 674 | + hf_processor_mm_kwargs: Mapping[str, object], |
| 675 | + ) -> Mapping[str, MultiModalFieldConfig]: |
| 676 | + return dict(input_features=MultiModalFieldConfig.batched("audio")) |
| 677 | + |
| 678 | + def _get_prompt_replacements( |
| 679 | + self, |
| 680 | + mm_items: MultiModalDataItems, |
| 681 | + hf_processor_mm_kwargs: Mapping[str, object], |
| 682 | + out_mm_kwargs: MultiModalKwargs, |
| 683 | + ) -> list[PromptReplacement]: |
| 684 | + num_tokens = self.info.get_max_audio_tokens() |
| 685 | + return [ |
| 686 | + PromptReplacement( |
| 687 | + modality="audio", |
| 688 | + target=[0], |
| 689 | + replacement=[0] * num_tokens, |
| 690 | + ) |
| 691 | + ] |
| 692 | + |
| 693 | + |
| 694 | +@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor, |
| 695 | + info=WhisperProcessingInfo, |
| 696 | + dummy_inputs=WhisperDummyInputsBuilder) |
640 | 697 | class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, |
641 | 698 | SupportsMultiModal): |
642 | 699 | packed_modules_mapping = { |
@@ -724,7 +781,8 @@ def _parse_and_validate_audio_input( |
724 | 781 | if not isinstance(input_features, (torch.Tensor, list)): |
725 | 782 | raise ValueError("Incorrect type of audio features. " |
726 | 783 | f"Got type: {type(input_features)}") |
727 | | - input_features = [feat.to(self.dtype) for feat in input_features] |
| 784 | + input_features = torch.cat( |
| 785 | + [feat.to(self.dtype) for feat in input_features]) |
728 | 786 |
|
729 | 787 | return WhisperAudioInputs(input_features=input_features) |
730 | 788 |
|
|
0 commit comments