diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 624de2a2debc..76e32e3a7f58 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -804,6 +804,43 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: ) +# Lfm2-VL +def run_lfm2_vl(questions: list[str], modality: str) -> ModelRequestData: + model_name = "LiquidAI/LFM2-VL-1.6B" + # model_name = os.path.expanduser("~/models/lfm2_vl") + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "" + else: + raise ValueError(f"Unsupported modality: {modality}") + + prompts = [ + ( + "<|startoftext|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|startoftext|>user\n<|image_start|>{placeholder}<|image_end|>" + f"{question}<|im_end|>\n" + "<|startoftext|>assistant\n" + ) + for question in questions + ] + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # LightOnOCR def run_lightonocr(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1827,6 +1864,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: "keye_vl": run_keye_vl, "keye_vl1_5": run_keye_vl1_5, "kimi_vl": run_kimi_vl, + "lfm2_vl": run_lfm2_vl, "lightonocr": run_lightonocr, "llama4": run_llama4, "llava": run_llava, diff --git a/vllm/model_executor/models/lfm2_vl.py b/vllm/model_executor/models/lfm2_vl.py new file mode 100644 index 000000000000..89fb4dad5145 --- /dev/null +++ b/vllm/model_executor/models/lfm2_vl.py @@ -0,0 +1,732 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Literal + +import torch +import torch.nn as nn +from transformers import BatchFeature +from transformers.activations import ACT2FN +from transformers.models.lfm2_vl import Lfm2VlProcessor +from transformers.models.lfm2_vl.configuration_lfm2_vl import Lfm2VlConfig +from transformers.models.lfm2_vl.image_processing_lfm2_vl_fast import ( + Lfm2VlImageProcessorFast, + find_closest_aspect_ratio, + round_by_factor, +) + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdateDetails, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import ( + IsHybrid, + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) + + +class Lfm2VLImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - d: Number of dimensions + - fd: Number of features per dimension + """ + + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("bn", "d", "fd")] + pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bn", "d")] + spatial_shapes: Annotated[torch.Tensor, TensorShape("bn", 2)] + num_patches: torch.Tensor + + +LFM2VLImageInputs = Lfm2VLImagePixelInputs + + +class Lfm2VLProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(Lfm2VlConfig) + + def get_hf_processor(self, **kwargs): + return self.ctx.get_hf_processor(Lfm2VlProcessor, **kwargs) + + def get_image_processor(self, **kwargs: object) -> Lfm2VlImageProcessorFast: + return self.get_hf_processor(**kwargs).image_processor + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None} + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_image_processor() + max_image_tokens = processor.max_image_tokens + encoder_patch_size = processor.encoder_patch_size + downsample_factor = processor.downsample_factor + max_pixels = max_image_tokens * (encoder_patch_size**2) * (downsample_factor**2) + side = int(math.sqrt(max_pixels)) + return ImageSize(width=side, height=side) + + def _is_image_too_large( + self, + height: int, + width: int, + max_image_tokens: int, + encoder_patch_size: int, + downsample_factor: int, + max_pixels_tolerance: float, + ) -> bool: + """Check if the image is too large to be processed as one tile.""" + total_factor = encoder_patch_size * downsample_factor + + h_bar = max(encoder_patch_size, round_by_factor(height, total_factor)) + w_bar = max(encoder_patch_size, round_by_factor(width, total_factor)) + return ( + h_bar * w_bar + > max_image_tokens + * encoder_patch_size**2 + * downsample_factor**2 + * max_pixels_tolerance + ) + + def smart_resize( + self, + height: int, + width: int, + downsample_factor: int, + min_image_tokens: int, + max_image_tokens: int, + encoder_patch_size: int, + ) -> tuple[int, int]: + total_factor = encoder_patch_size * downsample_factor + smart_resize_min_pixels = ( + min_image_tokens * encoder_patch_size**2 * downsample_factor**2 + ) + smart_resize_max_pixels = ( + max_image_tokens * encoder_patch_size**2 * downsample_factor**2 + ) + + h_bar = max(total_factor, round_by_factor(height, total_factor)) + w_bar = max(total_factor, round_by_factor(width, total_factor)) + + if h_bar * w_bar > smart_resize_max_pixels: + beta = math.sqrt((height * width) / smart_resize_max_pixels) + math.floor(height / beta / total_factor) * total_factor + h_bar = max( + total_factor, math.floor(height / beta / total_factor) * total_factor + ) + w_bar = max( + total_factor, math.floor(width / beta / total_factor) * total_factor + ) + elif h_bar * w_bar < smart_resize_min_pixels: + beta = math.sqrt(smart_resize_min_pixels / (height * width)) + h_bar = math.ceil(height * beta / total_factor) * total_factor + w_bar = math.ceil(width * beta / total_factor) * total_factor + + return w_bar, h_bar + + def _target_ratios(self, min_tiles: int, max_tiles: int) -> list[tuple[int, int]]: + ratios = [ + (w, h) + for n in range(min_tiles, max_tiles + 1) + for w in range(1, n + 1) + for h in range(1, n + 1) + if min_tiles <= w * h <= max_tiles + ] + return sorted(set(ratios), key=lambda x: x[0] * x[1]) + + def _get_grid_layout( + self, + height: int, + width: int, + min_tiles: int, + max_tiles: int, + tile_size: int, + ) -> tuple[int, int]: + aspect_ratio = width / height + target_ratios = self._target_ratios(min_tiles, max_tiles) + # find best matching grid configuration + grid_width, grid_height = find_closest_aspect_ratio( + aspect_ratio, target_ratios, width, height, tile_size + ) + total_patches = grid_width * grid_height + return grid_width, grid_height, total_patches + + def _get_image_feature_grid_size( + self, + image_width: int, + image_height: int, + processor: Lfm2VlProcessor | None, + ) -> tuple[int, int]: + if processor is None: + processor = self.get_image_processor() + + downsample_factor = processor.image_processor.downsample_factor + encoder_patch_size = processor.image_processor.encoder_patch_size + max_pixels_tolerance = processor.image_processor.max_pixels_tolerance + min_tiles = processor.image_processor.min_tiles + max_tiles = processor.image_processor.max_tiles + max_image_tokens = processor.image_processor.max_image_tokens + tile_size = processor.image_processor.tile_size + + do_image_splitting = not min_tiles == max_tiles == 1 + is_image_large = self._is_image_too_large( + height=image_height, + width=image_width, + max_image_tokens=max_image_tokens, + encoder_patch_size=encoder_patch_size, + downsample_factor=downsample_factor, + max_pixels_tolerance=max_pixels_tolerance, + ) + + # Big image will be cropped into patches and small images are just resized + if is_image_large and do_image_splitting: + grid_width, grid_height, total_patches = self._get_grid_layout( + image_height, + image_width, + min_tiles=min_tiles, + max_tiles=max_tiles, + tile_size=tile_size, + ) + else: + grid_width = grid_height = total_patches = 1 + + if grid_width * grid_height != 1: # Thumbnail + total_patches += 1 + + return grid_width, grid_height, total_patches + + def get_num_patches( + self, + *, + image_width: int, + image_height: int, + processor: Lfm2VlProcessor | None, + ) -> int: + _, _, total_patches = self._get_image_feature_grid_size( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + return total_patches + + def get_image_repl( + self, + image_width: int, + image_height: int, + spatial_shapes: torch.Tensor, + processor: Lfm2VlProcessor | None, + ) -> str: + if processor is None: + processor = self.get_hf_processor() + + grid_placeholder = "<|img_row_{n_h}_col_{n_w}|>" + image_token = processor.image_token + image_start_token = processor.image_start_token + image_end_token = processor.image_end_token + image_thumbnail_token = processor.image_thumbnail_token + + num_thumbnail_tokens, num_tokens_per_tile = self.get_num_image_tokens( + spatial_shapes=spatial_shapes, + processor=processor, + ) + tile_img_placeholder = grid_placeholder + (image_token * num_tokens_per_tile) + + grid_w, grid_h, _ = self._get_image_feature_grid_size( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + + if grid_w > 1 or grid_h > 1: + tiles_placeholder: list[str] = [ + tile_img_placeholder.format(n_h=i + 1, n_w=j + 1) + for i in range(grid_h) + for j in range(grid_w) + ] + + if num_thumbnail_tokens > 0: + tiles_placeholder.append( + image_thumbnail_token + (image_token * num_thumbnail_tokens) + ) + else: + tiles_placeholder = [image_token * num_thumbnail_tokens] + + placeholder = "".join( + itertools.chain([image_start_token], tiles_placeholder, [image_end_token]) + ) + return placeholder + + def get_num_image_tokens( + self, + *, + spatial_shapes: torch.Tensor, + processor: Lfm2VlProcessor | None, + ) -> tuple[int, int]: + tile_size = processor.image_processor.tile_size + downsample_factor = processor.image_processor.downsample_factor + encoder_patch_size = processor.image_processor.encoder_patch_size + num_thumbnail_tokens = spatial_shapes[-1].prod() // (downsample_factor**2) + num_patches_tile = tile_size // encoder_patch_size + dwn_num_patches_tile = math.ceil(num_patches_tile / downsample_factor) + num_tiles_tokens = dwn_num_patches_tile * dwn_num_patches_tile + return num_thumbnail_tokens, num_tiles_tokens + + +class Lfm2VLDummyInputsBuilder(BaseDummyInputsBuilder[Lfm2VLProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + processor = self.info.get_hf_processor() + image_token = processor.image_token + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + } + + +class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + # Text-only input not supported in composite processor + if not (images := mm_data.get("images", [])): + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + mm_kwargs, + tok_kwargs, + ) + + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", ImageProcessorItems) + ) + image_sizes = [ + parsed_images.get_image_size(i) for i in range(len(parsed_images)) + ] + hf_processor = self.info.get_hf_processor(**mm_kwargs) + + num_patches = [ + self.info.get_num_patches( + image_width=size.width, + image_height=size.height, + processor=hf_processor, + ) + for size in image_sizes + ] + processed_outputs["num_patches"] = torch.tensor(num_patches) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_patches = hf_inputs.get("num_patches", torch.empty(0)) + + return dict[str, MultiModalFieldConfig]( + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), + pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches + ), + spatial_shapes=MultiModalFieldConfig.flat_from_sizes("image", num_patches), + num_patches=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptReplacement]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_token = hf_processor.image_token + + def get_image_replacement_lfm2vl(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + out_item = out_mm_kwargs["image"][item_idx] + spatial_shapes = out_item["spatial_shapes"].data + assert isinstance(spatial_shapes, torch.Tensor) + image_repl = self.info.get_image_repl( + image_width=image_size.width, + image_height=image_size.height, + spatial_shapes=spatial_shapes, + processor=hf_processor, + ) + return PromptUpdateDetails.select_text( + image_repl, + embed_text=image_token, + ) + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_image_replacement_lfm2vl, + ) + ] + + +class Lfm2VLMultiModalProjector(nn.Module): + def __init__( + self, config: Lfm2VlConfig, use_data_parallel: bool = False, prefix: str = "" + ): + super().__init__() + self.use_data_parallel = use_data_parallel + + in_channels = config.vision_config.hidden_size * (config.downsample_factor**2) + self.factor = config.downsample_factor + self.layer_norm = nn.LayerNorm(in_channels) + self.linear_1 = nn.Linear( + in_channels, + config.projector_hidden_size, + bias=config.projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.projector_hidden_size, + config.text_config.hidden_size, + bias=config.projector_bias, + ) + + def forward(self, image_features: torch.Tensor): + image_features = self.pixel_unshuffle(image_features) + image_features = self.layer_norm(image_features) + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + def pixel_unshuffle(self, hidden_states: torch.Tensor): + batch_size, width, height, channels = hidden_states.size() + hidden_states = hidden_states.reshape( + batch_size, width, height // self.factor, channels * self.factor + ) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape( + batch_size, + height // self.factor, + width // self.factor, + channels * self.factor**2, + ) + hidden_states = hidden_states.permute(0, 2, 1, 3) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + Lfm2VLMultiModalProcessor, + info=Lfm2VLProcessingInfo, + dummy_inputs=Lfm2VLDummyInputsBuilder, +) +class Lfm2VLForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, IsHybrid +): + merge_by_field_config = True + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + } + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "" + + raise ValueError("Only image modality is supported") + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, ...]: + return MambaStateDtypeCalculator.short_conv_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[tuple[int, int]]: + """Calculate shapes for LFM2's convolutional cache. + + Args: + vllm_config: vLLM config + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + """ + parallel_config = vllm_config.parallel_config + hf_language_config = vllm_config.model_config.hf_config.text_config + + return MambaStateShapeCalculator.short_conv_state_shape( + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_language_config.conv_dim, + conv_kernel=hf_language_config.conv_L_cache, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): + super().__init__() + config: Lfm2VlConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + vision_config = config.vision_config + + self.config = config + self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + + if vision_config.model_type == "siglip2_vision_model": + from vllm.model_executor.models.siglip2 import Siglip2Model + + self.vision_tower = Siglip2Model( + config=vision_config, + quant_config=quant_config, + prefix=f"{prefix}.vit", + use_data_parallel=self.use_data_parallel, + ) + else: + raise ValueError( + f"Unsupported visual tokenizer model_type: {vision_config.model_type}" + ) + + self.multi_modal_projector = Lfm2VLMultiModalProjector( + config=config, + use_data_parallel=self.use_data_parallel, + prefix=f"{prefix}.multi_modal_projector", + ) + + if config.model_type == "lfm2_vl": + architectures = ["Lfm2ForCausalLM"] + elif config.model_type == "lfm2_vl_moe": + architectures = ["Lfm2MoeForCausalLM"] + else: + architectures = None + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language"), + architectures=architectures, + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> LFM2VLImageInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + pixel_attention_mask = kwargs.pop("pixel_attention_mask", None) + spatial_shapes = kwargs.pop("spatial_shapes", None) + num_patches = kwargs.pop("num_patches", None) + if pixel_values is None: + return None + + return LFM2VLImageInputs( + type="pixel_values", + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + num_patches=num_patches, + ) + + def image_pixels_to_features( + self, + pixel_values: torch.FloatTensor, + spatial_shapes: torch.Tensor, + pixel_attention_mask: torch.Tensor, + num_patches: torch.Tensor, + ) -> torch.Tensor: + pixel_values = pixel_values.to( + dtype=self.vision_tower.vision_model.embeddings.patch_embedding.weight.dtype + ) # fp16 compatibility + + image_outputs = self.vision_tower( + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + ) + + img_feature_lengths = pixel_attention_mask.sum(dim=1) + + image_features = [] + + for img_idx, feature_len in enumerate(img_feature_lengths.tolist()): + feature = image_outputs[img_idx, :feature_len] + + # reshape to original height and width + feature_org_h, feature_org_w = spatial_shapes[img_idx].tolist() + feature = feature.reshape(1, feature_org_h, feature_org_w, -1) + + # project the image representation + img_embedding = self.multi_modal_projector(feature) + + # flatten here to handle variable length in naflex + img_embedding = img_embedding.reshape(-1, img_embedding.size(-1)) + image_features.append(img_embedding) + + return image_features + + def _process_image_input( + self, + image_input: LFM2VLImageInputs, + ) -> torch.Tensor | list[torch.Tensor]: + pixel_values = image_input["pixel_values"] + pixel_attention_mask = image_input["pixel_attention_mask"] + spatial_shapes = image_input["spatial_shapes"] + num_patches = image_input["num_patches"] + + image_features = self.image_pixels_to_features( + pixel_values, + spatial_shapes=spatial_shapes, + pixel_attention_mask=pixel_attention_mask, + num_patches=num_patches, + ) + + # total num patches with token length per patch + patch_token_lengths = torch.as_tensor( + [f.size(0) for f in image_features], + device=image_features[0].device, + dtype=torch.long, + ) + image_features = torch.cat(image_features, dim=0) + + # cumulative token boundaries per patch: [0, L0, L0+L1, ..., sum_i Li] + token_cum = torch.cat( + ( + torch.zeros(1, device=image_features.device, dtype=torch.long), + patch_token_lengths.cumsum(0), + ), + dim=0, + ) + + batched_features: list[torch.Tensor] = [] + patch_start_idx = 0 + for count in num_patches.tolist(): + patch_end_idx = patch_start_idx + count + token_start = token_cum[patch_start_idx].item() + token_end = token_cum[patch_end_idx].item() + batched_features.append(image_features[token_start:token_end]) + patch_start_idx = patch_end_idx + + return batched_features + + def get_multimodal_embeddings( + self, **kwargs: object + ) -> MultiModalEmbeddings | None: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + return self._process_image_input(image_input) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.language_model.compute_logits(hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="multi_modal_projector", + tower_model="vision_tower", + ) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 494398760620..09ee9a85a464 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -309,6 +309,7 @@ ), "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 + "Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"), # noqa: E501 "LightOnOCRForConditionalGeneration": ( "lightonocr", "LightOnOCRForConditionalGeneration", diff --git a/vllm/model_executor/models/siglip2.py b/vllm/model_executor/models/siglip2.py new file mode 100644 index 000000000000..413c6b4abb46 --- /dev/null +++ b/vllm/model_executor/models/siglip2.py @@ -0,0 +1,421 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Implementation of Siglip2VisionModel intended to be only used +within a vision language model.""" + +from collections.abc import Iterable + +import torch +from torch import nn +from torch.nn import functional as F +from transformers import Siglip2VisionConfig + +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from .vision import run_dp_sharded_vision_model + + +class Siglip2VisionEmbeddings(nn.Module): + def __init__(self, config: Siglip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * self.patch_size, + out_features=self.embed_dim, + ) + self.num_patches = config.num_patches + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + + @staticmethod + def resize_positional_embeddings( + positional_embeddings: torch.Tensor, + spatial_shapes: torch.LongTensor, + max_length: int, + ) -> torch.Tensor: + """ + Resize positional embeddings to image-specific size and pad to a fixed size. + + Args: + positional_embeddings (`torch.Tensor`): + Position embeddings of shape (height, width, embed_dim) + spatial_shapes (`torch.LongTensor`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + max_length (`int`): + Maximum length of the positional embeddings to pad resized positional embeddings to + + Returns: + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) + """ + batch_size = spatial_shapes.shape[0] + embed_dim = positional_embeddings.shape[-1] + source_dtype = positional_embeddings.dtype + + resulted_positional_embeddings = torch.empty( + (batch_size, max_length, embed_dim), + device=positional_embeddings.device, + dtype=source_dtype, + ) + + # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation + positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0) + + # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU + if positional_embeddings.device.type == "cpu": + positional_embeddings = positional_embeddings.to(torch.float32) + + for i in range(batch_size): + # (1, dim, height, width) -> (1, dim, target_height, target_width) + height, width = spatial_shapes[i] + resized_embeddings = F.interpolate( + positional_embeddings, + size=(height, width), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + # (1, dim, target_height, target_width) -> (target_height * target_width, dim) + resized_embeddings = resized_embeddings.reshape( + embed_dim, height * width + ).transpose(0, 1) + + # Cast to original dtype + resized_embeddings = resized_embeddings.to(source_dtype) + + resulted_positional_embeddings[i, : height * width] = resized_embeddings + resulted_positional_embeddings[i, height * width :] = resized_embeddings[0] + + return resulted_positional_embeddings + + def forward( + self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor + ) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size) + spatial_shapes (`list[tuple[int, int]]`): + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to + """ + + # Apply patch embeddings to already patchified pixel values + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + + # Get positional resized and padded positional embeddings + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, -1 + ) + resized_positional_embeddings = self.resize_positional_embeddings( + positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1] + ) + + # Add positional embeddings to patch embeddings + embeddings = patch_embeds + resized_positional_embeddings + return embeddings + + +class Siglip2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0 + self.num_heads_per_partition = self.num_heads // tp_size + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.attn = MultiHeadAttention( + self.num_heads_per_partition, self.head_dim, self.scale + ) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj( + hidden_states + ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim + query_states, key_states, value_states = qkv.chunk(3, dim=-1) + + # Use unified MultiHeadAttention implementation + out = self.attn(query_states, key_states, value_states) + attn_output, _ = self.out_proj(out) + return attn_output + + +class Siglip2MLP(nn.Module): + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + # TODO(Isotr0py): Enable data parallel after we support + # disabling TP on parallel linear layer + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Siglip2EncoderLayer(nn.Module): + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = Siglip2Attention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel, + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + + def forward( + self, + hidden_states: torch.Tensor, + # attention_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states: Input tensor of shape (batch, seq_len, embed_dim). + cu_seqlens: Cumulative sequence lengths tensor. + position_embeddings: Position embeddings tensor. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + # attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Siglip2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` + self attention layers. Each layer is a [`Siglip2EncoderLayer`]. + + Args: + config: PretrainedConfig + """ + + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + Siglip2EncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{idx}", + use_data_parallel=use_data_parallel, + ) + for idx in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states) + hidden_states = layer_outputs + return hidden_states + + +class Siglip2VisionTransformer(nn.Module): + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + embed_dim = config.hidden_size + self.config = config + self.use_data_parallel = use_data_parallel + self.embeddings = Siglip2VisionEmbeddings(config) + self.encoder = Siglip2Encoder( + config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel, + ) + num_hidden_layers = config.num_hidden_layers + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values: torch.FloatTensor, + # attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + ) -> torch.Tensor: + r""" + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) + of the input images. + """ + hidden_states = self.embeddings(pixel_values, spatial_shapes) + if self.use_data_parallel: + encoder_outputs = run_dp_sharded_vision_model(hidden_states, self.encoder) + else: + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + last_hidden_state = self.post_layernorm(encoder_outputs) + return last_hidden_state + + +class Siglip2Model(torch.nn.Module): + def __init__( + self, + config: Siglip2VisionConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + + self.vision_model = Siglip2VisionTransformer( + config, + quant_config=quant_config, + prefix=f"{prefix}.vision_model", + use_data_parallel=use_data_parallel, + ) + + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_attention_mask: torch.Tensor, + spatial_shapes: torch.LongTensor, + ) -> torch.Tensor: + return self.vision_model( + pixel_values=pixel_values, + # attention_mask=pixel_attention_mask, + spatial_shapes=spatial_shapes, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params