From 7b6b7ba3e1e334b07ac7c99b6db6bd113fce7c02 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Thu, 10 Apr 2025 22:52:19 +0200 Subject: [PATCH 01/21] Implement efficient serialization of MultiModalKwargs In addition to serializing base Tensors, this now allows to pass Tensors embedded in MultiModalKwargs correctly. Handles both V0 and V1 style args. Improves memory usage with large multimodal payloads by a further 50% (but still not on par with single-threaded behavior). Signed-off-by: Staszek Pasko --- tests/v1/test_serial_utils.py | 98 +++++++++++++++++++++++++++++++++++ vllm/v1/serial_utils.py | 49 ++++++++++++++++++ 2 files changed, 147 insertions(+) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 0fc3b074533d..fc0dff33ab62 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -1,10 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 from collections import UserDict from dataclasses import dataclass +from typing import Optional +import msgspec import numpy as np import torch +from vllm.multimodal.inputs import (MultiModalBatchedField, + MultiModalFieldElem, MultiModalKwargs, + MultiModalKwargsItem, NestedTensors) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -70,6 +75,99 @@ def test_encode_decode(): assert_equal(decoded2, obj) +class MyRequest(msgspec.Struct): + mm: Optional[list[MultiModalKwargs]] + + +def test_multimodal_kwargs(): + d = { + "foo": + torch.zeros(1000, dtype=torch.float16), + "bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)], + "baz": [ + torch.rand((256), dtype=torch.float16), + [ + torch.rand((1, 12), dtype=torch.float32), + torch.rand((3, 5, 7), dtype=torch.float64), + ], [torch.rand((4, 4), dtype=torch.float16)] + ], + } + + # pack mm kwargs into a mock request so that it can be decoded properly + req = MyRequest(mm=[MultiModalKwargs(d)]) + + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(MyRequest) + + encoded = encoder.encode(req) + + # 8 total tensors + top level buffer + assert len(encoded) == 6 + + total_len = sum([len(x) for x in encoded]) + + # expected total encoding length, should be 4440, +-20 for minor changes + assert total_len >= 4420 and total_len <= 4460 + decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] + assert all(nested_equal(d[k], decoded[k]) for k in d) + + +def test_multimodal_items_by_modality(): + e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000, + dtype=torch.int16), + MultiModalBatchedField()) + e2 = MultiModalFieldElem( + "video", + "v0", + [torch.zeros(1000, dtype=torch.int8) for _ in range(4)], + MultiModalBatchedField(), + ) + e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000, + dtype=torch.int32), + MultiModalBatchedField()) + e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000, + dtype=torch.int32), + MultiModalBatchedField()) + audio = MultiModalKwargsItem.from_elems([e1]) + video = MultiModalKwargsItem.from_elems([e2]) + image = MultiModalKwargsItem.from_elems([e3, e4]) + mm = MultiModalKwargs.from_items([audio, video, image]) + + # pack mm kwargs into a mock request so that it can be decoded properly + req = MyRequest([mm]) + + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(MyRequest) + + encoded = encoder.encode(req) + + # 5 total tensors + top level buffer + assert len(encoded) == 8 + + total_len = sum([len(x) for x in encoded]) + + # expected total encoding length, should be 7507, +-20 for minor changes + assert total_len >= 7487 and total_len <= 7527 + decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] + + # check all modalities were recovered and do some basic sanity checks + assert len(decoded.modalities) == 3 + images = decoded.get_items("image") + assert len(images) == 1 + assert len(images[0].items()) == 2 + assert list(images[0].keys()) == ["i0", "i1"] + + # check the tensor contents and layout in the main dict + assert all(nested_equal(mm[k], decoded[k]) for k in mm) + + +def nested_equal(a: NestedTensors, b: NestedTensors): + if isinstance(a, torch.Tensor): + return torch.equal(a, b) + else: + return all([nested_equal(x, y) for (x, y) in zip(a, b)]) + + def assert_equal(obj1: MyType, obj2: MyType): assert torch.equal(obj1.tensor1, obj2.tensor1) assert obj1.a_string == obj2.a_string diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 99b352fdef80..53e37ee961cd 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -2,7 +2,9 @@ import pickle from collections.abc import Sequence +from dataclasses import asdict from inspect import isclass +from itertools import chain from types import FunctionType from typing import Any, Optional, Union @@ -12,6 +14,9 @@ import zmq from msgspec import msgpack +from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs, + MultiModalKwargsItem, NestedTensors) + CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 @@ -64,6 +69,26 @@ def enc_hook(self, obj: Any) -> Any: if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): return self._encode_ndarray(obj) + if isinstance(obj, MultiModalKwargs): + mm: MultiModalKwargs = obj + if mm.modalities: + # ignore the main dict, it will be re-indexed. + # pass a list of MultiModalKwargsItem, then see below + # Any tensors *not* indexed by modality will be ignored. + return [mm.get_items(m) for m in mm.modalities] + # just return the main dict if there are no modalities + return {k: v for k, v in obj.items()} + + if isinstance(obj, MultiModalKwargsItem): + # Encode as plain dictionary + special handling for '.field' + rd = {} + for k, v in obj.items(): + vv = asdict(v) + vv['field'] = pickle.dumps(v.field, + protocol=pickle.HIGHEST_PROTOCOL) + rd[k] = vv + return rd + if isinstance(obj, FunctionType): # `pickle` is generally faster than cloudpickle, but can have # problems serializing methods. @@ -121,6 +146,12 @@ def dec_hook(self, t: type, obj: Any) -> Any: if isclass(t): if issubclass(t, np.ndarray): return self._decode_ndarray(obj) + if issubclass(t, MultiModalKwargs) and isinstance(obj, dict): + return MultiModalKwargs( + {k: self._decode_nested(obj[k]) + for k in obj}) + if issubclass(t, MultiModalKwargs) and isinstance(obj, list): + return MultiModalKwargs.from_items(self._decode_items(obj)) if issubclass(t, torch.Tensor): return torch.from_numpy(self._decode_ndarray(obj)) return obj @@ -130,6 +161,24 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray: buffer = self.aux_buffers[data] if isinstance(data, int) else data return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) + def _decode_items(self, obj: list) -> list[MultiModalKwargsItem]: + all = [] + for item in chain.from_iterable(obj): + elems = [] + for v in item.values(): + v['data'] = self._decode_nested(v['data']) + v['field'] = pickle.loads(v['field']) + elems.append(MultiModalFieldElem(**v)) + all.append(MultiModalKwargsItem.from_elems(elems)) + return all + + def _decode_nested(self, obj: Any) -> NestedTensors: + if isinstance(obj, list) and isinstance(obj[0], str): + return torch.from_numpy(self._decode_ndarray(obj)) + if isinstance(obj, list): + return [self._decode_nested(x) for x in obj] + raise TypeError(f"Unexpected NestedArray contents: {obj}") + def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_PICKLE: return pickle.loads(data) From 4bdd16e6068f26540c66492c3dc92837a2025953 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Staszek=20Pa=C5=9Bko?= Date: Fri, 11 Apr 2025 13:38:21 +0200 Subject: [PATCH 02/21] Apply suggestions from code review Co-authored-by: Cyrus Leung Co-authored-by: Nick Hill Signed-off-by: Staszek Pasko --- tests/v1/test_serial_utils.py | 4 ++-- vllm/v1/serial_utils.py | 24 ++++++++++++------------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index fc0dff33ab62..ff2b1a662343 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -104,7 +104,7 @@ def test_multimodal_kwargs(): # 8 total tensors + top level buffer assert len(encoded) == 6 - total_len = sum([len(x) for x in encoded]) + total_len = sum(len(x) for x in encoded) # expected total encoding length, should be 4440, +-20 for minor changes assert total_len >= 4420 and total_len <= 4460 @@ -165,7 +165,7 @@ def nested_equal(a: NestedTensors, b: NestedTensors): if isinstance(a, torch.Tensor): return torch.equal(a, b) else: - return all([nested_equal(x, y) for (x, y) in zip(a, b)]) + return all(nested_equal(x, y) for x, y in zip(a, b)) def assert_equal(obj1: MyType, obj2: MyType): diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 53e37ee961cd..df30c76b76ad 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -75,9 +75,9 @@ def enc_hook(self, obj: Any) -> Any: # ignore the main dict, it will be re-indexed. # pass a list of MultiModalKwargsItem, then see below # Any tensors *not* indexed by modality will be ignored. - return [mm.get_items(m) for m in mm.modalities] + return mm._items_by_modality.values() # just return the main dict if there are no modalities - return {k: v for k, v in obj.items()} + return dict(mm) if isinstance(obj, MultiModalKwargsItem): # Encode as plain dictionary + special handling for '.field' @@ -146,12 +146,12 @@ def dec_hook(self, t: type, obj: Any) -> Any: if isclass(t): if issubclass(t, np.ndarray): return self._decode_ndarray(obj) - if issubclass(t, MultiModalKwargs) and isinstance(obj, dict): + if issubclass(t, MultiModalKwargs) + if isinstance(obj, list): + return MultiModalKwargs.from_items(self._decode_items(obj)) return MultiModalKwargs( - {k: self._decode_nested(obj[k]) - for k in obj}) - if issubclass(t, MultiModalKwargs) and isinstance(obj, list): - return MultiModalKwargs.from_items(self._decode_items(obj)) + {k: self._decode_nested(v) + for k in obj.items()}) if issubclass(t, torch.Tensor): return torch.from_numpy(self._decode_ndarray(obj)) return obj @@ -172,12 +172,12 @@ def _decode_items(self, obj: list) -> list[MultiModalKwargsItem]: all.append(MultiModalKwargsItem.from_elems(elems)) return all - def _decode_nested(self, obj: Any) -> NestedTensors: - if isinstance(obj, list) and isinstance(obj[0], str): + def _decode_nested_tensors(self, obj: Any) -> NestedTensors: + if not isinstance(obj, list): + raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}") + if obj and isinstance(obj[0], str): return torch.from_numpy(self._decode_ndarray(obj)) - if isinstance(obj, list): - return [self._decode_nested(x) for x in obj] - raise TypeError(f"Unexpected NestedArray contents: {obj}") + return [self._decode_nested_tensors(x) for x in obj] def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_PICKLE: From e5931afe14334480d1dcdb3239c075b85410594c Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Fri, 11 Apr 2025 13:43:42 +0200 Subject: [PATCH 03/21] Additional fixes after code review Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index df30c76b76ad..57f936462c8d 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -146,14 +146,14 @@ def dec_hook(self, t: type, obj: Any) -> Any: if isclass(t): if issubclass(t, np.ndarray): return self._decode_ndarray(obj) - if issubclass(t, MultiModalKwargs) + if issubclass(t, torch.Tensor): + return torch.from_numpy(self._decode_ndarray(obj)) + if issubclass(t, MultiModalKwargs): if isinstance(obj, list): - return MultiModalKwargs.from_items(self._decode_items(obj)) + return MultiModalKwargs.from_items(self._decode_mm_items(obj)) return MultiModalKwargs( {k: self._decode_nested(v) for k in obj.items()}) - if issubclass(t, torch.Tensor): - return torch.from_numpy(self._decode_ndarray(obj)) return obj def _decode_ndarray(self, arr: Any) -> np.ndarray: @@ -161,7 +161,7 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray: buffer = self.aux_buffers[data] if isinstance(data, int) else data return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) - def _decode_items(self, obj: list) -> list[MultiModalKwargsItem]: + def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: all = [] for item in chain.from_iterable(obj): elems = [] From 664158469f4b9c05404d876ec4fc61ed8e56cfd3 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Fri, 11 Apr 2025 13:49:20 +0200 Subject: [PATCH 04/21] Fix some broken bits & reformat Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 57f936462c8d..5cd687b32f3a 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -75,7 +75,7 @@ def enc_hook(self, obj: Any) -> Any: # ignore the main dict, it will be re-indexed. # pass a list of MultiModalKwargsItem, then see below # Any tensors *not* indexed by modality will be ignored. - return mm._items_by_modality.values() + return list(mm._items_by_modality.values()) # just return the main dict if there are no modalities return dict(mm) @@ -150,10 +150,12 @@ def dec_hook(self, t: type, obj: Any) -> Any: return torch.from_numpy(self._decode_ndarray(obj)) if issubclass(t, MultiModalKwargs): if isinstance(obj, list): - return MultiModalKwargs.from_items(self._decode_mm_items(obj)) - return MultiModalKwargs( - {k: self._decode_nested(v) - for k in obj.items()}) + return MultiModalKwargs.from_items( + self._decode_mm_items(obj)) + return MultiModalKwargs({ + k: self._decode_nested_tensors(v) + for k, v in obj.items() + }) return obj def _decode_ndarray(self, arr: Any) -> np.ndarray: @@ -166,7 +168,7 @@ def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: for item in chain.from_iterable(obj): elems = [] for v in item.values(): - v['data'] = self._decode_nested(v['data']) + v['data'] = self._decode_nested_tensors(v['data']) v['field'] = pickle.loads(v['field']) elems.append(MultiModalFieldElem(**v)) all.append(MultiModalKwargsItem.from_elems(elems)) From a94df99013b9ec97c2b31f757fec628ae1ebe646 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Fri, 11 Apr 2025 15:17:32 +0200 Subject: [PATCH 05/21] Add custom support for MultiModalFieldConfig, less pickle Signed-off-by: Staszek Pasko --- tests/v1/test_serial_utils.py | 4 ++-- vllm/multimodal/inputs.py | 18 ++++++++++++++++++ vllm/v1/serial_utils.py | 33 ++++++++++++++++++++------------- 3 files changed, 40 insertions(+), 15 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index ff2b1a662343..ac7e3101c478 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -146,8 +146,8 @@ def test_multimodal_items_by_modality(): total_len = sum([len(x) for x in encoded]) - # expected total encoding length, should be 7507, +-20 for minor changes - assert total_len >= 7487 and total_len <= 7527 + # expected total encoding length, should be 7263, +-20 for minor changes + assert total_len >= 7243 and total_len <= 7283 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] # check all modalities were recovered and do some basic sanity checks diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 53729799b629..98a46a5b3dfb 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -282,6 +282,15 @@ def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors: return self._reduce_data([item.data for item in elems]) + @abstractmethod + def field_type(self) -> tuple[str, ...]: + """ + Return the type of this field instance and constructor args. + + Required for serialization. + """ + raise NotImplementedError + @dataclass(frozen=True) class MultiModalBatchedField(BaseMultiModalField): @@ -312,6 +321,9 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: return batch + def field_type(self) -> tuple[str, ...]: + return ("batched") + @dataclass(frozen=True) class MultiModalFlatField(BaseMultiModalField): @@ -344,6 +356,9 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: return [e for elem in batch for e in elem] + def field_type(self) -> tuple[str, ...]: + return ("flat", self.slices) + @dataclass(frozen=True) class MultiModalSharedField(BaseMultiModalField): @@ -365,6 +380,9 @@ def build_elems( def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: return batch[0] + def field_type(self) -> tuple[str, ...]: + return ("shared", self.batch_size) + class MultiModalFieldConfig: diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 5cd687b32f3a..bb8d26bc1dce 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -14,8 +14,9 @@ import zmq from msgspec import msgpack -from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, NestedTensors) +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalFieldElem, + MultiModalKwargs, MultiModalKwargsItem, + NestedTensors) CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 @@ -80,14 +81,13 @@ def enc_hook(self, obj: Any) -> Any: return dict(mm) if isinstance(obj, MultiModalKwargsItem): - # Encode as plain dictionary + special handling for '.field' - rd = {} - for k, v in obj.items(): - vv = asdict(v) - vv['field'] = pickle.dumps(v.field, - protocol=pickle.HIGHEST_PROTOCOL) - rd[k] = vv - return rd + ret = [] + for elem in obj.values(): + # Encode as plain dictionary + special handling for .field + d = asdict(elem) + d["field"] = elem.field.field_type() + ret.append(d) + return ret if isinstance(obj, FunctionType): # `pickle` is generally faster than cloudpickle, but can have @@ -167,9 +167,16 @@ def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: all = [] for item in chain.from_iterable(obj): elems = [] - for v in item.values(): - v['data'] = self._decode_nested_tensors(v['data']) - v['field'] = pickle.loads(v['field']) + for v in item: + v["data"] = self._decode_nested_tensors(v["data"]) + # Reconstruct the field processor using MultiModalFieldConfig + field = v["field"] + if isinstance(field, list) and len(field) > 1: + v["field"] = getattr(MultiModalFieldConfig, + field[0])(None, **field[1:]).field + else: + v["field"] = getattr(MultiModalFieldConfig, + field)(None).field elems.append(MultiModalFieldElem(**v)) all.append(MultiModalKwargsItem.from_elems(elems)) return all From 57467e2d65e6d04388d0430d3e72aca8d714000d Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Fri, 11 Apr 2025 15:50:33 +0200 Subject: [PATCH 06/21] Too many stars. Test for other field types. Signed-off-by: Staszek Pasko --- tests/v1/test_serial_utils.py | 4 ++-- vllm/v1/serial_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index ac7e3101c478..94b5e7da0557 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -9,7 +9,7 @@ from vllm.multimodal.inputs import (MultiModalBatchedField, MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, NestedTensors) + MultiModalKwargsItem, MultiModalSharedField, NestedTensors) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -124,7 +124,7 @@ def test_multimodal_items_by_modality(): ) e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000, dtype=torch.int32), - MultiModalBatchedField()) + MultiModalSharedField(4)) e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000, dtype=torch.int32), MultiModalBatchedField()) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index bb8d26bc1dce..350810e378ab 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -173,7 +173,7 @@ def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: field = v["field"] if isinstance(field, list) and len(field) > 1: v["field"] = getattr(MultiModalFieldConfig, - field[0])(None, **field[1:]).field + field[0])(None, *field[1:]).field else: v["field"] = getattr(MultiModalFieldConfig, field)(None).field From d993e421a05c027fd495bbb7af3ade85de83db21 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Fri, 11 Apr 2025 18:44:41 +0200 Subject: [PATCH 07/21] Set zero-copy threshold to 256MB. Also copy out tensors. Signed-off-by: Staszek Pasko --- tests/v1/test_serial_utils.py | 31 +++++++++++++++++++------------ vllm/v1/serial_utils.py | 31 +++++++++++++++++-------------- 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 94b5e7da0557..4d92853e992f 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -9,7 +9,8 @@ from vllm.multimodal.inputs import (MultiModalBatchedField, MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, MultiModalSharedField, NestedTensors) + MultiModalKwargsItem, + MultiModalSharedField, NestedTensors) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -47,7 +48,7 @@ def test_encode_decode(): unrecognized=UnrecognizedType(33), ) - encoder = MsgpackEncoder() + encoder = MsgpackEncoder(size_threshold=256) decoder = MsgpackDecoder(MyType) encoded = encoder.encode(obj) @@ -82,7 +83,7 @@ class MyRequest(msgspec.Struct): def test_multimodal_kwargs(): d = { "foo": - torch.zeros(1000, dtype=torch.float16), + torch.zeros(20000, dtype=torch.float16), "bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)], "baz": [ torch.rand((256), dtype=torch.float16), @@ -96,18 +97,18 @@ def test_multimodal_kwargs(): # pack mm kwargs into a mock request so that it can be decoded properly req = MyRequest(mm=[MultiModalKwargs(d)]) - encoder = MsgpackEncoder() + encoder = MsgpackEncoder(size_threshold=16 * 1024) decoder = MsgpackDecoder(MyRequest) encoded = encoder.encode(req) - # 8 total tensors + top level buffer - assert len(encoded) == 6 + # Only "foo" is larger than threshold + assert len(encoded) == 2 total_len = sum(len(x) for x in encoded) - # expected total encoding length, should be 4440, +-20 for minor changes - assert total_len >= 4420 and total_len <= 4460 + # expected total encoding length, should be 24541, +-20 for minor changes + assert total_len >= 24521 and total_len <= 24561 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] assert all(nested_equal(d[k], decoded[k]) for k in d) @@ -141,13 +142,13 @@ def test_multimodal_items_by_modality(): encoded = encoder.encode(req) - # 5 total tensors + top level buffer - assert len(encoded) == 8 + # All messages are 'small', i.e. below 256MB default + assert len(encoded) == 1 total_len = sum([len(x) for x in encoded]) - # expected total encoding length, should be 7263, +-20 for minor changes - assert total_len >= 7243 and total_len <= 7283 + # expected total encoding length, should be 14252, +-20 for minor changes + assert total_len >= 14232 and total_len <= 14272 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] # check all modalities were recovered and do some basic sanity checks @@ -176,3 +177,9 @@ def assert_equal(obj1: MyType, obj2: MyType): for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors)) assert np.array_equal(obj1.numpy_array, obj2.numpy_array) assert obj1.unrecognized.an_int == obj2.unrecognized.an_int + + +if __name__ == "__main__": + test_encode_decode() + test_multimodal_kwargs() + test_multimodal_items_by_modality() diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 350810e378ab..f19b2900617b 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -21,9 +21,6 @@ CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 -# TODO calibrate this size -INLINE_BUF_SIZE_THRESHOLD = 256 - bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] @@ -32,29 +29,32 @@ class MsgpackEncoder: Note that unlike vanilla `msgspec` Encoders, this interface is generally not thread-safe when encoding tensors / numpy arrays. + + By default, arrays below 256MB are serialized inline. + Larger will get sent via dedicated messages. + Note that this is a per-tensor limit. + + Sending multiple large messages via zeromq saturates memory very quickly. + See: https://github.com/vllm-project/vllm/issues/16185 """ - def __init__(self): + def __init__(self, size_threshold=256 * 1024 * 1024): self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) # This is used as a local stash of buffers that we can then access from # our custom `msgspec` hook, `enc_hook`. We don't have a way to # pass custom data to the hook otherwise. self.aux_buffers: Optional[list[bytestr]] = None + self.size_threshold = size_threshold def encode(self, obj: Any) -> Sequence[bytestr]: + return self.encode_into(obj, bytearray()) + + def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: try: - self.aux_buffers = bufs = [b''] - bufs[0] = self.encoder.encode(obj) # This `bufs` list allows us to collect direct pointers to backing # buffers of tensors and np arrays, and return them along with the # top-level encoded buffer instead of copying their data into the # new buffer. - return bufs - finally: - self.aux_buffers = None - - def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: - try: self.aux_buffers = [buf] bufs = self.aux_buffers self.encoder.encode_into(obj, buf) @@ -101,7 +101,7 @@ def _encode_ndarray( self, obj: np.ndarray ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None - if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD: + if not obj.shape or obj.nbytes < self.size_threshold: # Encode small arrays and scalars inline. data = obj.data else: @@ -160,7 +160,10 @@ def dec_hook(self, t: type, obj: Any) -> Any: def _decode_ndarray(self, arr: Any) -> np.ndarray: dtype, shape, data = arr - buffer = self.aux_buffers[data] if isinstance(data, int) else data + # Copy from inline representation, otherwise Torch is unhappy since + # the returned memory is non-writeable. + buffer = self.aux_buffers[data] if isinstance( + data, int) else bytearray(data).copy() return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: From 3401429de804413d4bd94cc55a8aa115fed8d91e Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Fri, 11 Apr 2025 19:14:09 +0200 Subject: [PATCH 08/21] Make mypy happy, and also simplify field type restore Signed-off-by: Staszek Pasko --- tests/v1/test_serial_utils.py | 4 ++-- vllm/multimodal/inputs.py | 10 +++++----- vllm/v1/serial_utils.py | 8 ++------ 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 4d92853e992f..d2c67bbb6ec9 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -147,8 +147,8 @@ def test_multimodal_items_by_modality(): total_len = sum([len(x) for x in encoded]) - # expected total encoding length, should be 14252, +-20 for minor changes - assert total_len >= 14232 and total_len <= 14272 + # expected total encoding length, should be 14287, +-20 for minor changes + assert total_len >= 14267 and total_len <= 14307 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] # check all modalities were recovered and do some basic sanity checks diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 98a46a5b3dfb..bfa6596ef175 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -283,7 +283,7 @@ def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors: return self._reduce_data([item.data for item in elems]) @abstractmethod - def field_type(self) -> tuple[str, ...]: + def field_type(self) -> tuple[Any, ...]: """ Return the type of this field instance and constructor args. @@ -321,8 +321,8 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: return batch - def field_type(self) -> tuple[str, ...]: - return ("batched") + def field_type(self) -> tuple[Any, ...]: + return ("batched", ) @dataclass(frozen=True) @@ -356,7 +356,7 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: return [e for elem in batch for e in elem] - def field_type(self) -> tuple[str, ...]: + def field_type(self) -> tuple[Any, ...]: return ("flat", self.slices) @@ -380,7 +380,7 @@ def build_elems( def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: return batch[0] - def field_type(self) -> tuple[str, ...]: + def field_type(self) -> tuple[Any, ...]: return ("shared", self.batch_size) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index f19b2900617b..609d536fa3de 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -174,12 +174,8 @@ def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: v["data"] = self._decode_nested_tensors(v["data"]) # Reconstruct the field processor using MultiModalFieldConfig field = v["field"] - if isinstance(field, list) and len(field) > 1: - v["field"] = getattr(MultiModalFieldConfig, - field[0])(None, *field[1:]).field - else: - v["field"] = getattr(MultiModalFieldConfig, - field)(None).field + ctor = getattr(MultiModalFieldConfig, field[0]) + v["field"] = ctor(None, *field[1:]).field elems.append(MultiModalFieldElem(**v)) all.append(MultiModalKwargsItem.from_elems(elems)) return all From 57e19227c7602fdeb7f7ed624fd0ca675e39ceb2 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Sat, 12 Apr 2025 10:49:40 +0200 Subject: [PATCH 09/21] style fix Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index c1a927921b42..b71973701ffa 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -22,7 +22,6 @@ CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_RAW_VIEW = 3 - bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] From 176ba0609d56a173fa797c2f24903a774758e3af Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Sun, 13 Apr 2025 13:16:15 +0200 Subject: [PATCH 10/21] Copy memory when sending, zero copy when receiving This helps reduce memory usage and keeps very good performance. Signed-off-by: Staszek Pasko --- tests/v1/test_serial_utils.py | 16 +++++++--------- vllm/v1/serial_utils.py | 15 ++++++++++----- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 1349d988094a..57a19a99093d 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -105,18 +105,17 @@ def test_multimodal_kwargs(): # pack mm kwargs into a mock request so that it can be decoded properly req = MyRequest(mm=[MultiModalKwargs(d)]) - encoder = MsgpackEncoder(size_threshold=16 * 1024) + encoder = MsgpackEncoder() decoder = MsgpackDecoder(MyRequest) encoded = encoder.encode(req) - # Only "foo" is larger than threshold - assert len(encoded) == 2 + assert len(encoded) == 6 total_len = sum(len(x) for x in encoded) - # expected total encoding length, should be 24541, +-20 for minor changes - assert total_len >= 24521 and total_len <= 24561 + # expected total encoding length, should be 44536, +-20 for minor changes + assert total_len >= 44516 and total_len <= 44556 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] assert all(nested_equal(d[k], decoded[k]) for k in d) @@ -150,13 +149,12 @@ def test_multimodal_items_by_modality(): encoded = encoder.encode(req) - # All messages are 'small', i.e. below 256MB default - assert len(encoded) == 1 + assert len(encoded) == 8 total_len = sum([len(x) for x in encoded]) - # expected total encoding length, should be 14287, +-20 for minor changes - assert total_len >= 14267 and total_len <= 14307 + # expected total encoding length, should be 14255, +-20 for minor changes + assert total_len >= 14235 and total_len <= 14275 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] # check all modalities were recovered and do some basic sanity checks diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index b71973701ffa..fbdc71562344 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -31,7 +31,7 @@ class MsgpackEncoder: Note that unlike vanilla `msgspec` Encoders, this interface is generally not thread-safe when encoding tensors / numpy arrays. - By default, arrays below 256MB are serialized inline. + By default, arrays below 256B are serialized inline. Larger will get sent via dedicated messages. Note that this is a per-tensor limit. @@ -39,7 +39,7 @@ class MsgpackEncoder: See: https://github.com/vllm-project/vllm/issues/16185 """ - def __init__(self, size_threshold=256 * 1024 * 1024): + def __init__(self, size_threshold=256): self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) # This is used as a local stash of buffers that we can then access from # our custom `msgspec` hook, `enc_hook`. We don't have a way to @@ -102,7 +102,12 @@ def _encode_ndarray( self, obj: np.ndarray ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None - arr_data = obj.data if obj.data.c_contiguous else obj.tobytes() + # Either copy the memoryview directly or flatten the array to bytes. + # Sending memoryviews is theoretically faster, but in this particular + # case, it triggers some unnecessary copies anyway. + # With this, the tensors can still be zero-copy read. + arr_data = obj.data.tobytes() if obj.data.c_contiguous \ + else obj.tobytes() if not obj.shape or obj.nbytes < self.size_threshold: # Encode small arrays and scalars inline. Using this extension type # ensures we can avoid copying when decoding. @@ -165,8 +170,8 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray: dtype, shape, data = arr # Copy from inline representation, otherwise Torch is unhappy since # the returned memory is non-writeable. - buffer = self.aux_buffers[data] if isinstance( - data, int) else bytearray(data).copy() + buffer = self.aux_buffers[data] if isinstance(data, int) \ + else bytearray(data) return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: From 578aab87a3464ae4afde34ae2651e2b0cb29747c Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Tue, 15 Apr 2025 09:41:15 +0200 Subject: [PATCH 11/21] Add threshold env var, re-do field serialization, cleanup addresses review comments Signed-off-by: Staszek Pasko --- vllm/envs.py | 11 +++++++++++ vllm/multimodal/inputs.py | 18 ------------------ vllm/v1/serial_utils.py | 31 +++++++++++++++++++++---------- 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index f80bf878f79c..d32968c3d173 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -107,6 +107,7 @@ VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 + VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 def get_default_cache_root(): @@ -704,6 +705,16 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # It can be changed with this variable if needed for some reason. "VLLM_XGRAMMAR_CACHE_MB": lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")), + + # Control the threshold for msgspec to use 'zero copy' for + # serialization/deserialization of tensors. Tensors below + # this limit will be encoded into the msgpack buffer, and + # tensors above will instead be sent via a separate message. + # While the sending side still actually copies the tensor + # in all cases, on the receiving side, tensors above this + # limit will actually be zero-copy decoded. + "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": + lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), } # end-env-vars-definition diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index bfa6596ef175..53729799b629 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -282,15 +282,6 @@ def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors: return self._reduce_data([item.data for item in elems]) - @abstractmethod - def field_type(self) -> tuple[Any, ...]: - """ - Return the type of this field instance and constructor args. - - Required for serialization. - """ - raise NotImplementedError - @dataclass(frozen=True) class MultiModalBatchedField(BaseMultiModalField): @@ -321,9 +312,6 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: return batch - def field_type(self) -> tuple[Any, ...]: - return ("batched", ) - @dataclass(frozen=True) class MultiModalFlatField(BaseMultiModalField): @@ -356,9 +344,6 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: return [e for elem in batch for e in elem] - def field_type(self) -> tuple[Any, ...]: - return ("flat", self.slices) - @dataclass(frozen=True) class MultiModalSharedField(BaseMultiModalField): @@ -380,9 +365,6 @@ def build_elems( def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: return batch[0] - def field_type(self) -> tuple[Any, ...]: - return ("shared", self.batch_size) - class MultiModalFieldConfig: diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index fbdc71562344..ab8c32d1be99 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -14,9 +14,10 @@ import zmq from msgspec import msgpack -from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalFieldElem, - MultiModalKwargs, MultiModalKwargsItem, - NestedTensors) +from vllm import envs +from vllm.multimodal.inputs import (BaseMultiModalField, MultiModalFieldConfig, + MultiModalFieldElem, MultiModalKwargs, + MultiModalKwargsItem, NestedTensors) CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 @@ -39,16 +40,21 @@ class MsgpackEncoder: See: https://github.com/vllm-project/vllm/issues/16185 """ - def __init__(self, size_threshold=256): + def __init__(self, size_threshold=None): + if (size_threshold is None): + size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) # This is used as a local stash of buffers that we can then access from # our custom `msgspec` hook, `enc_hook`. We don't have a way to # pass custom data to the hook otherwise. + self.msg_buffer = bytearray() self.aux_buffers: Optional[list[bytestr]] = None self.size_threshold = size_threshold + # TODO - merge these constructors and remove the need for externally managed + # serialization buffers. def encode(self, obj: Any) -> Sequence[bytestr]: - return self.encode_into(obj, bytearray()) + return self.encode_into(obj, self.msg_buffer) def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: try: @@ -85,9 +91,8 @@ def enc_hook(self, obj: Any) -> Any: ret = [] for elem in obj.values(): # Encode as plain dictionary + special handling for .field - d = asdict(elem) - d["field"] = elem.field.field_type() - ret.append(d) + ret.append( + asdict(elem) | {"field": self._encode_field(elem.field)}) return ret if isinstance(obj, FunctionType): @@ -106,8 +111,7 @@ def _encode_ndarray( # Sending memoryviews is theoretically faster, but in this particular # case, it triggers some unnecessary copies anyway. # With this, the tensors can still be zero-copy read. - arr_data = obj.data.tobytes() if obj.data.c_contiguous \ - else obj.tobytes() + arr_data = obj.tobytes() if not obj.shape or obj.nbytes < self.size_threshold: # Encode small arrays and scalars inline. Using this extension type # ensures we can avoid copying when decoding. @@ -122,6 +126,13 @@ def _encode_ndarray( # backing buffers that we've stashed in `aux_buffers`. return obj.dtype.str, obj.shape, data + def _encode_field(self, field: BaseMultiModalField): + # Encode the field as a dictionary + special handling for .field + d = asdict(field) + # Strip first 10 characters and last 5 characters from the class name + # to get the field type name that matches the factory function name. + return (field.__class__.__name__[10:-5].lower(), *d.values()) + class MsgpackDecoder: """Decoder with custom torch tensor and numpy array serialization. From 936c95ea18d4d49a111c68620772a55d2a932133 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Tue, 15 Apr 2025 20:52:12 +0200 Subject: [PATCH 12/21] remove asdict() which involves object deep copy. Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index ab8c32d1be99..8aca0ed4eb41 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -91,8 +91,12 @@ def enc_hook(self, obj: Any) -> Any: ret = [] for elem in obj.values(): # Encode as plain dictionary + special handling for .field - ret.append( - asdict(elem) | {"field": self._encode_field(elem.field)}) + ret.append({ + "modality": elem.modality, + "key": elem.key, + "data": self._encode_nested_tensors(elem.data), + "field": self._encode_field(elem.field), + }) return ret if isinstance(obj, FunctionType): @@ -126,6 +130,11 @@ def _encode_ndarray( # backing buffers that we've stashed in `aux_buffers`. return obj.dtype.str, obj.shape, data + def _encode_nested_tensors(self, obj: Any) -> NestedTensors: + if isinstance(obj, torch.Tensor): + return self._encode_ndarray(obj.numpy()) + return [self._encode_nested_tensors(x) for x in obj] + def _encode_field(self, field: BaseMultiModalField): # Encode the field as a dictionary + special handling for .field d = asdict(field) From 7cf549205fa62b74fbbf8aa857cfe37d1048e3c1 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Tue, 15 Apr 2025 21:50:21 +0200 Subject: [PATCH 13/21] Bring back zero-copy, plus more review updates Signed-off-by: Staszek Pasko --- tests/v1/test_serial_utils.py | 12 ++++----- vllm/v1/engine/core.py | 6 ++--- vllm/v1/serial_utils.py | 50 ++++++++++++++++++++--------------- 3 files changed, 37 insertions(+), 31 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 57a19a99093d..7f0ba80e0b86 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -70,14 +70,12 @@ def test_encode_decode(): assert_equal(decoded, obj) - # Test encode_into case + # Test whether MsgpackEncoder properly reuses the buffers. - preallocated = bytearray() - - encoded2 = encoder.encode_into(obj, preallocated) + encoded2 = encoder.encode(obj) assert len(encoded2) == 6 - assert encoded2[0] is preallocated + assert encoded2[0] is encoded[0] decoded2: MyType = decoder.decode(encoded2) @@ -112,7 +110,7 @@ def test_multimodal_kwargs(): assert len(encoded) == 6 - total_len = sum(len(x) for x in encoded) + total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) # expected total encoding length, should be 44536, +-20 for minor changes assert total_len >= 44516 and total_len <= 44556 @@ -151,7 +149,7 @@ def test_multimodal_items_by_modality(): assert len(encoded) == 8 - total_len = sum([len(x) for x in encoded]) + total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) # expected total encoding length, should be 14255, +-20 for minor changes assert total_len >= 14235 and total_len <= 14275 diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f642e51001a8..54390c330eb2 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -507,15 +507,15 @@ def process_output_socket(self, output_path: str, engine_index: int): """Output socket IO thread.""" # Msgpack serialization encoding. + # The wrapper keeps an internal encoding buffer that avoids + # creating a new buffer for each encode call. encoder = MsgpackEncoder() - # Reuse send buffer. - buffer = bytearray() with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: while True: outputs = self.output_queue.get() outputs.engine_index = engine_index - buffers = encoder.encode_into(outputs, buffer) + buffers = encoder.encode(outputs) socket.send_multipart(buffers, copy=False) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 8aca0ed4eb41..fdbc8a309cbf 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +import dataclasses import pickle from collections.abc import Sequence -from dataclasses import asdict from inspect import isclass from itertools import chain from types import FunctionType @@ -15,14 +15,26 @@ from msgspec import msgpack from vllm import envs -from vllm.multimodal.inputs import (BaseMultiModalField, MultiModalFieldConfig, - MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, NestedTensors) +from vllm.multimodal.inputs import (BaseMultiModalField, + MultiModalBatchedField, + MultiModalFieldConfig, MultiModalFieldElem, + MultiModalFlatField, MultiModalKwargs, + MultiModalKwargsItem, + MultiModalSharedField, NestedTensors) CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_RAW_VIEW = 3 +# MultiModealField class serialization type map. +# These need to list all possible field types and match them +# to factory methods in `MultiModalFieldConfig`. +MMF_CLASS_TO_FACTORY = { + MultiModalFlatField: "flat", + MultiModalSharedField: "shared", + MultiModalBatchedField: "batched", +} + bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] @@ -51,20 +63,15 @@ def __init__(self, size_threshold=None): self.aux_buffers: Optional[list[bytestr]] = None self.size_threshold = size_threshold - # TODO - merge these constructors and remove the need for externally managed - # serialization buffers. def encode(self, obj: Any) -> Sequence[bytestr]: - return self.encode_into(obj, self.msg_buffer) - - def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: try: # This `bufs` list allows us to collect direct pointers to backing # buffers of tensors and np arrays, and return them along with the # top-level encoded buffer instead of copying their data into the # new buffer. - self.aux_buffers = [buf] + self.aux_buffers = [self.msg_buffer] bufs = self.aux_buffers - self.encoder.encode_into(obj, buf) + self.encoder.encode_into(obj, self.msg_buffer) return bufs finally: self.aux_buffers = None @@ -111,11 +118,8 @@ def _encode_ndarray( self, obj: np.ndarray ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None - # Either copy the memoryview directly or flatten the array to bytes. - # Sending memoryviews is theoretically faster, but in this particular - # case, it triggers some unnecessary copies anyway. - # With this, the tensors can still be zero-copy read. - arr_data = obj.tobytes() + # If the array is non-contiguous, we need to copy it first + arr_data = obj.data if obj.data.c_contiguous else obj.tobytes() if not obj.shape or obj.nbytes < self.size_threshold: # Encode small arrays and scalars inline. Using this extension type # ensures we can avoid copying when decoding. @@ -136,11 +140,15 @@ def _encode_nested_tensors(self, obj: Any) -> NestedTensors: return [self._encode_nested_tensors(x) for x in obj] def _encode_field(self, field: BaseMultiModalField): - # Encode the field as a dictionary + special handling for .field - d = asdict(field) - # Strip first 10 characters and last 5 characters from the class name - # to get the field type name that matches the factory function name. - return (field.__class__.__name__[10:-5].lower(), *d.values()) + # Figure out the factory name for the field type. + name = MMF_CLASS_TO_FACTORY.get(field.__class__) + if not name: + raise TypeError(f"Unsupported field type: {field.__class__}") + # We just need to copy all of the field values in order + # which will be then used to reconstruct the field. + field_values = (getattr(field, f.name) + for f in dataclasses.fields(field)) + return (name, *field_values) class MsgpackDecoder: From 12c9d8bb459d0fc0c113253ee2703e795328d41d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Staszek=20Pa=C5=9Bko?= Date: Tue, 15 Apr 2025 23:24:14 +0200 Subject: [PATCH 14/21] Apply suggestions from code review Co-authored-by: Nick Hill Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 39 +++++++++++++++++---------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index fdbc8a309cbf..9c07f9e55c66 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -52,8 +52,8 @@ class MsgpackEncoder: See: https://github.com/vllm-project/vllm/issues/16185 """ - def __init__(self, size_threshold=None): - if (size_threshold is None): + def __init__(self, size_threshold: Optional[int] = None): + if size_threshold is None: size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) # This is used as a local stash of buffers that we can then access from @@ -84,27 +84,22 @@ def enc_hook(self, obj: Any) -> Any: if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): return self._encode_ndarray(obj) - if isinstance(obj, MultiModalKwargs): - mm: MultiModalKwargs = obj - if mm.modalities: - # ignore the main dict, it will be re-indexed. - # pass a list of MultiModalKwargsItem, then see below - # Any tensors *not* indexed by modality will be ignored. - return list(mm._items_by_modality.values()) - # just return the main dict if there are no modalities + if isinstance(obj, MultiModalKwargs): + mm: MultiModalKwargs = obj + if not mm.modalities: + # just return the main dict if there are no modalities. return dict(mm) - - if isinstance(obj, MultiModalKwargsItem): - ret = [] - for elem in obj.values(): - # Encode as plain dictionary + special handling for .field - ret.append({ - "modality": elem.modality, - "key": elem.key, - "data": self._encode_nested_tensors(elem.data), - "field": self._encode_field(elem.field), - }) - return ret + + # ignore the main dict, it will be re-indexed. + # Encode a list of MultiModalKwargsItems as plain dicts + # + special handling for .field. + # Any tensors *not* indexed by modality will be ignored. + return [{ + "modality": elem.modality, + "key": elem.key, + "data": self._encode_nested_tensors(elem.data), + "field": self._encode_field(elem.field), + } for item in mm._items_by_modality.values() for elem in item] if isinstance(obj, FunctionType): # `pickle` is generally faster than cloudpickle, but can have From 8bda83cea41af56ada98afab2dc08cd4abd34b7e Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Tue, 15 Apr 2025 23:40:01 +0200 Subject: [PATCH 15/21] fix review edits Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 9c07f9e55c66..405843f90f18 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -4,7 +4,6 @@ import pickle from collections.abc import Sequence from inspect import isclass -from itertools import chain from types import FunctionType from typing import Any, Optional, Union @@ -84,22 +83,24 @@ def enc_hook(self, obj: Any) -> Any: if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): return self._encode_ndarray(obj) - if isinstance(obj, MultiModalKwargs): - mm: MultiModalKwargs = obj - if not mm.modalities: - # just return the main dict if there are no modalities. - return dict(mm) - - # ignore the main dict, it will be re-indexed. - # Encode a list of MultiModalKwargsItems as plain dicts - # + special handling for .field. - # Any tensors *not* indexed by modality will be ignored. - return [{ - "modality": elem.modality, - "key": elem.key, - "data": self._encode_nested_tensors(elem.data), - "field": self._encode_field(elem.field), - } for item in mm._items_by_modality.values() for elem in item] + if isinstance(obj, MultiModalKwargs): + mm: MultiModalKwargs = obj + if not mm.modalities: + # just return the main dict if there are no modalities. + return dict(mm) + + # ignore the main dict, it will be re-indexed. + # Encode a list of MultiModalKwargsItems as plain dicts + # + special handling for .field. + # Any tensors *not* indexed by modality will be ignored. + return [[{ + "modality": elem.modality, + "key": elem.key, + "data": self._encode_nested_tensors(elem.data), + "field": self._encode_field(elem.field), + } for elem in item.values()] + for itemlist in mm._items_by_modality.values() + for item in itemlist] if isinstance(obj, FunctionType): # `pickle` is generally faster than cloudpickle, but can have @@ -199,7 +200,7 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray: def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: all = [] - for item in chain.from_iterable(obj): + for item in obj: elems = [] for v in item: v["data"] = self._decode_nested_tensors(v["data"]) From 678cba185049a9f4757052f1e1569b6c60c19d66 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Tue, 15 Apr 2025 23:54:58 +0200 Subject: [PATCH 16/21] revert encode_into changes Signed-off-by: Staszek Pasko --- tests/v1/test_serial_utils.py | 8 +++++--- vllm/v1/engine/core.py | 6 +++--- vllm/v1/serial_utils.py | 16 ++++++++++++---- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 7f0ba80e0b86..e58d3c403c19 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -70,12 +70,14 @@ def test_encode_decode(): assert_equal(decoded, obj) - # Test whether MsgpackEncoder properly reuses the buffers. + # Test encode_into case - encoded2 = encoder.encode(obj) + preallocated = bytearray() + + encoded2 = encoder.encode_into(obj, preallocated) assert len(encoded2) == 6 - assert encoded2[0] is encoded[0] + assert encoded2[0] is preallocated decoded2: MyType = decoder.decode(encoded2) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 54390c330eb2..f642e51001a8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -507,15 +507,15 @@ def process_output_socket(self, output_path: str, engine_index: int): """Output socket IO thread.""" # Msgpack serialization encoding. - # The wrapper keeps an internal encoding buffer that avoids - # creating a new buffer for each encode call. encoder = MsgpackEncoder() + # Reuse send buffer. + buffer = bytearray() with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: while True: outputs = self.output_queue.get() outputs.engine_index = engine_index - buffers = encoder.encode(outputs) + buffers = encoder.encode_into(outputs, buffer) socket.send_multipart(buffers, copy=False) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 405843f90f18..e415567b919d 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -58,19 +58,27 @@ def __init__(self, size_threshold: Optional[int] = None): # This is used as a local stash of buffers that we can then access from # our custom `msgspec` hook, `enc_hook`. We don't have a way to # pass custom data to the hook otherwise. - self.msg_buffer = bytearray() self.aux_buffers: Optional[list[bytestr]] = None self.size_threshold = size_threshold def encode(self, obj: Any) -> Sequence[bytestr]: try: + self.aux_buffers = bufs = [b''] + bufs[0] = self.encoder.encode(obj) # This `bufs` list allows us to collect direct pointers to backing # buffers of tensors and np arrays, and return them along with the # top-level encoded buffer instead of copying their data into the # new buffer. - self.aux_buffers = [self.msg_buffer] - bufs = self.aux_buffers - self.encoder.encode_into(obj, self.msg_buffer) + return bufs + finally: + self.aux_buffers = None + + # TODO: would be nice to make this automatic + def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: + try: + self.aux_buffers = [buf] + bufs = [buf] + self.encoder.encode_into(obj, buf) return bufs finally: self.aux_buffers = None From f8d26df407cfdd79be01f5731866009f3c991740 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Staszek=20Pa=C5=9Bko?= Date: Wed, 16 Apr 2025 07:28:29 +0200 Subject: [PATCH 17/21] Apply suggestions from code review Signed-off-by: Staszek Pasko Co-authored-by: Nick Hill Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index e415567b919d..55f2eb80d137 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -213,9 +213,9 @@ def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: for v in item: v["data"] = self._decode_nested_tensors(v["data"]) # Reconstruct the field processor using MultiModalFieldConfig - field = v["field"] - ctor = getattr(MultiModalFieldConfig, field[0]) - v["field"] = ctor(None, *field[1:]).field + factory_meth_name, *field_args = v["field"] + factory_meth = getattr(MultiModalFieldConfig, factory_meth_name) + v["field"] = factory_meth(None, *field_args).field elems.append(MultiModalFieldElem(**v)) all.append(MultiModalKwargsItem.from_elems(elems)) return all From bce2f0755271f0889ebda9266b0efec0d32dbb01 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Wed, 16 Apr 2025 07:28:16 +0200 Subject: [PATCH 18/21] Small fixes Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 55f2eb80d137..812099a400f5 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -43,12 +43,8 @@ class MsgpackEncoder: Note that unlike vanilla `msgspec` Encoders, this interface is generally not thread-safe when encoding tensors / numpy arrays. - By default, arrays below 256B are serialized inline. - Larger will get sent via dedicated messages. - Note that this is a per-tensor limit. - - Sending multiple large messages via zeromq saturates memory very quickly. - See: https://github.com/vllm-project/vllm/issues/16185 + By default, arrays below 256B are serialized inline Larger will get sent + via dedicated messages. Note that this is a per-tensor limit. """ def __init__(self, size_threshold: Optional[int] = None): @@ -77,7 +73,7 @@ def encode(self, obj: Any) -> Sequence[bytestr]: def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: try: self.aux_buffers = [buf] - bufs = [buf] + bufs = self.aux_buffers self.encoder.encode_into(obj, buf) return bufs finally: @@ -105,7 +101,7 @@ def enc_hook(self, obj: Any) -> Any: "modality": elem.modality, "key": elem.key, "data": self._encode_nested_tensors(elem.data), - "field": self._encode_field(elem.field), + "field": self._encode_mm_field(elem.field), } for elem in item.values()] for itemlist in mm._items_by_modality.values() for item in itemlist] @@ -143,7 +139,7 @@ def _encode_nested_tensors(self, obj: Any) -> NestedTensors: return self._encode_ndarray(obj.numpy()) return [self._encode_nested_tensors(x) for x in obj] - def _encode_field(self, field: BaseMultiModalField): + def _encode_mm_field(self, field: BaseMultiModalField): # Figure out the factory name for the field type. name = MMF_CLASS_TO_FACTORY.get(field.__class__) if not name: From 7511262098add43c242ea5044cabf0048c319d57 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Wed, 16 Apr 2025 07:41:28 +0200 Subject: [PATCH 19/21] style Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 812099a400f5..08600592bd70 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -210,7 +210,8 @@ def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: v["data"] = self._decode_nested_tensors(v["data"]) # Reconstruct the field processor using MultiModalFieldConfig factory_meth_name, *field_args = v["field"] - factory_meth = getattr(MultiModalFieldConfig, factory_meth_name) + factory_meth = getattr(MultiModalFieldConfig, + factory_meth_name) v["field"] = factory_meth(None, *field_args).field elems.append(MultiModalFieldElem(**v)) all.append(MultiModalKwargsItem.from_elems(elems)) From 48ab2d9de9997773e0bb6aa0a3efdb56a1504dec Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Wed, 16 Apr 2025 16:59:28 +0200 Subject: [PATCH 20/21] remove unnecessary comment Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 08600592bd70..f27d84e1a72e 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -69,7 +69,6 @@ def encode(self, obj: Any) -> Sequence[bytestr]: finally: self.aux_buffers = None - # TODO: would be nice to make this automatic def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: try: self.aux_buffers = [buf] From 281f0f149756e0be0529cd7ee093b36c41884795 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 16 Apr 2025 16:18:09 -0700 Subject: [PATCH 21/21] Accommodate floats in NestedTensors Signed-off-by: Nick Hill --- vllm/v1/serial_utils.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index f27d84e1a72e..4f7987ee46a6 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -25,10 +25,10 @@ CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_RAW_VIEW = 3 -# MultiModealField class serialization type map. +# MultiModalField class serialization type map. # These need to list all possible field types and match them # to factory methods in `MultiModalFieldConfig`. -MMF_CLASS_TO_FACTORY = { +MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = { MultiModalFlatField: "flat", MultiModalSharedField: "shared", MultiModalBatchedField: "batched", @@ -133,10 +133,14 @@ def _encode_ndarray( # backing buffers that we've stashed in `aux_buffers`. return obj.dtype.str, obj.shape, data - def _encode_nested_tensors(self, obj: Any) -> NestedTensors: - if isinstance(obj, torch.Tensor): - return self._encode_ndarray(obj.numpy()) - return [self._encode_nested_tensors(x) for x in obj] + def _encode_nested_tensors(self, nt: NestedTensors) -> Any: + if isinstance(nt, torch.Tensor): + return self._encode_ndarray(nt.numpy()) + if isinstance(nt, (int, float)): + # Although it violates NestedTensors type, MultiModalKwargs + # values are sometimes floats. + return nt + return [self._encode_nested_tensors(x) for x in nt] def _encode_mm_field(self, field: BaseMultiModalField): # Figure out the factory name for the field type. @@ -147,7 +151,7 @@ def _encode_mm_field(self, field: BaseMultiModalField): # which will be then used to reconstruct the field. field_values = (getattr(field, f.name) for f in dataclasses.fields(field)) - return (name, *field_values) + return name, *field_values class MsgpackDecoder: @@ -202,7 +206,7 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray: return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: - all = [] + decoded_items = [] for item in obj: elems = [] for v in item: @@ -213,10 +217,14 @@ def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: factory_meth_name) v["field"] = factory_meth(None, *field_args).field elems.append(MultiModalFieldElem(**v)) - all.append(MultiModalKwargsItem.from_elems(elems)) - return all + decoded_items.append(MultiModalKwargsItem.from_elems(elems)) + return decoded_items def _decode_nested_tensors(self, obj: Any) -> NestedTensors: + if isinstance(obj, (int, float)): + # Although it violates NestedTensors type, MultiModalKwargs + # values are sometimes floats. + return obj if not isinstance(obj, list): raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}") if obj and isinstance(obj[0], str):