Skip to content

Commit 43d87ec

Browse files
committed
Implement efficient serialization of MultiModalKwargs
In addition to serializing base Tensors, this now allows to pass Tensors embedded in MultiModalKwargs correctly. Handles both V0 and V1 style args. Improves memory usage with large multimodal payloads by a further 50% (but still not on par with single-threaded behavior).
1 parent 56d4aef commit 43d87ec

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

tests/v1/test_serial_utils.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from collections import UserDict
33
from dataclasses import dataclass
4+
from typing import Optional
45

6+
import msgspec
57
import numpy as np
68
import torch
79

10+
from vllm.multimodal.inputs import (MultiModalBatchedField,
11+
MultiModalFieldElem, MultiModalKwargs,
12+
MultiModalKwargsItem, NestedTensors)
813
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
914

1015

@@ -70,6 +75,99 @@ def test_encode_decode():
7075
assert_equal(decoded2, obj)
7176

7277

78+
class MyRequest(msgspec.Struct):
79+
mm: Optional[list[MultiModalKwargs]]
80+
81+
82+
def test_multimodal_kwargs():
83+
d = {
84+
"foo":
85+
torch.zeros(1000, dtype=torch.float16),
86+
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
87+
"baz": [
88+
torch.rand((256), dtype=torch.float16),
89+
[
90+
torch.rand((1, 12), dtype=torch.float32),
91+
torch.rand((3, 5, 7), dtype=torch.float64),
92+
], [torch.rand((4, 4), dtype=torch.float16)]
93+
],
94+
}
95+
96+
# pack mm kwargs into a mock request so that it can be decoded properly
97+
req = MyRequest(mm=[MultiModalKwargs(d)])
98+
99+
encoder = MsgpackEncoder()
100+
decoder = MsgpackDecoder(MyRequest)
101+
102+
encoded = encoder.encode(req)
103+
104+
# 8 total tensors + top level buffer
105+
assert len(encoded) == 6
106+
107+
total_len = sum([len(x) for x in encoded])
108+
109+
# expected total encoding length, should be 4440, +-20 for minor changes
110+
assert total_len >= 4420 and total_len <= 4460
111+
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
112+
assert all(nested_equal(d[k], decoded[k]) for k in d)
113+
114+
115+
def test_multimodal_items_by_modality():
116+
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
117+
dtype=torch.int16),
118+
MultiModalBatchedField())
119+
e2 = MultiModalFieldElem(
120+
"video",
121+
"v0",
122+
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
123+
MultiModalBatchedField(),
124+
)
125+
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
126+
dtype=torch.int32),
127+
MultiModalBatchedField())
128+
e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000,
129+
dtype=torch.int32),
130+
MultiModalBatchedField())
131+
audio = MultiModalKwargsItem.from_elems([e1])
132+
video = MultiModalKwargsItem.from_elems([e2])
133+
image = MultiModalKwargsItem.from_elems([e3, e4])
134+
mm = MultiModalKwargs.from_items([audio, video, image])
135+
136+
# pack mm kwargs into a mock request so that it can be decoded properly
137+
req = MyRequest([mm])
138+
139+
encoder = MsgpackEncoder()
140+
decoder = MsgpackDecoder(MyRequest)
141+
142+
encoded = encoder.encode(req)
143+
144+
# 5 total tensors + top level buffer
145+
assert len(encoded) == 8
146+
147+
total_len = sum([len(x) for x in encoded])
148+
149+
# expected total encoding length, should be 7507, +-20 for minor changes
150+
assert total_len >= 7487 and total_len <= 7527
151+
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
152+
153+
# check all modalities were recovered and do some basic sanity checks
154+
assert len(decoded.modalities) == 3
155+
images = decoded.get_items("image")
156+
assert len(images) == 1
157+
assert len(images[0].items()) == 2
158+
assert list(images[0].keys()) == ["i0", "i1"]
159+
160+
# check the tensor contents and layout in the main dict
161+
assert all(nested_equal(mm[k], decoded[k]) for k in mm)
162+
163+
164+
def nested_equal(a: NestedTensors, b: NestedTensors):
165+
if isinstance(a, torch.Tensor):
166+
return torch.equal(a, b)
167+
else:
168+
return all([nested_equal(x, y) for (x, y) in zip(a, b)])
169+
170+
73171
def assert_equal(obj1: MyType, obj2: MyType):
74172
assert torch.equal(obj1.tensor1, obj2.tensor1)
75173
assert obj1.a_string == obj2.a_string

vllm/v1/serial_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import pickle
44
from collections.abc import Sequence
5+
from dataclasses import asdict
56
from inspect import isclass
7+
from itertools import chain
68
from types import FunctionType
79
from typing import Any, Optional, Union
810

@@ -12,6 +14,9 @@
1214
import zmq
1315
from msgspec import msgpack
1416

17+
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
18+
MultiModalKwargsItem, NestedTensors)
19+
1520
CUSTOM_TYPE_PICKLE = 1
1621
CUSTOM_TYPE_CLOUDPICKLE = 2
1722

@@ -64,6 +69,26 @@ def enc_hook(self, obj: Any) -> Any:
6469
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
6570
return self._encode_ndarray(obj)
6671

72+
if isinstance(obj, MultiModalKwargs):
73+
mm: MultiModalKwargs = obj
74+
if mm.modalities:
75+
# ignore the main dict, it will be re-indexed.
76+
# pass a list of MultiModalKwargsItem, then see below
77+
# Any tensors *not* indexed by modality will be ignored.
78+
return [mm.get_items(m) for m in mm.modalities]
79+
# just return the main dict if there are no modalities
80+
return {k: v for k, v in obj.items()}
81+
82+
if isinstance(obj, MultiModalKwargsItem):
83+
# Encode as plain dictionary + special handling for '.field'
84+
rd = {}
85+
for k, v in obj.items():
86+
vv = asdict(v)
87+
vv['field'] = pickle.dumps(v.field,
88+
protocol=pickle.HIGHEST_PROTOCOL)
89+
rd[k] = vv
90+
return rd
91+
6792
if isinstance(obj, FunctionType):
6893
# `pickle` is generally faster than cloudpickle, but can have
6994
# problems serializing methods.
@@ -121,6 +146,12 @@ def dec_hook(self, t: type, obj: Any) -> Any:
121146
if isclass(t):
122147
if issubclass(t, np.ndarray):
123148
return self._decode_ndarray(obj)
149+
if issubclass(t, MultiModalKwargs) and isinstance(obj, dict):
150+
return MultiModalKwargs(
151+
{k: self._decode_nested(obj[k])
152+
for k in obj})
153+
if issubclass(t, MultiModalKwargs) and isinstance(obj, list):
154+
return MultiModalKwargs.from_items(self._decode_items(obj))
124155
if issubclass(t, torch.Tensor):
125156
return torch.from_numpy(self._decode_ndarray(obj))
126157
return obj
@@ -130,6 +161,24 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
130161
buffer = self.aux_buffers[data] if isinstance(data, int) else data
131162
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
132163

164+
def _decode_items(self, obj: list) -> list[MultiModalKwargsItem]:
165+
all = []
166+
for item in chain.from_iterable(obj):
167+
elems = []
168+
for v in item.values():
169+
v['data'] = self._decode_nested(v['data'])
170+
v['field'] = pickle.loads(v['field'])
171+
elems.append(MultiModalFieldElem(**v))
172+
all.append(MultiModalKwargsItem.from_elems(elems))
173+
return all
174+
175+
def _decode_nested(self, obj: Any) -> NestedTensors:
176+
if isinstance(obj, list) and isinstance(obj[0], str):
177+
return torch.from_numpy(self._decode_ndarray(obj))
178+
if isinstance(obj, list):
179+
return [self._decode_nested(x) for x in obj]
180+
raise TypeError(f"Unexpected NestedArray contents: {obj}")
181+
133182
def ext_hook(self, code: int, data: memoryview) -> Any:
134183
if code == CUSTOM_TYPE_PICKLE:
135184
return pickle.loads(data)

0 commit comments

Comments
 (0)