Skip to content

Commit f893c4a

Browse files
committed
[vlm] infra to get debugmodel running
1 parent cad098a commit f893c4a

File tree

9 files changed

+592
-3
lines changed

9 files changed

+592
-3
lines changed

tests/assets/tokenizer/tokenizer.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2029,7 +2029,10 @@
20292029
"land": 1994,
20302030
"?\n": 1995,
20312031
" respect": 1996,
2032-
"ances": 1997
2032+
"ances": 1997,
2033+
"<|image|>": 1998,
2034+
"<|begin_of_image|>": 1999,
2035+
"<|end_of_image|>": 2000
20332036
},
20342037
"merges": [
20352038
]

tests/assets/tokenizer/tokenizer_config.json

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,38 @@
1515
"rstrip": false,
1616
"single_word": false,
1717
"special": true
18+
},
19+
"1998": {
20+
"content": "<|image|>",
21+
"lstrip": false,
22+
"normalized": false,
23+
"rstrip": false,
24+
"single_word": false,
25+
"special": true
26+
},
27+
"1999": {
28+
"content": "<|begin_of_image|>",
29+
"lstrip": false,
30+
"normalized": false,
31+
"rstrip": false,
32+
"single_word": false,
33+
"special": true
34+
},
35+
"2000": {
36+
"content": "<|end_of_image|>",
37+
"lstrip": false,
38+
"normalized": false,
39+
"rstrip": false,
40+
"single_word": false,
41+
"special": true
1842
}
1943
},
2044
"bos_token": "<|begin_of_text|>",
2145
"clean_up_tokenization_spaces": true,
2246
"eos_token": "<|end_of_text|>",
47+
"img_token": "<|image|>",
48+
"boi_token": "<|begin_of_image|>",
49+
"eoi_token": "<|end_of_image|>",
2350
"model_input_names": [
2451
"input_ids",
2552
"attention_mask"

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
import torchtitan.experiments.llama4 # noqa: F401
88
import torchtitan.experiments.qwen3
99
import torchtitan.experiments.simple_fsdp # noqa: F401
10+
import torchtitan.experiments.vlm # noqa: F401

torchtitan/experiments/vlm/README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Vision Language Model training in `torchtitan`
2+
3+
**under active development**
4+
5+
This folder showcases how to train modern Vision Language Model (vlm) in torchtitan.
6+
7+
8+
## Features:
9+
- Native Aspect Ratio: not limited to square crops.
10+
- Native Resolution: images in a batch can have different sizes, no more image tiles and thumbnails.
11+
- Native Interleaved data: training samples can have variable number of images, interleaved with text at different position. You can train more than just a captioning model.
12+
13+
14+
## Design
15+
Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two hyperparameters, image batch size `N` and image length `L` (in patches), and pad the actual image patches to this fixed size.
16+
Then we scatter the patch embeddings to their actual positions in the LLM input tokens.
17+
This result in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio.
18+
By setting the appropriate dataloader hyperparameters, we can easily reduce the amount of padding tokens.
19+
We leverage Flex Attention to efficiently handle varying number of patches per image.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torchtitan.components.loss import build_cross_entropy_loss
8+
from torchtitan.components.lr_scheduler import build_lr_schedulers
9+
from torchtitan.components.optimizer import build_optimizers
10+
from torchtitan.components.tokenizer import build_hf_tokenizer
11+
from torchtitan.components.validate import build_validator
12+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
13+
14+
from .datasets.mm_datasets import build_mm_dataloader
15+
from .infra.parallelize import parallelize_vlm
16+
# from .infra.pipeline import pipeline_llama
17+
from .model.args import Llama3Siglip2ModelArgs, Siglip2ModelArgs
18+
from .model.model import Llama3Siglip2Transformer
19+
20+
__all__ = [
21+
"parallelize_vlm",
22+
# "pipeline_llama",
23+
"Llama3Siglip2ModelArgs",
24+
"Llama3Siglip2Transformer",
25+
"llama3_siglip2_configs",
26+
]
27+
28+
29+
siglip2_configs = {
30+
"debugmodel": Siglip2ModelArgs(
31+
dim=128,
32+
ffn_dim=256,
33+
n_layers=4,
34+
n_heads=2,
35+
)
36+
}
37+
38+
llama3_siglip2_configs = {
39+
"debugmodel": Llama3Siglip2ModelArgs(
40+
encoder=siglip2_configs["debugmodel"],
41+
dim=256, n_layers=6, n_heads=16, vocab_size=2048, rope_theta=500000
42+
),
43+
"debugmodel_flex_attn": Llama3Siglip2ModelArgs(
44+
encoder=siglip2_configs["debugmodel"],
45+
dim=256,
46+
n_layers=6,
47+
n_heads=16,
48+
vocab_size=2000,
49+
rope_theta=500000,
50+
use_flex_attn=True,
51+
attn_mask_type="block_causal",
52+
),
53+
"8B": Llama3Siglip2ModelArgs(
54+
encoder=siglip2_configs["debugmodel"],
55+
dim=4096,
56+
n_layers=32,
57+
n_heads=32,
58+
n_kv_heads=8,
59+
ffn_dim_multiplier=1.3,
60+
multiple_of=1024,
61+
rope_theta=500000,
62+
),
63+
"70B": Llama3Siglip2ModelArgs(
64+
encoder=siglip2_configs["debugmodel"],
65+
dim=8192,
66+
n_layers=80,
67+
n_heads=64,
68+
n_kv_heads=8,
69+
ffn_dim_multiplier=1.3,
70+
multiple_of=4096,
71+
rope_theta=500000,
72+
),
73+
"405B": Llama3Siglip2ModelArgs(
74+
encoder=siglip2_configs["debugmodel"],
75+
dim=16384,
76+
n_layers=126,
77+
n_heads=128,
78+
n_kv_heads=8,
79+
ffn_dim_multiplier=1.2,
80+
multiple_of=4096,
81+
rope_theta=500000,
82+
),
83+
}
84+
85+
86+
register_train_spec(
87+
TrainSpec(
88+
name="llama3-siglip2",
89+
model_cls=Llama3Siglip2Transformer,
90+
model_args=llama3_siglip2_configs,
91+
parallelize_fn=parallelize_vlm,
92+
pipelining_fn=None,
93+
build_optimizers_fn=build_optimizers,
94+
build_lr_schedulers_fn=build_lr_schedulers,
95+
build_dataloader_fn=build_mm_dataloader,
96+
build_tokenizer_fn=build_hf_tokenizer,
97+
build_loss_fn=build_cross_entropy_loss,
98+
build_validator_fn=build_validator,
99+
# state_dict_adapter=Llama3StateDictAdapter,
100+
)
101+
)

0 commit comments

Comments
 (0)