|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import einops as E |
| 8 | +import torch |
| 9 | +import torch.nn.functional as F |
| 10 | +from torch import nn |
| 11 | + |
| 12 | +from torchtitan.models.attention import build_attention, init_attention_mask |
| 13 | + |
| 14 | +from .args import Siglip2ModelArgs |
| 15 | + |
| 16 | + |
| 17 | +def resize_positional_embeddings( |
| 18 | + pos_embs_HWD: torch.Tensor, |
| 19 | + spatial_shapes_N2: torch.Tensor, |
| 20 | + max_length: int, |
| 21 | +) -> torch.Tensor: |
| 22 | + """ |
| 23 | + Resize the learned 2D positional embeddings to image-specific size and pad to a fixed size. |
| 24 | +
|
| 25 | + Args: |
| 26 | + pos_embs_HWD (`torch.Tensor`): |
| 27 | + Position embeddings of shape (height, width, embed_dim) |
| 28 | + spatial_shapes (`torch.LongTensor`): |
| 29 | + Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to |
| 30 | + max_length (`int`): |
| 31 | + Maximum length of the positional embeddings to pad resized positional embeddings to |
| 32 | +
|
| 33 | + Returns: |
| 34 | + `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim) |
| 35 | + """ |
| 36 | + _, _, D = pos_embs_HWD.shape |
| 37 | + B, _ = spatial_shapes_N2.shape |
| 38 | + |
| 39 | + resized_embs_BLD = torch.empty( |
| 40 | + (B, max_length, D), |
| 41 | + device=pos_embs_HWD.device, |
| 42 | + dtype=pos_embs_HWD.dtype, |
| 43 | + ) |
| 44 | + |
| 45 | + # TODO: group images by size, and do interpolate, |
| 46 | + # or cache the interpolate output so we do this once per size |
| 47 | + for i in range(B): |
| 48 | + height, width = spatial_shapes_N2[i].tolist() |
| 49 | + if (height + width) == 0: # Skip empty padding images |
| 50 | + continue |
| 51 | + |
| 52 | + resized_emb = F.interpolate( |
| 53 | + E.rearrange(pos_embs_HWD, "h w d -> 1 d h w"), |
| 54 | + size=(height, width), |
| 55 | + mode="bilinear", |
| 56 | + align_corners=False, |
| 57 | + antialias=True, |
| 58 | + ) |
| 59 | + |
| 60 | + resized_emb_LD = E.rearrange(resized_emb, "1 d h w -> (h w) d") |
| 61 | + resized_embs_BLD[i, : int(height * width)] = resized_emb_LD |
| 62 | + |
| 63 | + return resized_embs_BLD |
| 64 | + |
| 65 | + |
| 66 | +class VisionEmbeddings(nn.Module): |
| 67 | + def __init__(self, args: Siglip2ModelArgs): |
| 68 | + super().__init__() |
| 69 | + self.patch_embedding = nn.Linear( |
| 70 | + in_features=args.n_channels * args.patch_size * args.patch_size, |
| 71 | + out_features=args.dim, |
| 72 | + ) |
| 73 | + self.position_embedding = nn.Embedding(args.n_pos_embs**2, args.dim) |
| 74 | + self.n_pos_embs = args.n_pos_embs |
| 75 | + |
| 76 | + def init_weights(self): |
| 77 | + nn.init.trunc_normal_(self.patch_embedding.weight, mean=0.0, std=0.02) |
| 78 | + nn.init.normal_(self.position_embedding.weight) |
| 79 | + |
| 80 | + def forward(self, pixels_NLD: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: |
| 81 | + # Apply patch embeddings to already patchified pixel values |
| 82 | + patch_embeds_NLD = self.patch_embedding(pixels_NLD) |
| 83 | + |
| 84 | + # Get positional resized and padded positional embeddings |
| 85 | + pos_emb_HWD = self.position_embedding.weight.reshape( |
| 86 | + self.n_pos_embs, self.n_pos_embs, -1 |
| 87 | + ) |
| 88 | + spatial_h = E.reduce(grid_hw[:, :, 0], "n l -> n", reduction="max") + 1 |
| 89 | + spatial_w = E.reduce(grid_hw[:, :, 1], "n l -> n", reduction="max") + 1 |
| 90 | + spatial_shapes = torch.stack([spatial_h, spatial_w], dim=-1).long() |
| 91 | + resized_positional_embeddings = resize_positional_embeddings( |
| 92 | + pos_emb_HWD, |
| 93 | + spatial_shapes, |
| 94 | + max_length=pixels_NLD.shape[1], |
| 95 | + ) |
| 96 | + # Add positional embeddings to patch embeddings |
| 97 | + embeddings = patch_embeds_NLD + resized_positional_embeddings |
| 98 | + return embeddings |
| 99 | + |
| 100 | + |
| 101 | +class Attention(nn.Module): |
| 102 | + """ |
| 103 | + Multi-head attention module. |
| 104 | +
|
| 105 | + Args: |
| 106 | + model_args (TransformerModelArgs): Model configuration arguments. |
| 107 | +
|
| 108 | + Attributes: |
| 109 | + n_heads (int): Number of query heads. |
| 110 | + head_dim (int): Dimension size of each attention head. |
| 111 | + wq (Linear): Linear transformation for queries. |
| 112 | + wk (Linear): Linear transformation for keys. |
| 113 | + wv (Linear): Linear transformation for values. |
| 114 | + wo (Linear): Linear transformation for output. |
| 115 | +
|
| 116 | + """ |
| 117 | + |
| 118 | + def __init__(self, args: Siglip2ModelArgs): |
| 119 | + super().__init__() |
| 120 | + self.dim = args.dim |
| 121 | + self.head_dim = args.dim // args.n_heads |
| 122 | + |
| 123 | + self.q_proj = nn.Linear(self.dim, self.dim) |
| 124 | + self.k_proj = nn.Linear(self.dim, self.dim) |
| 125 | + self.v_proj = nn.Linear(self.dim, self.dim) |
| 126 | + self.out_proj = nn.Linear(self.dim, self.dim) |
| 127 | + |
| 128 | + self.attn = build_attention( |
| 129 | + use_flex_attn=True, attn_mask_type=args.attn_mask_type |
| 130 | + ) |
| 131 | + |
| 132 | + def forward(self, x: torch.Tensor): |
| 133 | + xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) |
| 134 | + |
| 135 | + # Use self.head_dim instead of `n_heads` to infer the actual |
| 136 | + # local heads from sizes of xq, xk, and xv as TP may have sharded them |
| 137 | + # after the above linear ops. |
| 138 | + xq = E.rearrange(xq, "b l (h d) -> b h l d", d=self.head_dim) |
| 139 | + xk = E.rearrange(xk, "b l (h d) -> b h l d", d=self.head_dim) |
| 140 | + xv = E.rearrange(xv, "b l (h d) -> b h l d", d=self.head_dim) |
| 141 | + |
| 142 | + output = self.attn(xq, xk, xv) |
| 143 | + output = E.rearrange(output, "b h l d -> b l (h d)").contiguous() |
| 144 | + |
| 145 | + return self.out_proj(output) |
| 146 | + |
| 147 | + def init_weights(self): |
| 148 | + for linear in (self.q_proj, self.k_proj, self.v_proj, self.out_proj): |
| 149 | + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) |
| 150 | + |
| 151 | + |
| 152 | +class FeedForward(nn.Module): |
| 153 | + def __init__(self, args: Siglip2ModelArgs): |
| 154 | + super().__init__() |
| 155 | + self.fc1 = nn.Linear(args.dim, args.ffn_dim) |
| 156 | + self.fc2 = nn.Linear(args.ffn_dim, args.dim) |
| 157 | + |
| 158 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 159 | + x = self.fc1(x) |
| 160 | + x = F.gelu(x, approximate="tanh") |
| 161 | + x = self.fc2(x) |
| 162 | + return x |
| 163 | + |
| 164 | + def init_weights(self): |
| 165 | + nn.init.trunc_normal_(self.fc1.weight, mean=0.0, std=0.02) |
| 166 | + nn.init.trunc_normal_(self.fc2.weight, mean=0.0, std=0.02) |
| 167 | + |
| 168 | + |
| 169 | +class TransformerLayer(nn.Module): |
| 170 | + def __init__(self, args: Siglip2ModelArgs): |
| 171 | + super().__init__() |
| 172 | + self.layer_norm1 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) |
| 173 | + self.self_attn = Attention(args) |
| 174 | + self.layer_norm2 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) |
| 175 | + self.mlp = FeedForward(args) |
| 176 | + |
| 177 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 178 | + x = x + self.self_attn(self.layer_norm1(x)) |
| 179 | + x = x + self.mlp(self.layer_norm2(x)) |
| 180 | + return x |
| 181 | + |
| 182 | + def init_weights(self): |
| 183 | + self.layer_norm1.reset_parameters() |
| 184 | + self.layer_norm2.reset_parameters() |
| 185 | + self.self_attn.init_weights() |
| 186 | + self.mlp.init_weights() |
| 187 | + |
| 188 | + |
| 189 | +class VisionTransformer(nn.Module): |
| 190 | + def __init__(self, args: Siglip2ModelArgs): |
| 191 | + super().__init__() |
| 192 | + self.args = args |
| 193 | + self.eos_id = 11 |
| 194 | + |
| 195 | + self.embeddings = VisionEmbeddings(args) |
| 196 | + self.layers = nn.ModuleDict( |
| 197 | + {str(idx): TransformerLayer(args) for idx in range(args.n_layers)} |
| 198 | + ) |
| 199 | + self.post_layernorm = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) |
| 200 | + |
| 201 | + def forward( |
| 202 | + self, |
| 203 | + pixel_values_NLD: torch.FloatTensor, |
| 204 | + pixel_masks_NL: torch.BoolTensor, |
| 205 | + grid_hw: torch.LongTensor, |
| 206 | + ): |
| 207 | + init_attention_mask(pixel_masks_NL, eos_id=self.eos_id) |
| 208 | + |
| 209 | + h = self.embeddings(pixel_values_NLD, grid_hw) |
| 210 | + |
| 211 | + for layer in self.layers.values(): |
| 212 | + h = layer(h) |
| 213 | + h = self.post_layernorm(h) |
| 214 | + |
| 215 | + return h |
| 216 | + |
| 217 | + def init_weights(self): |
| 218 | + self.embeddings.init_weights() |
| 219 | + for layer in self.layers.values(): |
| 220 | + layer.init_weights() |
| 221 | + self.post_layernorm.reset_parameters() |
0 commit comments