11# SPDX-License-Identifier: Apache-2.0
22
3+ import copy
34from functools import partial
5+ from typing import Optional
46
57import numpy as np
68import pytest
@@ -21,6 +23,7 @@ def _test_processing_correctness(
2123 hit_rate : float ,
2224 num_batches : int ,
2325 simplify_rate : float ,
26+ ignore_mm_keys : Optional [list [str ]] = None ,
2427):
2528 model_info = HF_EXAMPLE_MODELS .find_hf_info (model_id )
2629 model_info .check_available_online (on_fail = "skip" )
@@ -123,26 +126,32 @@ def _test_processing_correctness(
123126 hf_processor_mm_kwargs = {},
124127 )
125128
126- assert baseline_result == cached_result , (
127- f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" )
129+ assert _drop_mm_kwargs_keys (
130+ baseline_result , ignore_mm_keys ) == _drop_mm_kwargs_keys (
131+ cached_result , ignore_mm_keys ), (
132+ f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" )
128133
129134 baseline_tokenized_result = baseline_processor .apply (
130135 tokenizer .encode (prompt , ** tokenizer_encode_kwargs ),
131136 mm_data = mm_data ,
132137 hf_processor_mm_kwargs = {},
133138 )
134139
135- assert baseline_result == baseline_tokenized_result , (
136- f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" )
140+ assert _drop_mm_kwargs_keys (
141+ baseline_result , ignore_mm_keys ) == _drop_mm_kwargs_keys (
142+ baseline_tokenized_result , ignore_mm_keys ), (
143+ f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" )
137144
138145 cached_tokenized_result = cached_processor .apply (
139146 tokenizer .encode (prompt , ** tokenizer_encode_kwargs ),
140147 mm_data = mm_data ,
141148 hf_processor_mm_kwargs = {},
142149 )
143150
144- assert cached_result == cached_tokenized_result , (
145- f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" )
151+ assert _drop_mm_kwargs_keys (
152+ cached_result , ignore_mm_keys ) == _drop_mm_kwargs_keys (
153+ cached_tokenized_result , ignore_mm_keys ), (
154+ f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" )
146155
147156
148157# yapf: disable
@@ -173,7 +182,7 @@ def _test_processing_correctness(
173182 "Qwen/Qwen2-VL-2B-Instruct" ,
174183 "Qwen/Qwen2.5-VL-3B-Instruct" ,
175184 "Qwen/Qwen2-Audio-7B-Instruct" ,
176- "fixie-ai/ultravox-v0_4 " ,
185+ "fixie-ai/ultravox-v0_5-llama-3_2-1b " ,
177186 "openai/whisper-large-v3" ,
178187 "google/paligemma-3b-mix-224" ,
179188 "google/paligemma2-3b-ft-docci-448" ,
@@ -188,11 +197,19 @@ def test_processing_correctness(
188197 num_batches : int ,
189198 simplify_rate : float ,
190199):
200+ ignore_mm_keys = None
201+ if 'ultravox' in model_id :
202+ # In Ultravox, the audio_features can be different depending on padding
203+ # The slight difference should not be a problem though, since
204+ # attention_mask lets us ignore the difference.
205+ ignore_mm_keys = ['audio_features' ]
206+
191207 _test_processing_correctness (
192208 model_id ,
193209 hit_rate = hit_rate ,
194210 num_batches = num_batches ,
195211 simplify_rate = simplify_rate ,
212+ ignore_mm_keys = ignore_mm_keys ,
196213 )
197214
198215
@@ -221,3 +238,29 @@ def test_processing_correctness_phi3v(
221238 num_batches = num_batches ,
222239 simplify_rate = simplify_rate ,
223240 )
241+
242+
243+ def _drop_mm_kwargs_keys (result : dict ,
244+ ignore_mm_keys : Optional [list [str ]] = None ) -> dict :
245+ """Drop specified keys from result['mm_kwargs'].
246+
247+ This is mainly to avoid doing exact match of audio_features in ultravox.
248+
249+ Args:
250+ result: Result to drop keys from
251+ ignore_mm_keys: List of keys to ignore, e.g. ['audio_features']
252+ """
253+ if not ignore_mm_keys :
254+ return result
255+
256+ if 'mm_kwargs' in result :
257+ result = copy .deepcopy (result )
258+ mm_kwargs = result ['mm_kwargs' ]
259+ for key in ignore_mm_keys :
260+ mm_kwargs .pop (key , None )
261+ for items in mm_kwargs ._items_by_modality .values ():
262+ for item in items :
263+ for key in ignore_mm_keys :
264+ item .pop (key , None )
265+
266+ return result
0 commit comments