Skip to content

Commit 996357e

Browse files
[VLM] Separate out profiling-related logic (#11746)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 2a622d7 commit 996357e

File tree

17 files changed

+1015
-718
lines changed

17 files changed

+1015
-718
lines changed

tests/multimodal/test_processing.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -586,17 +586,18 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
586586
)
587587

588588
processor = processor_factory(ctx, cache=None)
589+
profiler = processor.profiling_info
589590

590591
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
591-
processor.get_supported_mm_limits = mock_supported_mm_limits
592+
profiler.get_supported_mm_limits = mock_supported_mm_limits
592593

593594
if is_valid:
594595
exc_ctx = nullcontext()
595596
else:
596597
exc_ctx = pytest.raises(ValueError, match="this model only supports")
597598

598599
with exc_ctx:
599-
processor._get_and_validate_dummy_mm_counts()
600+
profiler.get_mm_limits()
600601

601602

602603
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@@ -723,7 +724,7 @@ def _test_processing_cache_correctness(
723724
}
724725

725726
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
726-
prompt = baseline_processor._get_dummy_processor_inputs(
727+
prompt = baseline_processor.profiling_info.get_dummy_processor_inputs(
727728
model_config.max_model_len,
728729
mm_counts,
729730
).prompt_text

vllm/model_executor/models/aria.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
2525
NestedTensors)
2626
from vllm.multimodal.processing import (BaseMultiModalProcessor,
27-
MultiModalDataItems, ProcessorInputs,
27+
MultiModalDataItems, ProcessingMixin,
2828
PromptReplacement)
29+
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
2930
from vllm.sequence import IntermediateTensors
3031
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
3132
AriaVisionConfig)
@@ -444,18 +445,58 @@ def build_mm_projector(config: PretrainedConfig):
444445
)
445446

446447

447-
class AriaMultiModalProcessor(BaseMultiModalProcessor):
448+
class AriaProcessingMixin(ProcessingMixin):
448449

449-
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
450-
return {"image": None}
450+
def _get_hf_config(self):
451+
return self.ctx.get_hf_config()
452+
453+
def _get_vision_config(self) -> AriaVisionConfig:
454+
return self._get_hf_config().vision_config
451455

452456
def _get_num_image_tokens(self) -> int:
453-
hf_config = self.ctx.get_hf_config()
457+
hf_config = self._get_hf_config()
454458
return max(hf_config.projector_patch_to_query_dict.values())
455459

460+
461+
class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
462+
463+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
464+
return {"image": None}
465+
456466
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
457467
return {"image": self._get_num_image_tokens()}
458468

469+
def get_dummy_processor_inputs(
470+
self,
471+
seq_len: int,
472+
mm_counts: Mapping[str, int],
473+
) -> ProcessorInputs:
474+
vision_config = self._get_vision_config()
475+
476+
max_image_size = vision_config.image_size
477+
num_images = mm_counts.get("image", 0)
478+
479+
mm_data = {
480+
"image":
481+
self._get_dummy_images(width=max_image_size,
482+
height=max_image_size,
483+
num_images=num_images)
484+
}
485+
486+
hf_processor = self._get_hf_processor()
487+
image_token: str = hf_processor.image_token # type: ignore
488+
489+
return ProcessorInputs(
490+
prompt_text=image_token * num_images,
491+
mm_data=mm_data,
492+
)
493+
494+
495+
class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
496+
497+
def _get_profiling_info(self) -> BaseProfilingInfo:
498+
return AriaProfilingInfo(self.ctx)
499+
459500
def _get_mm_fields_config(
460501
self,
461502
hf_inputs: BatchFeature,
@@ -472,7 +513,7 @@ def _get_prompt_replacements(
472513
hf_processor_mm_kwargs: Mapping[str, object],
473514
out_mm_kwargs: MultiModalKwargs,
474515
) -> list[PromptReplacement]:
475-
hf_config = self.ctx.get_hf_config()
516+
hf_config = self._get_hf_config()
476517
image_token_id = hf_config.image_token_index
477518

478519
num_image_tokens = self._get_num_image_tokens()
@@ -485,32 +526,6 @@ def _get_prompt_replacements(
485526
)
486527
]
487528

488-
def _get_dummy_processor_inputs(
489-
self,
490-
seq_len: int,
491-
mm_counts: Mapping[str, int],
492-
) -> ProcessorInputs:
493-
hf_config = self.ctx.get_hf_config()
494-
vision_config: AriaVisionConfig = hf_config.vision_config
495-
496-
max_image_size = vision_config.image_size
497-
num_images = mm_counts.get("image", 0)
498-
499-
mm_data = {
500-
"image":
501-
self._get_dummy_images(width=max_image_size,
502-
height=max_image_size,
503-
num_images=num_images)
504-
}
505-
506-
hf_processor = self._get_hf_processor()
507-
image_token: str = hf_processor.image_token # type: ignore
508-
509-
return ProcessorInputs(
510-
prompt_text=image_token * num_images,
511-
mm_data=mm_data,
512-
)
513-
514529

515530
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
516531
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):

vllm/model_executor/models/blip2.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import torch
66
import torch.nn as nn
7-
from transformers import (BatchFeature, Blip2Config, Blip2Processor,
8-
Blip2QFormerConfig, apply_chunking_to_forward)
7+
from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
8+
apply_chunking_to_forward)
99

1010
from vllm.attention import AttentionMetadata
1111
from vllm.config import CacheConfig, VllmConfig
@@ -18,8 +18,9 @@
1818
MultiModalInputsV2, MultiModalKwargs,
1919
NestedTensors, PlaceholderRange)
2020
from vllm.multimodal.processing import (BaseMultiModalProcessor,
21-
MultiModalDataItems, ProcessorInputs,
21+
MultiModalDataItems, ProcessingMixin,
2222
PromptReplacement)
23+
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
2324
from vllm.sequence import IntermediateTensors
2425

2526
from .blip import BlipVisionModel
@@ -396,20 +397,52 @@ def forward(
396397
return sequence_output
397398

398399

399-
class Blip2MultiModalProcessor(BaseMultiModalProcessor):
400+
class Blip2ProcessingMixin(ProcessingMixin):
400401

401-
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
402-
return {"image": 1}
402+
def _get_hf_config(self):
403+
return self.ctx.get_hf_config(Blip2Config)
403404

404405
def _get_num_image_tokens(self) -> int:
405-
hf_config = self.ctx.get_hf_config(Blip2Config)
406+
hf_config = self._get_hf_config()
406407
return hf_config.num_query_tokens
407408

409+
410+
class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
411+
412+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
413+
return {"image": 1}
414+
408415
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
409416
return {"image": self._get_num_image_tokens()}
410417

411-
def _get_hf_processor(self) -> Blip2Processor:
412-
return self.ctx.get_hf_processor(Blip2Processor)
418+
def get_dummy_processor_inputs(
419+
self,
420+
seq_len: int,
421+
mm_counts: Mapping[str, int],
422+
) -> ProcessorInputs:
423+
hf_config = self._get_hf_config()
424+
vision_config = hf_config.vision_config
425+
426+
max_image_size = vision_config.image_size
427+
num_images = mm_counts.get("image", 0)
428+
429+
mm_data = {
430+
"image":
431+
self._get_dummy_images(width=max_image_size,
432+
height=max_image_size,
433+
num_images=num_images)
434+
}
435+
436+
return ProcessorInputs(
437+
prompt_text="",
438+
mm_data=mm_data,
439+
)
440+
441+
442+
class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
443+
444+
def _get_profiling_info(self) -> BaseProfilingInfo:
445+
return Blip2ProfilingInfo(self.ctx)
413446

414447
def _get_mm_fields_config(
415448
self,
@@ -427,13 +460,13 @@ def _get_prompt_replacements(
427460
hf_processor_mm_kwargs: Mapping[str, object],
428461
out_mm_kwargs: MultiModalKwargs,
429462
) -> list[PromptReplacement]:
430-
max_image_tokens = self._get_num_image_tokens()
463+
num_image_tokens = self._get_num_image_tokens()
431464

432465
return [
433466
PromptReplacement(
434467
modality="image",
435468
target="</s>",
436-
replacement="<image>" * max_image_tokens + "</s>",
469+
replacement="<image>" * num_image_tokens + "</s>",
437470
)
438471
]
439472

@@ -457,29 +490,6 @@ def apply(
457490

458491
return result
459492

460-
def _get_dummy_processor_inputs(
461-
self,
462-
seq_len: int,
463-
mm_counts: Mapping[str, int],
464-
) -> ProcessorInputs:
465-
hf_config = self.ctx.get_hf_config(Blip2Config)
466-
vision_config = hf_config.vision_config
467-
468-
max_image_size = vision_config.image_size
469-
num_images = mm_counts.get("image", 0)
470-
471-
mm_data = {
472-
"image":
473-
self._get_dummy_images(width=max_image_size,
474-
height=max_image_size,
475-
num_images=num_images)
476-
}
477-
478-
return ProcessorInputs(
479-
prompt_text="",
480-
mm_data=mm_data,
481-
)
482-
483493

484494
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
485495
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):

vllm/model_executor/models/chameleon.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@
3131
MultiModalInputsV2, MultiModalKwargs,
3232
NestedTensors, PlaceholderRange)
3333
from vllm.multimodal.processing import (BaseMultiModalProcessor,
34-
MultiModalDataItems, ProcessorInputs,
34+
MultiModalDataItems, ProcessingMixin,
3535
PromptReplacement)
36+
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
3637
from vllm.sequence import IntermediateTensors
3738
from vllm.utils import print_warning_once
3839

@@ -48,20 +49,55 @@ class ChameleonImagePixelInputs(TypedDict):
4849
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
4950

5051

51-
class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
52+
class ChameleonProcessingMixin(ProcessingMixin):
5253

53-
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
54-
return {"image": 1}
54+
def _get_hf_config(self):
55+
return self.ctx.get_hf_config(ChameleonConfig)
56+
57+
def _get_hf_processor(self):
58+
return self.ctx.get_hf_processor(ChameleonProcessor)
5559

5660
def _get_num_image_tokens(self) -> int:
5761
processor = self._get_hf_processor()
5862
return processor.image_seq_length
5963

64+
65+
class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
66+
67+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
68+
return {"image": 1}
69+
6070
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
6171
return {"image": self._get_num_image_tokens()}
6272

63-
def _get_hf_processor(self) -> ChameleonProcessor:
64-
return self.ctx.get_hf_processor(ChameleonProcessor)
73+
def get_dummy_processor_inputs(
74+
self,
75+
seq_len: int,
76+
mm_counts: Mapping[str, int],
77+
) -> ProcessorInputs:
78+
config = self._get_hf_config()
79+
80+
width = height = config.vq_config.resolution
81+
num_images = mm_counts.get("image", 0)
82+
83+
mm_data = {
84+
"image":
85+
self._get_dummy_images(width=width,
86+
height=height,
87+
num_images=num_images)
88+
}
89+
90+
return ProcessorInputs(
91+
prompt_text="<image>" * num_images,
92+
mm_data=mm_data,
93+
)
94+
95+
96+
class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
97+
BaseMultiModalProcessor):
98+
99+
def _get_profiling_info(self) -> BaseProfilingInfo:
100+
return ChameleonProfilingInfo(self.ctx)
65101

66102
def _get_mm_fields_config(
67103
self,
@@ -76,7 +112,7 @@ def _get_prompt_replacements(
76112
hf_processor_mm_kwargs: Mapping[str, object],
77113
out_mm_kwargs: MultiModalKwargs,
78114
) -> list[PromptReplacement]:
79-
processor = self._get_hf_processor()
115+
processor = self._get_hf_processor(**hf_processor_mm_kwargs)
80116

81117
return [
82118
PromptReplacement(
@@ -90,28 +126,6 @@ def _get_prompt_replacements(
90126
)
91127
]
92128

93-
def _get_dummy_processor_inputs(
94-
self,
95-
seq_len: int,
96-
mm_counts: Mapping[str, int],
97-
) -> ProcessorInputs:
98-
config = self.ctx.get_hf_config(ChameleonConfig)
99-
100-
width = height = config.vq_config.resolution
101-
num_images = mm_counts.get("image", 0)
102-
103-
mm_data = {
104-
"image":
105-
self._get_dummy_images(width=width,
106-
height=height,
107-
num_images=num_images)
108-
}
109-
110-
return ProcessorInputs(
111-
prompt_text="<image>" * num_images,
112-
mm_data=mm_data,
113-
)
114-
115129
def apply(
116130
self,
117131
prompt_text: str,

0 commit comments

Comments
 (0)