Skip to content

Commit 578aab8

Browse files
committed
Add threshold env var, re-do field serialization, cleanup
addresses review comments Signed-off-by: Staszek Pasko <[email protected]>
1 parent 3461ce6 commit 578aab8

File tree

3 files changed

+32
-28
lines changed

3 files changed

+32
-28
lines changed

vllm/envs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
108108
VLLM_USE_DEEP_GEMM: bool = False
109109
VLLM_XGRAMMAR_CACHE_MB: int = 0
110+
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
110111

111112

112113
def get_default_cache_root():
@@ -704,6 +705,16 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
704705
# It can be changed with this variable if needed for some reason.
705706
"VLLM_XGRAMMAR_CACHE_MB":
706707
lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")),
708+
709+
# Control the threshold for msgspec to use 'zero copy' for
710+
# serialization/deserialization of tensors. Tensors below
711+
# this limit will be encoded into the msgpack buffer, and
712+
# tensors above will instead be sent via a separate message.
713+
# While the sending side still actually copies the tensor
714+
# in all cases, on the receiving side, tensors above this
715+
# limit will actually be zero-copy decoded.
716+
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD":
717+
lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")),
707718
}
708719

709720
# end-env-vars-definition

vllm/multimodal/inputs.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -282,15 +282,6 @@ def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
282282

283283
return self._reduce_data([item.data for item in elems])
284284

285-
@abstractmethod
286-
def field_type(self) -> tuple[Any, ...]:
287-
"""
288-
Return the type of this field instance and constructor args.
289-
290-
Required for serialization.
291-
"""
292-
raise NotImplementedError
293-
294285

295286
@dataclass(frozen=True)
296287
class MultiModalBatchedField(BaseMultiModalField):
@@ -321,9 +312,6 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
321312

322313
return batch
323314

324-
def field_type(self) -> tuple[Any, ...]:
325-
return ("batched", )
326-
327315

328316
@dataclass(frozen=True)
329317
class MultiModalFlatField(BaseMultiModalField):
@@ -356,9 +344,6 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
356344

357345
return [e for elem in batch for e in elem]
358346

359-
def field_type(self) -> tuple[Any, ...]:
360-
return ("flat", self.slices)
361-
362347

363348
@dataclass(frozen=True)
364349
class MultiModalSharedField(BaseMultiModalField):
@@ -380,9 +365,6 @@ def build_elems(
380365
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
381366
return batch[0]
382367

383-
def field_type(self) -> tuple[Any, ...]:
384-
return ("shared", self.batch_size)
385-
386368

387369
class MultiModalFieldConfig:
388370

vllm/v1/serial_utils.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
import zmq
1515
from msgspec import msgpack
1616

17-
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalFieldElem,
18-
MultiModalKwargs, MultiModalKwargsItem,
19-
NestedTensors)
17+
from vllm import envs
18+
from vllm.multimodal.inputs import (BaseMultiModalField, MultiModalFieldConfig,
19+
MultiModalFieldElem, MultiModalKwargs,
20+
MultiModalKwargsItem, NestedTensors)
2021

2122
CUSTOM_TYPE_PICKLE = 1
2223
CUSTOM_TYPE_CLOUDPICKLE = 2
@@ -39,16 +40,21 @@ class MsgpackEncoder:
3940
See: https://github.com/vllm-project/vllm/issues/16185
4041
"""
4142

42-
def __init__(self, size_threshold=256):
43+
def __init__(self, size_threshold=None):
44+
if (size_threshold is None):
45+
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
4346
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
4447
# This is used as a local stash of buffers that we can then access from
4548
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
4649
# pass custom data to the hook otherwise.
50+
self.msg_buffer = bytearray()
4751
self.aux_buffers: Optional[list[bytestr]] = None
4852
self.size_threshold = size_threshold
4953

54+
# TODO - merge these constructors and remove the need for externally managed
55+
# serialization buffers.
5056
def encode(self, obj: Any) -> Sequence[bytestr]:
51-
return self.encode_into(obj, bytearray())
57+
return self.encode_into(obj, self.msg_buffer)
5258

5359
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
5460
try:
@@ -85,9 +91,8 @@ def enc_hook(self, obj: Any) -> Any:
8591
ret = []
8692
for elem in obj.values():
8793
# Encode as plain dictionary + special handling for .field
88-
d = asdict(elem)
89-
d["field"] = elem.field.field_type()
90-
ret.append(d)
94+
ret.append(
95+
asdict(elem) | {"field": self._encode_field(elem.field)})
9196
return ret
9297

9398
if isinstance(obj, FunctionType):
@@ -106,8 +111,7 @@ def _encode_ndarray(
106111
# Sending memoryviews is theoretically faster, but in this particular
107112
# case, it triggers some unnecessary copies anyway.
108113
# With this, the tensors can still be zero-copy read.
109-
arr_data = obj.data.tobytes() if obj.data.c_contiguous \
110-
else obj.tobytes()
114+
arr_data = obj.tobytes()
111115
if not obj.shape or obj.nbytes < self.size_threshold:
112116
# Encode small arrays and scalars inline. Using this extension type
113117
# ensures we can avoid copying when decoding.
@@ -122,6 +126,13 @@ def _encode_ndarray(
122126
# backing buffers that we've stashed in `aux_buffers`.
123127
return obj.dtype.str, obj.shape, data
124128

129+
def _encode_field(self, field: BaseMultiModalField):
130+
# Encode the field as a dictionary + special handling for .field
131+
d = asdict(field)
132+
# Strip first 10 characters and last 5 characters from the class name
133+
# to get the field type name that matches the factory function name.
134+
return (field.__class__.__name__[10:-5].lower(), *d.values())
135+
125136

126137
class MsgpackDecoder:
127138
"""Decoder with custom torch tensor and numpy array serialization.

0 commit comments

Comments
 (0)