Skip to content

Commit 176ba06

Browse files
committed
Copy memory when sending, zero copy when receiving
This helps reduce memory usage and keeps very good performance. Signed-off-by: Staszek Pasko <[email protected]>
1 parent 2c0e9a8 commit 176ba06

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

tests/v1/test_serial_utils.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,17 @@ def test_multimodal_kwargs():
105105
# pack mm kwargs into a mock request so that it can be decoded properly
106106
req = MyRequest(mm=[MultiModalKwargs(d)])
107107

108-
encoder = MsgpackEncoder(size_threshold=16 * 1024)
108+
encoder = MsgpackEncoder()
109109
decoder = MsgpackDecoder(MyRequest)
110110

111111
encoded = encoder.encode(req)
112112

113-
# Only "foo" is larger than threshold
114-
assert len(encoded) == 2
113+
assert len(encoded) == 6
115114

116115
total_len = sum(len(x) for x in encoded)
117116

118-
# expected total encoding length, should be 24541, +-20 for minor changes
119-
assert total_len >= 24521 and total_len <= 24561
117+
# expected total encoding length, should be 44536, +-20 for minor changes
118+
assert total_len >= 44516 and total_len <= 44556
120119
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
121120
assert all(nested_equal(d[k], decoded[k]) for k in d)
122121

@@ -150,13 +149,12 @@ def test_multimodal_items_by_modality():
150149

151150
encoded = encoder.encode(req)
152151

153-
# All messages are 'small', i.e. below 256MB default
154-
assert len(encoded) == 1
152+
assert len(encoded) == 8
155153

156154
total_len = sum([len(x) for x in encoded])
157155

158-
# expected total encoding length, should be 14287, +-20 for minor changes
159-
assert total_len >= 14267 and total_len <= 14307
156+
# expected total encoding length, should be 14255, +-20 for minor changes
157+
assert total_len >= 14235 and total_len <= 14275
160158
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
161159

162160
# check all modalities were recovered and do some basic sanity checks

vllm/v1/serial_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ class MsgpackEncoder:
3131
Note that unlike vanilla `msgspec` Encoders, this interface is generally
3232
not thread-safe when encoding tensors / numpy arrays.
3333
34-
By default, arrays below 256MB are serialized inline.
34+
By default, arrays below 256B are serialized inline.
3535
Larger will get sent via dedicated messages.
3636
Note that this is a per-tensor limit.
3737
3838
Sending multiple large messages via zeromq saturates memory very quickly.
3939
See: https://github.com/vllm-project/vllm/issues/16185
4040
"""
4141

42-
def __init__(self, size_threshold=256 * 1024 * 1024):
42+
def __init__(self, size_threshold=256):
4343
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
4444
# This is used as a local stash of buffers that we can then access from
4545
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
@@ -102,7 +102,12 @@ def _encode_ndarray(
102102
self, obj: np.ndarray
103103
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
104104
assert self.aux_buffers is not None
105-
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
105+
# Either copy the memoryview directly or flatten the array to bytes.
106+
# Sending memoryviews is theoretically faster, but in this particular
107+
# case, it triggers some unnecessary copies anyway.
108+
# With this, the tensors can still be zero-copy read.
109+
arr_data = obj.data.tobytes() if obj.data.c_contiguous \
110+
else obj.tobytes()
106111
if not obj.shape or obj.nbytes < self.size_threshold:
107112
# Encode small arrays and scalars inline. Using this extension type
108113
# ensures we can avoid copying when decoding.
@@ -165,8 +170,8 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
165170
dtype, shape, data = arr
166171
# Copy from inline representation, otherwise Torch is unhappy since
167172
# the returned memory is non-writeable.
168-
buffer = self.aux_buffers[data] if isinstance(
169-
data, int) else bytearray(data).copy()
173+
buffer = self.aux_buffers[data] if isinstance(data, int) \
174+
else bytearray(data)
170175
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
171176

172177
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:

0 commit comments

Comments
 (0)