Skip to content
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
7b3b6ea
[V1] Zero-copy tensor/ndarray serialization/transmission
njhill Feb 24, 2025
35d1cd9
TypeAlias keyword is python >= 3.10 only
njhill Feb 25, 2025
f6f26b6
use highest pickle protocol
njhill Mar 1, 2025
4382a16
Merge remote-tracking branch 'origin/main' into tensor-nocopy
njhill Mar 15, 2025
9d91483
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill Mar 25, 2025
ea75bd3
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill Apr 5, 2025
06efa46
Implement custom serializaton for MultiModalKwargs
p88h Apr 8, 2025
543ee7b
Support all NestedTensors
p88h Apr 8, 2025
87b7385
fix formatting
p88h Apr 8, 2025
f5db471
proper logger format
p88h Apr 8, 2025
95b0600
Add unit test
njhill Apr 8, 2025
910f30f
pre-commit fix
njhill Apr 8, 2025
747ce1c
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill Apr 9, 2025
478ce09
Fix unrecognized type decode
njhill Apr 9, 2025
2d92af1
use msgspec.Raw for tensor data
p88h Apr 9, 2025
7215037
Merge branch 'main' into serialize-multimodal-kwargs
p88h Apr 9, 2025
7789c99
Merge branch 'main' into tensor-nocopy
p88h Apr 9, 2025
c1d62ad
Merge branch 'tensor-nocopy' into serialize-multimodal-kwargs
p88h Apr 9, 2025
b2e3219
Get rid of (some) workarounds, slightly more efficient encoding
p88h Apr 9, 2025
c98bf9a
style fixes
p88h Apr 9, 2025
139ae1c
properly rename fields
p88h Apr 9, 2025
7ea02a8
Handle scalars properly
njhill Apr 9, 2025
e7d010d
Optimization: encode small tensors inline.
njhill Apr 9, 2025
face6e4
Implement support for _items_by_modality, review fixes
p88h Apr 9, 2025
f946398
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill Apr 10, 2025
095d4fd
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 10, 2025
60797b4
Merge remote-tracking branch 'refs/remotes/origin/main' into tensor-n…
njhill Apr 10, 2025
c0c6e43
Update vllm/v1/serial_utils.py
njhill Apr 10, 2025
3b978ad
Update vllm/v1/serial_utils.py
njhill Apr 10, 2025
80d90a5
Update vllm/v1/serial_utils.py
njhill Apr 10, 2025
6bd45dc
Update vllm/v1/serial_utils.py
njhill Apr 10, 2025
97c144b
Update vllm/v1/serial_utils.py
njhill Apr 10, 2025
c6c2a90
Comment/docstring updates
njhill Apr 10, 2025
793c39c
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 10, 2025
714d615
Merge branch 'tensor-nocopy' into serialize-multimodal-kwargs
p88h Apr 10, 2025
aa64391
Get rid of (some) workarounds, slightly more efficient encoding
p88h Apr 9, 2025
2d471bc
style fixes
p88h Apr 9, 2025
9ca2552
Implement support for _items_by_modality, review fixes
p88h Apr 9, 2025
e1295bc
[Bugfix] Fix bug when dataset is json (#15899)
Chenyaaang Apr 10, 2025
9a81901
[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (…
cyang49 Apr 10, 2025
e0483bc
[VLM] Avoid unnecessary dummy multimodal data during processing (#16416)
DarkLight1337 Apr 10, 2025
bdacbb8
Merge branch 'main' into serialize-multimodal-kwargs
p88h Apr 10, 2025
7f779ef
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions tests/v1/test_serial_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
from collections import UserDict
from dataclasses import dataclass
from typing import Optional

import msgspec
import numpy as np
import torch

from vllm.multimodal.inputs import MultiModalKwargs
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder


class UnrecognizedType(UserDict):

def __init__(self, an_int: int):
super().__init__()
self.an_int = an_int


@dataclass
class MyType:
tensor1: torch.Tensor
a_string: str
list_of_tensors: list[torch.Tensor]
numpy_array: np.ndarray
unrecognized: UnrecognizedType


def test_encode_decode():
"""Test encode/decode loop with zero-copy tensors."""

obj = MyType(
tensor1=torch.randint(low=0, high=100, size=(10, ), dtype=torch.int32),
a_string="hello",
list_of_tensors=[
torch.rand((1, 10), dtype=torch.float32),
torch.rand((3, 5, 4), dtype=torch.float64)
],
numpy_array=np.arange(20),
unrecognized=UnrecognizedType(33),
)

encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyType)

encoded = encoder.encode(obj)

# There should be the main buffer + 3 tensor buffers + one ndarray buffer
assert len(encoded) == 5

decoded: MyType = decoder.decode(encoded)

assert_equal(decoded, obj)

# Test encode_into case

preallocated = bytearray()

encoded2 = encoder.encode_into(obj, preallocated)

assert len(encoded2) == 5
assert encoded2[0] is preallocated

decoded2: MyType = decoder.decode(encoded2)

assert_equal(decoded2, obj)


class MyRequest(msgspec.Struct):
mm: Optional[list[MultiModalKwargs]]


def test_multimodal_kwargs():
d = {
"foo": torch.zeros(1000, dtype=torch.float16),
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
"baz": (torch.zeros(256, dtype=torch.int64), "i'm a tuple")
}

# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest(mm=[MultiModalKwargs(d)])

encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)

encoded = encoder.encode(req)

# 5 total tensors + top level buffer
assert len(encoded) == 6

total_len = sum([len(x) for x in encoded])

# expected total encoding length, should be 4384, +-20 for minor changes
assert total_len >= 4364 and total_len <= 4404

decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
assert torch.equal(d["foo"], decoded["foo"])


def assert_equal(obj1: MyType, obj2: MyType):
assert torch.equal(obj1.tensor1, obj2.tensor1)
assert obj1.a_string == obj2.a_string
assert all(
torch.equal(a, b)
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors))
assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
assert obj1.unrecognized.an_int == obj2.unrecognized.an_int
8 changes: 4 additions & 4 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,14 +490,14 @@ def process_input_socket(self, input_path: str, engine_index: int):

while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
type_frame, *data_frames = socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))

# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frame.buffer)
request = decoder.decode(data_frames)

# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
Expand All @@ -514,8 +514,8 @@ def process_output_socket(self, output_path: str, engine_index: int):
while True:
outputs = self.output_queue.get()
outputs.engine_index = engine_index
encoder.encode_into(outputs, buffer)
socket.send(buffer, copy=False)
buffers = encoder.encode_into(outputs, buffer)
socket.send_multipart(buffers, copy=False)


ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
Expand Down
26 changes: 13 additions & 13 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import BackgroundProcHandle

logger = init_logger(__name__)
Expand Down Expand Up @@ -505,8 +505,8 @@ def process_outputs_socket():
# shutdown signal, exit thread.
break

frame = out_socket.recv(copy=False)
outputs = decoder.decode(frame.buffer)
frames = out_socket.recv_multipart(copy=False)
outputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
Expand All @@ -529,7 +529,7 @@ def get_output(self) -> EngineCoreOutputs:
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
# (Identity, RequestType, SerializedRequest)
msg = (self.core_engine.identity, request_type.value,
self.encoder.encode(request))
*self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)

def call_utility(self, method: str, *args) -> Any:
Expand Down Expand Up @@ -633,8 +633,8 @@ def _ensure_output_queue_task(self):

async def process_outputs_socket():
while True:
(frame, ) = await output_socket.recv_multipart(copy=False)
outputs: EngineCoreOutputs = decoder.decode(frame.buffer)
frames = await output_socket.recv_multipart(copy=False)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
Expand Down Expand Up @@ -666,12 +666,12 @@ def _send_input(self,
if engine is None:
engine = self.core_engine

message = (request_type.value, self.encoder.encode(request))
message = (request_type.value, *self.encoder.encode(request))
return self._send_input_message(message, engine)

def _send_input_message(self, message: tuple[bytes, bytes],
def _send_input_message(self, message: tuple[bytestr, ...],
engine: CoreEngine) -> Awaitable[None]:
message = (engine.identity, ) + message # type: ignore[assignment]
message = (engine.identity, ) + message
return self.input_socket.send_multipart(message, copy=False)

async def call_utility_async(self, method: str, *args) -> Any:
Expand All @@ -684,8 +684,8 @@ async def _call_utility_async(self, method: str, *args,
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
message = (EngineCoreRequestType.UTILITY.value,
self.encoder.encode((call_id, method, args)))
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
(call_id, method, args)))
await self._send_input_message(message, engine)
self._ensure_output_queue_task()
return await future
Expand Down Expand Up @@ -760,7 +760,7 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],

# Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
self.encoder.encode(None))
*self.encoder.encode(None))

self.num_engines_running = 0
self.reqs_in_flight: dict[str, CoreEngine] = {}
Expand Down Expand Up @@ -794,7 +794,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
# tokenized.
request.prompt = None

msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request))

chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine
Expand Down
Loading