Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
import torchtitan.experiments.llama4 # noqa: F401
import torchtitan.experiments.qwen3
import torchtitan.experiments.simple_fsdp # noqa: F401
import torchtitan.experiments.vlm # noqa: F401
58 changes: 58 additions & 0 deletions torchtitan/experiments/vlm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Vision Language Model training in `torchtitan`

**under active development**

This folder showcases how to train modern Vision Language Model (vlm) in torchtitan.


## Features:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is mainly describing dataloader features. Can we separate into 2 parts in README: 1) model features (What model/encoder, What you have achieved now, eg FSDP, AC, compile, and TODOs); 2) dataloader features

- Native Aspect Ratio: not limited to square crops.
- Native Resolution: images in a batch can have different sizes, no more image tiles and thumbnails.
- Native Interleaved data: training samples can have variable number of images, interleaved with text at different position. You can train more than just a captioning model.


## Design
Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two hyperparameters, image batch size `N` and image length `L` (in patches), and pad the actual image patches to this fixed size.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think including the digram in the PR description will make this part much clear! We should figure out a way to include images in README.

Then we scatter the patch embeddings to their actual positions in the LLM input tokens.

<img width="1398" height="840" alt="Screenshot 2025-08-21 at 16 21 57" src="https://github.com/user-attachments/assets/63fcbbc1-c587-4a63-8246-411cb72f5789" />

- After `tok_embedding`, we obtain tokens of shape `BxS`.
- After `encoder`, we obtain visual tokens of shape `NxL`.
- We extract the valid visual tokens only
- Then scatter those tokens to their actual positions in the LLM input tokens.


This result in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio:
- Depending on data mixtures, we can set dataloader's hyperparameters `N, L` to have minimal empty image padding (in batch dimension).
- Use modern pytorch features (FlexAttention, compile etc) for efficient handling of different attention mask per (padding in sequence dimension).
- Interface nicely with TP, PP, etc


## Implementation

### Dataloader
This approach requires the dataloader to handle the following aspect:
- [x] Interleave the correct precise numbers of image tokens in the inputs token based on encoder's patch size and input images' size
- [x] Convert images/videos to 1D sequence of patchs:
- `rearrange(pixels, 'n (t pt) (h ph) (w pw) c -> n (t h w) (pt p pw c)', pt=temporal_ps, ph=patch_size, pw=patch_size)`
- Pad all image patches sequence to a fixed length and return `pixel_values.shape == [N, L, D]`
- [x] Return a `grid_thw.shape == [N, L, 3]` to keep track of the location indicies of each patches in the images. Padding image can be tracked in the same tensors with values `-1`.
- [x] LLM Sample / Document Packing.
- [x] Captioning dataset: CC12M
- [x] Interleaved dataset: Obelics



### Model
We also need Ar pretrained vision encoder with support for native resolution and aspect ratio. There is relatively few Vision Encoder that have this capability up until recently, including Siglip2, AimV2, and most recently DINOv3.
- [ ] Currently we support Siglip2 encoder using Positional Embedding interpolation approach.
- [x] Base modelling code.
- [ ] Weights conversion and loading from HF.
- [x] FSDP for both Encoder and Decoder
- [x] Context Parallel for LLM only, since we will use FlexAttention for Encoder.
- [ ] FlexAttention for with different seq len per image.
- [ ] Compile for Encoder + Deocoder
- [ ] Tensor Parallel
- [ ] Pipeline Parallel

101 changes: 101 additions & 0 deletions torchtitan/experiments/vlm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.components.validate import build_validator
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .datasets.mm_datasets import build_mm_dataloader
from .infra.parallelize import parallelize_vlm
# from .infra.pipeline import pipeline_llama
from .model.args import Llama3Siglip2ModelArgs, Siglip2ModelArgs
from .model.model import Llama3Siglip2Transformer

__all__ = [
"parallelize_vlm",
# "pipeline_llama",
"Llama3Siglip2ModelArgs",
"Llama3Siglip2Transformer",
"llama3_siglip2_configs",
]


siglip2_configs = {
"debugmodel": Siglip2ModelArgs(
dim=128,
ffn_dim=256,
n_layers=4,
n_heads=2,
)
}

llama3_siglip2_configs = {
"debugmodel": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=256, n_layers=6, n_heads=16, vocab_size=2048, rope_theta=500000
),
"debugmodel_flex_attn": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=256,
n_layers=6,
n_heads=16,
vocab_size=2000,
rope_theta=500000,
use_flex_attn=True,
attn_mask_type="block_causal",
),
"8B": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=1024,
rope_theta=500000,
),
"70B": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=8192,
n_layers=80,
n_heads=64,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=4096,
rope_theta=500000,
),
"405B": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=16384,
n_layers=126,
n_heads=128,
n_kv_heads=8,
ffn_dim_multiplier=1.2,
multiple_of=4096,
rope_theta=500000,
),
}


register_train_spec(
TrainSpec(
name="llama3-siglip2",
model_cls=Llama3Siglip2Transformer,
model_args=llama3_siglip2_configs,
parallelize_fn=parallelize_vlm,
pipelining_fn=None,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_mm_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
build_validator_fn=build_validator,
# state_dict_adapter=Llama3StateDictAdapter,
)
)
48 changes: 48 additions & 0 deletions torchtitan/experiments/vlm/assets/job_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field


@dataclass
class Training:
max_images_per_batch: int = 10
"""Vision encoder batch size (N)"""
max_patches_per_image: int = 256
"""Vision encoder sequence length (L)"""
patch_size: int = 16
""" Patch size of the vision encoder.
For example, image size 256x256, patch size 16
Number of visual tokens is: (256/16)**2=256
"""
spatial_merge_size: int = 1
""" Spatially merge visual tokens after encoder.
For example: image size 256x256, patch size 16, spaitl merge size is 2
Number of visual tokens for the LLM: (256/16/2)**2 = 8
"""
packing_buffer_size: int = 0
""" Set to a value >0 to enable sample packing.
This control the buffer uses to store training samples avaliable for packing.
"""


# HACK: couldn't figure out how to modify the HF tokenizer's json
# to make these attribute accesible. Ideally these should be accesible from the tokenizer itself.
@dataclass
class SpecialTokens:
img_token: str = "<|image|>"
boi_token: str = "<|begin_of_image|>"
eoi_token: str = "<|end_of_image|>"
img_id: int = 1998
boi_id: int = 1999
eoi_id: int = 2000
pad_id: int = 2001


@dataclass
class JobConfig:
training: Training = field(default_factory=Training)
special_tokens: SpecialTokens = field(default_factory=SpecialTokens)
Loading