From 7f08435d72e1679e6d0c0f9238303372441845bd Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Fri, 18 Apr 2025 21:22:43 +0200 Subject: [PATCH] Serialize tensors using int8 views Allows to support arbitrary types like bfloat16 Signed-off-by: Staszek Pasko --- tests/v1/test_serial_utils.py | 6 +++--- vllm/v1/serial_utils.py | 34 ++++++++++++++++++++++++++++++---- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index e58d3c403c19..b2fe2a73945f 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -114,15 +114,15 @@ def test_multimodal_kwargs(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) - # expected total encoding length, should be 44536, +-20 for minor changes - assert total_len >= 44516 and total_len <= 44556 + # expected total encoding length, should be 44559, +-20 for minor changes + assert total_len >= 44539 and total_len <= 44579 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] assert all(nested_equal(d[k], decoded[k]) for k in d) def test_multimodal_items_by_modality(): e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000, - dtype=torch.int16), + dtype=torch.bfloat16), MultiModalBatchedField()) e2 = MultiModalFieldElem( "video", diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 4f7987ee46a6..4926b0e073e3 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -80,7 +80,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: def enc_hook(self, obj: Any) -> Any: if isinstance(obj, torch.Tensor): - return self._encode_ndarray(obj.numpy()) + return self._encode_tensor(obj) # Fall back to pickle for object or void kind ndarrays. if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): @@ -133,9 +133,26 @@ def _encode_ndarray( # backing buffers that we've stashed in `aux_buffers`. return obj.dtype.str, obj.shape, data + def _encode_tensor( + self, obj: torch.Tensor + ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: + assert self.aux_buffers is not None + # this creates a copy of the tensor + obj = obj.contiguous() if not obj.is_contiguous() else obj + # view the tensor as a 1D array of bytes + arr = obj.view([obj.numel()]).view(torch.uint8).numpy() + if obj.nbytes < self.size_threshold: + data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) + else: + # Otherwise encode index of backing buffer to avoid copy. + data = len(self.aux_buffers) + self.aux_buffers.append(arr.data) + dt = str(obj.dtype)[6:] # remove 'torch.' prefix + return dt, obj.shape, data + def _encode_nested_tensors(self, nt: NestedTensors) -> Any: if isinstance(nt, torch.Tensor): - return self._encode_ndarray(nt.numpy()) + return self._encode_tensor(nt) if isinstance(nt, (int, float)): # Although it violates NestedTensors type, MultiModalKwargs # values are sometimes floats. @@ -186,7 +203,7 @@ def dec_hook(self, t: type, obj: Any) -> Any: if issubclass(t, np.ndarray): return self._decode_ndarray(obj) if issubclass(t, torch.Tensor): - return torch.from_numpy(self._decode_ndarray(obj)) + return self._decode_tensor(obj) if issubclass(t, MultiModalKwargs): if isinstance(obj, list): return MultiModalKwargs.from_items( @@ -205,6 +222,15 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray: else bytearray(data) return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) + def _decode_tensor(self, arr: Any) -> torch.Tensor: + dtype, shape, data = arr + # Copy from inline representation, otherwise Torch is unhappy since + # the returned memory is non-writeable. + buffer = self.aux_buffers[data] if isinstance(data, int) \ + else bytearray(data) + arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=[len(buffer)]) + return torch.from_numpy(arr).view(getattr(torch, dtype)).view(shape) + def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: decoded_items = [] for item in obj: @@ -228,7 +254,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if not isinstance(obj, list): raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}") if obj and isinstance(obj[0], str): - return torch.from_numpy(self._decode_ndarray(obj)) + return self._decode_tensor(obj) return [self._decode_nested_tensors(x) for x in obj] def ext_hook(self, code: int, data: memoryview) -> Any: