|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +"""Tests for Llama4's multimodal preprocessing kwargs.""" |
| 3 | + |
| 4 | +import pytest |
| 5 | + |
| 6 | +from vllm.multimodal import MULTIMODAL_REGISTRY |
| 7 | +from vllm.transformers_utils.tokenizer import encode_tokens |
| 8 | + |
| 9 | +from ....conftest import _ImageAssets |
| 10 | +from ...utils import build_model_context |
| 11 | + |
| 12 | + |
| 13 | +@pytest.mark.parametrize("model_id", |
| 14 | + ["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) |
| 15 | +@pytest.mark.parametrize("mm_processor_kwargs", [{}]) |
| 16 | +@pytest.mark.parametrize("num_imgs", [1, 5]) |
| 17 | +@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False]) |
| 18 | +@pytest.mark.parametrize("tokenized_prompt", [True, False]) |
| 19 | +def test_processor_override( |
| 20 | + image_assets: _ImageAssets, |
| 21 | + model_id: str, |
| 22 | + mm_processor_kwargs: dict, |
| 23 | + num_imgs: int, |
| 24 | + disable_mm_preprocessor_cache: bool, |
| 25 | + tokenized_prompt: bool, |
| 26 | +): |
| 27 | + """Ensure llama4 processor works properly.""" |
| 28 | + ctx = build_model_context( |
| 29 | + model_id, |
| 30 | + mm_processor_kwargs=mm_processor_kwargs, |
| 31 | + limit_mm_per_prompt={"image": num_imgs}, |
| 32 | + disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, |
| 33 | + ) |
| 34 | + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) |
| 35 | + config = processor.info.get_hf_config() |
| 36 | + tokenizer = processor.info.get_tokenizer() |
| 37 | + hf_processor = processor.info.get_hf_processor() |
| 38 | + vocab = tokenizer.get_vocab() |
| 39 | + |
| 40 | + prompt = "<|begin_of_text|><|header_start|>user<|header_end|>" \ |
| 41 | + + "<|image|>" * num_imgs \ |
| 42 | + + "<|eot|><|header_start|>assistant<|header_end|>" |
| 43 | + mm_data = { |
| 44 | + "image": [ |
| 45 | + image_assets[(i % len(image_assets))].pil_image |
| 46 | + for i in range(num_imgs) |
| 47 | + ] |
| 48 | + } |
| 49 | + if tokenized_prompt: |
| 50 | + prompt = encode_tokens(tokenizer, prompt) |
| 51 | + |
| 52 | + processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) |
| 53 | + mm_kwargs = processed_inputs["mm_kwargs"] |
| 54 | + |
| 55 | + # place holder replacements |
| 56 | + prompt_token_ids = processed_inputs["prompt_token_ids"] |
| 57 | + assert prompt_token_ids.count(config.boi_token_index) == num_imgs |
| 58 | + assert prompt_token_ids.count(config.eoi_token_index) == num_imgs |
| 59 | + assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs |
| 60 | + aspect_ratios = mm_kwargs["aspect_ratios"] |
| 61 | + num_x_separators = num_y_separators = 0 |
| 62 | + for tiles_y, tiles_x in aspect_ratios: |
| 63 | + if tiles_x * tiles_y > 1: |
| 64 | + num_x_separators += (tiles_x - 1) * tiles_y |
| 65 | + num_y_separators += tiles_y |
| 66 | + assert prompt_token_ids.count(vocab[hf_processor.tile_token]) \ |
| 67 | + == num_x_separators |
| 68 | + assert prompt_token_ids.count(vocab[hf_processor.tile_global_token]) \ |
| 69 | + == num_y_separators |
| 70 | + |
| 71 | + # image token offsets |
| 72 | + img_locs = processed_inputs["mm_placeholders"].get("image", []) |
| 73 | + assert len(img_locs) == num_imgs |
| 74 | + assert [img_loc["offset"] for img_loc in img_locs] == \ |
| 75 | + [i for i, v in enumerate(prompt_token_ids) \ |
| 76 | + if v == config.boi_token_index] |
| 77 | + |
| 78 | + # patch sizes and masks |
| 79 | + assert prompt_token_ids.count(config.image_token_index) \ |
| 80 | + == sum(img_patch.sum() for img_patch in mm_kwargs["embed_is_patch"]) |
| 81 | + patch_token_id = vocab[hf_processor.img_patch_token] |
| 82 | + num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id) |
| 83 | + mm_counts = {"image": num_imgs} |
| 84 | + assert num_patches / num_imgs <= \ |
| 85 | + processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"] |
| 86 | + num_patches_per_chunk = processor.info.get_patch_per_chunk( |
| 87 | + config.vision_config) |
| 88 | + assert prompt_token_ids.count(config.image_token_index) \ |
| 89 | + == mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk |
| 90 | + assert mm_kwargs["pixel_values"].shape[0] \ |
| 91 | + == mm_kwargs["patches_per_image"].sum() |
| 92 | + |
| 93 | + for embed_is_patch, aspect_ratio in zip(mm_kwargs["embed_is_patch"], |
| 94 | + mm_kwargs["aspect_ratios"]): |
| 95 | + assert embed_is_patch.shape[0] == \ |
| 96 | + len(tokenizer.encode( |
| 97 | + hf_processor._prompt_split_image( |
| 98 | + aspect_ratio, num_patches_per_chunk), |
| 99 | + add_special_tokens=False)) |
0 commit comments