11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ import tempfile
34from collections .abc import Iterable
5+ from contextlib import contextmanager
46from functools import partial
57from typing import Any , Union
6- from unittest .mock import patch
78
89import numpy as np
910import pytest
11+ import torch .nn as nn
1012from mistral_common .protocol .instruct .messages import (ImageChunk , TextChunk ,
1113 UserMessage )
1214from mistral_common .protocol .instruct .request import ChatCompletionRequest
1315from PIL import Image
1416
15- from vllm .config import ModelConfig
16- from vllm .engine .llm_engine import LLMEngine as V0LLMEngine
17+ from vllm .config import ModelConfig , VllmConfig , set_current_vllm_config
18+ from vllm .distributed import (cleanup_dist_env_and_memory ,
19+ init_distributed_environment ,
20+ initialize_model_parallel )
1721from vllm .inputs import InputProcessingContext
18- from vllm .multimodal import ( MULTIMODAL_REGISTRY , BatchedTensorInputs ,
19- MultiModalKwargs )
22+ from vllm .model_executor . model_loader . utils import set_default_torch_dtype
23+ from vllm . multimodal import MULTIMODAL_REGISTRY , BatchedTensorInputs
2024from vllm .multimodal .processing import BaseMultiModalProcessor
2125from vllm .multimodal .utils import group_mm_kwargs_by_modality
2226from vllm .transformers_utils .tokenizer import cached_tokenizer_from_config
23- from vllm .utils import GiB_bytes , is_list_of , set_default_torch_num_threads
24- from vllm .v1 .core .kv_cache_utils import get_kv_cache_config
25- from vllm .v1 .engine .core import EngineCore as V1EngineCore
27+ from vllm .utils import is_list_of
2628
27- from ....conftest import VllmRunner
2829from ...registry import _MULTIMODAL_EXAMPLE_MODELS , HF_EXAMPLE_MODELS
2930from ...utils import dummy_hf_overrides
3031
@@ -137,6 +138,27 @@ def create_batched_mm_kwargs(
137138 return group_mm_kwargs_by_modality (items )
138139
139140
141+ @contextmanager
142+ def initialize_dummy_model (model_cls : nn .Module , model_config : ModelConfig ):
143+ temp_file = tempfile .mkstemp ()[1 ]
144+ init_distributed_environment (
145+ world_size = 1 ,
146+ rank = 0 ,
147+ distributed_init_method = f"file://{ temp_file } " ,
148+ local_rank = 0 ,
149+ backend = "nccl" ,
150+ )
151+ initialize_model_parallel (tensor_model_parallel_size = 1 )
152+ vllm_config = VllmConfig (model_config = model_config )
153+ with set_current_vllm_config (vllm_config = vllm_config ):
154+ with set_default_torch_dtype (model_config .dtype ):
155+ model = model_cls (vllm_config = vllm_config )
156+ yield model
157+
158+ del model
159+ cleanup_dist_env_and_memory ()
160+
161+
140162def get_model_id_to_test (
141163 model_arch_list : Iterable [str ]) -> list [tuple [str , str ]]:
142164 filtered_results = []
@@ -155,8 +177,7 @@ def get_model_id_to_test(
155177@pytest .mark .parametrize (
156178 "model_arch, model_id" ,
157179 get_model_id_to_test (_MULTIMODAL_EXAMPLE_MODELS .keys ()))
158- def test_model_tensor_schema (model_arch : str , model_id : str ,
159- vllm_runner : type [VllmRunner ], monkeypatch ):
180+ def test_model_tensor_schema (model_arch : str , model_id : str ):
160181 if model_arch in ARCH_TO_SKIP :
161182 pytest .skip (f"Skipping { model_arch } due to { ARCH_TO_SKIP [model_arch ]} " )
162183 if model_id in REPO_ID_TO_SKIP :
@@ -177,14 +198,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
177198 tokenizer_mode = model_info .tokenizer_mode ,
178199 revision = model_info .revision ,
179200 trust_remote_code = model_info .trust_remote_code ,
180- hf_overrides = model_info . hf_overrides ,
201+ hf_overrides = hf_overrides_fn ,
181202 )
182203 model_cls = MULTIMODAL_REGISTRY ._get_model_cls (model_config )
183204 factories = MULTIMODAL_REGISTRY ._processor_factories [model_cls ]
184205
185- if not any (
186- hasattr (model_cls , f"_parse_and_validate_{ m } _input" )
187- for m in ["image" , "video" , "audio" ]):
206+ inputs_parse_methods = []
207+ for attr_name in dir (model_cls ):
208+ attr = getattr (model_cls , attr_name )
209+ if hasattr (attr , "__annotations__" ):
210+ return_type = attr .__annotations__ .get ("return" , None )
211+ if return_type is not None and "Input" in str (return_type ):
212+ inputs_parse_methods .append (attr_name )
213+
214+ if not any (inputs_parse_methods ):
188215 pytest .skip (f"{ model_arch } does not support tensor schema validation." )
189216
190217 ctx = InputProcessingContext (
@@ -197,68 +224,13 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
197224 modality : 3 if limit is None else limit
198225 for modality , limit in supported_mm_limits .items ()
199226 }
200-
201- # Avoid calling model.forward()
202- def _initialize_kv_caches_v0 (self ) -> None :
203- self .cache_config .num_gpu_blocks = 0
204- self .cache_config .num_cpu_blocks = 0
205-
206- def _initialize_kv_caches_v1 (self , vllm_config ):
207- kv_cache_specs = self .model_executor .get_kv_cache_specs ()
208- scheduler_kv_cache_config = get_kv_cache_config (
209- vllm_config ,
210- kv_cache_specs [0 ],
211- 10 * GiB_bytes ,
212- )
213-
214- # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
215- return 1 , 0 , scheduler_kv_cache_config
216-
217- with (patch .object (V0LLMEngine , "_initialize_kv_caches" ,
218- _initialize_kv_caches_v0 ),
219- patch .object (V1EngineCore , "_initialize_kv_caches" ,
220- _initialize_kv_caches_v1 ), monkeypatch .context () as m ):
221- m .setenv ("VLLM_ALLOW_INSECURE_SERIALIZATION" , "1" )
222- if model_info .v0_only :
223- m .setenv ("VLLM_USE_V1" , "0" )
224-
225- # TODO(Isotr0py): Can we avoid initializing engine?
226- with (
227- set_default_torch_num_threads (1 ),
228- vllm_runner (
229- model_id ,
230- tokenizer_name = model_info .tokenizer ,
231- tokenizer_mode = model_info .tokenizer_mode ,
232- revision = model_info .revision ,
233- trust_remote_code = model_info .trust_remote_code ,
234- max_model_len = model_info .max_model_len ,
235- load_format = "dummy" ,
236- hf_overrides = hf_overrides_fn ,
237- limit_mm_per_prompt = limit_mm_per_prompt ,
238- enforce_eager = True ,
239- ) as vllm_model ,
240- ):
241- model_config = vllm_model .llm .llm_engine .model_config
242- llm_engine = vllm_model .llm .llm_engine
243-
244- if hasattr (llm_engine , "processor" ):
245- # v1 processor
246- mm_registry = llm_engine .processor .mm_registry
247- else :
248- # v0 input_preprocessor
249- mm_registry = llm_engine .input_preprocessor .mm_registry
250-
251- processor = mm_registry .create_processor (model_config )
252-
253- def validate_model_input (model , modality : str ,
254- mm_kwargs : MultiModalKwargs ):
255- method_name = f"_parse_and_validate_{ modality } _input"
256- if hasattr (model , method_name ):
257- getattr (model , method_name )(** mm_kwargs )
258-
259- for modality , _ , mm_kwargs in create_batched_mm_kwargs (
260- model_config , processor ):
261- valid_func = partial (validate_model_input ,
262- modality = modality ,
263- mm_kwargs = mm_kwargs )
264- vllm_model .apply_model (valid_func )
227+ model_config .get_multimodal_config ().limit_per_prompt = limit_mm_per_prompt
228+ processor = factories .build_processor (ctx , cache = None )
229+
230+ with initialize_dummy_model (model_cls , model_config ) as model :
231+ for modality , _ , mm_kwargs in create_batched_mm_kwargs (
232+ model_config , processor ):
233+ for method_name in inputs_parse_methods :
234+ print (f"Testing `{ method_name } ` with modality={ modality } "
235+ f"and mm_kwargs{ list (mm_kwargs .keys ())} " )
236+ getattr (model , method_name )(modality = modality , ** mm_kwargs )
0 commit comments