diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index e58d3c403c1..df9832fc4e4 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -47,6 +47,10 @@ def test_encode_decode(): torch.rand((1, 10), dtype=torch.float32), torch.rand((3, 5, 4000), dtype=torch.float64), torch.tensor(1984), # test scalar too + # Make sure to test bf16 which numpy doesn't support. + torch.rand((3, 5, 1000), dtype=torch.bfloat16), + torch.tensor([float("-inf"), float("inf")] * 1024, + dtype=torch.bfloat16), ], numpy_array=np.arange(512), unrecognized=UnrecognizedType(33), @@ -64,7 +68,7 @@ def test_encode_decode(): # There should be the main buffer + 4 large tensor buffers # + 1 large numpy array. "large" is <= 512 bytes. # The two small tensors are encoded inline. - assert len(encoded) == 6 + assert len(encoded) == 8 decoded: MyType = decoder.decode(encoded) @@ -76,7 +80,7 @@ def test_encode_decode(): encoded2 = encoder.encode_into(obj, preallocated) - assert len(encoded2) == 6 + assert len(encoded2) == 8 assert encoded2[0] is preallocated decoded2: MyType = decoder.decode(encoded2) @@ -114,15 +118,15 @@ def test_multimodal_kwargs(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) - # expected total encoding length, should be 44536, +-20 for minor changes - assert total_len >= 44516 and total_len <= 44556 + # expected total encoding length, should be 44559, +-20 for minor changes + assert total_len >= 44539 and total_len <= 44579 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] assert all(nested_equal(d[k], decoded[k]) for k in d) def test_multimodal_items_by_modality(): - e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000, - dtype=torch.int16), + e1 = MultiModalFieldElem("audio", "a0", + torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField()) e2 = MultiModalFieldElem( "video", diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 4f7987ee46a..a3ad8cb9209 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -80,7 +80,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: def enc_hook(self, obj: Any) -> Any: if isinstance(obj, torch.Tensor): - return self._encode_ndarray(obj.numpy()) + return self._encode_tensor(obj) # Fall back to pickle for object or void kind ndarrays. if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): @@ -133,9 +133,27 @@ def _encode_ndarray( # backing buffers that we've stashed in `aux_buffers`. return obj.dtype.str, obj.shape, data + def _encode_tensor( + self, obj: torch.Tensor + ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: + assert self.aux_buffers is not None + # this creates a copy of the tensor if it's not already contiguous + obj = obj.contiguous() + # view the tensor as a 1D array of bytes + arr = obj.view((obj.numel(), )).view(torch.uint8).numpy() + if obj.nbytes < self.size_threshold: + # Smaller tensors are encoded inline, just like ndarrays. + data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) + else: + # Otherwise encode index of backing buffer to avoid copy. + data = len(self.aux_buffers) + self.aux_buffers.append(arr.data) + dtype = str(obj.dtype)[6:] # remove 'torch.' prefix + return dtype, obj.shape, data + def _encode_nested_tensors(self, nt: NestedTensors) -> Any: if isinstance(nt, torch.Tensor): - return self._encode_ndarray(nt.numpy()) + return self._encode_tensor(nt) if isinstance(nt, (int, float)): # Although it violates NestedTensors type, MultiModalKwargs # values are sometimes floats. @@ -186,7 +204,7 @@ def dec_hook(self, t: type, obj: Any) -> Any: if issubclass(t, np.ndarray): return self._decode_ndarray(obj) if issubclass(t, torch.Tensor): - return torch.from_numpy(self._decode_ndarray(obj)) + return self._decode_tensor(obj) if issubclass(t, MultiModalKwargs): if isinstance(obj, list): return MultiModalKwargs.from_items( @@ -199,11 +217,24 @@ def dec_hook(self, t: type, obj: Any) -> Any: def _decode_ndarray(self, arr: Any) -> np.ndarray: dtype, shape, data = arr - # Copy from inline representation, otherwise Torch is unhappy since - # the returned memory is non-writeable. + # zero-copy decode. We assume the ndarray will not be kept around, + # as it now locks the whole received message buffer in memory. + buffer = self.aux_buffers[data] if isinstance(data, int) else data + return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) + + def _decode_tensor(self, arr: Any) -> torch.Tensor: + dtype, shape, data = arr + # Copy from inline representation, to decouple the memory storage + # of the message from the original buffer. And also make Torch + # not complain about a readonly memoryview. buffer = self.aux_buffers[data] if isinstance(data, int) \ else bytearray(data) - return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) + # Create numpy wrapper around the bytes + arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), )) + torch_dtype = getattr(torch, dtype) + assert isinstance(torch_dtype, torch.dtype) + # Convert back to proper shape & type + return torch.from_numpy(arr).view(torch_dtype).view(shape) def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: decoded_items = [] @@ -228,7 +259,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if not isinstance(obj, list): raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}") if obj and isinstance(obj[0], str): - return torch.from_numpy(self._decode_ndarray(obj)) + return self._decode_tensor(obj) return [self._decode_nested_tensors(x) for x in obj] def ext_hook(self, code: int, data: memoryview) -> Any: