From 7b34fe17b0847af229b0e39c768bf8480b9725d0 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 18 Apr 2025 09:56:42 -0700 Subject: [PATCH] [BugFix] Support bf16 in zero-copy tensor serialization Numpy doesn't support bfloat16 so we convert to/from a fp16 view. Signed-off-by: Nick Hill --- tests/v1/test_serial_utils.py | 8 ++++++-- vllm/v1/serial_utils.py | 31 +++++++++++++++++++++++++++---- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index e58d3c403c19..5ad2ced057e3 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -46,6 +46,10 @@ def test_encode_decode(): list_of_tensors=[ torch.rand((1, 10), dtype=torch.float32), torch.rand((3, 5, 4000), dtype=torch.float64), + # Make sure to test bf16 which numpy doesn't support. + torch.rand((3, 5, 1000), dtype=torch.bfloat16), + torch.tensor([float("-inf"), float("inf")] * 1024, + dtype=torch.bfloat16), torch.tensor(1984), # test scalar too ], numpy_array=np.arange(512), @@ -64,7 +68,7 @@ def test_encode_decode(): # There should be the main buffer + 4 large tensor buffers # + 1 large numpy array. "large" is <= 512 bytes. # The two small tensors are encoded inline. - assert len(encoded) == 6 + assert len(encoded) == 8 decoded: MyType = decoder.decode(encoded) @@ -76,7 +80,7 @@ def test_encode_decode(): encoded2 = encoder.encode_into(obj, preallocated) - assert len(encoded2) == 6 + assert len(encoded2) == 8 assert encoded2[0] is preallocated decoded2: MyType = decoder.decode(encoded2) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 4f7987ee46a6..1cee78ca431a 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -36,6 +36,11 @@ bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] +NP_FP16_STR = torch.tensor(0, dtype=torch.float16, + device="cpu").numpy().dtype.str +# Special dtype string for bf16 which numpy doesn't support. +ENC_BF16_STR = "!bf16" + class MsgpackEncoder: """Encoder with custom torch tensor and numpy array serialization. @@ -80,7 +85,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: def enc_hook(self, obj: Any) -> Any: if isinstance(obj, torch.Tensor): - return self._encode_ndarray(obj.numpy()) + return self._encode_tensor(obj) # Fall back to pickle for object or void kind ndarrays. if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): @@ -113,6 +118,15 @@ def enc_hook(self, obj: Any) -> Any: return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) + def _encode_tensor( + self, obj: torch.Tensor + ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: + if obj.dtype != torch.bfloat16: + return self._encode_ndarray(obj.numpy()) + # Numpy doesn't support as bf16 so send as fp16 view. + _, shape, data = self._encode_ndarray(obj.view(torch.float16).numpy()) + return ENC_BF16_STR, shape, data + def _encode_ndarray( self, obj: np.ndarray ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: @@ -135,7 +149,7 @@ def _encode_ndarray( def _encode_nested_tensors(self, nt: NestedTensors) -> Any: if isinstance(nt, torch.Tensor): - return self._encode_ndarray(nt.numpy()) + return self._encode_tensor(nt) if isinstance(nt, (int, float)): # Although it violates NestedTensors type, MultiModalKwargs # values are sometimes floats. @@ -186,7 +200,7 @@ def dec_hook(self, t: type, obj: Any) -> Any: if issubclass(t, np.ndarray): return self._decode_ndarray(obj) if issubclass(t, torch.Tensor): - return torch.from_numpy(self._decode_ndarray(obj)) + return self._decode_tensor(obj) if issubclass(t, MultiModalKwargs): if isinstance(obj, list): return MultiModalKwargs.from_items( @@ -197,6 +211,15 @@ def dec_hook(self, t: type, obj: Any) -> Any: }) return obj + def _decode_tensor(self, arr: Any) -> torch.Tensor: + dtype, shape, data = arr + if dtype != ENC_BF16_STR: + return torch.from_numpy(self._decode_ndarray(arr)) + # Numpy doesn't support as bf16 so convert from fp16 view. + arr = NP_FP16_STR, shape, data + tensor = torch.from_numpy(self._decode_ndarray(arr)) + return tensor.view(torch.bfloat16) + def _decode_ndarray(self, arr: Any) -> np.ndarray: dtype, shape, data = arr # Copy from inline representation, otherwise Torch is unhappy since @@ -228,7 +251,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if not isinstance(obj, list): raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}") if obj and isinstance(obj[0], str): - return torch.from_numpy(self._decode_ndarray(obj)) + return self._decode_tensor(obj) return [self._decode_nested_tensors(x) for x in obj] def ext_hook(self, code: int, data: memoryview) -> Any: