Skip to content

Commit ff0e59d

Browse files
authored
[CI/Build] Improve Tensor Schema tests speed by avoid engine core initialization (#23357)
Signed-off-by: Isotr0py <[email protected]>
1 parent b557136 commit ff0e59d

File tree

7 files changed

+157
-115
lines changed

7 files changed

+157
-115
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -566,8 +566,7 @@ steps:
566566
- tests/models/multimodal
567567
commands:
568568
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
569-
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
570-
- pytest -v -s models/multimodal/processing/test_tensor_schema.py
569+
- pytest -v -s models/multimodal/processing
571570

572571
- label: Multi-Modal Models Test (Standard)
573572
mirror_hardwares: [amdexperimental]

tests/models/multimodal/processing/test_tensor_schema.py

Lines changed: 52 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,31 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import tempfile
34
from collections.abc import Iterable
5+
from contextlib import contextmanager
46
from functools import partial
57
from typing import Any, Union
6-
from unittest.mock import patch
78

89
import numpy as np
910
import pytest
11+
import torch.nn as nn
1012
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
1113
UserMessage)
1214
from mistral_common.protocol.instruct.request import ChatCompletionRequest
1315
from PIL import Image
1416

15-
from vllm.config import ModelConfig
16-
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
17+
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
18+
from vllm.distributed import (cleanup_dist_env_and_memory,
19+
init_distributed_environment,
20+
initialize_model_parallel)
1721
from vllm.inputs import InputProcessingContext
18-
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
19-
MultiModalKwargs)
22+
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
23+
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
2024
from vllm.multimodal.processing import BaseMultiModalProcessor
2125
from vllm.multimodal.utils import group_mm_kwargs_by_modality
2226
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
23-
from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads
24-
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
25-
from vllm.v1.engine.core import EngineCore as V1EngineCore
27+
from vllm.utils import is_list_of
2628

27-
from ....conftest import VllmRunner
2829
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
2930
from ...utils import dummy_hf_overrides
3031

@@ -137,6 +138,27 @@ def create_batched_mm_kwargs(
137138
return group_mm_kwargs_by_modality(items)
138139

139140

141+
@contextmanager
142+
def initialize_dummy_model(model_cls: nn.Module, model_config: ModelConfig):
143+
temp_file = tempfile.mkstemp()[1]
144+
init_distributed_environment(
145+
world_size=1,
146+
rank=0,
147+
distributed_init_method=f"file://{temp_file}",
148+
local_rank=0,
149+
backend="nccl",
150+
)
151+
initialize_model_parallel(tensor_model_parallel_size=1)
152+
vllm_config = VllmConfig(model_config=model_config)
153+
with set_current_vllm_config(vllm_config=vllm_config):
154+
with set_default_torch_dtype(model_config.dtype):
155+
model = model_cls(vllm_config=vllm_config)
156+
yield model
157+
158+
del model
159+
cleanup_dist_env_and_memory()
160+
161+
140162
def get_model_id_to_test(
141163
model_arch_list: Iterable[str]) -> list[tuple[str, str]]:
142164
filtered_results = []
@@ -155,8 +177,7 @@ def get_model_id_to_test(
155177
@pytest.mark.parametrize(
156178
"model_arch, model_id",
157179
get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()))
158-
def test_model_tensor_schema(model_arch: str, model_id: str,
159-
vllm_runner: type[VllmRunner], monkeypatch):
180+
def test_model_tensor_schema(model_arch: str, model_id: str):
160181
if model_arch in ARCH_TO_SKIP:
161182
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")
162183
if model_id in REPO_ID_TO_SKIP:
@@ -177,14 +198,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
177198
tokenizer_mode=model_info.tokenizer_mode,
178199
revision=model_info.revision,
179200
trust_remote_code=model_info.trust_remote_code,
180-
hf_overrides=model_info.hf_overrides,
201+
hf_overrides=hf_overrides_fn,
181202
)
182203
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
183204
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
184205

185-
if not any(
186-
hasattr(model_cls, f"_parse_and_validate_{m}_input")
187-
for m in ["image", "video", "audio"]):
206+
inputs_parse_methods = []
207+
for attr_name in dir(model_cls):
208+
attr = getattr(model_cls, attr_name)
209+
if hasattr(attr, "__annotations__"):
210+
return_type = attr.__annotations__.get("return", None)
211+
if return_type is not None and "Input" in str(return_type):
212+
inputs_parse_methods.append(attr_name)
213+
214+
if not any(inputs_parse_methods):
188215
pytest.skip(f"{model_arch} does not support tensor schema validation.")
189216

190217
ctx = InputProcessingContext(
@@ -197,68 +224,13 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
197224
modality: 3 if limit is None else limit
198225
for modality, limit in supported_mm_limits.items()
199226
}
200-
201-
# Avoid calling model.forward()
202-
def _initialize_kv_caches_v0(self) -> None:
203-
self.cache_config.num_gpu_blocks = 0
204-
self.cache_config.num_cpu_blocks = 0
205-
206-
def _initialize_kv_caches_v1(self, vllm_config):
207-
kv_cache_specs = self.model_executor.get_kv_cache_specs()
208-
scheduler_kv_cache_config = get_kv_cache_config(
209-
vllm_config,
210-
kv_cache_specs[0],
211-
10 * GiB_bytes,
212-
)
213-
214-
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
215-
return 1, 0, scheduler_kv_cache_config
216-
217-
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
218-
_initialize_kv_caches_v0),
219-
patch.object(V1EngineCore, "_initialize_kv_caches",
220-
_initialize_kv_caches_v1), monkeypatch.context() as m):
221-
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
222-
if model_info.v0_only:
223-
m.setenv("VLLM_USE_V1", "0")
224-
225-
# TODO(Isotr0py): Can we avoid initializing engine?
226-
with (
227-
set_default_torch_num_threads(1),
228-
vllm_runner(
229-
model_id,
230-
tokenizer_name=model_info.tokenizer,
231-
tokenizer_mode=model_info.tokenizer_mode,
232-
revision=model_info.revision,
233-
trust_remote_code=model_info.trust_remote_code,
234-
max_model_len=model_info.max_model_len,
235-
load_format="dummy",
236-
hf_overrides=hf_overrides_fn,
237-
limit_mm_per_prompt=limit_mm_per_prompt,
238-
enforce_eager=True,
239-
) as vllm_model,
240-
):
241-
model_config = vllm_model.llm.llm_engine.model_config
242-
llm_engine = vllm_model.llm.llm_engine
243-
244-
if hasattr(llm_engine, "processor"):
245-
# v1 processor
246-
mm_registry = llm_engine.processor.mm_registry
247-
else:
248-
# v0 input_preprocessor
249-
mm_registry = llm_engine.input_preprocessor.mm_registry
250-
251-
processor = mm_registry.create_processor(model_config)
252-
253-
def validate_model_input(model, modality: str,
254-
mm_kwargs: MultiModalKwargs):
255-
method_name = f"_parse_and_validate_{modality}_input"
256-
if hasattr(model, method_name):
257-
getattr(model, method_name)(**mm_kwargs)
258-
259-
for modality, _, mm_kwargs in create_batched_mm_kwargs(
260-
model_config, processor):
261-
valid_func = partial(validate_model_input,
262-
modality=modality,
263-
mm_kwargs=mm_kwargs)
264-
vllm_model.apply_model(valid_func)
227+
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
228+
processor = factories.build_processor(ctx, cache=None)
229+
230+
with initialize_dummy_model(model_cls, model_config) as model:
231+
for modality, _, mm_kwargs in create_batched_mm_kwargs(
232+
model_config, processor):
233+
for method_name in inputs_parse_methods:
234+
print(f"Testing `{method_name}` with modality={modality} "
235+
f"and mm_kwargs{list(mm_kwargs.keys())}")
236+
getattr(model, method_name)(modality=modality, **mm_kwargs)

vllm/model_executor/models/granite_speech.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
549549

550550
raise ValueError("Only audio modality is supported")
551551

552-
def __init__(self, *, vllm_config: VllmConfig, prefix: str):
552+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
553553
super().__init__()
554554
config = vllm_config.model_config.hf_config
555555
quant_config = vllm_config.quant_config

vllm/model_executor/models/mllama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1371,7 +1371,8 @@ def unpack_data(self,
13711371
output_tensor[i, :t.size(0)] = t
13721372
return output_tensor
13731373

1374-
def _parse_and_validate_image_input(self, **kwargs: object):
1374+
def _parse_and_validate_image_input(
1375+
self, **kwargs: object) -> Optional[MllamaImagePixelInputs]:
13751376
# tensor with the same shape will be batched together by
13761377
# MultiModalKwargs.batch, so pixel_values here can be:
13771378
# - list[torch.Tensor]:

vllm/model_executor/models/ovis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ class OvisImagePatchInputs(TypedDict):
209209
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
210210
"""
211211

212-
inducator_tokens: torch.Tensor
212+
indicator_tokens: torch.Tensor
213213
"""
214214
Shape:
215215
`(batch_size * (num_patches + 1))`

vllm/model_executor/models/ovis2_5.py

Lines changed: 98 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
""" PyTorch Ovis model."""
44
from collections.abc import Iterable, Mapping
55
from functools import partial
6-
from typing import Optional, Union
6+
from typing import Literal, Optional, TypedDict, Union
77

88
import torch
99
import torch.nn as nn
@@ -50,6 +50,27 @@
5050
}
5151

5252

53+
class OvisVideoPatchInputs(TypedDict):
54+
type: Literal["video_patches"]
55+
flat_data: torch.Tensor
56+
"""
57+
Shape:
58+
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
59+
"""
60+
61+
indicator_tokens: torch.Tensor
62+
"""
63+
Shape:
64+
`(batch_size * (num_patches + 1))`
65+
"""
66+
67+
patches_per_image: list[int]
68+
"""
69+
List of number of total patches for each frame in the video.
70+
This is used to restore the first two dimensions of `flat_data`.
71+
"""
72+
73+
5374
def _ovis2_5_field_config():
5475
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
5576
grids=MultiModalFieldConfig.batched("image"),
@@ -429,17 +450,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
429450
self.make_empty_intermediate_tensors = (
430451
self.get_language_model().make_empty_intermediate_tensors)
431452

432-
def _parse_and_validate_visual_input(
433-
self, is_video,
434-
**kwargs: object) -> Optional[OvisImagePatchInputs]:
435-
if is_video:
436-
pixel_values = kwargs.pop("video_pixel_values", None)
437-
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
438-
grids = kwargs.pop("video_grids", None)
439-
else:
440-
pixel_values = kwargs.pop("pixel_values", None)
441-
indicator_tokens = kwargs.pop("indicator_tokens", None)
442-
grids = kwargs.pop("grids", None)
453+
def _parse_and_validate_image_input(
454+
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
455+
pixel_values = kwargs.pop("pixel_values", None)
456+
indicator_tokens = kwargs.pop("indicator_tokens", None)
457+
grids = kwargs.pop("grids", None)
443458
if pixel_values is None and indicator_tokens is None:
444459
return None
445460

@@ -466,8 +481,40 @@ def _parse_and_validate_visual_input(
466481

467482
raise AssertionError("This line should be unreachable.")
468483

484+
def _parse_and_validate_video_input(
485+
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
486+
pixel_values = kwargs.pop("video_pixel_values", None)
487+
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
488+
grids = kwargs.pop("video_grids", None)
489+
if pixel_values is None and indicator_tokens is None:
490+
return None
491+
492+
if pixel_values is not None and indicator_tokens is not None:
493+
if not isinstance(pixel_values, (torch.Tensor, list)):
494+
raise ValueError("Incorrect type of pixel values. "
495+
f"Got type: {type(pixel_values)}")
496+
497+
if not isinstance(indicator_tokens, (torch.Tensor, list)):
498+
raise ValueError("Incorrect type of indicator_tokens. "
499+
f"Got type: {type(indicator_tokens)}")
500+
501+
return OvisVideoPatchInputs(
502+
type="video_patches",
503+
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
504+
patches_per_image=[
505+
x.shape[0] // (self.config.vit_config.hidden_stride**2)
506+
for x in flatten_bn(pixel_values)
507+
],
508+
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens),
509+
concat=True),
510+
grids=flatten_bn(flatten_bn(grids), concat=True),
511+
)
512+
513+
raise AssertionError("This line should be unreachable.")
514+
469515
def _process_image_input(
470-
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
516+
self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs]
517+
) -> MultiModalEmbeddings:
471518
image_patches_flat = image_input["flat_data"]
472519
patches_per_image = image_input["patches_per_image"]
473520
indicator_tokens = image_input["indicator_tokens"]
@@ -500,21 +547,44 @@ def _process_image_input(
500547
torch.cat(vision_embeddings_per_image, dim=0))
501548
return tuple(vision_embeddings)
502549

503-
def get_multimodal_embeddings(
504-
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
505-
embeddings = []
506-
507-
# NOTE: _parse_and_validate_visual_input has side-effects and pops
508-
# keys from kwargs. We process images first, then videos.
509-
image_input = self._parse_and_validate_visual_input(False, **kwargs)
510-
if image_input:
511-
embeddings.extend(self._process_image_input(image_input))
512-
513-
video_input = self._parse_and_validate_visual_input(True, **kwargs)
514-
if video_input:
515-
embeddings.extend(self._process_image_input(video_input))
516-
517-
return tuple(embeddings) if embeddings else None
550+
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
551+
modalities = {}
552+
553+
# Preserve the order of modalities if there are multiple of them
554+
# from the order of kwargs.
555+
for input_key in kwargs:
556+
if input_key in ("pixel_values", "indicator_tokens",
557+
"grids") and "images" not in modalities:
558+
modalities["images"] = self._parse_and_validate_image_input(
559+
**kwargs)
560+
if input_key in ("video_pixel_values", "video_indicator_tokens",
561+
"video_grids") and "videos" not in modalities:
562+
modalities["videos"] = self._parse_and_validate_video_input(
563+
**kwargs)
564+
565+
return modalities
566+
567+
def get_multimodal_embeddings(self,
568+
**kwargs: object) -> MultiModalEmbeddings:
569+
570+
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
571+
if not modalities:
572+
return []
573+
574+
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
575+
# NOTE: It is important to iterate over the keys in this dictionary
576+
# to preserve the order of the modalities.
577+
for modality in modalities:
578+
if modality == "images":
579+
image_input = modalities["images"]
580+
vision_embeddings = self._process_image_input(image_input)
581+
multimodal_embeddings += vision_embeddings
582+
if modality == "videos":
583+
video_input = modalities["videos"]
584+
video_embeddings = self._process_image_input(video_input)
585+
multimodal_embeddings += video_embeddings
586+
587+
return multimodal_embeddings
518588

519589
def get_input_embeddings(
520590
self,

vllm/model_executor/models/phi4mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,8 +1031,8 @@ def _process_audio_input(self, audio_input: Phi4MMAudioInputs,
10311031
]
10321032
return audio_embeds
10331033

1034-
def _parse_and_validate_image_input(self,
1035-
**kwargs: object) -> Optional[dict]:
1034+
def _parse_and_validate_image_input(
1035+
self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]:
10361036
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
10371037
if input_image_embeds is None:
10381038
return None

0 commit comments

Comments
 (0)