1- # Copyright (c) Meta Platforms, Inc. and affiliates.
2- # All rights reserved.
3-
4- # This source code is licensed under the license found in the
5- # LICENSE file in the root directory of this source tree.
61import json
72import os
83import warnings
4+ from abc import ABC , abstractmethod
95
106from dataclasses import dataclass
117from enum import Enum
128from pathlib import Path
139
1410from typing import Any , Callable , Dict , Optional , Union
15- from abc import ABC , abstractmethod
1611
1712import torch
1813import torch .nn as nn
1914
2015from torch import Tensor
21- from torch .distributed ._tensor import Replicate , Shard , DTensor
16+ from torch .distributed ._tensor import DTensor , Replicate , Shard
2217from torch .distributed .device_mesh import DeviceMesh
2318from torch .distributed .tensor .parallel import (
2419 ColwiseParallel ,
3126from torchchat .utils .build_utils import find_multiple , get_precision
3227
3328from torchtune .models .flamingo import flamingo_decoder , flamingo_vision_encoder
34- from torchtune .modules .model_fusion import DeepFusionModel
3529from torchtune .models .llama3_1 ._component_builders import llama3_1 as llama3_1_builder
30+ from torchtune .modules .model_fusion import DeepFusionModel
3631
3732config_path = Path (f"{ str (Path (__file__ ).parent )} /model_params" )
3833
34+
3935def identity (** kwargs ):
4036 if len (kwargs ) != 1 :
4137 raise ValueError ("Only one argument is expected" )
@@ -56,8 +52,17 @@ def forward(self, image_features):
5652 hidden_states = self .linear_2 (hidden_states )
5753 return hidden_states
5854
55+
5956class ConcateFusion (nn .Module ):
60- def __init__ (self , encoder : nn .Module , decoder : nn .Module , token_embedding_name = "tok_embeddings" , mm_proj_in_channels = 1024 , mm_proj_out_channels = 4096 , mm_proj_activation = nn .GELU ):
57+ def __init__ (
58+ self ,
59+ encoder : nn .Module ,
60+ decoder : nn .Module ,
61+ token_embedding_name = "tok_embeddings" ,
62+ mm_proj_in_channels = 1024 ,
63+ mm_proj_out_channels = 4096 ,
64+ mm_proj_activation = nn .GELU ,
65+ ):
6166 super ().__init__ ()
6267 self .encoder = encoder
6368 self .decoder = decoder
@@ -67,36 +72,53 @@ def __init__(self, encoder: nn.Module, decoder: nn.Module, token_embedding_name=
6772 self .tok_embeddings = getattr (self .decoder , token_embedding_name )
6873
6974 # set the embedding layer in decoder to None to jump the embedding layer over in decoder
70- self .decoder .__setattr__ (token_embedding_name ) = None
75+ self .decoder .__setattr__ (token_embedding_name , None )
7176
72- self .mm_projector = MultiModalProjector (ProjectorArgs (in_channels = mm_proj_in_channels , out_channels = mm_proj_out_channels , activation = mm_proj_activation ))
77+ self .mm_projector = MultiModalProjector (
78+ ProjectorArgs (
79+ in_channels = mm_proj_in_channels ,
80+ out_channels = mm_proj_out_channels ,
81+ activation = mm_proj_activation ,
82+ )
83+ )
7384
74- def forward (self ,
85+ def forward (
86+ self ,
7587 tokens : Tensor ,
7688 * ,
7789 post_tokens : Optional [Tensor ] = None ,
7890 encoder_input : Optional [Tensor ] = None ,
7991 encoder_mask : Optional [torch .Tensor ] = None ,
80- input_pos : Optional [torch .Tensor ] = None ,) -> Tensor :
92+ input_pos : Optional [torch .Tensor ] = None ,
93+ ) -> Tensor :
8194 if encoder_input :
8295 encoder_output = self .encoder (
8396 encoder_input ,
8497 )
8598 else :
8699 encoder_output = None
87-
88- decoder_input = self ._get_decoder_input (tokens , encoder_input = encoder_input , post_tokens = post_tokens )
100+
101+ decoder_input = self ._get_decoder_input (
102+ tokens , encoder_input = encoder_input , post_tokens = post_tokens
103+ )
89104 return self .decoder (decoder_input )
90105
91- def _get_decoder_input (self , tokens : Tensor , * , encoder_input : Optional [Tensor ], post_tokens : Optional [Tensor ]):
92- assert bool (encoder_input ) == bool (post_tokens ), "encoder_input and post_tokens must be both None or not None"
106+ def _get_decoder_input (
107+ self ,
108+ tokens : Tensor ,
109+ * ,
110+ encoder_input : Optional [Tensor ],
111+ post_tokens : Optional [Tensor ],
112+ ):
113+ assert bool (encoder_input ) == bool (
114+ post_tokens
115+ ), "encoder_input and post_tokens must be both None or not None"
93116 if encoder_input is None :
94117 return self .tok_embeddings (tokens )
95118 else :
96119 pre_img_embed = self .tok_embeddings (tokens )
97120 post_img_embed = self .tok_embeddings (post_tokens )
98121 return torch .cat ((pre_img_embed , image_embeds , post_img_embed ), dim = 1 )
99-
100122
101123
102124
0 commit comments