|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +from functools import partial |
| 4 | +from typing import Any |
| 5 | +from unittest.mock import patch |
| 6 | + |
| 7 | +import pytest |
| 8 | +from transformers import PretrainedConfig |
| 9 | + |
| 10 | +from vllm.config import ModelConfig |
| 11 | +from vllm.engine.llm_engine import LLMEngine as V0LLMEngine |
| 12 | +from vllm.inputs import InputProcessingContext |
| 13 | +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs |
| 14 | +from vllm.multimodal.processing import BaseMultiModalProcessor |
| 15 | +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config |
| 16 | +from vllm.utils import GiB_bytes, set_default_torch_num_threads |
| 17 | +from vllm.v1.core.kv_cache_utils import get_kv_cache_config |
| 18 | +from vllm.v1.engine.core import EngineCore as V1EngineCore |
| 19 | + |
| 20 | +from ...conftest import VllmRunner |
| 21 | +from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS |
| 22 | + |
| 23 | +ARCH_TO_SKIP = { |
| 24 | + "MolmoForCausalLM": "incompatible requirements", |
| 25 | + "MiniMaxVL01ForConditionalGeneration": "broken model", |
| 26 | +} |
| 27 | + |
| 28 | + |
| 29 | +def create_batched_mm_kwargs( |
| 30 | + model_config: ModelConfig, |
| 31 | + processor: BaseMultiModalProcessor, |
| 32 | +) -> MultiModalKwargs: |
| 33 | + processing_info = processor.info |
| 34 | + dummy_inputs = processor.dummy_inputs |
| 35 | + supported_mm_limits = processing_info.get_supported_mm_limits() |
| 36 | + mm_counts = { |
| 37 | + modality: 3 if limit is None else limit |
| 38 | + for modality, limit in supported_mm_limits.items() |
| 39 | + } |
| 40 | + processor_inputs = dummy_inputs.get_dummy_processor_inputs( |
| 41 | + seq_len=model_config.max_model_len, |
| 42 | + mm_counts=mm_counts, |
| 43 | + ) |
| 44 | + mm_kwargs = processor.apply( |
| 45 | + prompt=processor_inputs.prompt, |
| 46 | + mm_data=processor_inputs.mm_data, |
| 47 | + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, |
| 48 | + tokenization_kwargs=processor_inputs.tokenization_kwargs, |
| 49 | + )["mm_kwargs"] |
| 50 | + mm_kwargs = MultiModalKwargs.batch([mm_kwargs]) |
| 51 | + return mm_kwargs |
| 52 | + |
| 53 | + |
| 54 | +# Avoid OOM and reduce initialization time by only using 1 layer |
| 55 | +def hf_overrides(hf_config: PretrainedConfig, |
| 56 | + exist_overrides: dict[str, Any]) -> PretrainedConfig: |
| 57 | + hf_config.update(exist_overrides) |
| 58 | + text_config = hf_config.get_text_config() |
| 59 | + # Ensure at least 2 expert per group |
| 60 | + # Since `grouped_topk` assumes top-2 |
| 61 | + n_group = getattr(text_config, 'n_group', None) |
| 62 | + num_experts = n_group * 2 if n_group is not None else 2 |
| 63 | + # we use three layers for Gemma-3n to check |
| 64 | + # both normal layer and kv_shared_layer |
| 65 | + text_config.update({ |
| 66 | + "num_layers": 1, |
| 67 | + "num_hidden_layers": 1, |
| 68 | + "num_experts": num_experts, |
| 69 | + "num_experts_per_tok": 2, |
| 70 | + "num_local_experts": num_experts, |
| 71 | + # Otherwise there will not be any expert layers |
| 72 | + "first_k_dense_replace": 0, |
| 73 | + # To avoid OOM on DeepSeek-V3 |
| 74 | + "n_routed_experts": num_experts, |
| 75 | + # For Gemma-3n |
| 76 | + "num_kv_shared_layers": 1, |
| 77 | + }) |
| 78 | + if hasattr(hf_config, "vision_config"): |
| 79 | + hf_config.vision_config.update({ |
| 80 | + "num_layers": 1, |
| 81 | + "num_hidden_layers": 1, |
| 82 | + }) |
| 83 | + # e.g.: ibm-granite/granite-speech-3.3-2b |
| 84 | + if hasattr(hf_config, "encoder_config"): |
| 85 | + hf_config.encoder_config.update({ |
| 86 | + "num_layers": 1, |
| 87 | + "num_hidden_layers": 1, |
| 88 | + }) |
| 89 | + # e.g.: Qwen/Qwen2-Audio-7B-Instruct |
| 90 | + if hasattr(hf_config, "audio_config"): |
| 91 | + hf_config.audio_config.update({ |
| 92 | + "num_layers": 1, |
| 93 | + "num_hidden_layers": 1, |
| 94 | + "encoder_layers": 1, |
| 95 | + }) |
| 96 | + return hf_config |
| 97 | + |
| 98 | + |
| 99 | +@pytest.mark.core_model |
| 100 | +@pytest.mark.parametrize("model_arch", list(_MULTIMODAL_EXAMPLE_MODELS.keys())) |
| 101 | +def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner], |
| 102 | + monkeypatch): |
| 103 | + if model_arch in ARCH_TO_SKIP: |
| 104 | + pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}") |
| 105 | + |
| 106 | + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) |
| 107 | + model_info.check_available_online(on_fail="skip") |
| 108 | + |
| 109 | + model_id = model_info.default |
| 110 | + |
| 111 | + hf_overrides_fn = partial(hf_overrides, |
| 112 | + exist_overrides=model_info.hf_overrides) |
| 113 | + |
| 114 | + model_config = ModelConfig( |
| 115 | + model_id, |
| 116 | + tokenizer=model_info.tokenizer or model_id, |
| 117 | + tokenizer_mode=model_info.tokenizer_mode, |
| 118 | + revision=model_info.revision, |
| 119 | + trust_remote_code=model_info.trust_remote_code, |
| 120 | + hf_overrides=model_info.hf_overrides, |
| 121 | + ) |
| 122 | + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) |
| 123 | + factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] |
| 124 | + |
| 125 | + if not any( |
| 126 | + hasattr(model_cls, f"_parse_and_validate_{m}_input") |
| 127 | + for m in ["image", "video", "audio"]): |
| 128 | + pytest.skip(f"{model_arch} does not support tensor schema validation.") |
| 129 | + |
| 130 | + ctx = InputProcessingContext( |
| 131 | + model_config, |
| 132 | + tokenizer=cached_tokenizer_from_config(model_config), |
| 133 | + ) |
| 134 | + processing_info = factories.info(ctx) |
| 135 | + supported_mm_limits = processing_info.get_supported_mm_limits() |
| 136 | + limit_mm_per_prompt = { |
| 137 | + modality: 3 if limit is None else limit |
| 138 | + for modality, limit in supported_mm_limits.items() |
| 139 | + } |
| 140 | + |
| 141 | + # Avoid calling model.forward() |
| 142 | + def _initialize_kv_caches_v0(self) -> None: |
| 143 | + self.cache_config.num_gpu_blocks = 0 |
| 144 | + self.cache_config.num_cpu_blocks = 0 |
| 145 | + |
| 146 | + def _initialize_kv_caches_v1(self, vllm_config): |
| 147 | + kv_cache_specs = self.model_executor.get_kv_cache_specs() |
| 148 | + scheduler_kv_cache_config = get_kv_cache_config( |
| 149 | + vllm_config, |
| 150 | + kv_cache_specs[0], |
| 151 | + 10 * GiB_bytes, |
| 152 | + ) |
| 153 | + |
| 154 | + # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config |
| 155 | + return 1, 0, scheduler_kv_cache_config |
| 156 | + |
| 157 | + with (patch.object(V0LLMEngine, "_initialize_kv_caches", |
| 158 | + _initialize_kv_caches_v0), |
| 159 | + patch.object(V1EngineCore, "_initialize_kv_caches", |
| 160 | + _initialize_kv_caches_v1), monkeypatch.context() as m): |
| 161 | + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") |
| 162 | + if model_info.v0_only: |
| 163 | + m.setenv("VLLM_USE_V1", "0") |
| 164 | + |
| 165 | + with ( |
| 166 | + set_default_torch_num_threads(1), |
| 167 | + vllm_runner( |
| 168 | + model_id, |
| 169 | + tokenizer_name=model_info.tokenizer, |
| 170 | + tokenizer_mode=model_info.tokenizer_mode, |
| 171 | + revision=model_info.revision, |
| 172 | + trust_remote_code=model_info.trust_remote_code, |
| 173 | + max_model_len=model_info.max_model_len, |
| 174 | + load_format="dummy", |
| 175 | + hf_overrides=hf_overrides_fn, |
| 176 | + limit_mm_per_prompt=limit_mm_per_prompt, |
| 177 | + enforce_eager=True, |
| 178 | + ) as vllm_model, |
| 179 | + ): |
| 180 | + model_config = vllm_model.llm.llm_engine.model_config |
| 181 | + llm_engine = vllm_model.llm.llm_engine |
| 182 | + |
| 183 | + if hasattr(llm_engine, "processor"): |
| 184 | + # v1 processor |
| 185 | + mm_registry = llm_engine.processor.mm_registry |
| 186 | + else: |
| 187 | + # v0 input_preprocessor |
| 188 | + mm_registry = llm_engine.input_preprocessor.mm_registry |
| 189 | + |
| 190 | + processor = mm_registry.create_processor(model_config) |
| 191 | + mm_kwargs = create_batched_mm_kwargs(model_config, processor) |
| 192 | + |
| 193 | + def validate_model_input(model): |
| 194 | + for modality in ("audio", "image", "video"): |
| 195 | + method_name = f"_parse_and_validate_{modality}_input" |
| 196 | + if hasattr(model, method_name): |
| 197 | + getattr(model, method_name)(**mm_kwargs) |
| 198 | + |
| 199 | + vllm_model.apply_model(validate_model_input) |
0 commit comments