22from collections import UserDict , defaultdict
33from collections .abc import Mapping , Sequence
44from dataclasses import dataclass
5- from typing import Any , Literal , TypedDict , TypeVar , Union , cast , final
5+ from typing import (Any , Literal , Optional , TypedDict , TypeVar , Union , cast ,
6+ final )
67
78import numpy as np
89import torch
1112from transformers import BatchFeature
1213from typing_extensions import NotRequired , TypeAlias
1314
14- from vllm .utils import JSONTree , is_list_of , json_map_leaves
15+ from vllm .utils import JSONTree , full_groupby , is_list_of , json_map_leaves
1516
1617_T = TypeVar ("_T" )
1718
@@ -160,11 +161,8 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
160161
161162
162163@dataclass (frozen = True )
163- class MultiModalFieldItem :
164- """
165- Contains metadata and data in :class:`MultiModalKwargs`
166- corresponding to a data item in :class:`MultiModalDataItems`.
167- """
164+ class MultiModalFieldElem :
165+ """Contains metadata and data of an item in :class:`MultiModalKwargs`."""
168166 field : "BaseMultiModalField"
169167 data : NestedTensors
170168
@@ -186,34 +184,34 @@ class BaseMultiModalField(ABC):
186184 def _reduce_data (self , batch : list [NestedTensors ]) -> NestedTensors :
187185 raise NotImplementedError
188186
189- def _build_item (self , data : NestedTensors ) -> MultiModalFieldItem :
190- return MultiModalFieldItem (self , data )
187+ def _build_elem (self , data : NestedTensors ) -> MultiModalFieldElem :
188+ return MultiModalFieldElem (self , data )
191189
192- def reduce (self , batch : list [MultiModalFieldItem ]) -> MultiModalFieldItem :
193- """Merge multiple instances of :class:`MultiModalFieldItem ` together."""
190+ def reduce (self , batch : list [MultiModalFieldElem ]) -> MultiModalFieldElem :
191+ """Merge multiple instances of :class:`MultiModalFieldElem ` together."""
194192 fields = [item .field for item in batch ]
195193 if len (set (fields )) > 1 :
196194 raise ValueError (f"Cannot merge different { fields = } " )
197195
198196 data = self ._reduce_data ([item .data for item in batch ])
199197
200- return self ._build_item (data )
198+ return self ._build_elem (data )
201199
202200
203201@dataclass (frozen = True )
204202class MultiModalBatchedField (BaseMultiModalField ):
205203 """
206- A :class:`BaseMultiModalField` implementation where an item is obtained by
207- directly indexing into the first dimension of the underlying data.
204+ A :class:`BaseMultiModalField` implementation where an element in the batch
205+ is obtained by indexing into the first dimension of the underlying data.
208206 """
209207
210- def build_items (self , batch : NestedTensors ) -> list [MultiModalFieldItem ]:
211- return [self ._build_item (item ) for item in batch ]
208+ def build_elems (self , batch : NestedTensors ) -> list [MultiModalFieldElem ]:
209+ return [self ._build_elem (item ) for item in batch ]
212210
213211 def _reduce_data (self , batch : list [NestedTensors ]) -> NestedTensors :
214212 if len (batch ) > 0 and is_list_of (batch , torch .Tensor , check = "all" ):
215213 first_shape = batch [0 ].shape
216- if all (item .shape == first_shape for item in batch ):
214+ if all (elem .shape == first_shape for elem in batch ):
217215 return torch .stack (batch )
218216
219217 return batch
@@ -222,24 +220,24 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
222220@dataclass (frozen = True )
223221class MultiModalFlatField (BaseMultiModalField ):
224222 """
225- A :class:`BaseMultiModalField` implementation where an item is obtained by
226- slicing along the first dimension of the underlying data.
223+ A :class:`BaseMultiModalField` implementation where an element in the batch
224+ is obtained by slicing along the first dimension of the underlying data.
227225 """
228226
229- def build_items (
227+ def build_elems (
230228 self ,
231229 batch : NestedTensors ,
232230 slices : Sequence [slice ],
233- ) -> list [MultiModalFieldItem ]:
234- return [self ._build_item (batch [slice_ ]) for slice_ in slices ]
231+ ) -> list [MultiModalFieldElem ]:
232+ return [self ._build_elem (batch [slice_ ]) for slice_ in slices ]
235233
236234 def _reduce_data (self , batch : list [NestedTensors ]) -> NestedTensors :
237235 if len (batch ) > 0 and is_list_of (batch , torch .Tensor , check = "all" ):
238236 first_shape = batch [0 ].shape
239- if all (item .shape [1 :] == first_shape [1 :] for item in batch ):
237+ if all (elem .shape [1 :] == first_shape [1 :] for elem in batch ):
240238 return torch .concat (batch )
241239
242- return [elem for item in batch for elem in item ]
240+ return [e for elem in batch for e in elem ]
243241
244242
245243class MultiModalFieldConfig :
@@ -267,115 +265,111 @@ def __init__(
267265 ) -> None :
268266 super ().__init__ ()
269267
270- self ._field_cls = field_cls
271- self ._modality = modality
272- self ._field_config = field_config
268+ self .field_cls = field_cls
269+ self .modality = modality
270+ self .field_config = field_config
273271
274- def build_items (
272+ def build_elems (
275273 self ,
276274 key : str ,
277275 batch : NestedTensors ,
278- ) -> list [ MultiModalFieldItem ]:
279- field = self ._field_cls (key = key , modality = self ._modality )
280- return field .build_items (batch , ** self ._field_config ) # type: ignore
276+ ) -> Sequence [ MultiModalFieldElem ]:
277+ field = self .field_cls (key = key , modality = self .modality )
278+ return field .build_elems (batch , ** self .field_config ) # type: ignore
281279
282280
283- class MultiModalKwargs (UserDict [str , NestedTensors ]):
281+ class MultiModalKwargsItem (UserDict [str , MultiModalFieldElem ]):
282+ """
283+ A collection of :class:`MultiModalFieldElem`
284+ corresponding to a data item in :class:`MultiModalDataItems`.
284285 """
285- A dictionary that represents the keyword arguments to
286- :meth:`~torch.nn.Module.forward`.
287286
288- The metadata :code:`items_by_key` defines how to split batched keyword
289- arguments corresponding to each data item in :class:`MultiModalDataItems`:
287+ @staticmethod
288+ def from_elems (elems : Sequence [MultiModalFieldElem ]):
289+ return MultiModalKwargsItem ({elem .field .key : elem for elem in elems })
290290
291- - For a keyword argument, we can access the :code:`i` th item in the batch
292- via :code:`items_by_key[key][i]`.
293- - We can gather the keyword arguments belonging to a modality by finding
294- the keys with items that belong to that modality, then accessing
295- the :code:`i` th item in the batch for each such key.
291+ @ property
292+ def modality ( self ) -> str :
293+ modalities = { elem . field . modality for elem in self . data . values ()}
294+ assert len ( modalities ) == 1 , f"Found different modalities= { modalities } "
295+ return next ( iter ( modalities ))
296296
297- Example:
298297
299- .. code-block:: python
300-
301- # All items belong to the "image" modality
302- items_by_key={
303- "pixel_values": [a, b, c, d], # "image" modality
304- "image_grid_thw": [e, f, g, h], # "image" modality
305- "pixel_values_video": [h, i, j], # "video" modality
306- "video_grid_thw": [k, l, m], # "video" modality
307- }
298+ # NOTE: UserDict is for V0 compatibility.
299+ # V1 should access individual items via `get_item`.
300+ class MultiModalKwargs (UserDict [str , NestedTensors ]):
301+ """
302+ A dictionary that represents the keyword arguments to
303+ :meth:`~torch.nn.Module.forward`.
308304
309- - The keyword arguments belonging to the first image are
310- :code:`{"pixel_values": a, "image_grid_thw": e}`.
311- - The keyword arguments belonging to the second video are
312- :code:`{"pixel_values_video": i, "video_grid_thw": l}`.
305+ The metadata :code:`items` enables us to obtain the keyword arguments
306+ corresponding to each data item in :class:`MultiModalDataItems`, via
307+ :meth:`get_item` and :meth:`get_items`.
313308 """
314309
315310 @staticmethod
316311 def from_hf_inputs (
317312 hf_inputs : BatchFeature ,
318313 config_by_key : Mapping [str , MultiModalFieldConfig ],
319- * ,
320- enable_sanity_checks : bool = False ,
321314 ):
322315 # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
323316 # We assume that those fields are not used in vLLM
324- items_by_key = {
325- key : config .build_items (key , batch )
326- for key , config in config_by_key .items ()
327- if (batch := hf_inputs .get (key )) is not None
328- }
329-
330- return MultiModalKwargs .from_items_by_key (
331- items_by_key ,
332- enable_sanity_checks = enable_sanity_checks ,
333- )
317+ elems_by_key = dict [str , Sequence [MultiModalFieldElem ]]()
318+ keys_by_modality = defaultdict [str , set [str ]](set )
319+ for key , config in config_by_key .items ():
320+ batch = hf_inputs .get (key )
321+ if batch is not None :
322+ elems = config .build_elems (key , batch )
323+ if len (elems ) > 0 :
324+ elems_by_key [key ] = elems
325+ keys_by_modality [config .modality ].add (key )
326+
327+ items = list [MultiModalKwargsItem ]()
328+ for modality , keys in keys_by_modality .items ():
329+ elems_in_modality = {k : elems_by_key [k ] for k in keys }
330+ batch_sizes = {k : len (v ) for k , v in elems_in_modality .items ()}
331+
332+ if len (set (batch_sizes .values ())) > 1 :
333+ raise ValueError (
334+ f"Cannot merge different batch sizes for { modality = } ! "
335+ f"Found: { batch_sizes = } " )
336+
337+ batch_size = next (iter (batch_sizes .values ()))
338+ for item_idx in range (batch_size ):
339+ elems = [v [item_idx ] for v in elems_in_modality .values ()]
340+ items .append (MultiModalKwargsItem .from_elems (elems ))
341+
342+ return MultiModalKwargs .from_items (items )
334343
335344 @staticmethod
336- def from_items_by_key (
337- items_by_key : Mapping [str , list [MultiModalFieldItem ]],
338- * ,
339- enable_sanity_checks : bool = False ,
340- ) -> "MultiModalKwargs" :
345+ def from_items (items : Sequence [MultiModalKwargsItem ]):
346+ """Construct a new :class:`MultiModalKwargs` from multiple items."""
347+ elems_by_key = defaultdict [str , list [MultiModalFieldElem ]](list )
348+ for item in items :
349+ for key , elem in item .items ():
350+ elems_by_key [key ].append (elem )
351+
341352 data = {
342- key : items [0 ].field .reduce (items ).data
343- for key , items in items_by_key .items () if len (items ) > 0
353+ key : elems [0 ].field .reduce (elems ).data
354+ for key , elems in elems_by_key .items () if len (elems ) > 0
344355 }
345356
346- return MultiModalKwargs (data ,
347- items_by_key = items_by_key ,
348- enable_sanity_checks = enable_sanity_checks )
357+ return MultiModalKwargs (data , items = items )
349358
350359 def __init__ (
351360 self ,
352361 data : Mapping [str , NestedTensors ],
353362 * ,
354- items_by_key : Mapping [str , list [MultiModalFieldItem ]] = {},
355- enable_sanity_checks : bool = False ,
363+ items : Optional [Sequence [MultiModalKwargsItem ]] = None ,
356364 ) -> None :
357365 super ().__init__ (data )
358366
359- # Shallow copy to avoid footgun in case a defaultdict is passed in
360- self ._items_by_key = dict (items_by_key )
367+ items_by_modality = full_groupby ( items or [], key = lambda x : x . modality )
368+ self ._items_by_modality = dict (items_by_modality )
361369
362- keys_by_modality = defaultdict [str , set [str ]](set )
363- for key , items in items_by_key .items ():
364- for item in items :
365- keys_by_modality [item .field .modality ].add (key )
366-
367- self ._keys_by_modality = dict (keys_by_modality )
368-
369- if enable_sanity_checks :
370- for modality , keys in keys_by_modality .items ():
371- items_in_modality = {k : items_by_key [k ] for k in keys }
372- batch_sizes = {k : len (v ) for k , v in items_in_modality .items ()}
373- batch_size = next (iter (batch_sizes .values ()), 0 )
374- assert all (bs == batch_size
375- for bs in batch_sizes .values ()), dict (
376- modality = modality ,
377- batch_sizes = batch_sizes ,
378- items_by_key = items_by_key )
370+ @property
371+ def modalities (self ):
372+ return self ._items_by_modality .keys ()
379373
380374 @staticmethod
381375 def _try_stack (nested_tensors : NestedTensors ) -> NestedTensors :
@@ -452,58 +446,44 @@ def as_kwargs(
452446 def __eq__ (self , other : object ) -> bool :
453447 if not isinstance (other , self .__class__ ):
454448 return False
455- if self ._items_by_key != other ._items_by_key :
449+ if self ._items_by_modality != other ._items_by_modality :
456450 return False
457451
458452 ks = self .keys ()
459453 return (ks == other .keys ()
460454 and all (nested_tensors_equal (self [k ], other [k ]) for k in ks ))
461455
462- def get_item (self , key : str , item_index : int ) -> MultiModalFieldItem :
463- return self ._items_by_key [key ][item_index ]
456+ def _validate_modality (self , method_name : str , modality : str ) -> None :
457+ if not self ._items_by_modality :
458+ raise RuntimeError (
459+ f"`{ method_name } ` is not supported when "
460+ "MultiModalKwargs is not initialized with `items`" )
464461
465- def get_items_by_modality (
466- self ,
467- modality : str ,
468- item_index : int ,
469- ) -> Mapping [str , MultiModalFieldItem ]:
470- """
471- Get the keyword arguments corresponding to an item identified by
472- its modality and index.
473- """
474- if modality not in self ._keys_by_modality :
475- available_modalities = set (self ._keys_by_modality .keys ())
462+ if modality not in self ._items_by_modality :
463+ available_modalities = set (self ._items_by_modality .keys ())
476464 raise KeyError (f"Modality { modality !r} not found. "
477465 f"Available modalities: { available_modalities } " )
478466
479- keys_to_gather = self ._keys_by_modality [modality ]
467+ def get_item_count (self , modality : str ) -> int :
468+ """Get the number of items belonging to a modality."""
469+ self ._validate_modality ("get_item_count" , modality )
470+ return len (self ._items_by_modality [modality ])
480471
481- return {
482- key : self .get_item (key , item_index )
483- for key in keys_to_gather if key in self
484- }
472+ def get_item (self , modality : str , item_index : int ) -> MultiModalKwargsItem :
473+ """
474+ Get the keyword arguments corresponding to an item identified by
475+ its modality and index.
476+ """
477+ self ._validate_modality ("get_item" , modality )
478+ return self ._items_by_modality [modality ][item_index ]
485479
486- @staticmethod
487- def from_items_by_modality (
488- items_by_modality : Mapping [str , list [Mapping [str ,
489- MultiModalFieldItem ]]],
490- * ,
491- enable_sanity_checks : bool = False ,
492- ) -> "MultiModalKwargs" :
480+ def get_items (self , modality : str ) -> Sequence [MultiModalKwargsItem ]:
493481 """
494- Construct a new :class:`MultiModalKwargs` from multiple items returned
495- by :meth:`get_fields_by_modality` .
482+ Get the keyword arguments corresponding to each item belonging to
483+ a modality .
496484 """
497- items_by_key = defaultdict [str , list [MultiModalFieldItem ]](list )
498- for fields in items_by_modality .values ():
499- for field in fields :
500- for k , v in field .items ():
501- items_by_key [k ].append (v )
502-
503- return MultiModalKwargs .from_items_by_key (
504- items_by_key ,
505- enable_sanity_checks = enable_sanity_checks ,
506- )
485+ self ._validate_modality ("get_items" , modality )
486+ return self ._items_by_modality [modality ]
507487
508488
509489MultiModalPlaceholderDict = Mapping [str , Sequence [PlaceholderRange ]]
0 commit comments