Skip to content

Commit 5c32143

Browse files
[Refactor] Defer tensor data construction in MultiModalKwargs (#23030)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 94096a4 commit 5c32143

File tree

12 files changed

+73
-104
lines changed

12 files changed

+73
-104
lines changed

tests/multimodal/test_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _dummy_item(modality: str, size_by_key: dict[str, int]):
2525

2626

2727
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
28-
return MultiModalKwargs.from_items([
28+
return MultiModalKwargs([
2929
_dummy_item(modality, size_by_key)
3030
for modality, size_by_key in size_by_key_modality.items()
3131
])

tests/v1/test_serial_utils.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -100,38 +100,6 @@ class MyRequest(msgspec.Struct):
100100

101101

102102
def test_multimodal_kwargs():
103-
d = {
104-
"foo":
105-
torch.zeros(20000, dtype=torch.float16),
106-
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
107-
"baz": [
108-
torch.rand((256), dtype=torch.float16),
109-
[
110-
torch.rand((1, 12), dtype=torch.float32),
111-
torch.rand((3, 5, 7), dtype=torch.float64),
112-
], [torch.rand((4, 4), dtype=torch.float16)]
113-
],
114-
}
115-
116-
# pack mm kwargs into a mock request so that it can be decoded properly
117-
req = MyRequest(mm=[MultiModalKwargs(d)])
118-
119-
encoder = MsgpackEncoder()
120-
decoder = MsgpackDecoder(MyRequest)
121-
122-
encoded = encoder.encode(req)
123-
124-
assert len(encoded) == 6
125-
126-
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
127-
128-
# expected total encoding length, should be 44559, +-20 for minor changes
129-
assert 44539 <= total_len <= 44579
130-
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
131-
assert all(nested_equal(d[k], decoded[k]) for k in d)
132-
133-
134-
def test_multimodal_items_by_modality():
135103
e1 = MultiModalFieldElem("audio", "a0",
136104
torch.zeros(1000, dtype=torch.bfloat16),
137105
MultiModalBatchedField())
@@ -151,7 +119,7 @@ def test_multimodal_items_by_modality():
151119
audio = MultiModalKwargsItem.from_elems([e1])
152120
video = MultiModalKwargsItem.from_elems([e2])
153121
image = MultiModalKwargsItem.from_elems([e3, e4])
154-
mm = MultiModalKwargs.from_items([audio, video, image])
122+
mm = MultiModalKwargs([audio, video, image])
155123

156124
# pack mm kwargs into a mock request so that it can be decoded properly
157125
req = MyRequest([mm])

vllm/inputs/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,6 @@ def dummy_data_for_profiling(
240240

241241
return DummyData(
242242
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
243-
multi_modal_data=dec_data.multi_modal_data,
243+
multi_modal_data=dec_data.multi_modal_data.get_data(),
244244
multi_modal_placeholders=dec_data.multi_modal_placeholders,
245245
)

vllm/model_executor/models/prithvi_geospatial_mae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def apply(
136136
type="multimodal",
137137
prompt=prompt,
138138
prompt_token_ids=[1],
139-
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
139+
mm_kwargs=MultiModalKwargs(multimodal_kwargs_items),
140140
mm_hashes=None,
141141
mm_placeholders=mm_placeholders,
142142
)

vllm/multimodal/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def from_seq_group(
9999
seq_mm_placeholders = seq_group.multi_modal_placeholders
100100

101101
if not seq_mm_data or not seq_mm_placeholders:
102-
return MultiModalKwargs({}), {}
102+
return MultiModalKwargs(), {}
103103

104104
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
105105

vllm/multimodal/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_leaf_size(
4646
) -> int:
4747
# MultiModalKwargs is not a subclass of dict
4848
if isinstance(leaf, MultiModalKwargs):
49-
return cls.get_item_size(leaf.data, debug=debug)
49+
return cls.get_item_size(leaf.get_data(), debug=debug)
5050

5151
# MultiModalKwargsItem is not a subclass of dict
5252
if isinstance(leaf, MultiModalKwargsItem):

vllm/multimodal/inputs.py

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def dummy(modality: str):
653653
def from_elems(elems: Sequence[MultiModalFieldElem]):
654654
return MultiModalKwargsItem({elem.key: elem for elem in elems})
655655

656-
def __init__(self, data: Mapping[str, MultiModalFieldElem]) -> None:
656+
def __init__(self, data: Mapping[str, MultiModalFieldElem] = {}) -> None:
657657
super().__init__(data)
658658

659659
modalities = {elem.modality for elem in self.data.values()}
@@ -668,9 +668,7 @@ def get_data(self) -> Mapping[str, NestedTensors]:
668668
return {key: elem.data for key, elem in self.items()}
669669

670670

671-
# NOTE: UserDict is for V0 compatibility.
672-
# V1 should access individual items via `get_item`.
673-
class MultiModalKwargs(UserDict[str, NestedTensors]):
671+
class MultiModalKwargs:
674672
"""
675673
A dictionary that represents the keyword arguments to
676674
[`torch.nn.Module.forward`][].
@@ -714,40 +712,16 @@ def from_hf_inputs(
714712
elems = [v[item_idx] for v in elems_in_modality.values()]
715713
items.append(MultiModalKwargsItem.from_elems(elems))
716714

717-
return MultiModalKwargs.from_items(items)
715+
return MultiModalKwargs(items)
718716

719-
@staticmethod
720-
def from_items(
721-
items: Sequence[MultiModalKwargsItem],
722-
*,
723-
pin_memory: bool = False,
724-
):
725-
"""Construct a new
726-
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
727-
from multiple items."""
728-
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
729-
for item in items:
730-
for key, elem in item.items():
731-
elems_by_key[key].append(elem)
732-
733-
data = {
734-
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
735-
for key, elems in elems_by_key.items() if len(elems) > 0
736-
}
737-
738-
return MultiModalKwargs(data, items=items)
739-
740-
def __init__(
741-
self,
742-
data: Mapping[str, NestedTensors],
743-
*,
744-
items: Optional[Sequence[MultiModalKwargsItem]] = None,
745-
) -> None:
746-
super().__init__(data)
717+
def __init__(self, items: Sequence[MultiModalKwargsItem] = ()) -> None:
718+
super().__init__()
747719

748-
items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
720+
items_by_modality = full_groupby(items, key=lambda x: x.modality)
749721
self._items_by_modality = dict(items_by_modality)
750722

723+
self._data: Optional[Mapping[str, NestedTensors]] = None
724+
751725
@property
752726
def modalities(self):
753727
return self._items_by_modality.keys()
@@ -839,22 +813,41 @@ def as_kwargs(
839813

840814
return cast(BatchedTensorInputs, json_mapped)
841815

842-
def __delitem__(self, key: str) -> None:
843-
super().__delitem__(key)
816+
def keys(self):
817+
return self.get_data().keys()
818+
819+
def values(self):
820+
return self.get_data().values()
821+
822+
def items(self):
823+
return self.get_data().items()
824+
825+
def get(self, key: str, /, default=None):
826+
return self.get_data().get(key, default)
827+
828+
def pop(self, key: str, *args, **kwargs):
829+
data = dict(self.get_data())
830+
res = data.pop(key, *args, **kwargs)
844831

845832
for items in self._items_by_modality.values():
846833
for item in items:
847-
item.pop(key, None)
834+
item.pop(key, *args, **kwargs)
835+
836+
self._data = None
837+
838+
return res
839+
840+
def __iter__(self):
841+
return iter(self.get_data())
842+
843+
def __getitem__(self, key: str):
844+
return self.get_data()[key]
848845

849846
def __eq__(self, other: object) -> bool:
850847
if not isinstance(other, self.__class__):
851848
return False
852-
if self._items_by_modality != other._items_by_modality:
853-
return False
854849

855-
ks = self.keys()
856-
return (ks == other.keys()
857-
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
850+
return self._items_by_modality == other._items_by_modality
858851

859852
def _validate_modality(self, method_name: str, modality: str) -> None:
860853
if not self._items_by_modality:
@@ -888,6 +881,25 @@ def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
888881
self._validate_modality("get_items", modality)
889882
return self._items_by_modality[modality]
890883

884+
def get_data(self,
885+
*,
886+
pin_memory: bool = False) -> Mapping[str, NestedTensors]:
887+
if self._data is not None:
888+
return self._data
889+
890+
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
891+
for items in self._items_by_modality.values():
892+
for item in items:
893+
for key, elem in item.items():
894+
elems_by_key[key].append(elem)
895+
896+
data = {
897+
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
898+
for key, elems in elems_by_key.items() if len(elems) > 0
899+
}
900+
self._data = data
901+
return data
902+
891903

892904
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
893905
"""

vllm/multimodal/processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1480,7 +1480,7 @@ def _cached_apply_hf_processor(
14801480
mm_missing_kwargs=mm_missing_kwargs,
14811481
)
14821482

1483-
mm_kwargs = MultiModalKwargs.from_items([
1483+
mm_kwargs = MultiModalKwargs([
14841484
item for cache_items in mm_cache_items_merged.values()
14851485
for item in cache_items
14861486
])

vllm/multimodal/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -402,20 +402,22 @@ def group_mm_kwargs_by_modality(
402402
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
403403
items_lst = list(items)
404404

405-
# mm_kwargs_group = MultiModalKwargs.from_items(items_lst,
406-
# pin_memory=pin_memory)
405+
# mm_kwargs_group = MultiModalKwargs(items_lst) \
406+
# .get_data(pin_memory=pin_memory)
407407

408408
# if device is not None:
409-
# mm_kwargs_group = json_map_leaves(lambda x: x.to(device=device),
410-
# mm_kwargs_group.data)
409+
# mm_kwargs_group = json_map_leaves(
410+
# lambda x: x.to(device=device),
411+
# mm_kwargs_group,
412+
# )
411413

412414
# TODO: Once V0 is removed, we can use the merging logic above
413415
# to avoid creating an extra batch dimension (except for fields
414416
# that are meant to be stacked anyway).
415417
# We will also need to update each model to remove `flatten_bn`.
416418
mm_kwargs_group = MultiModalKwargs.as_kwargs(
417419
MultiModalKwargs.batch(
418-
[MultiModalKwargs.from_items([item]) for item in items_lst],
420+
[MultiModalKwargs([item]) for item in items_lst],
419421
pin_memory=pin_memory,
420422
),
421423
device=device,

vllm/sequence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ def multi_modal_data(self) -> MultiModalKwargs:
524524
if self.inputs["type"] == "multimodal":
525525
return self.inputs["mm_kwargs"]
526526

527-
return MultiModalKwargs({})
527+
return MultiModalKwargs()
528528

529529
@property
530530
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
@@ -780,7 +780,7 @@ def multi_modal_data(self) -> MultiModalKwargs:
780780
return self.first_seq.multi_modal_data
781781
elif self.encoder_seq is not None:
782782
return self.encoder_seq.multi_modal_data
783-
return MultiModalKwargs({})
783+
return MultiModalKwargs()
784784

785785
@property
786786
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:

0 commit comments

Comments
 (0)