2828from torchtune .models .flamingo import flamingo_decoder , flamingo_vision_encoder
2929from torchtune .models .llama3_1 ._component_builders import llama3_1 as llama3_1_builder
3030from torchtune .modules .model_fusion import DeepFusionModel
31+ from torchtune .models .clip import clip_vision_encoder
3132
3233config_path = Path (f"{ str (Path (__file__ ).parent )} /model_params" )
3334
3435
36+ class QuickGELUActivation (nn .Module ):
37+ """
38+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
39+ """
40+
41+ def forward (self , input ):
42+ return input * torch .sigmoid (1.702 * input )
43+
44+
3545def identity (** kwargs ):
3646 if len (kwargs ) != 1 :
3747 raise ValueError ("Only one argument is expected" )
@@ -99,24 +109,25 @@ def forward(
99109 encoder_output = None
100110
101111 decoder_input = self ._get_decoder_input (
102- tokens , encoder_input = encoder_input , post_tokens = post_tokens
112+ tokens , encoder_output = encoder_output , post_tokens = post_tokens
103113 )
104114 return self .decoder (decoder_input )
105115
106116 def _get_decoder_input (
107117 self ,
108118 tokens : Tensor ,
109119 * ,
110- encoder_input : Optional [Tensor ],
120+ encoder_output : Optional [Tensor ],
111121 post_tokens : Optional [Tensor ],
112122 ):
113- assert bool (encoder_input ) == bool (
123+ assert bool (encoder_output ) == bool (
114124 post_tokens
115125 ), "encoder_input and post_tokens must be both None or not None"
116- if encoder_input is None :
126+ if encoder_output is None :
117127 return self .tok_embeddings (tokens )
118128 else :
119129 pre_img_embed = self .tok_embeddings (tokens )
130+ image_embeds = self .mm_projector (encoder_output )
120131 post_img_embed = self .tok_embeddings (post_tokens )
121132 return torch .cat ((pre_img_embed , image_embeds , post_img_embed ), dim = 1 )
122133
@@ -261,7 +272,7 @@ class ModelArgs:
261272
262273 def __init__ (
263274 self ,
264- transformer_args : Union [TransformerArgs , Dict [str , TransformerArgs ]],
275+ transformer_args : Union [TransformerArgs , Dict [str , Dict [ str , Any ] ]],
265276 model_type : ModelType = ModelType .TextOnly ,
266277 ) -> None :
267278 self ._sanity_check (transformer_args , model_type )
@@ -275,7 +286,7 @@ def __init__(
275286
276287 def _sanity_check (
277288 self ,
278- transformer_args : Union [TransformerArgs , Dict [str , TransformerArgs ]],
289+ transformer_args : Union [TransformerArgs , Dict [str , Dict [ str , Any ] ]],
279290 model_type : ModelType ,
280291 ) -> None :
281292 assert isinstance (model_type , ModelType )
@@ -393,12 +404,20 @@ def build_model(self) -> nn.Module:
393404 modules = {}
394405 for name , module_class in recipe .modules .items ():
395406 if isinstance (config_args := self .config .transformer_args [name ], dict ):
407+ config_args = self ._replace_know_params (config_args )
396408 modules [name ] = module_class (** config_args )
397409 else :
398410 modules [name ] = module_class (config_args )
399411
400412 return recipe .fusion_class (** modules )
401413
414+ def _replace_know_params (self , params ):
415+ patterns = {"QuickGELUActivation()" : QuickGELUActivation (), "False" : False , "True" : True }
416+ for key , value in params .items ():
417+ if value in patterns :
418+ params [key ] = patterns [value ]
419+ return params
420+
402421 @abstractmethod
403422 def forward (self , * args , ** kwargs ):
404423 raise NotImplementedError ("forward method is not implemented" )
0 commit comments