1111from dataclasses import dataclass
1212from enum import Enum
1313from pathlib import Path
14- from PIL import Image
15- import requests
14+
1615import torchvision
1716
1817from typing import Any , Callable , Dict , Optional , Union
3534from torchtune .models .flamingo import flamingo_decoder , flamingo_vision_encoder
3635from torchtune .models .llama3_1 ._component_builders import llama3_1 as llama3_1_builder
3736from torchtune .modules .model_fusion import DeepFusionModel
37+ from torchtune .models .clip import clip_vision_encoder
3838
3939from torchchat .utils .build_utils import find_multiple , get_precision
4040
41- from torchtune .models .flamingo import flamingo_decoder , flamingo_vision_encoder
42- from torchtune .models .llama3_1 ._component_builders import llama3_1 as llama3_1_builder
43- from torchtune .modules .model_fusion import DeepFusionModel
44- from torchtune .models .clip import clip_vision_encoder
45-
4641config_path = Path (f"{ str (Path (__file__ ).parent )} /model_params" )
4742
4843
@@ -61,19 +56,13 @@ def identity(**kwargs):
6156 return list (kwargs .values ())[0 ]
6257
6358
64- @dataclass
65- class ProjectorArgs :
66- in_channels : int = 1024
67- out_channels : int = 4096
68- activation : nn .Module = nn .GELU ()
69-
7059
7160class MultiModalProjector (nn .Module ):
72- def __init__ (self , args : ProjectorArgs ):
61+ def __init__ (self , in_channels : int , out_channels : int , act : nn . Module ):
7362 super ().__init__ ()
7463
7564 self .linear_1 = nn .Linear (args .in_channels , args .out_channels , bias = True )
76- self .act = args . activation
65+ self .act = act
7766 self .linear_2 = nn .Linear (args .out_channels , args .out_channels , bias = True )
7867
7968 def forward (self , image_features ):
@@ -105,11 +94,9 @@ def __init__(
10594 self .decoder .__setattr__ (token_embedding_name , None )
10695
10796 self .mm_projector = MultiModalProjector (
108- ProjectorArgs (
10997 in_channels = mm_proj_in_channels ,
11098 out_channels = mm_proj_out_channels ,
111- activation = mm_proj_activation ,
112- )
99+ act = mm_proj_activation ,
113100 )
114101
115102 def forward (
@@ -123,9 +110,7 @@ def forward(
123110 ) -> Tensor :
124111 if encoder_input is not None :
125112 encoder_input = encoder_input .view (1 , 1 , * encoder_input .shape )
126- encoder_output = self .encoder (
127- encoder_input ,
128- )
113+ encoder_output = self .encoder (encoder_input )
129114 encoder_output = self ._encoder_feature_select (encoder_output )
130115 else :
131116 encoder_output = None
@@ -143,10 +128,10 @@ def forward(
143128
144129 return self .decoder (decoder_input , input_pos = input_pos )
145130
146- def setup_caches (self , batch_size , max_seq_len ):
131+ def setup_caches (self , batch_size , max_seq_len ) -> None :
147132 self .decoder .setup_caches (batch_size , max_seq_len )
148133
149- def _encoder_feature_select (self , encoder_output ):
134+ def _encoder_feature_select (self , encoder_output ) -> Tensor :
150135 selected_image_feature = encoder_output [1 ][0 ].view (
151136 * encoder_output [1 ][0 ].shape [2 :]
152137 )
@@ -160,7 +145,7 @@ def _get_decoder_input(
160145 * ,
161146 encoder_output : Optional [Tensor ],
162147 post_tokens : Optional [Tensor ],
163- ):
148+ ) -> Tensor :
164149 if encoder_output is None :
165150 assert post_tokens is None
166151 return self .tok_embeddings (tokens )
@@ -245,16 +230,17 @@ def _llava(cls):
245230
246231 @classmethod
247232 def get_recipe (cls , model_type ):
248- if model_type == ModelType .TextOnly :
249- return cls ._text_only ()
250- elif model_type == ModelType .Flamingo :
251- return cls ._flamingo ()
252- elif model_type == ModelType .Llama3_1 :
253- return cls ._llama3_1 ()
254- elif model_type == ModelType .Llava :
255- return cls ._llava ()
256- else :
257- raise ValueError (f"Can not find the model recipe for { model_type } " )
233+ match model_type :
234+ case ModelType .TextOnly :
235+ return cls ._text_only ()
236+ case ModelType .Flamingo :
237+ return cls ._flamingo ()
238+ case ModelType .Llama3_1 :
239+ return cls ._llama3_1 ()
240+ case ModelType .Llava :
241+ return cls ._llava ()
242+ case _:
243+ raise ValueError (f"Can not find the model recipe for { model_type } " )
258244
259245
260246@dataclass
@@ -475,7 +461,7 @@ def build_model(self) -> nn.Module:
475461
476462 return recipe .fusion_class (** modules )
477463
478- def _replace_know_params (self , params ):
464+ def _replace_known_params (self , params ):
479465 patterns = {"QuickGELUActivation()" : QuickGELUActivation ()}
480466 for key , value in params .items ():
481467 if isinstance (value , Hashable ) and value in patterns :
0 commit comments