Skip to content

Commit 7f08435

Browse files
committed
Serialize tensors using int8 views
Allows to support arbitrary types like bfloat16 Signed-off-by: Staszek Pasko <[email protected]>
1 parent 3d3ab36 commit 7f08435

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

tests/v1/test_serial_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,15 @@ def test_multimodal_kwargs():
114114

115115
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
116116

117-
# expected total encoding length, should be 44536, +-20 for minor changes
118-
assert total_len >= 44516 and total_len <= 44556
117+
# expected total encoding length, should be 44559, +-20 for minor changes
118+
assert total_len >= 44539 and total_len <= 44579
119119
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
120120
assert all(nested_equal(d[k], decoded[k]) for k in d)
121121

122122

123123
def test_multimodal_items_by_modality():
124124
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
125-
dtype=torch.int16),
125+
dtype=torch.bfloat16),
126126
MultiModalBatchedField())
127127
e2 = MultiModalFieldElem(
128128
"video",

vllm/v1/serial_utils.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
8080

8181
def enc_hook(self, obj: Any) -> Any:
8282
if isinstance(obj, torch.Tensor):
83-
return self._encode_ndarray(obj.numpy())
83+
return self._encode_tensor(obj)
8484

8585
# Fall back to pickle for object or void kind ndarrays.
8686
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
@@ -133,9 +133,26 @@ def _encode_ndarray(
133133
# backing buffers that we've stashed in `aux_buffers`.
134134
return obj.dtype.str, obj.shape, data
135135

136+
def _encode_tensor(
137+
self, obj: torch.Tensor
138+
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
139+
assert self.aux_buffers is not None
140+
# this creates a copy of the tensor
141+
obj = obj.contiguous() if not obj.is_contiguous() else obj
142+
# view the tensor as a 1D array of bytes
143+
arr = obj.view([obj.numel()]).view(torch.uint8).numpy()
144+
if obj.nbytes < self.size_threshold:
145+
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
146+
else:
147+
# Otherwise encode index of backing buffer to avoid copy.
148+
data = len(self.aux_buffers)
149+
self.aux_buffers.append(arr.data)
150+
dt = str(obj.dtype)[6:] # remove 'torch.' prefix
151+
return dt, obj.shape, data
152+
136153
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
137154
if isinstance(nt, torch.Tensor):
138-
return self._encode_ndarray(nt.numpy())
155+
return self._encode_tensor(nt)
139156
if isinstance(nt, (int, float)):
140157
# Although it violates NestedTensors type, MultiModalKwargs
141158
# values are sometimes floats.
@@ -186,7 +203,7 @@ def dec_hook(self, t: type, obj: Any) -> Any:
186203
if issubclass(t, np.ndarray):
187204
return self._decode_ndarray(obj)
188205
if issubclass(t, torch.Tensor):
189-
return torch.from_numpy(self._decode_ndarray(obj))
206+
return self._decode_tensor(obj)
190207
if issubclass(t, MultiModalKwargs):
191208
if isinstance(obj, list):
192209
return MultiModalKwargs.from_items(
@@ -205,6 +222,15 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
205222
else bytearray(data)
206223
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
207224

225+
def _decode_tensor(self, arr: Any) -> torch.Tensor:
226+
dtype, shape, data = arr
227+
# Copy from inline representation, otherwise Torch is unhappy since
228+
# the returned memory is non-writeable.
229+
buffer = self.aux_buffers[data] if isinstance(data, int) \
230+
else bytearray(data)
231+
arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=[len(buffer)])
232+
return torch.from_numpy(arr).view(getattr(torch, dtype)).view(shape)
233+
208234
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
209235
decoded_items = []
210236
for item in obj:
@@ -228,7 +254,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
228254
if not isinstance(obj, list):
229255
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
230256
if obj and isinstance(obj[0], str):
231-
return torch.from_numpy(self._decode_ndarray(obj))
257+
return self._decode_tensor(obj)
232258
return [self._decode_nested_tensors(x) for x in obj]
233259

234260
def ext_hook(self, code: int, data: memoryview) -> Any:

0 commit comments

Comments
 (0)