Skip to content

Commit cad098a

Browse files
committed
[vlm] model code w/ siglip2 encoder
tmp
1 parent 2c3609b commit cad098a

File tree

3 files changed

+355
-0
lines changed

3 files changed

+355
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass, field
8+
9+
from torchtitan.models.llama3 import TransformerModelArgs as Llama3Args
10+
11+
12+
@dataclass
13+
class Siglip2ModelArgs:
14+
dim: int = 768
15+
ffn_dim: int = 3072
16+
n_layers: int = 12
17+
n_heads: int = 12
18+
19+
n_pos_embs: int = 16 # Number of positional embeddings per h&w
20+
n_channels: int = 3 # RGB channels
21+
patch_size: int = 16
22+
23+
layer_norm_eps: float = 1e-6
24+
use_flex_attn: bool = True
25+
attn_mask_type: str = "causal"
26+
27+
28+
@dataclass
29+
class Llama3Siglip2ModelArgs(Llama3Args):
30+
encoder: Siglip2ModelArgs = field(default_factory=Siglip2ModelArgs)
31+
img_token_id: int = 1998
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import einops as E
8+
import torch
9+
from torch import nn
10+
11+
from torchtitan.models.attention import init_attention_mask
12+
from torchtitan.models.llama3 import Transformer as Llama3
13+
14+
from .args import Llama3Siglip2ModelArgs
15+
from .siglip2 import VisionTransformer
16+
17+
18+
class Projector(nn.Module):
19+
"""Project the Encoder embedding to the LLM embedding."""
20+
21+
def __init__(self, in_dim: int, out_dim: int) -> None:
22+
super().__init__()
23+
self.w1 = nn.Linear(in_dim, in_dim)
24+
self.w2 = nn.Linear(in_dim, out_dim)
25+
self.init_weights()
26+
27+
def forward(self, x_NLD: torch.Tensor):
28+
x_NLD = self.w1(x_NLD)
29+
x_NLD = nn.functional.silu(x_NLD)
30+
x_NLD = self.w2(x_NLD)
31+
return x_NLD
32+
33+
def init_weights(self):
34+
nn.init.xavier_uniform_(self.w1.weight)
35+
if self.w1.bias is not None:
36+
nn.init.zeros_(self.w1.bias)
37+
nn.init.xavier_uniform_(self.w2.weight)
38+
if self.w2.bias is not None:
39+
nn.init.zeros_(self.w2.bias)
40+
41+
42+
class Llama3Siglip2Transformer(Llama3):
43+
def __init__(self, model_args: Llama3Siglip2ModelArgs):
44+
super().__init__(model_args)
45+
self.model_args = model_args
46+
self.encoder = VisionTransformer(model_args.encoder)
47+
self.projector = Projector(
48+
in_dim=model_args.encoder.dim, out_dim=model_args.dim
49+
)
50+
self.n_pixels_per_token = model_args.encoder.patch_size**2
51+
self.init_encoder_weights()
52+
53+
def init_encoder_weights(self, buffer_device=None):
54+
super().init_weights(buffer_device=buffer_device)
55+
if self.encoder is not None:
56+
self.encoder.init_weights()
57+
if self.projector is not None:
58+
self.projector.init_weights()
59+
60+
def _scatter_img_tokens(self, h_BSD, tokens_BS, i_NLD, i_mask_NL, img_id=None):
61+
img_id = img_id or self.model_args.img_token_id
62+
B, S, D = h_BSD.shape
63+
# Where are the image tokens in LLM input, make broadcastable with h_BSD
64+
img_mask_h_BSD = E.repeat(tokens_BS == img_id, "b s -> b s 1")
65+
# Only get valid (non-padded) tokens, result are flatten
66+
i_flatten = torch.masked_select(i_NLD, mask=i_mask_NL.unsqueeze(-1))
67+
68+
assert i_flatten.numel() // D == img_mask_h_BSD.sum(), (
69+
f"Different number of visual embeddings {i_flatten.numel() // D} "
70+
f"with placeholder in input token embeddings {img_mask_h_BSD.sum()}"
71+
)
72+
h_BSD.masked_scatter_(mask=img_mask_h_BSD, source=i_flatten)
73+
return h_BSD
74+
75+
def forward(
76+
self,
77+
tokens: torch.Tensor,
78+
eos_id: int | None = None,
79+
input_batch: torch.Tensor | None = None,
80+
pixel_values: torch.Tensor | None = None,
81+
grid_thw: torch.Tensor | None = None,
82+
):
83+
if self.model_args.use_flex_attn:
84+
init_attention_mask(
85+
input_batch if input_batch is not None else tokens, eos_id=self.eos_id
86+
)
87+
88+
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
89+
h_BSD = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
90+
91+
if self.encoder is not None:
92+
grid_hw = grid_thw[:, :, 1:] # Siglip2 only support image hw
93+
pixel_masks = E.reduce(grid_hw != -1, "n l hw -> n l", reduction="all")
94+
i_NLD = self.encoder(pixel_values, pixel_masks, grid_hw)
95+
i_NLD = self.projector(i_NLD)
96+
h_BSD = self._scatter_img_tokens(h_BSD, tokens, i_NLD, pixel_masks)
97+
98+
for layer in self.layers.values():
99+
h_BSD = layer(h_BSD, self.freqs_cis)
100+
101+
h_BSD = self.norm(h_BSD) if self.norm else h_BSD
102+
output = self.output(h_BSD) if self.output else h_BSD
103+
return output
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import einops as E
8+
import torch
9+
import torch.nn.functional as F
10+
from torch import nn
11+
12+
from torchtitan.models.attention import build_attention, init_attention_mask
13+
14+
from .args import Siglip2ModelArgs
15+
16+
17+
def resize_positional_embeddings(
18+
pos_embs_HWD: torch.Tensor,
19+
spatial_shapes_N2: torch.Tensor,
20+
max_length: int,
21+
) -> torch.Tensor:
22+
"""
23+
Resize the learned 2D positional embeddings to image-specific size and pad to a fixed size.
24+
25+
Args:
26+
pos_embs_HWD (`torch.Tensor`):
27+
Position embeddings of shape (height, width, embed_dim)
28+
spatial_shapes (`torch.LongTensor`):
29+
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
30+
max_length (`int`):
31+
Maximum length of the positional embeddings to pad resized positional embeddings to
32+
33+
Returns:
34+
`torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
35+
"""
36+
_, _, D = pos_embs_HWD.shape
37+
B, _ = spatial_shapes_N2.shape
38+
39+
resized_embs_BLD = torch.empty(
40+
(B, max_length, D),
41+
device=pos_embs_HWD.device,
42+
dtype=pos_embs_HWD.dtype,
43+
)
44+
45+
# TODO: group images by size, and do interpolate,
46+
# or cache the interpolate output so we do this once per size
47+
for i in range(B):
48+
height, width = spatial_shapes_N2[i].tolist()
49+
if (height + width) == 0: # Skip empty padding images
50+
continue
51+
52+
resized_emb = F.interpolate(
53+
E.rearrange(pos_embs_HWD, "h w d -> 1 d h w"),
54+
size=(height, width),
55+
mode="bilinear",
56+
align_corners=False,
57+
antialias=True,
58+
)
59+
60+
resized_emb_LD = E.rearrange(resized_emb, "1 d h w -> (h w) d")
61+
resized_embs_BLD[i, : int(height * width)] = resized_emb_LD
62+
63+
return resized_embs_BLD
64+
65+
66+
class VisionEmbeddings(nn.Module):
67+
def __init__(self, args: Siglip2ModelArgs):
68+
super().__init__()
69+
self.patch_embedding = nn.Linear(
70+
in_features=args.n_channels * args.patch_size * args.patch_size,
71+
out_features=args.dim,
72+
)
73+
self.position_embedding = nn.Embedding(args.n_pos_embs**2, args.dim)
74+
self.n_pos_embs = args.n_pos_embs
75+
76+
def init_weights(self):
77+
nn.init.trunc_normal_(self.patch_embedding.weight, mean=0.0, std=0.02)
78+
nn.init.normal_(self.position_embedding.weight)
79+
80+
def forward(self, pixels_NLD: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
81+
# Apply patch embeddings to already patchified pixel values
82+
patch_embeds_NLD = self.patch_embedding(pixels_NLD)
83+
84+
# Get positional resized and padded positional embeddings
85+
pos_emb_HWD = self.position_embedding.weight.reshape(
86+
self.n_pos_embs, self.n_pos_embs, -1
87+
)
88+
spatial_h = E.reduce(grid_hw[:, :, 0], "n l -> n", reduction="max") + 1
89+
spatial_w = E.reduce(grid_hw[:, :, 1], "n l -> n", reduction="max") + 1
90+
spatial_shapes = torch.stack([spatial_h, spatial_w], dim=-1).long()
91+
resized_positional_embeddings = resize_positional_embeddings(
92+
pos_emb_HWD,
93+
spatial_shapes,
94+
max_length=pixels_NLD.shape[1],
95+
)
96+
# Add positional embeddings to patch embeddings
97+
embeddings = patch_embeds_NLD + resized_positional_embeddings
98+
return embeddings
99+
100+
101+
class Attention(nn.Module):
102+
"""
103+
Multi-head attention module.
104+
105+
Args:
106+
model_args (TransformerModelArgs): Model configuration arguments.
107+
108+
Attributes:
109+
n_heads (int): Number of query heads.
110+
head_dim (int): Dimension size of each attention head.
111+
wq (Linear): Linear transformation for queries.
112+
wk (Linear): Linear transformation for keys.
113+
wv (Linear): Linear transformation for values.
114+
wo (Linear): Linear transformation for output.
115+
116+
"""
117+
118+
def __init__(self, args: Siglip2ModelArgs):
119+
super().__init__()
120+
self.dim = args.dim
121+
self.head_dim = args.dim // args.n_heads
122+
123+
self.q_proj = nn.Linear(self.dim, self.dim)
124+
self.k_proj = nn.Linear(self.dim, self.dim)
125+
self.v_proj = nn.Linear(self.dim, self.dim)
126+
self.out_proj = nn.Linear(self.dim, self.dim)
127+
128+
self.attn = build_attention(
129+
use_flex_attn=True, attn_mask_type=args.attn_mask_type
130+
)
131+
132+
def forward(self, x: torch.Tensor):
133+
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
134+
135+
# Use self.head_dim instead of `n_heads` to infer the actual
136+
# local heads from sizes of xq, xk, and xv as TP may have sharded them
137+
# after the above linear ops.
138+
xq = E.rearrange(xq, "b l (h d) -> b h l d", d=self.head_dim)
139+
xk = E.rearrange(xk, "b l (h d) -> b h l d", d=self.head_dim)
140+
xv = E.rearrange(xv, "b l (h d) -> b h l d", d=self.head_dim)
141+
142+
output = self.attn(xq, xk, xv)
143+
output = E.rearrange(output, "b h l d -> b l (h d)").contiguous()
144+
145+
return self.out_proj(output)
146+
147+
def init_weights(self):
148+
for linear in (self.q_proj, self.k_proj, self.v_proj, self.out_proj):
149+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
150+
151+
152+
class FeedForward(nn.Module):
153+
def __init__(self, args: Siglip2ModelArgs):
154+
super().__init__()
155+
self.fc1 = nn.Linear(args.dim, args.ffn_dim)
156+
self.fc2 = nn.Linear(args.ffn_dim, args.dim)
157+
158+
def forward(self, x: torch.Tensor) -> torch.Tensor:
159+
x = self.fc1(x)
160+
x = F.gelu(x, approximate="tanh")
161+
x = self.fc2(x)
162+
return x
163+
164+
def init_weights(self):
165+
nn.init.trunc_normal_(self.fc1.weight, mean=0.0, std=0.02)
166+
nn.init.trunc_normal_(self.fc2.weight, mean=0.0, std=0.02)
167+
168+
169+
class TransformerLayer(nn.Module):
170+
def __init__(self, args: Siglip2ModelArgs):
171+
super().__init__()
172+
self.layer_norm1 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps)
173+
self.self_attn = Attention(args)
174+
self.layer_norm2 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps)
175+
self.mlp = FeedForward(args)
176+
177+
def forward(self, x: torch.Tensor) -> torch.Tensor:
178+
x = x + self.self_attn(self.layer_norm1(x))
179+
x = x + self.mlp(self.layer_norm2(x))
180+
return x
181+
182+
def init_weights(self):
183+
self.layer_norm1.reset_parameters()
184+
self.layer_norm2.reset_parameters()
185+
self.self_attn.init_weights()
186+
self.mlp.init_weights()
187+
188+
189+
class VisionTransformer(nn.Module):
190+
def __init__(self, args: Siglip2ModelArgs):
191+
super().__init__()
192+
self.args = args
193+
self.eos_id = 11
194+
195+
self.embeddings = VisionEmbeddings(args)
196+
self.layers = nn.ModuleDict(
197+
{str(idx): TransformerLayer(args) for idx in range(args.n_layers)}
198+
)
199+
self.post_layernorm = nn.LayerNorm(args.dim, eps=args.layer_norm_eps)
200+
201+
def forward(
202+
self,
203+
pixel_values_NLD: torch.FloatTensor,
204+
pixel_masks_NL: torch.BoolTensor,
205+
grid_hw: torch.LongTensor,
206+
):
207+
init_attention_mask(pixel_masks_NL, eos_id=self.eos_id)
208+
209+
h = self.embeddings(pixel_values_NLD, grid_hw)
210+
211+
for layer in self.layers.values():
212+
h = layer(h)
213+
h = self.post_layernorm(h)
214+
215+
return h
216+
217+
def init_weights(self):
218+
self.embeddings.init_weights()
219+
for layer in self.layers.values():
220+
layer.init_weights()
221+
self.post_layernorm.reset_parameters()

0 commit comments

Comments
 (0)