Skip to content

Commit 07db6af

Browse files
committed
now use MiniCPM_o_2_6Tokenizer and MiniCPM_o_2_6TokenizerFast; move MiniCPMVImageProcessor to image_processing_minicpm; remove token parameters in MiniCPMVImageProcessor
1 parent bceee7d commit 07db6af

File tree

7 files changed

+491
-308
lines changed

7 files changed

+491
-308
lines changed

src/transformers/models/auto/tokenization_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@
347347
("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
348348
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
349349
("mgp-str", ("MgpstrTokenizer", None)),
350-
("minicpm_o_2_6", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
350+
("minicpm_o_2_6", ("MiniCPM_o_2_6Tokenizer", "MiniCPM_o_2_6TokenizerFast" if is_tokenizers_available() else None)),
351351
(
352352
"minimax",
353353
(

src/transformers/models/minicpm_o_2_6/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121

2222
if TYPE_CHECKING:
2323
from .configuration_minicpm_o_2_6 import *
24+
from .image_processing_minicpm import *
2425
from .modeling_minicpm_o_2_6 import *
2526
from .processing_minicpm_o_2_6 import *
2627
from .tokenization_minicpm_o_2_6_fast import *
28+
from .tokenization_minicpm_o_2_6 import *
2729
else:
2830
import sys
2931

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
# coding=utf-8
2+
# Copyright 2025 The OpenBMB Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import math
17+
from typing import Any
18+
from typing import Dict
19+
from typing import List
20+
from typing import Optional
21+
from typing import Union
22+
23+
import numpy as np
24+
import PIL
25+
import PIL.Image
26+
import PIL.ImageSequence
27+
import torch
28+
from PIL import Image
29+
from transformers import AutoImageProcessor
30+
from transformers.image_processing_utils import BaseImageProcessor
31+
from transformers.image_processing_utils import BatchFeature
32+
from transformers.image_transforms import to_channel_dimension_format
33+
from transformers.image_utils import ChannelDimension
34+
from transformers.image_utils import infer_channel_dimension_format
35+
from transformers.image_utils import is_torch_tensor
36+
from transformers.image_utils import to_numpy_array
37+
from transformers.image_utils import valid_images
38+
from transformers.utils import is_torch_device
39+
from transformers.utils import is_torch_dtype
40+
from transformers.utils import requires_backends
41+
from transformers.utils import TensorType
42+
from .processing_minicpm_o_2_6 import MiniCPMOBatchFeature
43+
44+
45+
def recursive_converter(converter, value):
46+
if isinstance(value, list):
47+
new_value = []
48+
for v in value:
49+
new_value += [recursive_converter(converter, v)]
50+
return new_value
51+
else:
52+
return converter(value)
53+
54+
55+
class MiniCPMVImageProcessor(BaseImageProcessor):
56+
model_input_names = ["pixel_values"]
57+
58+
def __init__(self, max_slice_nums=9, scale_resolution=448, patch_size=14, **kwargs):
59+
super().__init__(**kwargs)
60+
self.max_slice_nums = max_slice_nums
61+
self.scale_resolution = scale_resolution
62+
self.patch_size = patch_size
63+
64+
self.use_image_id = kwargs.pop("use_image_id", False)
65+
self.image_feature_size = kwargs.pop("image_feature_size", 64)
66+
67+
self.slice_mode = kwargs.pop("slice_mode", True)
68+
69+
self.mean = np.array(kwargs.pop("norm_mean", [0.5, 0.5, 0.5]))
70+
self.std = np.array(kwargs.pop("norm_std", [0.5, 0.5, 0.5]))
71+
self.version = kwargs.pop("version", 2.0)
72+
73+
def ensure_divide(self, length, patch_size):
74+
return max(round(length / patch_size) * patch_size, patch_size)
75+
76+
def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False):
77+
width, height = original_size
78+
if (width * height > scale_resolution * scale_resolution) or allow_upscale:
79+
r = width / height
80+
height = int(scale_resolution / math.sqrt(r))
81+
width = int(height * r)
82+
best_width = self.ensure_divide(width, patch_size)
83+
best_height = self.ensure_divide(height, patch_size)
84+
return (best_width, best_height)
85+
86+
def get_refine_size(self, original_size, grid, scale_resolution, patch_size, allow_upscale=False):
87+
width, height = original_size
88+
grid_x, grid_y = grid
89+
90+
refine_width = self.ensure_divide(width, grid_x)
91+
refine_height = self.ensure_divide(height, grid_y)
92+
93+
grid_width = refine_width / grid_x
94+
grid_height = refine_height / grid_y
95+
96+
best_grid_size = self.find_best_resize(
97+
(grid_width, grid_height), scale_resolution, patch_size, allow_upscale=allow_upscale
98+
)
99+
refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
100+
return refine_size
101+
102+
def split_to_patches(self, image, grid):
103+
patches = []
104+
width, height = image.size
105+
grid_x = int(width / grid[0])
106+
grid_y = int(height / grid[1])
107+
for i in range(0, height, grid_y):
108+
images = []
109+
for j in range(0, width, grid_x):
110+
box = (j, i, j + grid_x, i + grid_y)
111+
patch = image.crop(box)
112+
images.append(patch)
113+
patches.append(images)
114+
return patches
115+
116+
def slice_image(self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):
117+
original_size = image.size
118+
source_image = None
119+
best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split)
120+
patches = []
121+
122+
if best_grid is None:
123+
# dont need to slice, upsample
124+
best_size = self.find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)
125+
source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
126+
else:
127+
# source image, down-sampling and ensure divided by patch_size
128+
best_resize = self.find_best_resize(original_size, scale_resolution, patch_size)
129+
source_image = image.copy().resize(best_resize, resample=Image.Resampling.BICUBIC)
130+
refine_size = self.get_refine_size(
131+
original_size, best_grid, scale_resolution, patch_size, allow_upscale=True
132+
)
133+
refine_image = image.resize(refine_size, resample=Image.Resampling.BICUBIC)
134+
patches = self.split_to_patches(refine_image, best_grid)
135+
136+
return source_image, patches, best_grid
137+
138+
def get_grid_placeholder(self, grid):
139+
if grid is None:
140+
return ""
141+
slice_image_placeholder = (
142+
self.tokenizer.slice_start + self.tokenizer.unk_token * self.image_feature_size + self.tokenizer.slice_end
143+
)
144+
145+
cols = grid[0]
146+
rows = grid[1]
147+
slices = []
148+
for i in range(rows):
149+
lines = []
150+
for j in range(cols):
151+
lines.append(slice_image_placeholder)
152+
slices.append("".join(lines))
153+
154+
slice_placeholder = "\n".join(slices)
155+
return slice_placeholder
156+
157+
def get_image_id_placeholder(self, idx=0):
158+
return f"{self.tokenizer.im_id_start}{idx}{self.tokenizer.im_id_end}"
159+
160+
def get_sliced_images(self, image, max_slice_nums=None):
161+
slice_images = []
162+
163+
if not self.slice_mode:
164+
return [image]
165+
166+
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
167+
assert max_slice_nums > 0
168+
source_image, patches, sliced_grid = self.slice_image(
169+
image, max_slice_nums, self.scale_resolution, self.patch_size # default: 9 # default: 448 # default: 14
170+
)
171+
172+
slice_images.append(source_image)
173+
if len(patches) > 0:
174+
for i in range(len(patches)):
175+
for j in range(len(patches[0])):
176+
slice_images.append(patches[i][j])
177+
return slice_images
178+
179+
def get_sliced_grid(self, image_size, max_slice_nums, nerver_split=False):
180+
original_width, original_height = image_size
181+
log_ratio = math.log(original_width / original_height)
182+
ratio = original_width * original_height / (self.scale_resolution * self.scale_resolution)
183+
multiple = min(math.ceil(ratio), max_slice_nums)
184+
if multiple <= 1 or nerver_split:
185+
return None
186+
candidate_split_grids_nums = []
187+
for i in [multiple - 1, multiple, multiple + 1]:
188+
if i == 1 or i > max_slice_nums:
189+
continue
190+
candidate_split_grids_nums.append(i)
191+
192+
candidate_grids = []
193+
for split_grids_nums in candidate_split_grids_nums:
194+
m = 1
195+
while m <= split_grids_nums:
196+
if split_grids_nums % m == 0:
197+
candidate_grids.append([m, split_grids_nums // m])
198+
m += 1
199+
200+
best_grid = [1, 1]
201+
min_error = float("inf")
202+
for grid in candidate_grids:
203+
error = abs(log_ratio - math.log(grid[0] / grid[1]))
204+
if error < min_error:
205+
best_grid = grid
206+
min_error = error
207+
208+
return best_grid
209+
210+
def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None):
211+
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
212+
assert max_slice_nums > 0
213+
grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums)
214+
215+
image_placeholder = self.tokenizer.im_start + self.tokenizer.unk_token * self.image_feature_size + self.tokenizer.im_end
216+
use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id)
217+
if use_image_id:
218+
final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder
219+
else:
220+
final_placeholder = image_placeholder
221+
222+
if self.slice_mode:
223+
final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid)
224+
return final_placeholder
225+
226+
def to_pil_image(self, image, rescale=None) -> PIL.Image.Image:
227+
"""
228+
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
229+
needed.
230+
231+
Args:
232+
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
233+
The image to convert to the PIL Image format.
234+
rescale (`bool`, *optional*):
235+
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
236+
default to `True` if the image type is a floating type, `False` otherwise.
237+
"""
238+
if isinstance(image, PIL.Image.Image):
239+
return image
240+
if is_torch_tensor(image):
241+
image = image.numpy()
242+
243+
if isinstance(image, np.ndarray):
244+
if rescale is None:
245+
# rescale default to the array being of floating type.
246+
rescale = isinstance(image.flat[0], np.floating)
247+
# If the channel as been moved to first dim, we put it back at the end.
248+
if image.ndim == 3 and image.shape[0] in [1, 3]:
249+
image = image.transpose(1, 2, 0)
250+
if rescale:
251+
image = image * 255
252+
image = image.astype(np.uint8)
253+
return PIL.Image.fromarray(image)
254+
return image
255+
256+
def reshape_by_patch(self, image):
257+
"""
258+
:param image: shape [3, H, W]
259+
:param patch_size:
260+
:return: [3, patch_size, HW/patch_size]
261+
"""
262+
image = torch.from_numpy(image)
263+
patch_size = self.patch_size
264+
patches = torch.nn.functional.unfold(image, (patch_size, patch_size), stride=(patch_size, patch_size))
265+
266+
patches = patches.reshape(image.size(0), patch_size, patch_size, -1)
267+
patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1)
268+
return patches.numpy()
269+
270+
def preprocess(
271+
self,
272+
images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
273+
do_pad: Optional[bool] = True,
274+
max_slice_nums: int = None,
275+
return_tensors: Optional[Union[str, TensorType]] = None,
276+
**kwargs,
277+
) -> MiniCPMOBatchFeature:
278+
if isinstance(images, Image.Image):
279+
images_list = [[images]]
280+
elif isinstance(images[0], Image.Image):
281+
images_list = [images]
282+
else:
283+
images_list = images
284+
285+
new_images_list = []
286+
image_sizes_list = []
287+
tgt_sizes_list = []
288+
289+
for _images in images_list:
290+
if _images is None or len(_images) == 0:
291+
new_images_list.append([])
292+
image_sizes_list.append([])
293+
tgt_sizes_list.append([])
294+
continue
295+
if not valid_images(_images):
296+
raise ValueError(
297+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
298+
"torch.Tensor, tf.Tensor or jax.ndarray."
299+
)
300+
301+
_images = [self.to_pil_image(image).convert("RGB") for image in _images]
302+
input_data_format = infer_channel_dimension_format(np.array(_images[0]))
303+
304+
new_images = []
305+
image_sizes = [image.size for image in _images]
306+
tgt_sizes = []
307+
for image in _images:
308+
image_patches = self.get_sliced_images(image, max_slice_nums)
309+
image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches]
310+
image_patches = [
311+
self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format)
312+
for image in image_patches
313+
]
314+
image_patches = [
315+
to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
316+
for image in image_patches
317+
]
318+
for slice_image in image_patches:
319+
new_images.append(self.reshape_by_patch(slice_image))
320+
tgt_sizes.append(
321+
np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))
322+
)
323+
324+
if tgt_sizes:
325+
tgt_sizes = np.vstack(tgt_sizes)
326+
327+
new_images_list.append(new_images)
328+
image_sizes_list.append(image_sizes)
329+
tgt_sizes_list.append(tgt_sizes)
330+
return MiniCPMOBatchFeature(
331+
data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list},
332+
tensor_type=return_tensors,
333+
)
334+
335+
336+
AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor)
337+
338+
339+
__all__ = ["MiniCPMVImageProcessor"]

0 commit comments

Comments
 (0)