|
12 | 12 | import torch |
13 | 13 | import torch.nn as nn |
14 | 14 | import torch.nn.functional as F |
| 15 | +from qwen_vl_utils import smart_resize |
15 | 16 | from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel |
16 | 17 | from transformers.cache_utils import Cache |
17 | 18 | from transformers.modeling_outputs import ( |
@@ -1026,69 +1027,31 @@ def get_specializations( |
1026 | 1027 | logger.warning( |
1027 | 1028 | f"Setting height and width to be {height} and {width} respectively, as it was neither passed nor found in vision_config" |
1028 | 1029 | ) |
| 1030 | + height = [height] if isinstance(height, int) else height |
| 1031 | + width = [width] if isinstance(width, int) else width |
| 1032 | + |
1029 | 1033 | prefill_seq_len = prefill_seq_len if prefill_seq_len else 128 |
1030 | 1034 | ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN |
1031 | 1035 | channel = 3 |
1032 | 1036 | patch_size = self.config.vision_config.patch_size |
1033 | 1037 | temporal_patch_size = self.config.vision_config.temporal_patch_size |
1034 | 1038 |
|
1035 | | - # Modified from qwen_vl_utils/vision_process.py |
1036 | 1039 | IMAGE_FACTOR = 28 |
1037 | | - MAX_RATIO = 200 |
1038 | 1040 | IMAGE_MIN_TOKEN_NUM = 4 |
1039 | 1041 | IMAGE_MAX_TOKEN_NUM = 16384 |
1040 | | - |
1041 | | - def round_by_factor(number: int, factor: int) -> int: |
1042 | | - """Returns the closest integer to 'number' that is divisible by 'factor'.""" |
1043 | | - return round(number / factor) * factor |
1044 | | - |
1045 | | - def ceil_by_factor(number: int, factor: int) -> int: |
1046 | | - """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" |
1047 | | - return math.ceil(number / factor) * factor |
1048 | | - |
1049 | | - def floor_by_factor(number: int, factor: int) -> int: |
1050 | | - """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" |
1051 | | - return math.floor(number / factor) * factor |
1052 | | - |
1053 | | - def smart_resize( |
1054 | | - height: int, |
1055 | | - width: int, |
1056 | | - factor: int = IMAGE_FACTOR, |
1057 | | - min_pixels: Optional[int] = None, |
1058 | | - max_pixels: Optional[int] = None, |
1059 | | - ) -> tuple[int, int]: |
1060 | | - """ |
1061 | | - Rescales the image so that the following conditions are met: |
1062 | | -
|
1063 | | - 1. Both dimensions (height and width) are divisible by 'factor'. |
1064 | | - 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. |
1065 | | - 3. The aspect ratio of the image is maintained as closely as possible. |
1066 | | - """ |
1067 | | - max_pixels = max_pixels if max_pixels is not None else (IMAGE_MAX_TOKEN_NUM * factor ** 2) |
1068 | | - min_pixels = min_pixels if min_pixels is not None else (IMAGE_MIN_TOKEN_NUM * factor ** 2) |
1069 | | - assert max_pixels >= min_pixels, "The max_pixels of image must be greater than or equal to min_pixels." |
1070 | | - if max(height, width) / min(height, width) > MAX_RATIO: |
1071 | | - raise ValueError( |
1072 | | - f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" |
1073 | | - ) |
1074 | | - h_bar = max(factor, round_by_factor(height, factor)) |
1075 | | - w_bar = max(factor, round_by_factor(width, factor)) |
1076 | | - if h_bar * w_bar > max_pixels: |
1077 | | - beta = math.sqrt((height * width) / max_pixels) |
1078 | | - h_bar = floor_by_factor(height / beta, factor) |
1079 | | - w_bar = floor_by_factor(width / beta, factor) |
1080 | | - elif h_bar * w_bar < min_pixels: |
1081 | | - beta = math.sqrt(min_pixels / (height * width)) |
1082 | | - h_bar = ceil_by_factor(height * beta, factor) |
1083 | | - w_bar = ceil_by_factor(width * beta, factor) |
1084 | | - return h_bar, w_bar |
| 1042 | + min_pixels = IMAGE_MIN_TOKEN_NUM * IMAGE_FACTOR**2 |
| 1043 | + max_pixels = IMAGE_MAX_TOKEN_NUM * IMAGE_FACTOR**2 |
| 1044 | + mm_processor_kwargs = compiler_options.pop("mm_processor_kwargs", None) |
| 1045 | + if mm_processor_kwargs: |
| 1046 | + min_pixels = mm_processor_kwargs.get("min_pixels", min_pixels) |
| 1047 | + max_pixels = mm_processor_kwargs.get("max_pixels", max_pixels) |
1085 | 1048 |
|
1086 | 1049 | vision = [] |
1087 | 1050 | min_vision_size = ctx_len |
1088 | | - height = [height] if isinstance(height, int) else height |
1089 | | - width = [width] if isinstance(width, int) else width |
1090 | 1051 | for h, w in zip(height, width): |
1091 | | - resized_height, resized_width = smart_resize(height=h, width=w) |
| 1052 | + resized_height, resized_width = smart_resize( |
| 1053 | + height=h, width=w, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels |
| 1054 | + ) |
1092 | 1055 | grid_h, grid_w = resized_height // patch_size, resized_width // patch_size |
1093 | 1056 | grid_height = grid_h * grid_w |
1094 | 1057 | grid_width = patch_size * patch_size * temporal_patch_size * channel |
|
0 commit comments