Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
import torchtitan.experiments.llama4 # noqa: F401
import torchtitan.experiments.qwen3
import torchtitan.experiments.simple_fsdp # noqa: F401
import torchtitan.experiments.vlm # noqa: F401
19 changes: 19 additions & 0 deletions torchtitan/experiments/vlm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Vision Language Model training in `torchtitan`

**under active development**

This folder showcases how to train modern Vision Language Model (vlm) in torchtitan.


## Features:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is mainly describing dataloader features. Can we separate into 2 parts in README: 1) model features (What model/encoder, What you have achieved now, eg FSDP, AC, compile, and TODOs); 2) dataloader features

- Native Aspect Ratio: not limited to square crops.
- Native Resolution: images in a batch can have different sizes, no more image tiles and thumbnails.
- Native Interleaved data: training samples can have variable number of images, interleaved with text at different position. You can train more than just a captioning model.


## Design
Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two hyperparameters, image batch size `N` and image length `L` (in patches), and pad the actual image patches to this fixed size.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think including the digram in the PR description will make this part much clear! We should figure out a way to include images in README.

Then we scatter the patch embeddings to their actual positions in the LLM input tokens.
This result in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio.
By setting the appropriate dataloader hyperparameters, we can easily reduce the amount of padding tokens.
We leverage FlexAttention to efficiently handle varying number of patches per image.
101 changes: 101 additions & 0 deletions torchtitan/experiments/vlm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.components.validate import build_validator
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .datasets.mm_datasets import build_mm_dataloader
from .infra.parallelize import parallelize_vlm
# from .infra.pipeline import pipeline_llama
from .model.args import Llama3Siglip2ModelArgs, Siglip2ModelArgs
from .model.model import Llama3Siglip2Transformer

__all__ = [
"parallelize_vlm",
# "pipeline_llama",
"Llama3Siglip2ModelArgs",
"Llama3Siglip2Transformer",
"llama3_siglip2_configs",
]


siglip2_configs = {
"debugmodel": Siglip2ModelArgs(
dim=128,
ffn_dim=256,
n_layers=4,
n_heads=2,
)
}

llama3_siglip2_configs = {
"debugmodel": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=256, n_layers=6, n_heads=16, vocab_size=2048, rope_theta=500000
),
"debugmodel_flex_attn": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=256,
n_layers=6,
n_heads=16,
vocab_size=2000,
rope_theta=500000,
use_flex_attn=True,
attn_mask_type="block_causal",
),
"8B": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=1024,
rope_theta=500000,
),
"70B": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=8192,
n_layers=80,
n_heads=64,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=4096,
rope_theta=500000,
),
"405B": Llama3Siglip2ModelArgs(
encoder=siglip2_configs["debugmodel"],
dim=16384,
n_layers=126,
n_heads=128,
n_kv_heads=8,
ffn_dim_multiplier=1.2,
multiple_of=4096,
rope_theta=500000,
),
}


register_train_spec(
TrainSpec(
name="llama3-siglip2",
model_cls=Llama3Siglip2Transformer,
model_args=llama3_siglip2_configs,
parallelize_fn=parallelize_vlm,
pipelining_fn=None,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_mm_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
build_validator_fn=build_validator,
# state_dict_adapter=Llama3StateDictAdapter,
)
)
48 changes: 48 additions & 0 deletions torchtitan/experiments/vlm/assets/job_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field


@dataclass
class Training:
max_images_per_batch: int = 10
"""Vision encoder batch size (N)"""
max_patches_per_image: int = 256
"""Vision encoder sequence length (L)"""
patch_size: int = 16
""" Patch size of the vision encoder.
For example, image size 256x256, patch size 16
Number of visual tokens is: (256/16)**2=256
"""
spatial_merge_size: int = 1
""" Spatially merge visual tokens after encoder.
For example: image size 256x256, patch size 16, spaitl merge size is 2
Number of visual tokens for the LLM: (256/16/2)**2 = 8
"""
packing_buffer_size: int = 0
""" Set to a value >0 to enable sample packing.
This control the buffer uses to store training samples avaliable for packing.
"""


# HACK: couldn't figure out how to modify the HF tokenizer's json
# to make these attribute accesible. Ideally these should be accesible from the tokenizer itself.
@dataclass
class SpecialTokens:
img_token: str = "<|image|>"
boi_token: str = "<|begin_of_image|>"
eoi_token: str = "<|end_of_image|>"
img_id: int = 1998
boi_id: int = 1999
eoi_id: int = 2000
pad_id: int = 2001


@dataclass
class JobConfig:
training: Training = field(default_factory=Training)
special_tokens: SpecialTokens = field(default_factory=SpecialTokens)
Loading