@@ -653,7 +653,7 @@ def dummy(modality: str):
653
653
def from_elems (elems : Sequence [MultiModalFieldElem ]):
654
654
return MultiModalKwargsItem ({elem .key : elem for elem in elems })
655
655
656
- def __init__ (self , data : Mapping [str , MultiModalFieldElem ]) -> None :
656
+ def __init__ (self , data : Mapping [str , MultiModalFieldElem ] = {} ) -> None :
657
657
super ().__init__ (data )
658
658
659
659
modalities = {elem .modality for elem in self .data .values ()}
@@ -668,9 +668,7 @@ def get_data(self) -> Mapping[str, NestedTensors]:
668
668
return {key : elem .data for key , elem in self .items ()}
669
669
670
670
671
- # NOTE: UserDict is for V0 compatibility.
672
- # V1 should access individual items via `get_item`.
673
- class MultiModalKwargs (UserDict [str , NestedTensors ]):
671
+ class MultiModalKwargs :
674
672
"""
675
673
A dictionary that represents the keyword arguments to
676
674
[`torch.nn.Module.forward`][].
@@ -714,40 +712,16 @@ def from_hf_inputs(
714
712
elems = [v [item_idx ] for v in elems_in_modality .values ()]
715
713
items .append (MultiModalKwargsItem .from_elems (elems ))
716
714
717
- return MultiModalKwargs . from_items (items )
715
+ return MultiModalKwargs (items )
718
716
719
- @staticmethod
720
- def from_items (
721
- items : Sequence [MultiModalKwargsItem ],
722
- * ,
723
- pin_memory : bool = False ,
724
- ):
725
- """Construct a new
726
- [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
727
- from multiple items."""
728
- elems_by_key = defaultdict [str , list [MultiModalFieldElem ]](list )
729
- for item in items :
730
- for key , elem in item .items ():
731
- elems_by_key [key ].append (elem )
732
-
733
- data = {
734
- key : elems [0 ].field .reduce_data (elems , pin_memory = pin_memory )
735
- for key , elems in elems_by_key .items () if len (elems ) > 0
736
- }
737
-
738
- return MultiModalKwargs (data , items = items )
739
-
740
- def __init__ (
741
- self ,
742
- data : Mapping [str , NestedTensors ],
743
- * ,
744
- items : Optional [Sequence [MultiModalKwargsItem ]] = None ,
745
- ) -> None :
746
- super ().__init__ (data )
717
+ def __init__ (self , items : Sequence [MultiModalKwargsItem ] = ()) -> None :
718
+ super ().__init__ ()
747
719
748
- items_by_modality = full_groupby (items or [] , key = lambda x : x .modality )
720
+ items_by_modality = full_groupby (items , key = lambda x : x .modality )
749
721
self ._items_by_modality = dict (items_by_modality )
750
722
723
+ self ._data : Optional [Mapping [str , NestedTensors ]] = None
724
+
751
725
@property
752
726
def modalities (self ):
753
727
return self ._items_by_modality .keys ()
@@ -839,22 +813,41 @@ def as_kwargs(
839
813
840
814
return cast (BatchedTensorInputs , json_mapped )
841
815
842
- def __delitem__ (self , key : str ) -> None :
843
- super ().__delitem__ (key )
816
+ def keys (self ):
817
+ return self .get_data ().keys ()
818
+
819
+ def values (self ):
820
+ return self .get_data ().values ()
821
+
822
+ def items (self ):
823
+ return self .get_data ().items ()
824
+
825
+ def get (self , key : str , / , default = None ):
826
+ return self .get_data ().get (key , default )
827
+
828
+ def pop (self , key : str , * args , ** kwargs ):
829
+ data = dict (self .get_data ())
830
+ res = data .pop (key , * args , ** kwargs )
844
831
845
832
for items in self ._items_by_modality .values ():
846
833
for item in items :
847
- item .pop (key , None )
834
+ item .pop (key , * args , ** kwargs )
835
+
836
+ self ._data = None
837
+
838
+ return res
839
+
840
+ def __iter__ (self ):
841
+ return iter (self .get_data ())
842
+
843
+ def __getitem__ (self , key : str ):
844
+ return self .get_data ()[key ]
848
845
849
846
def __eq__ (self , other : object ) -> bool :
850
847
if not isinstance (other , self .__class__ ):
851
848
return False
852
- if self ._items_by_modality != other ._items_by_modality :
853
- return False
854
849
855
- ks = self .keys ()
856
- return (ks == other .keys ()
857
- and all (nested_tensors_equal (self [k ], other [k ]) for k in ks ))
850
+ return self ._items_by_modality == other ._items_by_modality
858
851
859
852
def _validate_modality (self , method_name : str , modality : str ) -> None :
860
853
if not self ._items_by_modality :
@@ -888,6 +881,25 @@ def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
888
881
self ._validate_modality ("get_items" , modality )
889
882
return self ._items_by_modality [modality ]
890
883
884
+ def get_data (self ,
885
+ * ,
886
+ pin_memory : bool = False ) -> Mapping [str , NestedTensors ]:
887
+ if self ._data is not None :
888
+ return self ._data
889
+
890
+ elems_by_key = defaultdict [str , list [MultiModalFieldElem ]](list )
891
+ for items in self ._items_by_modality .values ():
892
+ for item in items :
893
+ for key , elem in item .items ():
894
+ elems_by_key [key ].append (elem )
895
+
896
+ data = {
897
+ key : elems [0 ].field .reduce_data (elems , pin_memory = pin_memory )
898
+ for key , elems in elems_by_key .items () if len (elems ) > 0
899
+ }
900
+ self ._data = data
901
+ return data
902
+
891
903
892
904
MultiModalPlaceholderDict : TypeAlias = Mapping [str , Sequence [PlaceholderRange ]]
893
905
"""
0 commit comments