@@ -42,40 +42,64 @@ def identity(**kwargs):
4242 return list (kwargs .values ())[0 ]
4343
4444
45+ class MultiModalProjector (nn .Module ):
46+ def __init__ (self , args : ProjectorArgs ):
47+ super ().__init__ ()
48+
49+ self .linear_1 = nn .Linear (args .in_channels , args .out_channels , bias = True )
50+ self .act = args .activation
51+ self .linear_2 = nn .Linear (args .out_channels , args .out_channels , bias = True )
52+
53+ def forward (self , image_features ):
54+ hidden_states = self .linear_1 (image_features )
55+ hidden_states = self .act (hidden_states )
56+ hidden_states = self .linear_2 (hidden_states )
57+ return hidden_states
58+
4559class ConcateFusion (nn .Module ):
46- def __init__ (self , encoder : nn .Module , decoder : nn .Module ):
60+ def __init__ (self , encoder : nn .Module , decoder : nn .Module , token_embedding_name = "tok_embeddings" , mm_proj_in_channels = 1024 , mm_proj_out_channels = 4096 , mm_proj_activation = nn . GELU ):
4761 super ().__init__ ()
4862 self .encoder = encoder
4963 self .decoder = decoder
5064
65+ # esclate the embedding layer outside decoder llava model need to fuse
66+ # the text and image embedding together before passing to decoder.
67+ self .tok_embeddings = getattr (self .decoder , token_embedding_name )
68+
69+ # set the embedding layer in decoder to None to jump the embedding layer over in decoder
70+ self .decoder .__setattr__ (token_embedding_name ) = None
71+
72+ self .mm_projector = MultiModalProjector (ProjectorArgs (in_channels = mm_proj_in_channels , out_channels = mm_proj_out_channels , activation = mm_proj_activation ))
73+
5174 def forward (self ,
5275 tokens : Tensor ,
5376 * ,
5477 post_tokens : Optional [Tensor ] = None ,
55- mask : Optional [torch .Tensor ] = None ,
5678 encoder_input : Optional [Tensor ] = None ,
5779 encoder_mask : Optional [torch .Tensor ] = None ,
5880 input_pos : Optional [torch .Tensor ] = None ,) -> Tensor :
59- # split prompt from img tag into before img and after img
60- # concate before img, image result and after img into a large prompt
61- # forward that to text transformer
62- # resturn result
63-
6481 if encoder_input :
6582 encoder_output = self .encoder (
6683 encoder_input ,
6784 )
85+ else :
86+ encoder_output = None
87+
88+ decoder_input = self ._get_decoder_input (tokens , encoder_input = encoder_input , post_tokens = post_tokens )
89+ return self .decoder (decoder_input )
6890
69- def _gen_mm_embedding (self , tokens : Tensor , * , encoder_input : Optional [Tensor ], post_tokens : Optional [Tensor ]):
91+ def _get_decoder_input (self , tokens : Tensor , * , encoder_input : Optional [Tensor ], post_tokens : Optional [Tensor ]):
7092 assert bool (encoder_input ) == bool (post_tokens ), "encoder_input and post_tokens must be both None or not None"
7193 if encoder_input is None :
72- return tokens
94+ return self .tok_embeddings (tokens )
95+ else :
96+ pre_img_embed = self .tok_embeddings (tokens )
97+ post_img_embed = self .tok_embeddings (post_tokens )
98+ return torch .cat ((pre_img_embed , image_embeds , post_img_embed ), dim = 1 )
7399
74100
75101
76102
77-
78-
79103class ModelType (Enum ):
80104 TextOnly = "text_only"
81105 Llama3_1 = "llama3_1"
0 commit comments