Skip to content

Commit 95c7f50

Browse files
skhorasganiTTidjuricTT
authored andcommitted
[vLLM] Compatibility fixes for model generators after pulling Apr16-July22 upstream changes - removed legacy input processors and refactored for multi-modal models (#28406)
### Ticket [N/A](#27285) ### Problem description - Legacy input mappers/processors were removed from vLLM V0 (vllm-project/vllm#15686, vllm-project/vllm#10114). These changes are required to maintain compatibility of existing integrated models after pulling upstream changes in tenstorrent/vllm#172. ### What's changed - Removed legacy vLLM input processors from Llama3, Gemma3, Qwen2.5-VL - Defined new multi-modal input processor classes for Llama3.2-11B-Vision (`MllamaMultiModalProcessor`), Gemma3 / Qwen2.5-VL (`MultiModalProcessor`) and added support multi-modal limits for each - Moved max seq len assertion for Llama8B to model initialization, `--max_model_len` must be set on vLLM side for any models which support less than default max context length - Fixed bug where `create_multimodal_model` import was removed for Llama3.2-11B-Vision and broke the model (from 87b758d) ### Checklist - [x] [All post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml) CI passes - [x] [Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml) CI with demo tests passes (if applicable) - [x] [Model regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-models.yaml) CI passes (if applicable) - [x] [Device performance regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-device-models.yaml) CI passes (if applicable) - [x] (For models and ops writers) [Single-card demo tests](https://github.com/tenstorrent/tt-metal/actions/workflows/single-card-demo-tests.yaml) CI passes (if applicable) See [recommended dev flow](https://github.com/tenstorrent/tt-metal/blob/main/models/docs/MODEL_ADD.md#a-recommended-dev-flow-on-github-for-adding-new-models). - [x] [Galaxy quick](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-quick.yaml) CI passes (if applicable) - [x] [Galaxy demo tests, for Llama](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-demo-tests.yaml) CI passes, if applicable, because of current Llama work - [x] (For runtime and ops writers) [T3000 unit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-unit-tests.yaml) CI passes (if applicable, since this is run on push to main) - [x] (For models and ops writers) [T3000 demo tests](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-demo-tests.yaml) CI passes (if applicable, since this is required for release) - [x] New/Existing tests provide coverage for changes vLLM nightly tests - https://github.com/tenstorrent/tt-metal/actions/runs/17680447236 --------- Signed-off-by: Salar <[email protected]> Co-authored-by: Igor Djuric <[email protected]>
1 parent f04ca48 commit 95c7f50

File tree

3 files changed

+246
-140
lines changed

3 files changed

+246
-140
lines changed

models/demos/llama3_70b_galaxy/tt/generator_vllm.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from models.demos.llama3_70b_galaxy.tt.llama_model import TtTransformer
1010
from models.demos.llama3_70b_galaxy.tt.model_config import LlamaOptimizations, TtModelArgs
1111
from models.tt_transformers.tt.generator import create_submeshes
12-
from vllm.inputs import INPUT_REGISTRY
1312

1413

1514
def allocate_vllm_kv_cache(kv_cache_shape, dtype, num_layers, model: TtTransformer, tt_cache_path):
@@ -88,11 +87,6 @@ def initialize_vllm_text_transformer(
8887
return tt_model, model_args
8988

9089

91-
def input_processor_for_llama_text(ctx, inputs):
92-
return inputs
93-
94-
95-
@INPUT_REGISTRY.register_input_processor(input_processor_for_llama_text)
9690
class LlamaForCausalLM(Generator):
9791
def __init__(self, *args, **kwargs):
9892
super().__init__(*args, **kwargs)

models/demos/qwen25_vl/tt/generator_vllm.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,23 @@
44

55

66
from types import SimpleNamespace
7+
from typing import Mapping, Optional
78

89
import torch
910
from loguru import logger
1011
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
1112
Qwen2_5_VLForConditionalGeneration as Ref_Qwen2_5_VLForConditionalGeneration,
1213
)
13-
from vllm.inputs import INPUT_REGISTRY
1414
from vllm.model_executor.models.interfaces import SupportsMultiModal
15+
from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VLProcessingInfo
16+
from vllm.multimodal import MULTIMODAL_REGISTRY
1517

1618
import ttnn
1719
from models.demos.qwen25_vl.tt.common import merge_vision_tokens, multimodal_rope_from_hf, preprocess_inputs_prefill
1820
from models.demos.qwen25_vl.tt.generator import Generator as QwenVLGenerator
1921
from models.demos.qwen25_vl.tt.model import DropInVisionTransformer, Transformer
2022
from models.demos.qwen25_vl.tt.model_config import VisionModelArgs
21-
from models.tt_transformers.tt.generator_vllm import input_processor_for_multimodal
23+
from models.tt_transformers.tt.generator_vllm import DummyInputsBuilder, MultiModalProcessor
2224
from models.tt_transformers.tt.model_config import DecodersPrecision, ModelArgs
2325

2426

@@ -89,7 +91,15 @@ def __contains__(self, key):
8991
return key in self.__dict__
9092

9193

92-
@INPUT_REGISTRY.register_input_processor(input_processor_for_multimodal)
94+
class TT_Qwen2_5_VLProcessingInfo(Qwen2_5_VLProcessingInfo):
95+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
96+
return {"image": 1, "video": 0} # [INFO] videos are not supported yet, only supporting 1 image for now
97+
98+
99+
# TODO: Eventually replace MultiModalProcessor with vllm.model_executor.models.qwen2_5_vl::Qwen2_5_VLMultiModalProcessor
100+
@MULTIMODAL_REGISTRY.register_processor(
101+
MultiModalProcessor, info=TT_Qwen2_5_VLProcessingInfo, dummy_inputs=DummyInputsBuilder
102+
)
93103
class Qwen2_5_VLForConditionalGeneration(QwenVLGenerator, SupportsMultiModal):
94104
def __init__(self, *args, **kwargs):
95105
self.reference_model = kwargs.pop("reference_model", None)
@@ -167,15 +177,17 @@ def prefill_forward(
167177

168178
# reconstruct the inputs that Qwen2.5-VL expects
169179
inputs = CustomNamespace()
170-
inputs.input_ids = tokens.to(images[0].attention_mask.dtype)
180+
inputs.input_ids = tokens.to(images[0].attention_mask.dtype) if images[0] is not None else tokens
171181
inputs.attention_mask = torch.concat(
172182
[
173183
torch.nn.functional.pad(im.attention_mask, (0, padded_seq_len - im.attention_mask.shape[-1]), value=0)
174-
for im in images
184+
if im is not None
185+
else torch.ones_like(tokens[i : i + 1], dtype=tokens.dtype)
186+
for i, im in enumerate(images)
175187
],
176188
dim=0,
177189
)
178-
if "pixel_values" in images[0]:
190+
if images[0] is not None and "pixel_values" in images[0]:
179191
# we currently do not support mixed inputs of text-only users and text-image users; hence checking images[0] is enough
180192
inputs.pixel_values = torch.concat([im.pixel_values for im in images], dim=0)
181193
inputs.image_grid_thw = torch.concat([im.image_grid_thw for im in images], dim=0)

0 commit comments

Comments
 (0)