|
4 | 4 |
|
5 | 5 |
|
6 | 6 | from types import SimpleNamespace
|
| 7 | +from typing import Mapping, Optional |
7 | 8 |
|
8 | 9 | import torch
|
9 | 10 | from loguru import logger
|
10 | 11 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
11 | 12 | Qwen2_5_VLForConditionalGeneration as Ref_Qwen2_5_VLForConditionalGeneration,
|
12 | 13 | )
|
13 |
| -from vllm.inputs import INPUT_REGISTRY |
14 | 14 | from vllm.model_executor.models.interfaces import SupportsMultiModal
|
| 15 | +from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VLProcessingInfo |
| 16 | +from vllm.multimodal import MULTIMODAL_REGISTRY |
15 | 17 |
|
16 | 18 | import ttnn
|
17 | 19 | from models.demos.qwen25_vl.tt.common import merge_vision_tokens, multimodal_rope_from_hf, preprocess_inputs_prefill
|
18 | 20 | from models.demos.qwen25_vl.tt.generator import Generator as QwenVLGenerator
|
19 | 21 | from models.demos.qwen25_vl.tt.model import DropInVisionTransformer, Transformer
|
20 | 22 | from models.demos.qwen25_vl.tt.model_config import VisionModelArgs
|
21 |
| -from models.tt_transformers.tt.generator_vllm import input_processor_for_multimodal |
| 23 | +from models.tt_transformers.tt.generator_vllm import DummyInputsBuilder, MultiModalProcessor |
22 | 24 | from models.tt_transformers.tt.model_config import DecodersPrecision, ModelArgs
|
23 | 25 |
|
24 | 26 |
|
@@ -89,7 +91,15 @@ def __contains__(self, key):
|
89 | 91 | return key in self.__dict__
|
90 | 92 |
|
91 | 93 |
|
92 |
| -@INPUT_REGISTRY.register_input_processor(input_processor_for_multimodal) |
| 94 | +class TT_Qwen2_5_VLProcessingInfo(Qwen2_5_VLProcessingInfo): |
| 95 | + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: |
| 96 | + return {"image": 1, "video": 0} # [INFO] videos are not supported yet, only supporting 1 image for now |
| 97 | + |
| 98 | + |
| 99 | +# TODO: Eventually replace MultiModalProcessor with vllm.model_executor.models.qwen2_5_vl::Qwen2_5_VLMultiModalProcessor |
| 100 | +@MULTIMODAL_REGISTRY.register_processor( |
| 101 | + MultiModalProcessor, info=TT_Qwen2_5_VLProcessingInfo, dummy_inputs=DummyInputsBuilder |
| 102 | +) |
93 | 103 | class Qwen2_5_VLForConditionalGeneration(QwenVLGenerator, SupportsMultiModal):
|
94 | 104 | def __init__(self, *args, **kwargs):
|
95 | 105 | self.reference_model = kwargs.pop("reference_model", None)
|
@@ -167,15 +177,17 @@ def prefill_forward(
|
167 | 177 |
|
168 | 178 | # reconstruct the inputs that Qwen2.5-VL expects
|
169 | 179 | inputs = CustomNamespace()
|
170 |
| - inputs.input_ids = tokens.to(images[0].attention_mask.dtype) |
| 180 | + inputs.input_ids = tokens.to(images[0].attention_mask.dtype) if images[0] is not None else tokens |
171 | 181 | inputs.attention_mask = torch.concat(
|
172 | 182 | [
|
173 | 183 | torch.nn.functional.pad(im.attention_mask, (0, padded_seq_len - im.attention_mask.shape[-1]), value=0)
|
174 |
| - for im in images |
| 184 | + if im is not None |
| 185 | + else torch.ones_like(tokens[i : i + 1], dtype=tokens.dtype) |
| 186 | + for i, im in enumerate(images) |
175 | 187 | ],
|
176 | 188 | dim=0,
|
177 | 189 | )
|
178 |
| - if "pixel_values" in images[0]: |
| 190 | + if images[0] is not None and "pixel_values" in images[0]: |
179 | 191 | # we currently do not support mixed inputs of text-only users and text-image users; hence checking images[0] is enough
|
180 | 192 | inputs.pixel_values = torch.concat([im.pixel_values for im in images], dim=0)
|
181 | 193 | inputs.image_grid_thw = torch.concat([im.image_grid_thw for im in images], dim=0)
|
|
0 commit comments