From 7f08435d72e1679e6d0c0f9238303372441845bd Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Fri, 18 Apr 2025 21:22:43 +0200 Subject: [PATCH 01/10] Serialize tensors using int8 views Allows to support arbitrary types like bfloat16 Signed-off-by: Staszek Pasko --- tests/v1/test_serial_utils.py | 6 +++--- vllm/v1/serial_utils.py | 34 ++++++++++++++++++++++++++++++---- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index e58d3c403c19..b2fe2a73945f 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -114,15 +114,15 @@ def test_multimodal_kwargs(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) - # expected total encoding length, should be 44536, +-20 for minor changes - assert total_len >= 44516 and total_len <= 44556 + # expected total encoding length, should be 44559, +-20 for minor changes + assert total_len >= 44539 and total_len <= 44579 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] assert all(nested_equal(d[k], decoded[k]) for k in d) def test_multimodal_items_by_modality(): e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000, - dtype=torch.int16), + dtype=torch.bfloat16), MultiModalBatchedField()) e2 = MultiModalFieldElem( "video", diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 4f7987ee46a6..4926b0e073e3 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -80,7 +80,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: def enc_hook(self, obj: Any) -> Any: if isinstance(obj, torch.Tensor): - return self._encode_ndarray(obj.numpy()) + return self._encode_tensor(obj) # Fall back to pickle for object or void kind ndarrays. if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): @@ -133,9 +133,26 @@ def _encode_ndarray( # backing buffers that we've stashed in `aux_buffers`. return obj.dtype.str, obj.shape, data + def _encode_tensor( + self, obj: torch.Tensor + ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: + assert self.aux_buffers is not None + # this creates a copy of the tensor + obj = obj.contiguous() if not obj.is_contiguous() else obj + # view the tensor as a 1D array of bytes + arr = obj.view([obj.numel()]).view(torch.uint8).numpy() + if obj.nbytes < self.size_threshold: + data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) + else: + # Otherwise encode index of backing buffer to avoid copy. + data = len(self.aux_buffers) + self.aux_buffers.append(arr.data) + dt = str(obj.dtype)[6:] # remove 'torch.' prefix + return dt, obj.shape, data + def _encode_nested_tensors(self, nt: NestedTensors) -> Any: if isinstance(nt, torch.Tensor): - return self._encode_ndarray(nt.numpy()) + return self._encode_tensor(nt) if isinstance(nt, (int, float)): # Although it violates NestedTensors type, MultiModalKwargs # values are sometimes floats. @@ -186,7 +203,7 @@ def dec_hook(self, t: type, obj: Any) -> Any: if issubclass(t, np.ndarray): return self._decode_ndarray(obj) if issubclass(t, torch.Tensor): - return torch.from_numpy(self._decode_ndarray(obj)) + return self._decode_tensor(obj) if issubclass(t, MultiModalKwargs): if isinstance(obj, list): return MultiModalKwargs.from_items( @@ -205,6 +222,15 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray: else bytearray(data) return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) + def _decode_tensor(self, arr: Any) -> torch.Tensor: + dtype, shape, data = arr + # Copy from inline representation, otherwise Torch is unhappy since + # the returned memory is non-writeable. + buffer = self.aux_buffers[data] if isinstance(data, int) \ + else bytearray(data) + arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=[len(buffer)]) + return torch.from_numpy(arr).view(getattr(torch, dtype)).view(shape) + def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: decoded_items = [] for item in obj: @@ -228,7 +254,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if not isinstance(obj, list): raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}") if obj and isinstance(obj[0], str): - return torch.from_numpy(self._decode_ndarray(obj)) + return self._decode_tensor(obj) return [self._decode_nested_tensors(x) for x in obj] def ext_hook(self, code: int, data: memoryview) -> Any: From 49941eac8008d4d20392452e9a9d543d0b7ea3d5 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Fri, 18 Apr 2025 21:45:21 +0200 Subject: [PATCH 02/10] Formatting Signed-off-by: Staszek Pasko --- tests/v1/test_serial_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index b2fe2a73945f..6612ced48cf0 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -121,8 +121,8 @@ def test_multimodal_kwargs(): def test_multimodal_items_by_modality(): - e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000, - dtype=torch.bfloat16), + e1 = MultiModalFieldElem("audio", "a0", + torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField()) e2 = MultiModalFieldElem( "video", From 29daef4f20c7e8cb585f5f670358b7e5fb68a62b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Staszek=20Pa=C5=9Bko?= Date: Fri, 18 Apr 2025 22:08:00 +0200 Subject: [PATCH 03/10] Apply suggestions from code review Co-authored-by: Nick Hill Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 4926b0e073e3..5dbbf1bb481f 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -140,15 +140,15 @@ def _encode_tensor( # this creates a copy of the tensor obj = obj.contiguous() if not obj.is_contiguous() else obj # view the tensor as a 1D array of bytes - arr = obj.view([obj.numel()]).view(torch.uint8).numpy() + arr = obj.view((obj.numel(),)).view(torch.uint8).numpy() if obj.nbytes < self.size_threshold: data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) else: # Otherwise encode index of backing buffer to avoid copy. data = len(self.aux_buffers) self.aux_buffers.append(arr.data) - dt = str(obj.dtype)[6:] # remove 'torch.' prefix - return dt, obj.shape, data + dtype = str(obj.dtype)[6:] # remove 'torch.' prefix + return dtype, obj.shape, data def _encode_nested_tensors(self, nt: NestedTensors) -> Any: if isinstance(nt, torch.Tensor): @@ -228,8 +228,10 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor: # the returned memory is non-writeable. buffer = self.aux_buffers[data] if isinstance(data, int) \ else bytearray(data) - arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=[len(buffer)]) - return torch.from_numpy(arr).view(getattr(torch, dtype)).view(shape) + arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer),)) + torch_dtype = getattr(torch, dtype) + assert isinstance(torch_dtype, torch.dtype) + return torch.from_numpy(arr).view(torch_dtype).view(shape) def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: decoded_items = [] From b57879501ae0a07537b5eb34c3aa1213c4c26458 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Fri, 18 Apr 2025 22:11:01 +0200 Subject: [PATCH 04/10] add more bf16 tests from #16860 Signed-off-by: Staszek Pasko --- tests/v1/test_serial_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 6612ced48cf0..df9832fc4e48 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -47,6 +47,10 @@ def test_encode_decode(): torch.rand((1, 10), dtype=torch.float32), torch.rand((3, 5, 4000), dtype=torch.float64), torch.tensor(1984), # test scalar too + # Make sure to test bf16 which numpy doesn't support. + torch.rand((3, 5, 1000), dtype=torch.bfloat16), + torch.tensor([float("-inf"), float("inf")] * 1024, + dtype=torch.bfloat16), ], numpy_array=np.arange(512), unrecognized=UnrecognizedType(33), @@ -64,7 +68,7 @@ def test_encode_decode(): # There should be the main buffer + 4 large tensor buffers # + 1 large numpy array. "large" is <= 512 bytes. # The two small tensors are encoded inline. - assert len(encoded) == 6 + assert len(encoded) == 8 decoded: MyType = decoder.decode(encoded) @@ -76,7 +80,7 @@ def test_encode_decode(): encoded2 = encoder.encode_into(obj, preallocated) - assert len(encoded2) == 6 + assert len(encoded2) == 8 assert encoded2[0] is preallocated decoded2: MyType = decoder.decode(encoded2) From cd4a2d4558bb1db7d22b75ba854da777c44a64e4 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Fri, 18 Apr 2025 22:21:06 +0200 Subject: [PATCH 05/10] style Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 5dbbf1bb481f..ba49944dc6d3 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -140,7 +140,7 @@ def _encode_tensor( # this creates a copy of the tensor obj = obj.contiguous() if not obj.is_contiguous() else obj # view the tensor as a 1D array of bytes - arr = obj.view((obj.numel(),)).view(torch.uint8).numpy() + arr = obj.view((obj.numel(), )).view(torch.uint8).numpy() if obj.nbytes < self.size_threshold: data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) else: @@ -228,7 +228,7 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor: # the returned memory is non-writeable. buffer = self.aux_buffers[data] if isinstance(data, int) \ else bytearray(data) - arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer),)) + arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), )) torch_dtype = getattr(torch, dtype) assert isinstance(torch_dtype, torch.dtype) return torch.from_numpy(arr).view(torch_dtype).view(shape) From e45c80a4bcf2b3faab815f48223d819cdbaafdb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Staszek=20Pa=C5=9Bko?= Date: Fri, 18 Apr 2025 23:23:54 +0200 Subject: [PATCH 06/10] Update vllm/v1/serial_utils.py Co-authored-by: Nick Hill Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index ba49944dc6d3..6d7fcad09bff 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -138,7 +138,7 @@ def _encode_tensor( ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None # this creates a copy of the tensor - obj = obj.contiguous() if not obj.is_contiguous() else obj + obj = obj.contiguous() # view the tensor as a 1D array of bytes arr = obj.view((obj.numel(), )).view(torch.uint8).numpy() if obj.nbytes < self.size_threshold: From d5903c621c48c30ea4c969e5e0c02a27cc28dead Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Staszek=20Pa=C5=9Bko?= Date: Fri, 18 Apr 2025 23:49:17 +0200 Subject: [PATCH 07/10] Apply suggestions from code review Co-authored-by: Nick Hill Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 6d7fcad09bff..218def087ea3 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -137,7 +137,7 @@ def _encode_tensor( self, obj: torch.Tensor ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None - # this creates a copy of the tensor + # this creates a copy of the tensor if it's not already contiguous obj = obj.contiguous() # view the tensor as a 1D array of bytes arr = obj.view((obj.numel(), )).view(torch.uint8).numpy() From 8d65edf31f82bfa86a2273e8789427e071881eb8 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Fri, 18 Apr 2025 23:56:35 +0200 Subject: [PATCH 08/10] Comments Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 218def087ea3..8775d10bc9a9 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -142,6 +142,7 @@ def _encode_tensor( # view the tensor as a 1D array of bytes arr = obj.view((obj.numel(), )).view(torch.uint8).numpy() if obj.nbytes < self.size_threshold: + # Smaller tensors are encoded inline, just like ndarrays. data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) else: # Otherwise encode index of backing buffer to avoid copy. @@ -216,21 +217,25 @@ def dec_hook(self, t: type, obj: Any) -> Any: def _decode_ndarray(self, arr: Any) -> np.ndarray: dtype, shape, data = arr - # Copy from inline representation, otherwise Torch is unhappy since - # the returned memory is non-writeable. + # Copy from inline representation, to decouple the memory storage + # of the message from the original buffer. Not needed in the + # auxillary buffers case. buffer = self.aux_buffers[data] if isinstance(data, int) \ else bytearray(data) return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) def _decode_tensor(self, arr: Any) -> torch.Tensor: dtype, shape, data = arr - # Copy from inline representation, otherwise Torch is unhappy since - # the returned memory is non-writeable. + # Copy from inline representation, to decouple the memory storage + # of the message from the original buffer. And also make Torch + # not complain about a readonly memoryview. buffer = self.aux_buffers[data] if isinstance(data, int) \ else bytearray(data) + # Create numpy wrapper around the bytes arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), )) torch_dtype = getattr(torch, dtype) assert isinstance(torch_dtype, torch.dtype) + # Convert back to proper shape & type return torch.from_numpy(arr).view(torch_dtype).view(shape) def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: From f4e13284c0183c3fa95e7c7e5852e385327b13a8 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Sat, 19 Apr 2025 00:03:32 +0200 Subject: [PATCH 09/10] fix typo Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 8775d10bc9a9..ecf79653f999 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -219,7 +219,7 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray: dtype, shape, data = arr # Copy from inline representation, to decouple the memory storage # of the message from the original buffer. Not needed in the - # auxillary buffers case. + # auxiliary buffers case. buffer = self.aux_buffers[data] if isinstance(data, int) \ else bytearray(data) return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) From dcaa7cc3a2825b9600b936cfd0557983100efeb0 Mon Sep 17 00:00:00 2001 From: Staszek Pasko Date: Sat, 19 Apr 2025 08:18:07 +0200 Subject: [PATCH 10/10] Skip copy ndarray in the decode path It's now separate from tensors Signed-off-by: Staszek Pasko --- vllm/v1/serial_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index ecf79653f999..a3ad8cb92096 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -217,11 +217,9 @@ def dec_hook(self, t: type, obj: Any) -> Any: def _decode_ndarray(self, arr: Any) -> np.ndarray: dtype, shape, data = arr - # Copy from inline representation, to decouple the memory storage - # of the message from the original buffer. Not needed in the - # auxiliary buffers case. - buffer = self.aux_buffers[data] if isinstance(data, int) \ - else bytearray(data) + # zero-copy decode. We assume the ndarray will not be kept around, + # as it now locks the whole received message buffer in memory. + buffer = self.aux_buffers[data] if isinstance(data, int) else data return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) def _decode_tensor(self, arr: Any) -> torch.Tensor: