-
Notifications
You must be signed in to change notification settings - Fork 495
VLM: Onboarding native resolution, native aspect ratio, interleaved VLM training #1615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
2c3609b
cad098a
f893c4a
fbca6dc
3b9860e
12027be
6cf9d67
0034eaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Vision Language Model training in `torchtitan` | ||
|
||
**under active development** | ||
|
||
This folder showcases how to train modern Vision Language Model (vlm) in torchtitan. | ||
|
||
|
||
## Features: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is mainly describing dataloader features. Can we separate into 2 parts in README: 1) model features (What model/encoder, What you have achieved now, eg FSDP, AC, compile, and TODOs); 2) dataloader features |
||
- Native Aspect Ratio: not limited to square crops. | ||
- Native Resolution: images in a batch can have different sizes, no more image tiles and thumbnails. | ||
- Native Interleaved data: training samples can have variable number of images, interleaved with text at different position. You can train more than just a captioning model. | ||
|
||
|
||
## Design | ||
Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two hyperparameters, image batch size `N` and image length `L` (in patches), and pad the actual image patches to this fixed size. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think including the digram in the PR description will make this part much clear! We should figure out a way to include images in README. |
||
Then we scatter the patch embeddings to their actual positions in the LLM input tokens. | ||
This result in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio. | ||
By setting the appropriate dataloader hyperparameters, we can easily reduce the amount of padding tokens. | ||
We leverage Flex Attention to efficiently handle varying number of patches per image. | ||
lkhphuc marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtitan.components.loss import build_cross_entropy_loss | ||
from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
from torchtitan.components.optimizer import build_optimizers | ||
from torchtitan.components.tokenizer import build_hf_tokenizer | ||
from torchtitan.components.validate import build_validator | ||
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec | ||
|
||
from .datasets.mm_datasets import build_mm_dataloader | ||
from .infra.parallelize import parallelize_vlm | ||
# from .infra.pipeline import pipeline_llama | ||
from .model.args import Llama3Siglip2ModelArgs, Siglip2ModelArgs | ||
from .model.model import Llama3Siglip2Transformer | ||
|
||
__all__ = [ | ||
"parallelize_vlm", | ||
# "pipeline_llama", | ||
"Llama3Siglip2ModelArgs", | ||
"Llama3Siglip2Transformer", | ||
"llama3_siglip2_configs", | ||
] | ||
|
||
|
||
siglip2_configs = { | ||
"debugmodel": Siglip2ModelArgs( | ||
dim=128, | ||
ffn_dim=256, | ||
n_layers=4, | ||
n_heads=2, | ||
) | ||
} | ||
|
||
llama3_siglip2_configs = { | ||
"debugmodel": Llama3Siglip2ModelArgs( | ||
encoder=siglip2_configs["debugmodel"], | ||
dim=256, n_layers=6, n_heads=16, vocab_size=2048, rope_theta=500000 | ||
), | ||
"debugmodel_flex_attn": Llama3Siglip2ModelArgs( | ||
encoder=siglip2_configs["debugmodel"], | ||
dim=256, | ||
n_layers=6, | ||
n_heads=16, | ||
vocab_size=2000, | ||
rope_theta=500000, | ||
use_flex_attn=True, | ||
attn_mask_type="block_causal", | ||
), | ||
"8B": Llama3Siglip2ModelArgs( | ||
encoder=siglip2_configs["debugmodel"], | ||
dim=4096, | ||
n_layers=32, | ||
n_heads=32, | ||
n_kv_heads=8, | ||
ffn_dim_multiplier=1.3, | ||
multiple_of=1024, | ||
rope_theta=500000, | ||
), | ||
"70B": Llama3Siglip2ModelArgs( | ||
encoder=siglip2_configs["debugmodel"], | ||
dim=8192, | ||
n_layers=80, | ||
n_heads=64, | ||
n_kv_heads=8, | ||
ffn_dim_multiplier=1.3, | ||
multiple_of=4096, | ||
rope_theta=500000, | ||
), | ||
"405B": Llama3Siglip2ModelArgs( | ||
encoder=siglip2_configs["debugmodel"], | ||
dim=16384, | ||
n_layers=126, | ||
n_heads=128, | ||
n_kv_heads=8, | ||
ffn_dim_multiplier=1.2, | ||
multiple_of=4096, | ||
rope_theta=500000, | ||
), | ||
} | ||
|
||
|
||
register_train_spec( | ||
TrainSpec( | ||
name="llama3-siglip2", | ||
model_cls=Llama3Siglip2Transformer, | ||
model_args=llama3_siglip2_configs, | ||
parallelize_fn=parallelize_vlm, | ||
pipelining_fn=None, | ||
build_optimizers_fn=build_optimizers, | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
build_dataloader_fn=build_mm_dataloader, | ||
build_tokenizer_fn=build_hf_tokenizer, | ||
build_loss_fn=build_cross_entropy_loss, | ||
build_validator_fn=build_validator, | ||
# state_dict_adapter=Llama3StateDictAdapter, | ||
) | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass | ||
from typing import Any, Dict, List | ||
|
||
import einops as E | ||
import torch | ||
from torch.nn.utils.rnn import pad_sequence | ||
|
||
from torchtitan.tools.logging import logger | ||
|
||
|
||
IGNORE_INDEX = -100 | ||
|
||
|
||
@dataclass | ||
class MultiModalCollatorNLD: | ||
"""Collator that works with patches in NLD format (N=batch, L=patches, D=patch_features)""" | ||
|
||
padding_idx: int = 0 | ||
ignore_idx: int = IGNORE_INDEX | ||
max_images_per_batch: int = 5 | ||
max_patch_per_image: int = 256 # Maximum patches per image | ||
patch_size: int = 16 # Patch size for converting images to patches | ||
merge_size: int = 1 # Merge size for converting spatial patches to channel dim | ||
seq_len: int = 2048 | ||
|
||
def convert_to_patches( | ||
self, pixel_values: torch.Tensor | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
"""Direct NTHWC -> NLD conversion using einops.""" | ||
N, T, H, W, C = pixel_values.shape | ||
wwwjn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ps = self.patch_size | ||
device = pixel_values.device | ||
patches = E.rearrange( | ||
pixel_values, "n t (h p1) (w p2) c -> n (t h w) (p1 p2 c)", p1=ps, p2=ps | ||
) | ||
|
||
coords = torch.meshgrid( | ||
torch.arange(T, device=device), | ||
torch.arange(H // ps, device=device), | ||
torch.arange(W // ps, device=device), | ||
indexing="ij", | ||
) | ||
grid = E.rearrange(torch.stack(coords), "coords t h w -> (t h w) coords") | ||
grid = grid.unsqueeze(0).expand(N, -1, -1) # (N, t*h*w, 3) | ||
|
||
# All patches are valid since we resize images to be divisible by patch_size | ||
return patches, grid | ||
|
||
def _pad_to_max(self, patches, grids): | ||
"""Pad or truncate to max_patch_per_image.""" | ||
N, L, D = patches.shape | ||
if L == self.max_patch_per_image: | ||
return patches, grids | ||
elif L < self.max_patch_per_image: | ||
# Pad | ||
pad_len = self.max_patch_per_image - L | ||
zero_patches = torch.zeros(N, pad_len, D, device=patches.device) | ||
invalid_grids = torch.full( | ||
(grids.shape[0], pad_len, 3), -1, device=grids.device | ||
) | ||
return torch.cat([patches, zero_patches], 1), torch.cat( | ||
[grids, invalid_grids], 1 | ||
) | ||
else: | ||
# Truncate | ||
return ( | ||
wwwjn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
patches[:, : self.max_patch_per_image], | ||
grids[:, : self.max_patch_per_image], | ||
) | ||
|
||
def __call__( | ||
self, batch: List[Dict[str, Any]] | ||
) -> tuple[Dict[str, torch.Tensor | None], torch.Tensor]: | ||
"""Encode batch with patch-based approach.""" | ||
if not batch: | ||
return None | ||
|
||
# Count images per sample and total images | ||
images_per_sample = [] | ||
for sample in batch: | ||
num_images = ( | ||
len(sample.get("pixel_values", [])) if "pixel_values" in sample else 0 | ||
) | ||
images_per_sample.append(num_images) | ||
|
||
# Remove samples from end until total images <= max_images_per_batch | ||
wwwjn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
total_images = sum(images_per_sample) | ||
while total_images > self.max_images_per_batch and batch: | ||
removed_images = images_per_sample.pop() | ||
total_images -= removed_images | ||
batch.pop() | ||
logger.warning(f"Removed sample with {removed_images} images to keep total images <= {self.max_images_per_batch}") | ||
|
||
all_images = [ | ||
img | ||
for sample in batch | ||
if "pixel_values" in sample | ||
for img in sample["pixel_values"] | ||
] | ||
|
||
if all_images: | ||
patch_list, grid_list = [], [] | ||
for img in all_images: | ||
p, g = self.convert_to_patches(img.unsqueeze(0)) | ||
lkhphuc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
p, g = self._pad_to_max(p, g) | ||
patch_list.append(p[0]) | ||
grid_list.append(g[0]) | ||
patches = torch.stack(patch_list) | ||
grids = torch.stack(grid_list) | ||
|
||
if len(all_images) < self.max_images_per_batch: | ||
blank_count = self.max_images_per_batch - len(all_images) | ||
blank_patches = torch.zeros( | ||
blank_count, | ||
self.max_patch_per_image, | ||
patches.shape[2], | ||
device=patches.device, | ||
) | ||
blank_grids = torch.full( | ||
(blank_count, self.max_patch_per_image, 3), -1, device=grids.device | ||
) | ||
patches = torch.cat([patches, blank_patches], dim=0) | ||
grids = torch.cat([grids, blank_grids], dim=0) | ||
else: | ||
patches = grids = None | ||
|
||
# Text processing | ||
input_ids = pad_sequence( | ||
lkhphuc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
[s["input_ids"] for s in batch], | ||
batch_first=True, | ||
padding_value=self.padding_idx, | ||
) | ||
labels = pad_sequence( | ||
[s["labels"] for s in batch], | ||
batch_first=True, | ||
padding_value=self.padding_idx, | ||
) | ||
|
||
# Pad along batch dimension if needed | ||
batch_size = len(batch) | ||
if input_ids.size(0) < batch_size: | ||
padding_needed = batch_size - input_ids.size(0) | ||
padding_input = ( | ||
torch.ones(padding_needed, input_ids.size(1), dtype=torch.long) | ||
* self.padding_idx | ||
) | ||
padding_labels = ( | ||
torch.ones(padding_needed, labels.size(1), dtype=torch.long) | ||
* self.padding_idx | ||
) | ||
input_ids = torch.cat([input_ids, padding_input], dim=0) | ||
labels = torch.cat([labels, padding_labels], dim=0) | ||
|
||
# Handle sequence length | ||
current_length = input_ids.size(1) | ||
desired_length = self.seq_len + 1 # Extra token for label shift and cut | ||
if current_length < desired_length: | ||
padding_length = desired_length - current_length | ||
padding_input = ( | ||
torch.ones(batch_size, padding_length, dtype=torch.long) | ||
* self.padding_idx | ||
) | ||
padding_labels = ( | ||
torch.ones(batch_size, padding_length, dtype=torch.long) | ||
* self.padding_idx | ||
) | ||
input_ids = torch.cat([input_ids, padding_input], dim=1) | ||
labels = torch.cat([labels, padding_labels], dim=1) | ||
elif current_length > self.seq_len: | ||
input_ids = input_ids[:, :desired_length] | ||
labels = labels[:, :desired_length] | ||
|
||
labels[labels == self.padding_idx] = self.ignore_idx | ||
# Cut and shift | ||
input_ids = input_ids[:, :-1] | ||
labels = labels[:, 1:] | ||
|
||
return { | ||
"input": input_ids, | ||
"pixel_values": patches, | ||
"grid_thw": grids, | ||
}, labels |
Uh oh!
There was an error while loading. Please reload this page.