Skip to content

Commit 8bda83c

Browse files
committed
fix review edits
Signed-off-by: Staszek Pasko <[email protected]>
1 parent 12c9d8b commit 8bda83c

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

vllm/v1/serial_utils.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pickle
55
from collections.abc import Sequence
66
from inspect import isclass
7-
from itertools import chain
87
from types import FunctionType
98
from typing import Any, Optional, Union
109

@@ -84,22 +83,24 @@ def enc_hook(self, obj: Any) -> Any:
8483
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
8584
return self._encode_ndarray(obj)
8685

87-
if isinstance(obj, MultiModalKwargs):
88-
mm: MultiModalKwargs = obj
89-
if not mm.modalities:
90-
# just return the main dict if there are no modalities.
91-
return dict(mm)
92-
93-
# ignore the main dict, it will be re-indexed.
94-
# Encode a list of MultiModalKwargsItems as plain dicts
95-
# + special handling for .field.
96-
# Any tensors *not* indexed by modality will be ignored.
97-
return [{
98-
"modality": elem.modality,
99-
"key": elem.key,
100-
"data": self._encode_nested_tensors(elem.data),
101-
"field": self._encode_field(elem.field),
102-
} for item in mm._items_by_modality.values() for elem in item]
86+
if isinstance(obj, MultiModalKwargs):
87+
mm: MultiModalKwargs = obj
88+
if not mm.modalities:
89+
# just return the main dict if there are no modalities.
90+
return dict(mm)
91+
92+
# ignore the main dict, it will be re-indexed.
93+
# Encode a list of MultiModalKwargsItems as plain dicts
94+
# + special handling for .field.
95+
# Any tensors *not* indexed by modality will be ignored.
96+
return [[{
97+
"modality": elem.modality,
98+
"key": elem.key,
99+
"data": self._encode_nested_tensors(elem.data),
100+
"field": self._encode_field(elem.field),
101+
} for elem in item.values()]
102+
for itemlist in mm._items_by_modality.values()
103+
for item in itemlist]
103104

104105
if isinstance(obj, FunctionType):
105106
# `pickle` is generally faster than cloudpickle, but can have
@@ -199,7 +200,7 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
199200

200201
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
201202
all = []
202-
for item in chain.from_iterable(obj):
203+
for item in obj:
203204
elems = []
204205
for v in item:
205206
v["data"] = self._decode_nested_tensors(v["data"])

0 commit comments

Comments
 (0)