14
14
import zmq
15
15
from msgspec import msgpack
16
16
17
- from vllm .multimodal .inputs import (MultiModalFieldConfig , MultiModalFieldElem ,
18
- MultiModalKwargs , MultiModalKwargsItem ,
19
- NestedTensors )
17
+ from vllm import envs
18
+ from vllm .multimodal .inputs import (BaseMultiModalField , MultiModalFieldConfig ,
19
+ MultiModalFieldElem , MultiModalKwargs ,
20
+ MultiModalKwargsItem , NestedTensors )
20
21
21
22
CUSTOM_TYPE_PICKLE = 1
22
23
CUSTOM_TYPE_CLOUDPICKLE = 2
@@ -39,16 +40,21 @@ class MsgpackEncoder:
39
40
See: https://github.com/vllm-project/vllm/issues/16185
40
41
"""
41
42
42
- def __init__ (self , size_threshold = 256 ):
43
+ def __init__ (self , size_threshold = None ):
44
+ if (size_threshold is None ):
45
+ size_threshold = envs .VLLM_MSGPACK_ZERO_COPY_THRESHOLD
43
46
self .encoder = msgpack .Encoder (enc_hook = self .enc_hook )
44
47
# This is used as a local stash of buffers that we can then access from
45
48
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
46
49
# pass custom data to the hook otherwise.
50
+ self .msg_buffer = bytearray ()
47
51
self .aux_buffers : Optional [list [bytestr ]] = None
48
52
self .size_threshold = size_threshold
49
53
54
+ # TODO - merge these constructors and remove the need for externally managed
55
+ # serialization buffers.
50
56
def encode (self , obj : Any ) -> Sequence [bytestr ]:
51
- return self .encode_into (obj , bytearray () )
57
+ return self .encode_into (obj , self . msg_buffer )
52
58
53
59
def encode_into (self , obj : Any , buf : bytearray ) -> Sequence [bytestr ]:
54
60
try :
@@ -85,9 +91,8 @@ def enc_hook(self, obj: Any) -> Any:
85
91
ret = []
86
92
for elem in obj .values ():
87
93
# Encode as plain dictionary + special handling for .field
88
- d = asdict (elem )
89
- d ["field" ] = elem .field .field_type ()
90
- ret .append (d )
94
+ ret .append (
95
+ asdict (elem ) | {"field" : self ._encode_field (elem .field )})
91
96
return ret
92
97
93
98
if isinstance (obj , FunctionType ):
@@ -106,8 +111,7 @@ def _encode_ndarray(
106
111
# Sending memoryviews is theoretically faster, but in this particular
107
112
# case, it triggers some unnecessary copies anyway.
108
113
# With this, the tensors can still be zero-copy read.
109
- arr_data = obj .data .tobytes () if obj .data .c_contiguous \
110
- else obj .tobytes ()
114
+ arr_data = obj .tobytes ()
111
115
if not obj .shape or obj .nbytes < self .size_threshold :
112
116
# Encode small arrays and scalars inline. Using this extension type
113
117
# ensures we can avoid copying when decoding.
@@ -122,6 +126,13 @@ def _encode_ndarray(
122
126
# backing buffers that we've stashed in `aux_buffers`.
123
127
return obj .dtype .str , obj .shape , data
124
128
129
+ def _encode_field (self , field : BaseMultiModalField ):
130
+ # Encode the field as a dictionary + special handling for .field
131
+ d = asdict (field )
132
+ # Strip first 10 characters and last 5 characters from the class name
133
+ # to get the field type name that matches the factory function name.
134
+ return (field .__class__ .__name__ [10 :- 5 ].lower (), * d .values ())
135
+
125
136
126
137
class MsgpackDecoder :
127
138
"""Decoder with custom torch tensor and numpy array serialization.
0 commit comments