1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
+ import dataclasses
3
4
import pickle
4
5
from collections .abc import Sequence
5
- from dataclasses import asdict
6
6
from inspect import isclass
7
7
from itertools import chain
8
8
from types import FunctionType
15
15
from msgspec import msgpack
16
16
17
17
from vllm import envs
18
- from vllm .multimodal .inputs import (BaseMultiModalField , MultiModalFieldConfig ,
19
- MultiModalFieldElem , MultiModalKwargs ,
20
- MultiModalKwargsItem , NestedTensors )
18
+ from vllm .multimodal .inputs import (BaseMultiModalField ,
19
+ MultiModalBatchedField ,
20
+ MultiModalFieldConfig , MultiModalFieldElem ,
21
+ MultiModalFlatField , MultiModalKwargs ,
22
+ MultiModalKwargsItem ,
23
+ MultiModalSharedField , NestedTensors )
21
24
22
25
CUSTOM_TYPE_PICKLE = 1
23
26
CUSTOM_TYPE_CLOUDPICKLE = 2
24
27
CUSTOM_TYPE_RAW_VIEW = 3
25
28
29
+ # MultiModealField class serialization type map.
30
+ # These need to list all possible field types and match them
31
+ # to factory methods in `MultiModalFieldConfig`.
32
+ MMF_CLASS_TO_FACTORY = {
33
+ MultiModalFlatField : "flat" ,
34
+ MultiModalSharedField : "shared" ,
35
+ MultiModalBatchedField : "batched" ,
36
+ }
37
+
26
38
bytestr = Union [bytes , bytearray , memoryview , zmq .Frame ]
27
39
28
40
@@ -51,20 +63,15 @@ def __init__(self, size_threshold=None):
51
63
self .aux_buffers : Optional [list [bytestr ]] = None
52
64
self .size_threshold = size_threshold
53
65
54
- # TODO - merge these constructors and remove the need for externally managed
55
- # serialization buffers.
56
66
def encode (self , obj : Any ) -> Sequence [bytestr ]:
57
- return self .encode_into (obj , self .msg_buffer )
58
-
59
- def encode_into (self , obj : Any , buf : bytearray ) -> Sequence [bytestr ]:
60
67
try :
61
68
# This `bufs` list allows us to collect direct pointers to backing
62
69
# buffers of tensors and np arrays, and return them along with the
63
70
# top-level encoded buffer instead of copying their data into the
64
71
# new buffer.
65
- self .aux_buffers = [buf ]
72
+ self .aux_buffers = [self . msg_buffer ]
66
73
bufs = self .aux_buffers
67
- self .encoder .encode_into (obj , buf )
74
+ self .encoder .encode_into (obj , self . msg_buffer )
68
75
return bufs
69
76
finally :
70
77
self .aux_buffers = None
@@ -111,11 +118,8 @@ def _encode_ndarray(
111
118
self , obj : np .ndarray
112
119
) -> tuple [str , tuple [int , ...], Union [int , memoryview ]]:
113
120
assert self .aux_buffers is not None
114
- # Either copy the memoryview directly or flatten the array to bytes.
115
- # Sending memoryviews is theoretically faster, but in this particular
116
- # case, it triggers some unnecessary copies anyway.
117
- # With this, the tensors can still be zero-copy read.
118
- arr_data = obj .tobytes ()
121
+ # If the array is non-contiguous, we need to copy it first
122
+ arr_data = obj .data if obj .data .c_contiguous else obj .tobytes ()
119
123
if not obj .shape or obj .nbytes < self .size_threshold :
120
124
# Encode small arrays and scalars inline. Using this extension type
121
125
# ensures we can avoid copying when decoding.
@@ -136,11 +140,15 @@ def _encode_nested_tensors(self, obj: Any) -> NestedTensors:
136
140
return [self ._encode_nested_tensors (x ) for x in obj ]
137
141
138
142
def _encode_field (self , field : BaseMultiModalField ):
139
- # Encode the field as a dictionary + special handling for .field
140
- d = asdict (field )
141
- # Strip first 10 characters and last 5 characters from the class name
142
- # to get the field type name that matches the factory function name.
143
- return (field .__class__ .__name__ [10 :- 5 ].lower (), * d .values ())
143
+ # Figure out the factory name for the field type.
144
+ name = MMF_CLASS_TO_FACTORY .get (field .__class__ )
145
+ if not name :
146
+ raise TypeError (f"Unsupported field type: { field .__class__ } " )
147
+ # We just need to copy all of the field values in order
148
+ # which will be then used to reconstruct the field.
149
+ field_values = (getattr (field , f .name )
150
+ for f in dataclasses .fields (field ))
151
+ return (name , * field_values )
144
152
145
153
146
154
class MsgpackDecoder :
0 commit comments