|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2025 The OpenBMB Team. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import math |
| 17 | +from typing import Any |
| 18 | +from typing import Dict |
| 19 | +from typing import List |
| 20 | +from typing import Optional |
| 21 | +from typing import Union |
| 22 | + |
| 23 | +import numpy as np |
| 24 | +import PIL |
| 25 | +import PIL.Image |
| 26 | +import PIL.ImageSequence |
| 27 | +import torch |
| 28 | +from PIL import Image |
| 29 | +from transformers import AutoImageProcessor |
| 30 | +from transformers.image_processing_utils import BaseImageProcessor |
| 31 | +from transformers.image_processing_utils import BatchFeature |
| 32 | +from transformers.image_transforms import to_channel_dimension_format |
| 33 | +from transformers.image_utils import ChannelDimension |
| 34 | +from transformers.image_utils import infer_channel_dimension_format |
| 35 | +from transformers.image_utils import is_torch_tensor |
| 36 | +from transformers.image_utils import to_numpy_array |
| 37 | +from transformers.image_utils import valid_images |
| 38 | +from transformers.utils import is_torch_device |
| 39 | +from transformers.utils import is_torch_dtype |
| 40 | +from transformers.utils import requires_backends |
| 41 | +from transformers.utils import TensorType |
| 42 | +from .processing_minicpm_o_2_6 import MiniCPMOBatchFeature |
| 43 | + |
| 44 | + |
| 45 | +def recursive_converter(converter, value): |
| 46 | + if isinstance(value, list): |
| 47 | + new_value = [] |
| 48 | + for v in value: |
| 49 | + new_value += [recursive_converter(converter, v)] |
| 50 | + return new_value |
| 51 | + else: |
| 52 | + return converter(value) |
| 53 | + |
| 54 | + |
| 55 | +class MiniCPMVImageProcessor(BaseImageProcessor): |
| 56 | + model_input_names = ["pixel_values"] |
| 57 | + |
| 58 | + def __init__(self, max_slice_nums=9, scale_resolution=448, patch_size=14, **kwargs): |
| 59 | + super().__init__(**kwargs) |
| 60 | + self.max_slice_nums = max_slice_nums |
| 61 | + self.scale_resolution = scale_resolution |
| 62 | + self.patch_size = patch_size |
| 63 | + |
| 64 | + self.use_image_id = kwargs.pop("use_image_id", False) |
| 65 | + self.image_feature_size = kwargs.pop("image_feature_size", 64) |
| 66 | + |
| 67 | + self.slice_mode = kwargs.pop("slice_mode", True) |
| 68 | + |
| 69 | + self.mean = np.array(kwargs.pop("norm_mean", [0.5, 0.5, 0.5])) |
| 70 | + self.std = np.array(kwargs.pop("norm_std", [0.5, 0.5, 0.5])) |
| 71 | + self.version = kwargs.pop("version", 2.0) |
| 72 | + |
| 73 | + def ensure_divide(self, length, patch_size): |
| 74 | + return max(round(length / patch_size) * patch_size, patch_size) |
| 75 | + |
| 76 | + def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False): |
| 77 | + width, height = original_size |
| 78 | + if (width * height > scale_resolution * scale_resolution) or allow_upscale: |
| 79 | + r = width / height |
| 80 | + height = int(scale_resolution / math.sqrt(r)) |
| 81 | + width = int(height * r) |
| 82 | + best_width = self.ensure_divide(width, patch_size) |
| 83 | + best_height = self.ensure_divide(height, patch_size) |
| 84 | + return (best_width, best_height) |
| 85 | + |
| 86 | + def get_refine_size(self, original_size, grid, scale_resolution, patch_size, allow_upscale=False): |
| 87 | + width, height = original_size |
| 88 | + grid_x, grid_y = grid |
| 89 | + |
| 90 | + refine_width = self.ensure_divide(width, grid_x) |
| 91 | + refine_height = self.ensure_divide(height, grid_y) |
| 92 | + |
| 93 | + grid_width = refine_width / grid_x |
| 94 | + grid_height = refine_height / grid_y |
| 95 | + |
| 96 | + best_grid_size = self.find_best_resize( |
| 97 | + (grid_width, grid_height), scale_resolution, patch_size, allow_upscale=allow_upscale |
| 98 | + ) |
| 99 | + refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) |
| 100 | + return refine_size |
| 101 | + |
| 102 | + def split_to_patches(self, image, grid): |
| 103 | + patches = [] |
| 104 | + width, height = image.size |
| 105 | + grid_x = int(width / grid[0]) |
| 106 | + grid_y = int(height / grid[1]) |
| 107 | + for i in range(0, height, grid_y): |
| 108 | + images = [] |
| 109 | + for j in range(0, width, grid_x): |
| 110 | + box = (j, i, j + grid_x, i + grid_y) |
| 111 | + patch = image.crop(box) |
| 112 | + images.append(patch) |
| 113 | + patches.append(images) |
| 114 | + return patches |
| 115 | + |
| 116 | + def slice_image(self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False): |
| 117 | + original_size = image.size |
| 118 | + source_image = None |
| 119 | + best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split) |
| 120 | + patches = [] |
| 121 | + |
| 122 | + if best_grid is None: |
| 123 | + # dont need to slice, upsample |
| 124 | + best_size = self.find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True) |
| 125 | + source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC) |
| 126 | + else: |
| 127 | + # source image, down-sampling and ensure divided by patch_size |
| 128 | + best_resize = self.find_best_resize(original_size, scale_resolution, patch_size) |
| 129 | + source_image = image.copy().resize(best_resize, resample=Image.Resampling.BICUBIC) |
| 130 | + refine_size = self.get_refine_size( |
| 131 | + original_size, best_grid, scale_resolution, patch_size, allow_upscale=True |
| 132 | + ) |
| 133 | + refine_image = image.resize(refine_size, resample=Image.Resampling.BICUBIC) |
| 134 | + patches = self.split_to_patches(refine_image, best_grid) |
| 135 | + |
| 136 | + return source_image, patches, best_grid |
| 137 | + |
| 138 | + def get_grid_placeholder(self, grid): |
| 139 | + if grid is None: |
| 140 | + return "" |
| 141 | + slice_image_placeholder = ( |
| 142 | + self.tokenizer.slice_start + self.tokenizer.unk_token * self.image_feature_size + self.tokenizer.slice_end |
| 143 | + ) |
| 144 | + |
| 145 | + cols = grid[0] |
| 146 | + rows = grid[1] |
| 147 | + slices = [] |
| 148 | + for i in range(rows): |
| 149 | + lines = [] |
| 150 | + for j in range(cols): |
| 151 | + lines.append(slice_image_placeholder) |
| 152 | + slices.append("".join(lines)) |
| 153 | + |
| 154 | + slice_placeholder = "\n".join(slices) |
| 155 | + return slice_placeholder |
| 156 | + |
| 157 | + def get_image_id_placeholder(self, idx=0): |
| 158 | + return f"{self.tokenizer.im_id_start}{idx}{self.tokenizer.im_id_end}" |
| 159 | + |
| 160 | + def get_sliced_images(self, image, max_slice_nums=None): |
| 161 | + slice_images = [] |
| 162 | + |
| 163 | + if not self.slice_mode: |
| 164 | + return [image] |
| 165 | + |
| 166 | + max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums) |
| 167 | + assert max_slice_nums > 0 |
| 168 | + source_image, patches, sliced_grid = self.slice_image( |
| 169 | + image, max_slice_nums, self.scale_resolution, self.patch_size # default: 9 # default: 448 # default: 14 |
| 170 | + ) |
| 171 | + |
| 172 | + slice_images.append(source_image) |
| 173 | + if len(patches) > 0: |
| 174 | + for i in range(len(patches)): |
| 175 | + for j in range(len(patches[0])): |
| 176 | + slice_images.append(patches[i][j]) |
| 177 | + return slice_images |
| 178 | + |
| 179 | + def get_sliced_grid(self, image_size, max_slice_nums, nerver_split=False): |
| 180 | + original_width, original_height = image_size |
| 181 | + log_ratio = math.log(original_width / original_height) |
| 182 | + ratio = original_width * original_height / (self.scale_resolution * self.scale_resolution) |
| 183 | + multiple = min(math.ceil(ratio), max_slice_nums) |
| 184 | + if multiple <= 1 or nerver_split: |
| 185 | + return None |
| 186 | + candidate_split_grids_nums = [] |
| 187 | + for i in [multiple - 1, multiple, multiple + 1]: |
| 188 | + if i == 1 or i > max_slice_nums: |
| 189 | + continue |
| 190 | + candidate_split_grids_nums.append(i) |
| 191 | + |
| 192 | + candidate_grids = [] |
| 193 | + for split_grids_nums in candidate_split_grids_nums: |
| 194 | + m = 1 |
| 195 | + while m <= split_grids_nums: |
| 196 | + if split_grids_nums % m == 0: |
| 197 | + candidate_grids.append([m, split_grids_nums // m]) |
| 198 | + m += 1 |
| 199 | + |
| 200 | + best_grid = [1, 1] |
| 201 | + min_error = float("inf") |
| 202 | + for grid in candidate_grids: |
| 203 | + error = abs(log_ratio - math.log(grid[0] / grid[1])) |
| 204 | + if error < min_error: |
| 205 | + best_grid = grid |
| 206 | + min_error = error |
| 207 | + |
| 208 | + return best_grid |
| 209 | + |
| 210 | + def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None): |
| 211 | + max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums) |
| 212 | + assert max_slice_nums > 0 |
| 213 | + grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums) |
| 214 | + |
| 215 | + image_placeholder = self.tokenizer.im_start + self.tokenizer.unk_token * self.image_feature_size + self.tokenizer.im_end |
| 216 | + use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id) |
| 217 | + if use_image_id: |
| 218 | + final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder |
| 219 | + else: |
| 220 | + final_placeholder = image_placeholder |
| 221 | + |
| 222 | + if self.slice_mode: |
| 223 | + final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid) |
| 224 | + return final_placeholder |
| 225 | + |
| 226 | + def to_pil_image(self, image, rescale=None) -> PIL.Image.Image: |
| 227 | + """ |
| 228 | + Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if |
| 229 | + needed. |
| 230 | +
|
| 231 | + Args: |
| 232 | + image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`): |
| 233 | + The image to convert to the PIL Image format. |
| 234 | + rescale (`bool`, *optional*): |
| 235 | + Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will |
| 236 | + default to `True` if the image type is a floating type, `False` otherwise. |
| 237 | + """ |
| 238 | + if isinstance(image, PIL.Image.Image): |
| 239 | + return image |
| 240 | + if is_torch_tensor(image): |
| 241 | + image = image.numpy() |
| 242 | + |
| 243 | + if isinstance(image, np.ndarray): |
| 244 | + if rescale is None: |
| 245 | + # rescale default to the array being of floating type. |
| 246 | + rescale = isinstance(image.flat[0], np.floating) |
| 247 | + # If the channel as been moved to first dim, we put it back at the end. |
| 248 | + if image.ndim == 3 and image.shape[0] in [1, 3]: |
| 249 | + image = image.transpose(1, 2, 0) |
| 250 | + if rescale: |
| 251 | + image = image * 255 |
| 252 | + image = image.astype(np.uint8) |
| 253 | + return PIL.Image.fromarray(image) |
| 254 | + return image |
| 255 | + |
| 256 | + def reshape_by_patch(self, image): |
| 257 | + """ |
| 258 | + :param image: shape [3, H, W] |
| 259 | + :param patch_size: |
| 260 | + :return: [3, patch_size, HW/patch_size] |
| 261 | + """ |
| 262 | + image = torch.from_numpy(image) |
| 263 | + patch_size = self.patch_size |
| 264 | + patches = torch.nn.functional.unfold(image, (patch_size, patch_size), stride=(patch_size, patch_size)) |
| 265 | + |
| 266 | + patches = patches.reshape(image.size(0), patch_size, patch_size, -1) |
| 267 | + patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1) |
| 268 | + return patches.numpy() |
| 269 | + |
| 270 | + def preprocess( |
| 271 | + self, |
| 272 | + images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]], |
| 273 | + do_pad: Optional[bool] = True, |
| 274 | + max_slice_nums: int = None, |
| 275 | + return_tensors: Optional[Union[str, TensorType]] = None, |
| 276 | + **kwargs, |
| 277 | + ) -> MiniCPMOBatchFeature: |
| 278 | + if isinstance(images, Image.Image): |
| 279 | + images_list = [[images]] |
| 280 | + elif isinstance(images[0], Image.Image): |
| 281 | + images_list = [images] |
| 282 | + else: |
| 283 | + images_list = images |
| 284 | + |
| 285 | + new_images_list = [] |
| 286 | + image_sizes_list = [] |
| 287 | + tgt_sizes_list = [] |
| 288 | + |
| 289 | + for _images in images_list: |
| 290 | + if _images is None or len(_images) == 0: |
| 291 | + new_images_list.append([]) |
| 292 | + image_sizes_list.append([]) |
| 293 | + tgt_sizes_list.append([]) |
| 294 | + continue |
| 295 | + if not valid_images(_images): |
| 296 | + raise ValueError( |
| 297 | + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " |
| 298 | + "torch.Tensor, tf.Tensor or jax.ndarray." |
| 299 | + ) |
| 300 | + |
| 301 | + _images = [self.to_pil_image(image).convert("RGB") for image in _images] |
| 302 | + input_data_format = infer_channel_dimension_format(np.array(_images[0])) |
| 303 | + |
| 304 | + new_images = [] |
| 305 | + image_sizes = [image.size for image in _images] |
| 306 | + tgt_sizes = [] |
| 307 | + for image in _images: |
| 308 | + image_patches = self.get_sliced_images(image, max_slice_nums) |
| 309 | + image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches] |
| 310 | + image_patches = [ |
| 311 | + self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format) |
| 312 | + for image in image_patches |
| 313 | + ] |
| 314 | + image_patches = [ |
| 315 | + to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format) |
| 316 | + for image in image_patches |
| 317 | + ] |
| 318 | + for slice_image in image_patches: |
| 319 | + new_images.append(self.reshape_by_patch(slice_image)) |
| 320 | + tgt_sizes.append( |
| 321 | + np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size)) |
| 322 | + ) |
| 323 | + |
| 324 | + if tgt_sizes: |
| 325 | + tgt_sizes = np.vstack(tgt_sizes) |
| 326 | + |
| 327 | + new_images_list.append(new_images) |
| 328 | + image_sizes_list.append(image_sizes) |
| 329 | + tgt_sizes_list.append(tgt_sizes) |
| 330 | + return MiniCPMOBatchFeature( |
| 331 | + data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list}, |
| 332 | + tensor_type=return_tensors, |
| 333 | + ) |
| 334 | + |
| 335 | + |
| 336 | +AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor) |
| 337 | + |
| 338 | + |
| 339 | +__all__ = ["MiniCPMVImageProcessor"] |
0 commit comments