|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | from collections import UserDict
|
3 | 3 | from dataclasses import dataclass
|
| 4 | +from typing import Optional |
4 | 5 |
|
| 6 | +import msgspec |
5 | 7 | import numpy as np
|
6 | 8 | import torch
|
7 | 9 |
|
| 10 | +from vllm.multimodal.inputs import (MultiModalBatchedField, |
| 11 | + MultiModalFieldElem, MultiModalKwargs, |
| 12 | + MultiModalKwargsItem, NestedTensors) |
8 | 13 | from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
9 | 14 |
|
10 | 15 |
|
@@ -70,6 +75,99 @@ def test_encode_decode():
|
70 | 75 | assert_equal(decoded2, obj)
|
71 | 76 |
|
72 | 77 |
|
| 78 | +class MyRequest(msgspec.Struct): |
| 79 | + mm: Optional[list[MultiModalKwargs]] |
| 80 | + |
| 81 | + |
| 82 | +def test_multimodal_kwargs(): |
| 83 | + d = { |
| 84 | + "foo": |
| 85 | + torch.zeros(1000, dtype=torch.float16), |
| 86 | + "bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)], |
| 87 | + "baz": [ |
| 88 | + torch.rand((256), dtype=torch.float16), |
| 89 | + [ |
| 90 | + torch.rand((1, 12), dtype=torch.float32), |
| 91 | + torch.rand((3, 5, 7), dtype=torch.float64), |
| 92 | + ], [torch.rand((4, 4), dtype=torch.float16)] |
| 93 | + ], |
| 94 | + } |
| 95 | + |
| 96 | + # pack mm kwargs into a mock request so that it can be decoded properly |
| 97 | + req = MyRequest(mm=[MultiModalKwargs(d)]) |
| 98 | + |
| 99 | + encoder = MsgpackEncoder() |
| 100 | + decoder = MsgpackDecoder(MyRequest) |
| 101 | + |
| 102 | + encoded = encoder.encode(req) |
| 103 | + |
| 104 | + # 8 total tensors + top level buffer |
| 105 | + assert len(encoded) == 6 |
| 106 | + |
| 107 | + total_len = sum([len(x) for x in encoded]) |
| 108 | + |
| 109 | + # expected total encoding length, should be 4440, +-20 for minor changes |
| 110 | + assert total_len >= 4420 and total_len <= 4460 |
| 111 | + decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] |
| 112 | + assert all(nested_equal(d[k], decoded[k]) for k in d) |
| 113 | + |
| 114 | + |
| 115 | +def test_multimodal_items_by_modality(): |
| 116 | + e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000, |
| 117 | + dtype=torch.int16), |
| 118 | + MultiModalBatchedField()) |
| 119 | + e2 = MultiModalFieldElem( |
| 120 | + "video", |
| 121 | + "v0", |
| 122 | + [torch.zeros(1000, dtype=torch.int8) for _ in range(4)], |
| 123 | + MultiModalBatchedField(), |
| 124 | + ) |
| 125 | + e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000, |
| 126 | + dtype=torch.int32), |
| 127 | + MultiModalBatchedField()) |
| 128 | + e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000, |
| 129 | + dtype=torch.int32), |
| 130 | + MultiModalBatchedField()) |
| 131 | + audio = MultiModalKwargsItem.from_elems([e1]) |
| 132 | + video = MultiModalKwargsItem.from_elems([e2]) |
| 133 | + image = MultiModalKwargsItem.from_elems([e3, e4]) |
| 134 | + mm = MultiModalKwargs.from_items([audio, video, image]) |
| 135 | + |
| 136 | + # pack mm kwargs into a mock request so that it can be decoded properly |
| 137 | + req = MyRequest([mm]) |
| 138 | + |
| 139 | + encoder = MsgpackEncoder() |
| 140 | + decoder = MsgpackDecoder(MyRequest) |
| 141 | + |
| 142 | + encoded = encoder.encode(req) |
| 143 | + |
| 144 | + # 5 total tensors + top level buffer |
| 145 | + assert len(encoded) == 8 |
| 146 | + |
| 147 | + total_len = sum([len(x) for x in encoded]) |
| 148 | + |
| 149 | + # expected total encoding length, should be 7507, +-20 for minor changes |
| 150 | + assert total_len >= 7487 and total_len <= 7527 |
| 151 | + decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] |
| 152 | + |
| 153 | + # check all modalities were recovered and do some basic sanity checks |
| 154 | + assert len(decoded.modalities) == 3 |
| 155 | + images = decoded.get_items("image") |
| 156 | + assert len(images) == 1 |
| 157 | + assert len(images[0].items()) == 2 |
| 158 | + assert list(images[0].keys()) == ["i0", "i1"] |
| 159 | + |
| 160 | + # check the tensor contents and layout in the main dict |
| 161 | + assert all(nested_equal(mm[k], decoded[k]) for k in mm) |
| 162 | + |
| 163 | + |
| 164 | +def nested_equal(a: NestedTensors, b: NestedTensors): |
| 165 | + if isinstance(a, torch.Tensor): |
| 166 | + return torch.equal(a, b) |
| 167 | + else: |
| 168 | + return all([nested_equal(x, y) for (x, y) in zip(a, b)]) |
| 169 | + |
| 170 | + |
73 | 171 | def assert_equal(obj1: MyType, obj2: MyType):
|
74 | 172 | assert torch.equal(obj1.tensor1, obj2.tensor1)
|
75 | 173 | assert obj1.a_string == obj2.a_string
|
|
0 commit comments