Skip to content

Commit 3b9860e

Browse files
committed
mm dataset: refactor and add sample packing
1 parent fbca6dc commit 3b9860e

File tree

8 files changed

+1126
-484
lines changed

8 files changed

+1126
-484
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass, field
8+
9+
10+
@dataclass
11+
class Training:
12+
max_images_per_batch: int = 10
13+
"""Vision encoder batch size (N)"""
14+
max_patches_per_image: int = 256
15+
"""Vision encoder sequence length (L)"""
16+
patch_size: int = 16
17+
""" Patch size of the vision encoder.
18+
For example, image size 256x256, patch size 16
19+
Number of visual tokens is: (256/16)**2=256
20+
"""
21+
spatial_merge_size: int = 1
22+
""" Spatially merge visual tokens after encoder.
23+
For example: image size 256x256, patch size 16, spaitl merge size is 2
24+
Number of visual tokens for the LLM: (256/16/2)**2 = 8
25+
"""
26+
packing_buffer_size: int = 0
27+
""" Set to a value >0 to enable sample packing.
28+
This control the buffer uses to store training samples avaliable for packing.
29+
"""
30+
31+
32+
# HACK: couldn't figure out how to modify the HF tokenizer's json
33+
# to make these attribute accesible. Ideally these should be accesible from the tokenizer itself.
34+
@dataclass
35+
class SpecialTokens:
36+
img_token: str = "<|image|>"
37+
boi_token: str = "<|begin_of_image|>"
38+
eoi_token: str = "<|end_of_image|>"
39+
img_id: int = 1998
40+
boi_id: int = 1999
41+
eoi_id: int = 2000
42+
pad_id: int = 2001
43+
44+
45+
@dataclass
46+
class JobConfig:
47+
training: Training = field(default_factory=Training)
48+
special_tokens: SpecialTokens = field(default_factory=SpecialTokens)
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Utility functions for image processing in multimodal datasets."""
8+
9+
import math
10+
from io import BytesIO
11+
12+
import einops as E
13+
import numpy as np
14+
import requests
15+
import torch
16+
17+
from PIL import Image
18+
19+
from torchtitan.tools.logging import logger
20+
21+
22+
def process_image(
23+
image: str | bytes | Image.Image,
24+
patch_size: int = 16,
25+
merge_size: int = 1,
26+
max_patch_per_image: int = 256,
27+
min_patch_per_image: int = 1,
28+
) -> torch.Tensor | None:
29+
"""Process a single image into normalized tensor format.
30+
31+
Args:
32+
image: PIL Image, bytes, or URL string
33+
patch_size: Size of each patch
34+
merge_size: Spatial Merge size factor
35+
max_patch_per_image: Maximum patches allowed per image
36+
min_dimension: Minimum dimension for width/height
37+
38+
Returns:
39+
Tensor of shape (1, H, W, 3) or None if processing fails
40+
41+
Note:
42+
- Resizes image while maintaining aspect ratio
43+
- Normalizes using CLIP mean/std values
44+
- Returns None if any processing step fails
45+
"""
46+
try:
47+
# Convert various input formats to PIL Image
48+
if isinstance(image, str) and image.startswith("http"):
49+
response = requests.get(image, timeout=10)
50+
image = Image.open(BytesIO(response.content))
51+
elif isinstance(image, bytes):
52+
image = Image.open(BytesIO(image))
53+
elif isinstance(image, str):
54+
image = Image.open(image)
55+
56+
if image.mode != "RGB":
57+
image = image.convert("RGB")
58+
59+
# Resize maintaining aspect ratio
60+
image = resize_image_by_patch_count(
61+
image,
62+
max_patch_per_image=max_patch_per_image,
63+
patch_size=patch_size,
64+
merge_size=merge_size,
65+
min_patch_per_image=min_patch_per_image,
66+
)
67+
68+
# Convert to numpy and normalize
69+
img_array = np.array(image)
70+
img_array = img_array / 255.0
71+
72+
# CLIP normalization
73+
mean = np.array([0.48145466, 0.4578275, 0.40821073])
74+
std = np.array([0.26862954, 0.26130258, 0.27577711])
75+
img_array = (img_array - mean) / std
76+
77+
# Convert to tensor (1, H, W, 3) with dummy temporal dim
78+
return torch.from_numpy(img_array).float().unsqueeze(0)
79+
80+
except Exception as e:
81+
logger.warning(f"Error processing image: {e}")
82+
return None
83+
84+
85+
def smart_resize(
86+
height: int,
87+
width: int,
88+
factor: int, # should be equal patch_size * merge_size
89+
max_patch_per_image: int,
90+
min_patch_per_image: int = 1,
91+
):
92+
"""Calculate dimensions that maintain aspect ratio and satisfy constraints."""
93+
if height < factor or width < factor:
94+
raise ValueError(
95+
f"height:{height} or width:{width} must be larger than factor:{factor}"
96+
)
97+
elif max(height, width) / min(height, width) > 200:
98+
raise ValueError(
99+
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
100+
)
101+
102+
h_bar = round(height / factor) * factor
103+
w_bar = round(width / factor) * factor
104+
105+
# Calculate patch count from adjusted dimensions
106+
current_patches = (h_bar * w_bar) // (factor * factor)
107+
108+
if current_patches > max_patch_per_image:
109+
# Scale down to fit within max patch limit
110+
max_area = max_patch_per_image * (factor * factor)
111+
beta = math.sqrt((h_bar * w_bar) / max_area)
112+
h_bar = math.floor(height / beta / factor) * factor
113+
w_bar = math.floor(width / beta / factor) * factor
114+
elif current_patches < min_patch_per_image:
115+
beta = math.sqrt(min_patch_per_image / current_patches)
116+
h_bar = math.ceil(height * beta / factor) * factor
117+
w_bar = math.ceil(width * beta / factor) * factor
118+
119+
return h_bar, w_bar
120+
121+
122+
def resize_image_by_patch_count(
123+
image: Image.Image,
124+
max_patch_per_image: int,
125+
patch_size: int,
126+
merge_size: int,
127+
min_patch_per_image: int = 1,
128+
) -> Image.Image:
129+
"""Resize image while maintaining aspect ratio and ensuring patch count is within [min_patch_per_image, max_patch_per_image]."""
130+
original_width, original_height = image.size
131+
factor = patch_size * merge_size
132+
133+
# Calculate current number of patches
134+
current_patches = (original_height * original_width) // (factor * factor)
135+
136+
# If patches < min_patch_per_image, scale up proportionally
137+
if current_patches < min_patch_per_image:
138+
if current_patches == 0:
139+
# Special case: image too small to produce any patches
140+
# Scale to minimum viable size (at least factor x factor)
141+
scale_factor = max(factor / original_width, factor / original_height)
142+
else:
143+
scale_factor = math.sqrt(min_patch_per_image / current_patches)
144+
145+
new_width = int(original_width * scale_factor)
146+
new_height = int(original_height * scale_factor)
147+
148+
resized_height, resized_width = smart_resize(
149+
new_height,
150+
new_width,
151+
factor,
152+
max_patch_per_image,
153+
)
154+
return image.resize((resized_width, resized_height))
155+
156+
# If patches are within [min, max] range, just use smart_resize
157+
elif current_patches <= max_patch_per_image:
158+
resized_height, resized_width = smart_resize(
159+
original_height, original_width, factor, max_patch_per_image
160+
)
161+
return image.resize((resized_width, resized_height))
162+
163+
# If patches > max_patch_per_image, scale down proportionally
164+
else:
165+
scale_factor = math.sqrt(max_patch_per_image / current_patches)
166+
new_width = int(original_width * scale_factor)
167+
new_height = int(original_height * scale_factor)
168+
169+
resized_height, resized_width = smart_resize(
170+
new_height, new_width, factor, max_patch_per_image
171+
)
172+
return image.resize((resized_width, resized_height))
173+
174+
175+
def calculate_image_tokens(
176+
image: Image.Image | torch.Tensor,
177+
patch_size: int,
178+
spatial_merge_size: int,
179+
) -> tuple[int, int, int]:
180+
"""Calculate number of tokens needed for an image."""
181+
if isinstance(image, torch.Tensor):
182+
height, width = image.shape[1:3]
183+
else:
184+
width, height = image.size
185+
186+
tokens_per_row = int(width / (patch_size * spatial_merge_size))
187+
num_rows = int(height / (patch_size * spatial_merge_size))
188+
total_tokens = tokens_per_row * num_rows
189+
190+
return total_tokens, tokens_per_row, num_rows
191+
192+
193+
def convert_to_patches(
194+
pixel_values: torch.Tensor,
195+
patch_size: int,
196+
temporal_patch_size: int = 1,
197+
) -> tuple[torch.Tensor, torch.Tensor]:
198+
"""Convert single image tensor to patches and generate coordinate grids.
199+
200+
Args:
201+
pixel_values: Tensor of shape (T, H, W, C)
202+
patch_size: Spatial patch size (height and width)
203+
temporal_patch_size: Temporal patch size (default=1 for no temporal patching)
204+
205+
Returns:
206+
patches: Tensor of shape (L, D) where:
207+
L = (T//temporal_patch_size) * (H//patch_size) * (W//patch_size)
208+
D = temporal_patch_size * patch_size * patch_size * C
209+
grid: Tensor of shape (L, 3) containing (t, h, w) coordinates
210+
211+
Example:
212+
>>> x = torch.randn(4, 224, 224, 3) # Single image with 4 frames
213+
>>> patches, grid = convert_to_patches(x, patch_size=14, temporal_patch_size=2)
214+
>>> print(patches.shape) # (512, 1176) # 512 patches, each 1176-dim
215+
>>> print(grid.shape) # (512, 3) # (t,h,w) coordinates
216+
"""
217+
T, H, W, C = pixel_values.shape
218+
ps = patch_size
219+
ts = temporal_patch_size
220+
device = pixel_values.device
221+
222+
# Ensure dimensions are divisible
223+
if T % ts != 0:
224+
raise ValueError(
225+
f"Temporal dimension {T} must be divisible by temporal_patch_size {ts}"
226+
)
227+
if H % ps != 0 or W % ps != 0:
228+
raise ValueError(
229+
f"Spatial dimensions {H},{W} must be divisible by patch_size {ps}"
230+
)
231+
232+
patches = E.rearrange(
233+
pixel_values,
234+
"(t pt) (h ph) (w pw) c -> (t h w) (pt ph pw c)",
235+
pt=ts,
236+
ph=ps,
237+
pw=ps,
238+
)
239+
240+
# Generate coordinate grid
241+
coords = torch.meshgrid(
242+
torch.arange(T // ts, device=device),
243+
torch.arange(H // ps, device=device),
244+
torch.arange(W // ps, device=device),
245+
indexing="ij",
246+
)
247+
grid = E.rearrange(torch.stack(coords), "coords t h w -> (t h w) coords") # (L, 3)
248+
249+
return patches, grid
250+
251+
252+
def pad_patches(
253+
patches: torch.Tensor, # Shape L,D
254+
grids: torch.Tensor, # Shape L,3(thw)
255+
max_patches: int,
256+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
257+
"""Pad or truncate patches and grids to max_patches length for single image."""
258+
L, D = patches.shape
259+
260+
if L == max_patches:
261+
return patches, grids
262+
elif L < max_patches:
263+
# Pad
264+
pad_len = max_patches - L
265+
zero_patches = torch.zeros(pad_len, D, device=patches.device)
266+
invalid_grids = torch.full((pad_len, 3), -1, device=grids.device)
267+
return (
268+
torch.cat([patches, zero_patches], 0),
269+
torch.cat([grids, invalid_grids], 0),
270+
)
271+
else:
272+
# Truncate
273+
logger.error(
274+
f"Truncating Image Patches from {L} to {max_patches} should not happen."
275+
)
276+
return None, None
277+
278+
279+
def pad_empty_images_to_target_batch_size(
280+
patches: torch.Tensor,
281+
grids: torch.Tensor,
282+
max_images: int,
283+
) -> tuple[torch.Tensor, torch.Tensor]:
284+
"""Pad vision encoder batch with blank images if needed."""
285+
N, L, D = patches.shape
286+
if N >= max_images:
287+
return patches, grids
288+
289+
blank_count = max_images - N
290+
blank_patches = torch.zeros(blank_count, L, D, device=patches.device)
291+
blank_grids = torch.full((blank_count, L, 3), -1, device=grids.device)
292+
return (
293+
torch.cat([patches, blank_patches], dim=0),
294+
torch.cat([grids, blank_grids], dim=0),
295+
)

0 commit comments

Comments
 (0)