Skip to content

Commit 7cf5492

Browse files
committed
Bring back zero-copy, plus more review updates
Signed-off-by: Staszek Pasko <[email protected]>
1 parent 936c95e commit 7cf5492

File tree

3 files changed

+37
-31
lines changed

3 files changed

+37
-31
lines changed

tests/v1/test_serial_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,12 @@ def test_encode_decode():
7070

7171
assert_equal(decoded, obj)
7272

73-
# Test encode_into case
73+
# Test whether MsgpackEncoder properly reuses the buffers.
7474

75-
preallocated = bytearray()
76-
77-
encoded2 = encoder.encode_into(obj, preallocated)
75+
encoded2 = encoder.encode(obj)
7876

7977
assert len(encoded2) == 6
80-
assert encoded2[0] is preallocated
78+
assert encoded2[0] is encoded[0]
8179

8280
decoded2: MyType = decoder.decode(encoded2)
8381

@@ -112,7 +110,7 @@ def test_multimodal_kwargs():
112110

113111
assert len(encoded) == 6
114112

115-
total_len = sum(len(x) for x in encoded)
113+
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
116114

117115
# expected total encoding length, should be 44536, +-20 for minor changes
118116
assert total_len >= 44516 and total_len <= 44556
@@ -151,7 +149,7 @@ def test_multimodal_items_by_modality():
151149

152150
assert len(encoded) == 8
153151

154-
total_len = sum([len(x) for x in encoded])
152+
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
155153

156154
# expected total encoding length, should be 14255, +-20 for minor changes
157155
assert total_len >= 14235 and total_len <= 14275

vllm/v1/engine/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -507,15 +507,15 @@ def process_output_socket(self, output_path: str, engine_index: int):
507507
"""Output socket IO thread."""
508508

509509
# Msgpack serialization encoding.
510+
# The wrapper keeps an internal encoding buffer that avoids
511+
# creating a new buffer for each encode call.
510512
encoder = MsgpackEncoder()
511-
# Reuse send buffer.
512-
buffer = bytearray()
513513

514514
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
515515
while True:
516516
outputs = self.output_queue.get()
517517
outputs.engine_index = engine_index
518-
buffers = encoder.encode_into(outputs, buffer)
518+
buffers = encoder.encode(outputs)
519519
socket.send_multipart(buffers, copy=False)
520520

521521

vllm/v1/serial_utils.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import dataclasses
34
import pickle
45
from collections.abc import Sequence
5-
from dataclasses import asdict
66
from inspect import isclass
77
from itertools import chain
88
from types import FunctionType
@@ -15,14 +15,26 @@
1515
from msgspec import msgpack
1616

1717
from vllm import envs
18-
from vllm.multimodal.inputs import (BaseMultiModalField, MultiModalFieldConfig,
19-
MultiModalFieldElem, MultiModalKwargs,
20-
MultiModalKwargsItem, NestedTensors)
18+
from vllm.multimodal.inputs import (BaseMultiModalField,
19+
MultiModalBatchedField,
20+
MultiModalFieldConfig, MultiModalFieldElem,
21+
MultiModalFlatField, MultiModalKwargs,
22+
MultiModalKwargsItem,
23+
MultiModalSharedField, NestedTensors)
2124

2225
CUSTOM_TYPE_PICKLE = 1
2326
CUSTOM_TYPE_CLOUDPICKLE = 2
2427
CUSTOM_TYPE_RAW_VIEW = 3
2528

29+
# MultiModealField class serialization type map.
30+
# These need to list all possible field types and match them
31+
# to factory methods in `MultiModalFieldConfig`.
32+
MMF_CLASS_TO_FACTORY = {
33+
MultiModalFlatField: "flat",
34+
MultiModalSharedField: "shared",
35+
MultiModalBatchedField: "batched",
36+
}
37+
2638
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
2739

2840

@@ -51,20 +63,15 @@ def __init__(self, size_threshold=None):
5163
self.aux_buffers: Optional[list[bytestr]] = None
5264
self.size_threshold = size_threshold
5365

54-
# TODO - merge these constructors and remove the need for externally managed
55-
# serialization buffers.
5666
def encode(self, obj: Any) -> Sequence[bytestr]:
57-
return self.encode_into(obj, self.msg_buffer)
58-
59-
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
6067
try:
6168
# This `bufs` list allows us to collect direct pointers to backing
6269
# buffers of tensors and np arrays, and return them along with the
6370
# top-level encoded buffer instead of copying their data into the
6471
# new buffer.
65-
self.aux_buffers = [buf]
72+
self.aux_buffers = [self.msg_buffer]
6673
bufs = self.aux_buffers
67-
self.encoder.encode_into(obj, buf)
74+
self.encoder.encode_into(obj, self.msg_buffer)
6875
return bufs
6976
finally:
7077
self.aux_buffers = None
@@ -111,11 +118,8 @@ def _encode_ndarray(
111118
self, obj: np.ndarray
112119
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
113120
assert self.aux_buffers is not None
114-
# Either copy the memoryview directly or flatten the array to bytes.
115-
# Sending memoryviews is theoretically faster, but in this particular
116-
# case, it triggers some unnecessary copies anyway.
117-
# With this, the tensors can still be zero-copy read.
118-
arr_data = obj.tobytes()
121+
# If the array is non-contiguous, we need to copy it first
122+
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
119123
if not obj.shape or obj.nbytes < self.size_threshold:
120124
# Encode small arrays and scalars inline. Using this extension type
121125
# ensures we can avoid copying when decoding.
@@ -136,11 +140,15 @@ def _encode_nested_tensors(self, obj: Any) -> NestedTensors:
136140
return [self._encode_nested_tensors(x) for x in obj]
137141

138142
def _encode_field(self, field: BaseMultiModalField):
139-
# Encode the field as a dictionary + special handling for .field
140-
d = asdict(field)
141-
# Strip first 10 characters and last 5 characters from the class name
142-
# to get the field type name that matches the factory function name.
143-
return (field.__class__.__name__[10:-5].lower(), *d.values())
143+
# Figure out the factory name for the field type.
144+
name = MMF_CLASS_TO_FACTORY.get(field.__class__)
145+
if not name:
146+
raise TypeError(f"Unsupported field type: {field.__class__}")
147+
# We just need to copy all of the field values in order
148+
# which will be then used to reconstruct the field.
149+
field_values = (getattr(field, f.name)
150+
for f in dataclasses.fields(field))
151+
return (name, *field_values)
144152

145153

146154
class MsgpackDecoder:

0 commit comments

Comments
 (0)