Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/v1/test_serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ def test_multimodal_kwargs():

total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)

# expected total encoding length, should be 44536, +-20 for minor changes
assert total_len >= 44516 and total_len <= 44556
# expected total encoding length, should be 44559, +-20 for minor changes
assert total_len >= 44539 and total_len <= 44579
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
assert all(nested_equal(d[k], decoded[k]) for k in d)


def test_multimodal_items_by_modality():
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
dtype=torch.int16),
dtype=torch.bfloat16),
MultiModalBatchedField())
e2 = MultiModalFieldElem(
"video",
Expand Down
34 changes: 30 additions & 4 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:

def enc_hook(self, obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
return self._encode_ndarray(obj.numpy())
return self._encode_tensor(obj)

# Fall back to pickle for object or void kind ndarrays.
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
Expand Down Expand Up @@ -133,9 +133,26 @@ def _encode_ndarray(
# backing buffers that we've stashed in `aux_buffers`.
return obj.dtype.str, obj.shape, data

def _encode_tensor(
self, obj: torch.Tensor
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
# this creates a copy of the tensor
obj = obj.contiguous() if not obj.is_contiguous() else obj
# view the tensor as a 1D array of bytes
arr = obj.view([obj.numel()]).view(torch.uint8).numpy()
if obj.nbytes < self.size_threshold:
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
else:
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
self.aux_buffers.append(arr.data)
dt = str(obj.dtype)[6:] # remove 'torch.' prefix
return dt, obj.shape, data

def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
if isinstance(nt, torch.Tensor):
return self._encode_ndarray(nt.numpy())
return self._encode_tensor(nt)
if isinstance(nt, (int, float)):
# Although it violates NestedTensors type, MultiModalKwargs
# values are sometimes floats.
Expand Down Expand Up @@ -186,7 +203,7 @@ def dec_hook(self, t: type, obj: Any) -> Any:
if issubclass(t, np.ndarray):
return self._decode_ndarray(obj)
if issubclass(t, torch.Tensor):
return torch.from_numpy(self._decode_ndarray(obj))
return self._decode_tensor(obj)
if issubclass(t, MultiModalKwargs):
if isinstance(obj, list):
return MultiModalKwargs.from_items(
Expand All @@ -205,6 +222,15 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
else bytearray(data)
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)

def _decode_tensor(self, arr: Any) -> torch.Tensor:
dtype, shape, data = arr
# Copy from inline representation, otherwise Torch is unhappy since
# the returned memory is non-writeable.
buffer = self.aux_buffers[data] if isinstance(data, int) \
else bytearray(data)
arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=[len(buffer)])
return torch.from_numpy(arr).view(getattr(torch, dtype)).view(shape)

def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
decoded_items = []
for item in obj:
Expand All @@ -228,7 +254,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
if not isinstance(obj, list):
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
if obj and isinstance(obj[0], str):
return torch.from_numpy(self._decode_ndarray(obj))
return self._decode_tensor(obj)
return [self._decode_nested_tensors(x) for x in obj]

def ext_hook(self, code: int, data: memoryview) -> Any:
Expand Down
Loading