-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[V1][Performance] Implement custom serializaton for MultiModalKwargs #16279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 21 commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
7b3b6ea
[V1] Zero-copy tensor/ndarray serialization/transmission
njhill 35d1cd9
TypeAlias keyword is python >= 3.10 only
njhill f6f26b6
use highest pickle protocol
njhill 4382a16
Merge remote-tracking branch 'origin/main' into tensor-nocopy
njhill 9d91483
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill ea75bd3
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill 06efa46
Implement custom serializaton for MultiModalKwargs
p88h 543ee7b
Support all NestedTensors
p88h 87b7385
fix formatting
p88h f5db471
proper logger format
p88h 95b0600
Add unit test
njhill 910f30f
pre-commit fix
njhill 747ce1c
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill 478ce09
Fix unrecognized type decode
njhill 2d92af1
use msgspec.Raw for tensor data
p88h 7215037
Merge branch 'main' into serialize-multimodal-kwargs
p88h 7789c99
Merge branch 'main' into tensor-nocopy
p88h c1d62ad
Merge branch 'tensor-nocopy' into serialize-multimodal-kwargs
p88h b2e3219
Get rid of (some) workarounds, slightly more efficient encoding
p88h c98bf9a
style fixes
p88h 139ae1c
properly rename fields
p88h 7ea02a8
Handle scalars properly
njhill e7d010d
Optimization: encode small tensors inline.
njhill face6e4
Implement support for _items_by_modality, review fixes
p88h f946398
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill 095d4fd
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h 60797b4
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill c0c6e43
Update vllm/v1/serial_utils.py
njhill 3b978ad
Update vllm/v1/serial_utils.py
njhill 80d90a5
Update vllm/v1/serial_utils.py
njhill 6bd45dc
Update vllm/v1/serial_utils.py
njhill 97c144b
Update vllm/v1/serial_utils.py
njhill c6c2a90
Comment/docstring updates
njhill 793c39c
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h 714d615
Merge branch 'tensor-nocopy' into serialize-multimodal-kwargs
p88h aa64391
Get rid of (some) workarounds, slightly more efficient encoding
p88h 2d471bc
style fixes
p88h 9ca2552
Implement support for _items_by_modality, review fixes
p88h e1295bc
[Bugfix] Fix bug when dataset is json (#15899)
Chenyaaang 9a81901
[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (…
cyang49 e0483bc
[VLM] Avoid unnecessary dummy multimodal data during processing (#16416)
DarkLight1337 bdacbb8
Merge branch 'main' into serialize-multimodal-kwargs
p88h 7f779ef
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from collections import UserDict | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import msgspec | ||
import numpy as np | ||
import torch | ||
|
||
from vllm.multimodal.inputs import MultiModalKwargs | ||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder | ||
|
||
|
||
class UnrecognizedType(UserDict): | ||
|
||
def __init__(self, an_int: int): | ||
super().__init__() | ||
self.an_int = an_int | ||
|
||
|
||
@dataclass | ||
class MyType: | ||
tensor1: torch.Tensor | ||
a_string: str | ||
list_of_tensors: list[torch.Tensor] | ||
numpy_array: np.ndarray | ||
unrecognized: UnrecognizedType | ||
|
||
|
||
def test_encode_decode(): | ||
"""Test encode/decode loop with zero-copy tensors.""" | ||
|
||
obj = MyType( | ||
tensor1=torch.randint(low=0, high=100, size=(10, ), dtype=torch.int32), | ||
a_string="hello", | ||
list_of_tensors=[ | ||
torch.rand((1, 10), dtype=torch.float32), | ||
torch.rand((3, 5, 4), dtype=torch.float64) | ||
], | ||
numpy_array=np.arange(20), | ||
unrecognized=UnrecognizedType(33), | ||
) | ||
|
||
encoder = MsgpackEncoder() | ||
decoder = MsgpackDecoder(MyType) | ||
|
||
encoded = encoder.encode(obj) | ||
|
||
# There should be the main buffer + 3 tensor buffers + one ndarray buffer | ||
assert len(encoded) == 5 | ||
|
||
decoded: MyType = decoder.decode(encoded) | ||
|
||
assert_equal(decoded, obj) | ||
|
||
# Test encode_into case | ||
|
||
preallocated = bytearray() | ||
|
||
encoded2 = encoder.encode_into(obj, preallocated) | ||
|
||
assert len(encoded2) == 5 | ||
assert encoded2[0] is preallocated | ||
|
||
decoded2: MyType = decoder.decode(encoded2) | ||
|
||
assert_equal(decoded2, obj) | ||
|
||
|
||
class MyRequest(msgspec.Struct): | ||
mm: Optional[list[MultiModalKwargs]] | ||
|
||
|
||
def test_multimodal_kwargs(): | ||
d = { | ||
"foo": torch.zeros(1000, dtype=torch.float16), | ||
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)], | ||
"baz": (torch.zeros(256, dtype=torch.int64), "i'm a tuple") | ||
} | ||
|
||
# pack mm kwargs into a mock request so that it can be decoded properly | ||
req = MyRequest(mm=[MultiModalKwargs(d)]) | ||
|
||
encoder = MsgpackEncoder() | ||
decoder = MsgpackDecoder(MyRequest) | ||
|
||
encoded = encoder.encode(req) | ||
|
||
# 5 total tensors + top level buffer | ||
assert len(encoded) == 6 | ||
|
||
total_len = sum([len(x) for x in encoded]) | ||
|
||
# expected total encoding length, should be 4384, +-20 for minor changes | ||
assert total_len >= 4364 and total_len <= 4404 | ||
|
||
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] | ||
assert torch.equal(d["foo"], decoded["foo"]) | ||
|
||
|
||
def assert_equal(obj1: MyType, obj2: MyType): | ||
assert torch.equal(obj1.tensor1, obj2.tensor1) | ||
assert obj1.a_string == obj2.a_string | ||
assert all( | ||
torch.equal(a, b) | ||
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors)) | ||
assert np.array_equal(obj1.numpy_array, obj2.numpy_array) | ||
assert obj1.unrecognized.an_int == obj2.unrecognized.an_int |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.