Skip to content

Commit 4bdd16e

Browse files
p88hDarkLight1337njhill
committed
Apply suggestions from code review
Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Nick Hill <[email protected]> Signed-off-by: Staszek Pasko <[email protected]>
1 parent 7b6b7ba commit 4bdd16e

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

tests/v1/test_serial_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_multimodal_kwargs():
104104
# 8 total tensors + top level buffer
105105
assert len(encoded) == 6
106106

107-
total_len = sum([len(x) for x in encoded])
107+
total_len = sum(len(x) for x in encoded)
108108

109109
# expected total encoding length, should be 4440, +-20 for minor changes
110110
assert total_len >= 4420 and total_len <= 4460
@@ -165,7 +165,7 @@ def nested_equal(a: NestedTensors, b: NestedTensors):
165165
if isinstance(a, torch.Tensor):
166166
return torch.equal(a, b)
167167
else:
168-
return all([nested_equal(x, y) for (x, y) in zip(a, b)])
168+
return all(nested_equal(x, y) for x, y in zip(a, b))
169169

170170

171171
def assert_equal(obj1: MyType, obj2: MyType):

vllm/v1/serial_utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def enc_hook(self, obj: Any) -> Any:
7575
# ignore the main dict, it will be re-indexed.
7676
# pass a list of MultiModalKwargsItem, then see below
7777
# Any tensors *not* indexed by modality will be ignored.
78-
return [mm.get_items(m) for m in mm.modalities]
78+
return mm._items_by_modality.values()
7979
# just return the main dict if there are no modalities
80-
return {k: v for k, v in obj.items()}
80+
return dict(mm)
8181

8282
if isinstance(obj, MultiModalKwargsItem):
8383
# Encode as plain dictionary + special handling for '.field'
@@ -146,12 +146,12 @@ def dec_hook(self, t: type, obj: Any) -> Any:
146146
if isclass(t):
147147
if issubclass(t, np.ndarray):
148148
return self._decode_ndarray(obj)
149-
if issubclass(t, MultiModalKwargs) and isinstance(obj, dict):
149+
if issubclass(t, MultiModalKwargs)
150+
if isinstance(obj, list):
151+
return MultiModalKwargs.from_items(self._decode_items(obj))
150152
return MultiModalKwargs(
151-
{k: self._decode_nested(obj[k])
152-
for k in obj})
153-
if issubclass(t, MultiModalKwargs) and isinstance(obj, list):
154-
return MultiModalKwargs.from_items(self._decode_items(obj))
153+
{k: self._decode_nested(v)
154+
for k in obj.items()})
155155
if issubclass(t, torch.Tensor):
156156
return torch.from_numpy(self._decode_ndarray(obj))
157157
return obj
@@ -172,12 +172,12 @@ def _decode_items(self, obj: list) -> list[MultiModalKwargsItem]:
172172
all.append(MultiModalKwargsItem.from_elems(elems))
173173
return all
174174

175-
def _decode_nested(self, obj: Any) -> NestedTensors:
176-
if isinstance(obj, list) and isinstance(obj[0], str):
175+
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
176+
if not isinstance(obj, list):
177+
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
178+
if obj and isinstance(obj[0], str):
177179
return torch.from_numpy(self._decode_ndarray(obj))
178-
if isinstance(obj, list):
179-
return [self._decode_nested(x) for x in obj]
180-
raise TypeError(f"Unexpected NestedArray contents: {obj}")
180+
return [self._decode_nested_tensors(x) for x in obj]
181181

182182
def ext_hook(self, code: int, data: memoryview) -> Any:
183183
if code == CUSTOM_TYPE_PICKLE:

0 commit comments

Comments
 (0)