1+ from functools import partial
12from typing import cast
23
4+ import numpy as np
35import pytest
4-
5- from vllm .multimodal .processing import (PromptReplacement , _PlaceholderInfo ,
6- find_text_matches , find_token_matches ,
7- iter_placeholders , iter_token_matches ,
6+ from PIL import Image
7+
8+ from vllm .config import ModelConfig
9+ from vllm .inputs import InputProcessingContext
10+ from vllm .multimodal import MULTIMODAL_REGISTRY
11+ from vllm .multimodal .processing import (ProcessingCache , PromptReplacement ,
12+ _PlaceholderInfo , find_text_matches ,
13+ find_token_matches , iter_placeholders ,
14+ iter_token_matches ,
815 replace_text_matches ,
916 replace_token_matches )
17+ from vllm .multimodal .utils import cached_get_tokenizer
1018from vllm .transformers_utils .tokenizer import AnyTokenizer
1119from vllm .utils import full_groupby
1220
@@ -457,6 +465,7 @@ def test_find_replace_tokens(
457465 ),
458466 ]
459467)
468+ # yapf: enable
460469def test_iter_placeholders (
461470 repl_by_key ,
462471 prompt ,
@@ -475,11 +484,199 @@ def test_iter_placeholders(
475484 prompt_repls ,
476485 prompt ,
477486 # Effectively match all occurrences in the prompt
478- {key : 3 for key in repl_by_key },
479- ))
487+ {key : 3
488+ for key in repl_by_key },
489+ ))
480490
481491 # Only displayed on error
482492 print ("result:" , result )
483493
484494 # Manually constructed results
485495 assert result == expected
496+
497+
498+ def _rand_img (rng : np .random .RandomState , min_wh : int , max_wh : int ):
499+ w , h = rng .randint (min_wh , max_wh , size = (2 , ))
500+ arr = rng .randint (0 , 255 , size = (w , h , 3 ), dtype = np .uint8 )
501+ return Image .fromarray (arr )
502+
503+
504+ def _rand_video (
505+ rng : np .random .RandomState ,
506+ min_frames : int ,
507+ max_frames : int ,
508+ min_wh : int ,
509+ max_wh : int ,
510+ ):
511+ # Temporary workaround for https://github.com/huggingface/transformers/issues/35412
512+ num_frames = rng .randint (min_frames , max_frames )
513+ num_frames = (num_frames // 2 ) * 2
514+
515+ w , h = rng .randint (min_wh , max_wh , size = (2 , ))
516+ return rng .randint (0 , 255 , size = (num_frames , w , h , 3 ), dtype = np .uint8 )
517+
518+
519+ def _rand_audio (
520+ rng : np .random .RandomState ,
521+ min_len : int ,
522+ max_len : int ,
523+ sr : int ,
524+ ):
525+ audio_len = rng .randint (min_len , max_len )
526+ return rng .rand (audio_len ), sr
527+
528+
529+ def _test_processing_cache_correctness (
530+ model_id : str ,
531+ modalities : set [str ],
532+ hit_rate : float ,
533+ num_batches : int ,
534+ simplify_rate : float ,
535+ ):
536+ if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3" :
537+ hf_overrides = {"architectures" : ["MantisForConditionalGeneration" ]}
538+ else :
539+ hf_overrides = {}
540+
541+ model_config = ModelConfig (
542+ model_id ,
543+ task = "auto" ,
544+ tokenizer = model_id ,
545+ tokenizer_mode = "auto" ,
546+ trust_remote_code = True ,
547+ seed = 0 ,
548+ dtype = "float16" ,
549+ revision = None ,
550+ hf_overrides = hf_overrides ,
551+ )
552+ model_cls = MULTIMODAL_REGISTRY ._get_model_cls (model_config )
553+
554+ processor_factory = MULTIMODAL_REGISTRY ._processor_factories [model_cls ]
555+ ctx = InputProcessingContext (
556+ model_config ,
557+ tokenizer = cached_get_tokenizer (model_config .tokenizer ),
558+ )
559+ # Ensure that it can fit all of the data
560+ cache = ProcessingCache (capacity = 1 << 30 )
561+
562+ baseline_processor = processor_factory (ctx , cache = None )
563+ cached_processor = processor_factory (ctx , cache = cache )
564+
565+ rng = np .random .RandomState (0 )
566+
567+ input_to_hit = {
568+ "image" : Image .new ("RGB" , size = (128 , 128 )),
569+ "video" : np .zeros ((4 , 128 , 128 , 3 ), dtype = np .uint8 ),
570+ "audio" : (np .zeros ((512 , )), 16000 ),
571+ }
572+ input_factory = {
573+ "image" :
574+ partial (_rand_img , rng , min_wh = 128 , max_wh = 256 ),
575+ "video" :
576+ partial (_rand_video ,
577+ rng ,
578+ min_frames = 2 ,
579+ max_frames = 8 ,
580+ min_wh = 128 ,
581+ max_wh = 256 ),
582+ "audio" :
583+ partial (_rand_audio , rng , min_len = 256 , max_len = 512 , sr = 16000 ),
584+ }
585+ input_max_count = {
586+ "image" : 3 ,
587+ "video" : 3 ,
588+ "audio" : 3 ,
589+ }
590+
591+ for batch_idx in range (num_batches ):
592+ mm_data = {
593+ k :
594+ [(input_to_hit [k ] if rng .rand () < hit_rate else input_factory [k ]())
595+ for _ in range (rng .randint (input_max_count [k ]))]
596+ for k in modalities
597+ }
598+
599+ mm_counts = {k : len (vs ) for k , vs in mm_data .items ()}
600+ prompt = baseline_processor ._get_dummy_mm_inputs (mm_counts ).prompt_text
601+
602+ # Drop unnecessary keys and test single -> multi conversion
603+ if rng .rand () < simplify_rate :
604+ for k in list (mm_data .keys ()):
605+ if not mm_data [k ]:
606+ del mm_data [k ]
607+ elif len (mm_data [k ]) == 1 :
608+ mm_data [k ] = mm_data [k ][0 ]
609+
610+ baseline_result = baseline_processor .apply (
611+ prompt ,
612+ mm_data = mm_data ,
613+ hf_processor_mm_kwargs = {},
614+ )
615+ cached_result = cached_processor .apply (
616+ prompt ,
617+ mm_data = mm_data ,
618+ hf_processor_mm_kwargs = {},
619+ )
620+
621+ assert baseline_result == cached_result , (
622+ f"Failed ({ batch_idx = } , { mm_data = } )" )
623+
624+
625+ # yapf: disable
626+ @pytest .mark .parametrize (("model_id" , "modalities" ), [
627+ ("llava-hf/llava-1.5-7b-hf" , {"image" }),
628+ ("TIGER-Lab/Mantis-8B-siglip-llama3" , {"image" }),
629+ ("mistral-community/pixtral-12b" , {"image" }),
630+ ("Qwen/Qwen2-VL-2B-Instruct" , {"image" , "video" }),
631+ ("Qwen/Qwen2-Audio-7B-Instruct" , {"audio" }),
632+ ("fixie-ai/ultravox-v0_3" , {"audio" }),
633+ ])
634+ @pytest .mark .parametrize ("hit_rate" , [0.3 , 0.5 , 1.0 ])
635+ @pytest .mark .parametrize ("num_batches" , [32 ])
636+ @pytest .mark .parametrize ("simplify_rate" , [1.0 ])
637+ # yapf: enable
638+ def test_processing_cache_correctness (
639+ model_id : str ,
640+ modalities : set [str ],
641+ hit_rate : float ,
642+ num_batches : int ,
643+ simplify_rate : float ,
644+ ):
645+ _test_processing_cache_correctness (
646+ model_id ,
647+ modalities ,
648+ hit_rate = hit_rate ,
649+ num_batches = num_batches ,
650+ simplify_rate = simplify_rate ,
651+ )
652+
653+
654+ # yapf: disable
655+ @pytest .mark .parametrize (("model_id" , "modalities" ), [
656+ ("microsoft/Phi-3-vision-128k-instruct" , {"image" }),
657+ ])
658+ @pytest .mark .parametrize ("hit_rate" , [0.3 , 0.5 , 1.0 ])
659+ @pytest .mark .parametrize ("num_batches" , [32 ])
660+ @pytest .mark .parametrize ("simplify_rate" , [1.0 ])
661+ # yapf: enable
662+ def test_processing_cache_correctness_phi3v (
663+ model_id : str ,
664+ modalities : set [str ],
665+ hit_rate : float ,
666+ num_batches : int ,
667+ simplify_rate : float ,
668+ ):
669+ # HACK - this is an attempted workaround for the following bug
670+ # https://github.com/huggingface/transformers/issues/34307
671+ from transformers import AutoImageProcessor # noqa: F401
672+ from transformers import AutoProcessor # noqa: F401
673+
674+ AutoImageProcessor .from_pretrained (model_id , trust_remote_code = True )
675+
676+ _test_processing_cache_correctness (
677+ model_id ,
678+ modalities ,
679+ hit_rate = hit_rate ,
680+ num_batches = num_batches ,
681+ simplify_rate = simplify_rate ,
682+ )
0 commit comments