You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: torchtitan/experiments/vlm/README.md
+42-3Lines changed: 42 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -14,6 +14,45 @@ This folder showcases how to train modern Vision Language Model (vlm) in torchti
14
14
## Design
15
15
Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two hyperparameters, image batch size `N` and image length `L` (in patches), and pad the actual image patches to this fixed size.
16
16
Then we scatter the patch embeddings to their actual positions in the LLM input tokens.
17
-
This result in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio.
18
-
By setting the appropriate dataloader hyperparameters, we can easily reduce the amount of padding tokens.
19
-
We leverage FlexAttention to efficiently handle varying number of patches per image.
17
+
18
+
<imgwidth="1398"height="840"alt="Screenshot 2025-08-21 at 16 21 57"src="https://github.com/user-attachments/assets/63fcbbc1-c587-4a63-8246-411cb72f5789" />
19
+
20
+
- After `tok_embedding`, we obtain tokens of shape `BxS`.
21
+
- After `encoder`, we obtain visual tokens of shape `NxL`.
22
+
- We extract the valid visual tokens only
23
+
- Then scatter those tokens to their actual positions in the LLM input tokens.
24
+
25
+
26
+
This result in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio:
27
+
- Depending on data mixtures, we can set dataloader's hyperparameters `N, L` to have minimal empty image padding (in batch dimension).
28
+
- Use modern pytorch features (FlexAttention, compile etc) for efficient handling of different attention mask per (padding in sequence dimension).
29
+
- Interface nicely with TP, PP, etc
30
+
31
+
32
+
## Implementation
33
+
34
+
### Dataloader
35
+
This approach requires the dataloader to handle the following aspect:
36
+
-[x] Interleave the correct precise numbers of image tokens in the inputs token based on encoder's patch size and input images' size
37
+
-[x] Convert images/videos to 1D sequence of patchs:
38
+
-`rearrange(pixels, 'n (t pt) (h ph) (w pw) c -> n (t h w) (pt p pw c)', pt=temporal_ps, ph=patch_size, pw=patch_size)`
39
+
- Pad all image patches sequence to a fixed length and return `pixel_values.shape == [N, L, D]`
40
+
-[x] Return a `grid_thw.shape == [N, L, 3]` to keep track of the location indicies of each patches in the images. Padding image can be tracked in the same tensors with values `-1`.
41
+
-[x] LLM Sample / Document Packing.
42
+
-[x] Captioning dataset: CC12M
43
+
-[x] Interleaved dataset: Obelics
44
+
45
+
46
+
47
+
### Model
48
+
We also need Ar pretrained vision encoder with support for native resolution and aspect ratio. There is relatively few Vision Encoder that have this capability up until recently, including Siglip2, AimV2, and most recently DINOv3.
49
+
-[ ] Currently we support Siglip2 encoder using Positional Embedding interpolation approach.
50
+
-[x] Base modelling code.
51
+
-[ ] Weights conversion and loading from HF.
52
+
-[x] FSDP for both Encoder and Decoder
53
+
-[x] Context Parallel for LLM only, since we will use FlexAttention for Encoder.
54
+
-[ ] FlexAttention for with different seq len per image.
0 commit comments