|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +"""Utility functions for image processing in multimodal datasets.""" |
| 8 | + |
| 9 | +import math |
| 10 | +from io import BytesIO |
| 11 | + |
| 12 | +import einops as E |
| 13 | +import numpy as np |
| 14 | +import requests |
| 15 | +import torch |
| 16 | + |
| 17 | +from PIL import Image |
| 18 | + |
| 19 | +from torchtitan.tools.logging import logger |
| 20 | + |
| 21 | + |
| 22 | +def process_image( |
| 23 | + image: str | bytes | Image.Image, |
| 24 | + patch_size: int = 16, |
| 25 | + merge_size: int = 1, |
| 26 | + max_patch_per_image: int = 256, |
| 27 | + min_patch_per_image: int = 1, |
| 28 | +) -> torch.Tensor | None: |
| 29 | + """Process a single image into normalized tensor format. |
| 30 | +
|
| 31 | + Args: |
| 32 | + image: PIL Image, bytes, or URL string |
| 33 | + patch_size: Size of each patch |
| 34 | + merge_size: Spatial Merge size factor |
| 35 | + max_patch_per_image: Maximum patches allowed per image |
| 36 | + min_dimension: Minimum dimension for width/height |
| 37 | +
|
| 38 | + Returns: |
| 39 | + Tensor of shape (1, H, W, 3) or None if processing fails |
| 40 | +
|
| 41 | + Note: |
| 42 | + - Resizes image while maintaining aspect ratio |
| 43 | + - Normalizes using CLIP mean/std values |
| 44 | + - Returns None if any processing step fails |
| 45 | + """ |
| 46 | + try: |
| 47 | + # Convert various input formats to PIL Image |
| 48 | + if isinstance(image, str) and image.startswith("http"): |
| 49 | + response = requests.get(image, timeout=10) |
| 50 | + image = Image.open(BytesIO(response.content)) |
| 51 | + elif isinstance(image, bytes): |
| 52 | + image = Image.open(BytesIO(image)) |
| 53 | + elif isinstance(image, str): |
| 54 | + image = Image.open(image) |
| 55 | + |
| 56 | + if image.mode != "RGB": |
| 57 | + image = image.convert("RGB") |
| 58 | + |
| 59 | + # Resize maintaining aspect ratio |
| 60 | + image = resize_image_by_patch_count( |
| 61 | + image, |
| 62 | + max_patch_per_image=max_patch_per_image, |
| 63 | + patch_size=patch_size, |
| 64 | + merge_size=merge_size, |
| 65 | + min_patch_per_image=min_patch_per_image, |
| 66 | + ) |
| 67 | + |
| 68 | + # Convert to numpy and normalize |
| 69 | + img_array = np.array(image) |
| 70 | + img_array = img_array / 255.0 |
| 71 | + |
| 72 | + # CLIP normalization |
| 73 | + mean = np.array([0.48145466, 0.4578275, 0.40821073]) |
| 74 | + std = np.array([0.26862954, 0.26130258, 0.27577711]) |
| 75 | + img_array = (img_array - mean) / std |
| 76 | + |
| 77 | + # Convert to tensor (1, H, W, 3) with dummy temporal dim |
| 78 | + return torch.from_numpy(img_array).float().unsqueeze(0) |
| 79 | + |
| 80 | + except Exception as e: |
| 81 | + logger.warning(f"Error processing image: {e}") |
| 82 | + return None |
| 83 | + |
| 84 | + |
| 85 | +def smart_resize( |
| 86 | + height: int, |
| 87 | + width: int, |
| 88 | + factor: int, # should be equal patch_size * merge_size |
| 89 | + max_patch_per_image: int, |
| 90 | + min_patch_per_image: int = 1, |
| 91 | +): |
| 92 | + """Calculate dimensions that maintain aspect ratio and satisfy constraints.""" |
| 93 | + if height < factor or width < factor: |
| 94 | + raise ValueError( |
| 95 | + f"height:{height} or width:{width} must be larger than factor:{factor}" |
| 96 | + ) |
| 97 | + elif max(height, width) / min(height, width) > 200: |
| 98 | + raise ValueError( |
| 99 | + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" |
| 100 | + ) |
| 101 | + |
| 102 | + h_bar = round(height / factor) * factor |
| 103 | + w_bar = round(width / factor) * factor |
| 104 | + |
| 105 | + # Calculate patch count from adjusted dimensions |
| 106 | + current_patches = (h_bar * w_bar) // (factor * factor) |
| 107 | + |
| 108 | + if current_patches > max_patch_per_image: |
| 109 | + # Scale down to fit within max patch limit |
| 110 | + max_area = max_patch_per_image * (factor * factor) |
| 111 | + beta = math.sqrt((h_bar * w_bar) / max_area) |
| 112 | + h_bar = math.floor(height / beta / factor) * factor |
| 113 | + w_bar = math.floor(width / beta / factor) * factor |
| 114 | + elif current_patches < min_patch_per_image: |
| 115 | + beta = math.sqrt(min_patch_per_image / current_patches) |
| 116 | + h_bar = math.ceil(height * beta / factor) * factor |
| 117 | + w_bar = math.ceil(width * beta / factor) * factor |
| 118 | + |
| 119 | + return h_bar, w_bar |
| 120 | + |
| 121 | + |
| 122 | +def resize_image_by_patch_count( |
| 123 | + image: Image.Image, |
| 124 | + max_patch_per_image: int, |
| 125 | + patch_size: int, |
| 126 | + merge_size: int, |
| 127 | + min_patch_per_image: int = 1, |
| 128 | +) -> Image.Image: |
| 129 | + """Resize image while maintaining aspect ratio and ensuring patch count is within [min_patch_per_image, max_patch_per_image].""" |
| 130 | + original_width, original_height = image.size |
| 131 | + factor = patch_size * merge_size |
| 132 | + |
| 133 | + # Calculate current number of patches |
| 134 | + current_patches = (original_height * original_width) // (factor * factor) |
| 135 | + |
| 136 | + # If patches < min_patch_per_image, scale up proportionally |
| 137 | + if current_patches < min_patch_per_image: |
| 138 | + if current_patches == 0: |
| 139 | + # Special case: image too small to produce any patches |
| 140 | + # Scale to minimum viable size (at least factor x factor) |
| 141 | + scale_factor = max(factor / original_width, factor / original_height) |
| 142 | + else: |
| 143 | + scale_factor = math.sqrt(min_patch_per_image / current_patches) |
| 144 | + |
| 145 | + new_width = int(original_width * scale_factor) |
| 146 | + new_height = int(original_height * scale_factor) |
| 147 | + |
| 148 | + resized_height, resized_width = smart_resize( |
| 149 | + new_height, |
| 150 | + new_width, |
| 151 | + factor, |
| 152 | + max_patch_per_image, |
| 153 | + ) |
| 154 | + return image.resize((resized_width, resized_height)) |
| 155 | + |
| 156 | + # If patches are within [min, max] range, just use smart_resize |
| 157 | + elif current_patches <= max_patch_per_image: |
| 158 | + resized_height, resized_width = smart_resize( |
| 159 | + original_height, original_width, factor, max_patch_per_image |
| 160 | + ) |
| 161 | + return image.resize((resized_width, resized_height)) |
| 162 | + |
| 163 | + # If patches > max_patch_per_image, scale down proportionally |
| 164 | + else: |
| 165 | + scale_factor = math.sqrt(max_patch_per_image / current_patches) |
| 166 | + new_width = int(original_width * scale_factor) |
| 167 | + new_height = int(original_height * scale_factor) |
| 168 | + |
| 169 | + resized_height, resized_width = smart_resize( |
| 170 | + new_height, new_width, factor, max_patch_per_image |
| 171 | + ) |
| 172 | + return image.resize((resized_width, resized_height)) |
| 173 | + |
| 174 | + |
| 175 | +def calculate_image_tokens( |
| 176 | + image: Image.Image | torch.Tensor, |
| 177 | + patch_size: int, |
| 178 | + spatial_merge_size: int, |
| 179 | +) -> tuple[int, int, int]: |
| 180 | + """Calculate number of tokens needed for an image.""" |
| 181 | + if isinstance(image, torch.Tensor): |
| 182 | + height, width = image.shape[1:3] |
| 183 | + else: |
| 184 | + width, height = image.size |
| 185 | + |
| 186 | + tokens_per_row = int(width / (patch_size * spatial_merge_size)) |
| 187 | + num_rows = int(height / (patch_size * spatial_merge_size)) |
| 188 | + total_tokens = tokens_per_row * num_rows |
| 189 | + |
| 190 | + return total_tokens, tokens_per_row, num_rows |
| 191 | + |
| 192 | + |
| 193 | +def convert_to_patches( |
| 194 | + pixel_values: torch.Tensor, |
| 195 | + patch_size: int, |
| 196 | + temporal_patch_size: int = 1, |
| 197 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 198 | + """Convert single image tensor to patches and generate coordinate grids. |
| 199 | +
|
| 200 | + Args: |
| 201 | + pixel_values: Tensor of shape (T, H, W, C) |
| 202 | + patch_size: Spatial patch size (height and width) |
| 203 | + temporal_patch_size: Temporal patch size (default=1 for no temporal patching) |
| 204 | +
|
| 205 | + Returns: |
| 206 | + patches: Tensor of shape (L, D) where: |
| 207 | + L = (T//temporal_patch_size) * (H//patch_size) * (W//patch_size) |
| 208 | + D = temporal_patch_size * patch_size * patch_size * C |
| 209 | + grid: Tensor of shape (L, 3) containing (t, h, w) coordinates |
| 210 | +
|
| 211 | + Example: |
| 212 | + >>> x = torch.randn(4, 224, 224, 3) # Single image with 4 frames |
| 213 | + >>> patches, grid = convert_to_patches(x, patch_size=14, temporal_patch_size=2) |
| 214 | + >>> print(patches.shape) # (512, 1176) # 512 patches, each 1176-dim |
| 215 | + >>> print(grid.shape) # (512, 3) # (t,h,w) coordinates |
| 216 | + """ |
| 217 | + T, H, W, C = pixel_values.shape |
| 218 | + ps = patch_size |
| 219 | + ts = temporal_patch_size |
| 220 | + device = pixel_values.device |
| 221 | + |
| 222 | + # Ensure dimensions are divisible |
| 223 | + if T % ts != 0: |
| 224 | + raise ValueError( |
| 225 | + f"Temporal dimension {T} must be divisible by temporal_patch_size {ts}" |
| 226 | + ) |
| 227 | + if H % ps != 0 or W % ps != 0: |
| 228 | + raise ValueError( |
| 229 | + f"Spatial dimensions {H},{W} must be divisible by patch_size {ps}" |
| 230 | + ) |
| 231 | + |
| 232 | + patches = E.rearrange( |
| 233 | + pixel_values, |
| 234 | + "(t pt) (h ph) (w pw) c -> (t h w) (pt ph pw c)", |
| 235 | + pt=ts, |
| 236 | + ph=ps, |
| 237 | + pw=ps, |
| 238 | + ) |
| 239 | + |
| 240 | + # Generate coordinate grid |
| 241 | + coords = torch.meshgrid( |
| 242 | + torch.arange(T // ts, device=device), |
| 243 | + torch.arange(H // ps, device=device), |
| 244 | + torch.arange(W // ps, device=device), |
| 245 | + indexing="ij", |
| 246 | + ) |
| 247 | + grid = E.rearrange(torch.stack(coords), "coords t h w -> (t h w) coords") # (L, 3) |
| 248 | + |
| 249 | + return patches, grid |
| 250 | + |
| 251 | + |
| 252 | +def pad_patches( |
| 253 | + patches: torch.Tensor, # Shape L,D |
| 254 | + grids: torch.Tensor, # Shape L,3(thw) |
| 255 | + max_patches: int, |
| 256 | +) -> tuple[torch.Tensor | None, torch.Tensor | None]: |
| 257 | + """Pad or truncate patches and grids to max_patches length for single image.""" |
| 258 | + L, D = patches.shape |
| 259 | + |
| 260 | + if L == max_patches: |
| 261 | + return patches, grids |
| 262 | + elif L < max_patches: |
| 263 | + # Pad |
| 264 | + pad_len = max_patches - L |
| 265 | + zero_patches = torch.zeros(pad_len, D, device=patches.device) |
| 266 | + invalid_grids = torch.full((pad_len, 3), -1, device=grids.device) |
| 267 | + return ( |
| 268 | + torch.cat([patches, zero_patches], 0), |
| 269 | + torch.cat([grids, invalid_grids], 0), |
| 270 | + ) |
| 271 | + else: |
| 272 | + # Truncate |
| 273 | + logger.error( |
| 274 | + f"Truncating Image Patches from {L} to {max_patches} should not happen." |
| 275 | + ) |
| 276 | + return None, None |
| 277 | + |
| 278 | + |
| 279 | +def pad_empty_images_to_target_batch_size( |
| 280 | + patches: torch.Tensor, |
| 281 | + grids: torch.Tensor, |
| 282 | + max_images: int, |
| 283 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 284 | + """Pad vision encoder batch with blank images if needed.""" |
| 285 | + N, L, D = patches.shape |
| 286 | + if N >= max_images: |
| 287 | + return patches, grids |
| 288 | + |
| 289 | + blank_count = max_images - N |
| 290 | + blank_patches = torch.zeros(blank_count, L, D, device=patches.device) |
| 291 | + blank_grids = torch.full((blank_count, L, 3), -1, device=grids.device) |
| 292 | + return ( |
| 293 | + torch.cat([patches, blank_patches], dim=0), |
| 294 | + torch.cat([grids, blank_grids], dim=0), |
| 295 | + ) |
0 commit comments