Skip to content

Commit 2c3609b

Browse files
committed
[vlm] add Obelics interleaved dataloader
tmp tmp
1 parent 08b8b24 commit 2c3609b

File tree

2 files changed

+610
-0
lines changed

2 files changed

+610
-0
lines changed
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import Any, Dict, List
9+
10+
import einops as E
11+
import torch
12+
from torch.nn.utils.rnn import pad_sequence
13+
14+
from torchtitan.tools.logging import logger
15+
16+
17+
IGNORE_INDEX = -100
18+
19+
20+
@dataclass
21+
class MultiModalCollatorNLD:
22+
"""Collator that works with patches in NLD format (N=batch, L=patches, D=patch_features)"""
23+
24+
padding_idx: int = 0
25+
ignore_idx: int = IGNORE_INDEX
26+
max_images_per_batch: int = 5
27+
max_patch_per_image: int = 256 # Maximum patches per image
28+
patch_size: int = 16 # Patch size for converting images to patches
29+
merge_size: int = 1 # Merge size for converting spatial patches to channel dim
30+
seq_len: int = 2048
31+
32+
def convert_to_patches(
33+
self, pixel_values: torch.Tensor
34+
) -> tuple[torch.Tensor, torch.Tensor]:
35+
"""Direct NTHWC -> NLD conversion using einops."""
36+
N, T, H, W, C = pixel_values.shape
37+
ps = self.patch_size
38+
device = pixel_values.device
39+
patches = E.rearrange(
40+
pixel_values, "n t (h p1) (w p2) c -> n (t h w) (p1 p2 c)", p1=ps, p2=ps
41+
)
42+
43+
coords = torch.meshgrid(
44+
torch.arange(T, device=device),
45+
torch.arange(H // ps, device=device),
46+
torch.arange(W // ps, device=device),
47+
indexing="ij",
48+
)
49+
grid = E.rearrange(torch.stack(coords), "coords t h w -> (t h w) coords")
50+
grid = grid.unsqueeze(0).expand(N, -1, -1) # (N, t*h*w, 3)
51+
52+
# All patches are valid since we resize images to be divisible by patch_size
53+
return patches, grid
54+
55+
def _pad_to_max(self, patches, grids):
56+
"""Pad or truncate to max_patch_per_image."""
57+
N, L, D = patches.shape
58+
if L == self.max_patch_per_image:
59+
return patches, grids
60+
elif L < self.max_patch_per_image:
61+
# Pad
62+
pad_len = self.max_patch_per_image - L
63+
zero_patches = torch.zeros(N, pad_len, D, device=patches.device)
64+
invalid_grids = torch.full(
65+
(grids.shape[0], pad_len, 3), -1, device=grids.device
66+
)
67+
return torch.cat([patches, zero_patches], 1), torch.cat(
68+
[grids, invalid_grids], 1
69+
)
70+
else:
71+
# Truncate
72+
return (
73+
patches[:, : self.max_patch_per_image],
74+
grids[:, : self.max_patch_per_image],
75+
)
76+
77+
def __call__(
78+
self, batch: List[Dict[str, Any]]
79+
) -> tuple[Dict[str, torch.Tensor | None], torch.Tensor]:
80+
"""Encode batch with patch-based approach."""
81+
if not batch:
82+
return None
83+
84+
# Count images per sample and total images
85+
images_per_sample = []
86+
for sample in batch:
87+
num_images = (
88+
len(sample.get("pixel_values", [])) if "pixel_values" in sample else 0
89+
)
90+
images_per_sample.append(num_images)
91+
92+
# Remove samples from end until total images <= max_images_per_batch
93+
total_images = sum(images_per_sample)
94+
while total_images > self.max_images_per_batch and batch:
95+
removed_images = images_per_sample.pop()
96+
total_images -= removed_images
97+
batch.pop()
98+
logger.warning(f"Removed sample with {removed_images} images to keep total images <= {self.max_images_per_batch}")
99+
100+
all_images = [
101+
img
102+
for sample in batch
103+
if "pixel_values" in sample
104+
for img in sample["pixel_values"]
105+
]
106+
107+
if all_images:
108+
patch_list, grid_list = [], []
109+
for img in all_images:
110+
p, g = self.convert_to_patches(img.unsqueeze(0))
111+
p, g = self._pad_to_max(p, g)
112+
patch_list.append(p[0])
113+
grid_list.append(g[0])
114+
patches = torch.stack(patch_list)
115+
grids = torch.stack(grid_list)
116+
117+
if len(all_images) < self.max_images_per_batch:
118+
blank_count = self.max_images_per_batch - len(all_images)
119+
blank_patches = torch.zeros(
120+
blank_count,
121+
self.max_patch_per_image,
122+
patches.shape[2],
123+
device=patches.device,
124+
)
125+
blank_grids = torch.full(
126+
(blank_count, self.max_patch_per_image, 3), -1, device=grids.device
127+
)
128+
patches = torch.cat([patches, blank_patches], dim=0)
129+
grids = torch.cat([grids, blank_grids], dim=0)
130+
else:
131+
patches = grids = None
132+
133+
# Text processing
134+
input_ids = pad_sequence(
135+
[s["input_ids"] for s in batch],
136+
batch_first=True,
137+
padding_value=self.padding_idx,
138+
)
139+
labels = pad_sequence(
140+
[s["labels"] for s in batch],
141+
batch_first=True,
142+
padding_value=self.padding_idx,
143+
)
144+
145+
# Pad along batch dimension if needed
146+
batch_size = len(batch)
147+
if input_ids.size(0) < batch_size:
148+
padding_needed = batch_size - input_ids.size(0)
149+
padding_input = (
150+
torch.ones(padding_needed, input_ids.size(1), dtype=torch.long)
151+
* self.padding_idx
152+
)
153+
padding_labels = (
154+
torch.ones(padding_needed, labels.size(1), dtype=torch.long)
155+
* self.padding_idx
156+
)
157+
input_ids = torch.cat([input_ids, padding_input], dim=0)
158+
labels = torch.cat([labels, padding_labels], dim=0)
159+
160+
# Handle sequence length
161+
current_length = input_ids.size(1)
162+
desired_length = self.seq_len + 1 # Extra token for label shift and cut
163+
if current_length < desired_length:
164+
padding_length = desired_length - current_length
165+
padding_input = (
166+
torch.ones(batch_size, padding_length, dtype=torch.long)
167+
* self.padding_idx
168+
)
169+
padding_labels = (
170+
torch.ones(batch_size, padding_length, dtype=torch.long)
171+
* self.padding_idx
172+
)
173+
input_ids = torch.cat([input_ids, padding_input], dim=1)
174+
labels = torch.cat([labels, padding_labels], dim=1)
175+
elif current_length > self.seq_len:
176+
input_ids = input_ids[:, :desired_length]
177+
labels = labels[:, :desired_length]
178+
179+
labels[labels == self.padding_idx] = self.ignore_idx
180+
# Cut and shift
181+
input_ids = input_ids[:, :-1]
182+
labels = labels[:, 1:]
183+
184+
return {
185+
"input": input_ids,
186+
"pixel_values": patches,
187+
"grid_thw": grids,
188+
}, labels

0 commit comments

Comments
 (0)