Skip to content

Commit 38759e2

Browse files
skhorasganiTTidjuricTT
authored andcommitted
[vLLM] Compatibility fixes for model generators after pulling Apr16-July22 upstream changes - removed legacy input processors and refactored for multi-modal models (#28406)
### Ticket [N/A](#27285) ### Problem description - Legacy input mappers/processors were removed from vLLM V0 (vllm-project/vllm#15686, vllm-project/vllm#10114). These changes are required to maintain compatibility of existing integrated models after pulling upstream changes in tenstorrent/vllm#172. ### What's changed - Removed legacy vLLM input processors from Llama3, Gemma3, Qwen2.5-VL - Defined new multi-modal input processor classes for Llama3.2-11B-Vision (`MllamaMultiModalProcessor`), Gemma3 / Qwen2.5-VL (`MultiModalProcessor`) and added support multi-modal limits for each - Moved max seq len assertion for Llama8B to model initialization, `--max_model_len` must be set on vLLM side for any models which support less than default max context length - Fixed bug where `create_multimodal_model` import was removed for Llama3.2-11B-Vision and broke the model (from 87b758d) ### Checklist - [x] [All post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml) CI passes - [x] [Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml) CI with demo tests passes (if applicable) - [x] [Model regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-models.yaml) CI passes (if applicable) - [x] [Device performance regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-device-models.yaml) CI passes (if applicable) - [x] (For models and ops writers) [Single-card demo tests](https://github.com/tenstorrent/tt-metal/actions/workflows/single-card-demo-tests.yaml) CI passes (if applicable) See [recommended dev flow](https://github.com/tenstorrent/tt-metal/blob/main/models/docs/MODEL_ADD.md#a-recommended-dev-flow-on-github-for-adding-new-models). - [x] [Galaxy quick](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-quick.yaml) CI passes (if applicable) - [x] [Galaxy demo tests, for Llama](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-demo-tests.yaml) CI passes, if applicable, because of current Llama work - [x] (For runtime and ops writers) [T3000 unit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-unit-tests.yaml) CI passes (if applicable, since this is run on push to main) - [x] (For models and ops writers) [T3000 demo tests](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-demo-tests.yaml) CI passes (if applicable, since this is required for release) - [x] New/Existing tests provide coverage for changes vLLM nightly tests - https://github.com/tenstorrent/tt-metal/actions/runs/17680447236 --------- Signed-off-by: Salar <[email protected]> Co-authored-by: Igor Djuric <[email protected]>
1 parent 6b89da5 commit 38759e2

File tree

3 files changed

+246
-191
lines changed

3 files changed

+246
-191
lines changed

models/demos/llama3_70b_galaxy/tt/generator_vllm.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -88,63 +88,6 @@ def initialize_vllm_text_transformer(
8888
return tt_model, model_args
8989

9090

91-
def initialize_vllm_text_transformer_qwen(
92-
hf_config,
93-
tt_data_parallel,
94-
mesh_device,
95-
max_batch_size,
96-
max_seq_len,
97-
n_layers=None,
98-
dtype=ttnn.bfloat8_b,
99-
optimizations=LlamaOptimizations.performance,
100-
):
101-
submesh_devices = create_submeshes(mesh_device, tt_data_parallel)
102-
# Load model args, weights
103-
model_args = []
104-
for submesh in submesh_devices:
105-
model_args_i = TtQwenModelArgs(
106-
submesh,
107-
instruct=(
108-
"Instruct" in hf_config._name_or_path or "DeepSeek-R1-Distill-Llama-70B" in hf_config._name_or_path
109-
),
110-
max_batch_size=max_batch_size // tt_data_parallel,
111-
# optimizations=optimizations,
112-
max_seq_len=max_seq_len,
113-
)
114-
115-
if n_layers is not None:
116-
model_args_i.n_layers = n_layers
117-
118-
model_args.append(model_args_i)
119-
120-
state_dict = model_args[0].load_state_dict()
121-
122-
tt_model = []
123-
for i, submesh in enumerate(submesh_devices):
124-
tt_model_i = TtTransformer(
125-
args=model_args[i],
126-
mesh_device=submesh,
127-
dtype=ttnn.bfloat8_b,
128-
state_dict=state_dict,
129-
weight_cache_path=model_args[i].weight_cache_path(ttnn.bfloat8_b),
130-
use_paged_kv_cache=True,
131-
mode="prefill",
132-
enable_prefetcher_performance_mode=True,
133-
)
134-
tt_model.append(tt_model_i)
135-
136-
return tt_model, model_args
137-
138-
139-
def input_processor_for_llama_text(ctx, inputs):
140-
return inputs
141-
142-
143-
def input_processor_for_qwen_text(ctx, inputs):
144-
return inputs
145-
146-
147-
# @INPUT_REGISTRY.register_input_processor(input_processor_for_llama_text)
14891
class LlamaForCausalLM(Generator):
14992
def __init__(self, *args, **kwargs):
15093
super().__init__(*args, **kwargs)

models/demos/qwen25_vl/tt/generator_vllm.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,23 @@
44

55

66
from types import SimpleNamespace
7+
from typing import Mapping, Optional
78

89
import torch
910
from loguru import logger
1011
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
1112
Qwen2_5_VLForConditionalGeneration as Ref_Qwen2_5_VLForConditionalGeneration,
1213
)
13-
from vllm.inputs import INPUT_REGISTRY
1414
from vllm.model_executor.models.interfaces import SupportsMultiModal
15+
from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VLProcessingInfo
16+
from vllm.multimodal import MULTIMODAL_REGISTRY
1517

1618
import ttnn
1719
from models.demos.qwen25_vl.tt.common import merge_vision_tokens, multimodal_rope_from_hf, preprocess_inputs_prefill
1820
from models.demos.qwen25_vl.tt.generator import Generator as QwenVLGenerator
1921
from models.demos.qwen25_vl.tt.model import DropInVisionTransformer, Transformer
2022
from models.demos.qwen25_vl.tt.model_config import VisionModelArgs
21-
from models.tt_transformers.tt.generator_vllm import input_processor_for_multimodal
23+
from models.tt_transformers.tt.generator_vllm import DummyInputsBuilder, MultiModalProcessor
2224
from models.tt_transformers.tt.model_config import DecodersPrecision, ModelArgs
2325

2426

@@ -89,7 +91,15 @@ def __contains__(self, key):
8991
return key in self.__dict__
9092

9193

92-
@INPUT_REGISTRY.register_input_processor(input_processor_for_multimodal)
94+
class TT_Qwen2_5_VLProcessingInfo(Qwen2_5_VLProcessingInfo):
95+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
96+
return {"image": 1, "video": 0} # [INFO] videos are not supported yet, only supporting 1 image for now
97+
98+
99+
# TODO: Eventually replace MultiModalProcessor with vllm.model_executor.models.qwen2_5_vl::Qwen2_5_VLMultiModalProcessor
100+
@MULTIMODAL_REGISTRY.register_processor(
101+
MultiModalProcessor, info=TT_Qwen2_5_VLProcessingInfo, dummy_inputs=DummyInputsBuilder
102+
)
93103
class Qwen2_5_VLForConditionalGeneration(QwenVLGenerator, SupportsMultiModal):
94104
def __init__(self, *args, **kwargs):
95105
self.reference_model = kwargs.pop("reference_model", None)
@@ -167,15 +177,17 @@ def prefill_forward(
167177

168178
# reconstruct the inputs that Qwen2.5-VL expects
169179
inputs = CustomNamespace()
170-
inputs.input_ids = tokens.to(images[0].attention_mask.dtype)
180+
inputs.input_ids = tokens.to(images[0].attention_mask.dtype) if images[0] is not None else tokens
171181
inputs.attention_mask = torch.concat(
172182
[
173183
torch.nn.functional.pad(im.attention_mask, (0, padded_seq_len - im.attention_mask.shape[-1]), value=0)
174-
for im in images
184+
if im is not None
185+
else torch.ones_like(tokens[i : i + 1], dtype=tokens.dtype)
186+
for i, im in enumerate(images)
175187
],
176188
dim=0,
177189
)
178-
if "pixel_values" in images[0]:
190+
if images[0] is not None and "pixel_values" in images[0]:
179191
# we currently do not support mixed inputs of text-only users and text-image users; hence checking images[0] is enough
180192
inputs.pixel_values = torch.concat([im.pixel_values for im in images], dim=0)
181193
inputs.image_grid_thw = torch.concat([im.image_grid_thw for im in images], dim=0)

0 commit comments

Comments
 (0)