Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
7b3b6ea
[V1] Zero-copy tensor/ndarray serialization/transmission
njhill Feb 24, 2025
35d1cd9
TypeAlias keyword is python >= 3.10 only
njhill Feb 25, 2025
f6f26b6
use highest pickle protocol
njhill Mar 1, 2025
4382a16
Merge remote-tracking branch 'origin/main' into tensor-nocopy
njhill Mar 15, 2025
9d91483
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill Mar 25, 2025
ea75bd3
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill Apr 5, 2025
06efa46
Implement custom serializaton for MultiModalKwargs
p88h Apr 8, 2025
543ee7b
Support all NestedTensors
p88h Apr 8, 2025
87b7385
fix formatting
p88h Apr 8, 2025
f5db471
proper logger format
p88h Apr 8, 2025
95b0600
Add unit test
njhill Apr 8, 2025
910f30f
pre-commit fix
njhill Apr 8, 2025
747ce1c
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill Apr 9, 2025
478ce09
Fix unrecognized type decode
njhill Apr 9, 2025
2d92af1
use msgspec.Raw for tensor data
p88h Apr 9, 2025
7215037
Merge branch 'main' into serialize-multimodal-kwargs
p88h Apr 9, 2025
7789c99
Merge branch 'main' into tensor-nocopy
p88h Apr 9, 2025
c1d62ad
Merge branch 'tensor-nocopy' into serialize-multimodal-kwargs
p88h Apr 9, 2025
b2e3219
Get rid of (some) workarounds, slightly more efficient encoding
p88h Apr 9, 2025
c98bf9a
style fixes
p88h Apr 9, 2025
139ae1c
properly rename fields
p88h Apr 9, 2025
7ea02a8
Handle scalars properly
njhill Apr 9, 2025
e7d010d
Optimization: encode small tensors inline.
njhill Apr 9, 2025
face6e4
Implement support for _items_by_modality, review fixes
p88h Apr 9, 2025
f946398
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill Apr 10, 2025
095d4fd
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 10, 2025
60797b4
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill Apr 10, 2025
c0c6e43
Update vllm/v1/serial_utils.py
njhill Apr 10, 2025
3b978ad
Update vllm/v1/serial_utils.py
njhill Apr 10, 2025
80d90a5
Update vllm/v1/serial_utils.py
njhill Apr 10, 2025
6bd45dc
Update vllm/v1/serial_utils.py
njhill Apr 10, 2025
97c144b
Update vllm/v1/serial_utils.py
njhill Apr 10, 2025
c6c2a90
Comment/docstring updates
njhill Apr 10, 2025
793c39c
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 10, 2025
714d615
Merge branch 'tensor-nocopy' into serialize-multimodal-kwargs
p88h Apr 10, 2025
aa64391
Get rid of (some) workarounds, slightly more efficient encoding
p88h Apr 9, 2025
2d471bc
style fixes
p88h Apr 9, 2025
9ca2552
Implement support for _items_by_modality, review fixes
p88h Apr 9, 2025
e1295bc
[Bugfix] Fix bug when dataset is json (#15899)
Chenyaaang Apr 10, 2025
9a81901
[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (…
cyang49 Apr 10, 2025
e0483bc
[VLM] Avoid unnecessary dummy multimodal data during processing (#16416)
DarkLight1337 Apr 10, 2025
bdacbb8
Merge branch 'main' into serialize-multimodal-kwargs
p88h Apr 10, 2025
7f779ef
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added images/01.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/02.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/03.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/04.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/05.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/06.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/07.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/08.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/09.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/10.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/11.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/12.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/13.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/14.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/r01.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/r02.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/r03.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/r04.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/r05.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/r06.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/r07.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/r08.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/r09.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/r10.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/r11.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/r12.jpg
Binary file added images/r13.jpg
Binary file added images/r14.jpg
Binary file added images/r15.jpg
Binary file added images/r16.jpg
Binary file added images/result.jpg
100 changes: 99 additions & 1 deletion tests/v1/test_serial_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
from collections import UserDict
from dataclasses import dataclass
from typing import Optional

import msgspec
import numpy as np
import torch

from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem, NestedTensors)
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder


Expand Down Expand Up @@ -36,7 +41,7 @@ def test_encode_decode():
list_of_tensors=[
torch.rand((1, 10), dtype=torch.float32),
torch.rand((3, 5, 4000), dtype=torch.float64),
torch.tensor(1984), # test scalar too
torch.tensor(1984), # test scalar too,
],
numpy_array=np.arange(512),
unrecognized=UnrecognizedType(33),
Expand Down Expand Up @@ -70,6 +75,99 @@ def test_encode_decode():
assert_equal(decoded2, obj)


class MyRequest(msgspec.Struct):
mm: Optional[list[MultiModalKwargs]]


def test_multimodal_kwargs():
d = {
"foo":
torch.zeros(1000, dtype=torch.float16),
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
"baz": [
torch.rand((256), dtype=torch.float16),
[
torch.rand((1, 12), dtype=torch.float32),
torch.rand((3, 5, 7), dtype=torch.float64),
], [torch.rand((4, 4), dtype=torch.float16)]
],
}

# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest(mm=[MultiModalKwargs(d)])

encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)

encoded = encoder.encode(req)

# 8 total tensors + top level buffer
assert len(encoded) == 9

total_len = sum([len(x) for x in encoded])

# expected total encoding length, should be 4362, +-20 for minor changes
assert total_len >= 4342 and total_len <= 4382
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
assert all(nested_equal(d[k], decoded[k]) for k in d)


def test_multimodal_items_by_modality():
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
dtype=torch.int16),
MultiModalBatchedField())
e2 = MultiModalFieldElem(
"video",
"v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalBatchedField(),
)
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
dtype=torch.int32),
MultiModalBatchedField())
e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000,
dtype=torch.int32),
MultiModalBatchedField())
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4])
mm = MultiModalKwargs.from_items([audio, video, image])

# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm])

encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)

encoded = encoder.encode(req)

# 5 total tensors + top level buffer
assert len(encoded) == 8

total_len = sum([len(x) for x in encoded])

# expected total encoding length, should be 7507, +-20 for minor changes
assert total_len >= 7487 and total_len <= 7527
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]

# check all modalities were recovered and do some basic sanity checks
assert len(decoded.modalities) == 3
images = decoded.get_items("image")
assert len(images) == 1
assert len(images[0].items()) == 2
assert list(images[0].keys()) == ["i0", "i1"]

# check the tensor contents and layout in the main dict
assert all(nested_equal(mm[k], decoded[k]) for k in mm)


def nested_equal(a: NestedTensors, b: NestedTensors):
if isinstance(a, torch.Tensor):
return torch.equal(a, b)
else:
return all([nested_equal(x, y) for (x, y) in zip(a, b)])


def assert_equal(obj1: MyType, obj2: MyType):
assert torch.equal(obj1.tensor1, obj2.tensor1)
assert obj1.a_string == obj2.a_string
Expand Down
56 changes: 53 additions & 3 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import pickle
from collections.abc import Sequence
from dataclasses import asdict
from inspect import isclass
from itertools import chain
from types import FunctionType
from typing import Any, Optional, Union

Expand All @@ -12,12 +14,17 @@
import zmq
from msgspec import msgpack

from vllm.logger import init_logger
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem, NestedTensors)

CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2

# TODO calibrate this size
INLINE_BUF_SIZE_THRESHOLD = 256

logger = init_logger(__name__)
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]


Expand Down Expand Up @@ -64,6 +71,25 @@ def enc_hook(self, obj: Any) -> Any:
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
return self._encode_ndarray(obj)

if isinstance(obj, MultiModalKwargs):
mm: MultiModalKwargs = obj
if mm.modalities:
# ignore the main dict, it will be re-indexed.
# pass a list of MultiModalKwargsItem, then see below
# Any tensors *not* indexed by modality will be ignored.
return [mm.get_items(m) for m in mm.modalities]
# just return the main dict if there are no modalities
return {k: v for k, v in obj.items()}

if isinstance(obj, MultiModalKwargsItem):
rd = {}
for k, v in obj.items():
vv = asdict(v)
vv['field'] = pickle.dumps(v.field,
protocol=pickle.HIGHEST_PROTOCOL)
rd[k] = vv
return rd

if isinstance(obj, FunctionType):
# `pickle` is generally faster than cloudpickle, but can have
# problems serializing methods.
Expand All @@ -87,7 +113,7 @@ def _encode_ndarray(
# We serialize the ndarray as a tuple of native types.
# The data is either inlined if small, or an index into a list of
# backing buffers that we've stashed in `aux_buffers`.
return obj.dtype.str, obj.shape, data
return (obj.dtype.str, obj.shape, data)


class MsgpackDecoder:
Expand Down Expand Up @@ -121,15 +147,39 @@ def dec_hook(self, t: type, obj: Any) -> Any:
if isclass(t):
if issubclass(t, np.ndarray):
return self._decode_ndarray(obj)
if issubclass(t, MultiModalKwargs) and isinstance(obj, dict):
return MultiModalKwargs(
{k: self._decode_nested(obj[k])
for k in obj})
if issubclass(t, MultiModalKwargs) and isinstance(obj, list):
return MultiModalKwargs.from_items(self._decode_items(obj))
if issubclass(t, torch.Tensor):
return torch.from_numpy(self._decode_ndarray(obj))
return obj

def _decode_ndarray(self, arr: Any) -> np.ndarray:
dtype, shape, data = arr
def _decode_ndarray(self, obj: Any) -> np.ndarray:
(dtype, shape, data) = obj
buffer = self.aux_buffers[data] if isinstance(data, int) else data
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)

def _decode_items(self, obj: list) -> list[MultiModalKwargsItem]:
all = []
for item in chain.from_iterable(obj):
elems = []
for v in item.values():
v['data'] = self._decode_nested(v['data'])
v['field'] = pickle.loads(v['field'])
elems.append(MultiModalFieldElem(**v))
all.append(MultiModalKwargsItem.from_elems(elems))
return all

def _decode_nested(self, obj: Any) -> NestedTensors:
if isinstance(obj, list) and isinstance(obj[0], str):
return torch.from_numpy(self._decode_ndarray(obj))
if isinstance(obj, list):
return [self._decode_nested(x) for x in obj]
raise TypeError(f"Unexpected NestedArray contents: {obj}")

def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
Expand Down