Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions tests/v1/test_serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def test_encode_decode():
torch.rand((1, 10), dtype=torch.float32),
torch.rand((3, 5, 4000), dtype=torch.float64),
torch.tensor(1984), # test scalar too
# Make sure to test bf16 which numpy doesn't support.
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
torch.tensor([float("-inf"), float("inf")] * 1024,
dtype=torch.bfloat16),
],
numpy_array=np.arange(512),
unrecognized=UnrecognizedType(33),
Expand All @@ -64,7 +68,7 @@ def test_encode_decode():
# There should be the main buffer + 4 large tensor buffers
# + 1 large numpy array. "large" is <= 512 bytes.
# The two small tensors are encoded inline.
assert len(encoded) == 6
assert len(encoded) == 8

decoded: MyType = decoder.decode(encoded)

Expand All @@ -76,7 +80,7 @@ def test_encode_decode():

encoded2 = encoder.encode_into(obj, preallocated)

assert len(encoded2) == 6
assert len(encoded2) == 8
assert encoded2[0] is preallocated

decoded2: MyType = decoder.decode(encoded2)
Expand Down Expand Up @@ -114,15 +118,15 @@ def test_multimodal_kwargs():

total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)

# expected total encoding length, should be 44536, +-20 for minor changes
assert total_len >= 44516 and total_len <= 44556
# expected total encoding length, should be 44559, +-20 for minor changes
assert total_len >= 44539 and total_len <= 44579
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
assert all(nested_equal(d[k], decoded[k]) for k in d)


def test_multimodal_items_by_modality():
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
dtype=torch.int16),
e1 = MultiModalFieldElem("audio", "a0",
torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField())
e2 = MultiModalFieldElem(
"video",
Expand Down
45 changes: 38 additions & 7 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:

def enc_hook(self, obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
return self._encode_ndarray(obj.numpy())
return self._encode_tensor(obj)

# Fall back to pickle for object or void kind ndarrays.
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
Expand Down Expand Up @@ -133,9 +133,27 @@ def _encode_ndarray(
# backing buffers that we've stashed in `aux_buffers`.
return obj.dtype.str, obj.shape, data

def _encode_tensor(
self, obj: torch.Tensor
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
# this creates a copy of the tensor if it's not already contiguous
obj = obj.contiguous()
# view the tensor as a 1D array of bytes
arr = obj.view((obj.numel(), )).view(torch.uint8).numpy()
if obj.nbytes < self.size_threshold:
# Smaller tensors are encoded inline, just like ndarrays.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
else:
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
self.aux_buffers.append(arr.data)
dtype = str(obj.dtype)[6:] # remove 'torch.' prefix
return dtype, obj.shape, data

def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
if isinstance(nt, torch.Tensor):
return self._encode_ndarray(nt.numpy())
return self._encode_tensor(nt)
if isinstance(nt, (int, float)):
# Although it violates NestedTensors type, MultiModalKwargs
# values are sometimes floats.
Expand Down Expand Up @@ -186,7 +204,7 @@ def dec_hook(self, t: type, obj: Any) -> Any:
if issubclass(t, np.ndarray):
return self._decode_ndarray(obj)
if issubclass(t, torch.Tensor):
return torch.from_numpy(self._decode_ndarray(obj))
return self._decode_tensor(obj)
if issubclass(t, MultiModalKwargs):
if isinstance(obj, list):
return MultiModalKwargs.from_items(
Expand All @@ -199,11 +217,24 @@ def dec_hook(self, t: type, obj: Any) -> Any:

def _decode_ndarray(self, arr: Any) -> np.ndarray:
dtype, shape, data = arr
# Copy from inline representation, otherwise Torch is unhappy since
# the returned memory is non-writeable.
# zero-copy decode. We assume the ndarray will not be kept around,
# as it now locks the whole received message buffer in memory.
buffer = self.aux_buffers[data] if isinstance(data, int) else data
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)

def _decode_tensor(self, arr: Any) -> torch.Tensor:
dtype, shape, data = arr
# Copy from inline representation, to decouple the memory storage
# of the message from the original buffer. And also make Torch
# not complain about a readonly memoryview.
buffer = self.aux_buffers[data] if isinstance(data, int) \
else bytearray(data)
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
# Create numpy wrapper around the bytes
arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), ))
torch_dtype = getattr(torch, dtype)
assert isinstance(torch_dtype, torch.dtype)
# Convert back to proper shape & type
return torch.from_numpy(arr).view(torch_dtype).view(shape)

def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
decoded_items = []
Expand All @@ -228,7 +259,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
if not isinstance(obj, list):
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
if obj and isinstance(obj[0], str):
return torch.from_numpy(self._decode_ndarray(obj))
return self._decode_tensor(obj)
return [self._decode_nested_tensors(x) for x in obj]

def ext_hook(self, code: int, data: memoryview) -> Any:
Expand Down