@@ -116,7 +116,8 @@ def forward(
116116 encoder_mask : Optional [torch .Tensor ] = None ,
117117 input_pos : Optional [torch .Tensor ] = None ,
118118 ) -> Tensor :
119- if encoder_input :
119+ if encoder_input is not None :
120+ encoder_input = encoder_input .view (1 , 1 , * encoder_input .shape )
120121 encoder_output = self .encoder (
121122 encoder_input ,
122123 )
@@ -223,7 +224,7 @@ def _llava(cls):
223224 'encoder' : clip_vision_encoder ,
224225 'decoder' : Transformer
225226 },
226- fusion_class = DeepFusionModel ,
227+ fusion_class = ConcateFusion ,
227228 )
228229
229230 @classmethod
@@ -968,46 +969,3 @@ def setup_caches(self, max_batch_size, max_seq_length):
968969
969970except :
970971 pass
971-
972-
973- if __name__ == "__main__" :
974- def prepare_image (target_h : int , target_w : int ) -> torch .Tensor :
975- """Read image into a tensor and resize the image so that it fits in
976- a target_h x target_w canvas.
977-
978- Args:
979- image (Image): An Image object.
980- target_h (int): Target height.
981- target_w (int): Target width.
982-
983- Returns:
984- torch.Tensor: resized image tensor.
985- """
986- image = Image .open (
987- requests .get (
988- "https://llava-vl.github.io/static/images/view.jpg" , stream = True
989- ).raw )
990-
991- img = torchvision .transforms .functional .pil_to_tensor (image )
992- # height ratio
993- ratio_h = img .shape [1 ] / target_h
994- # width ratio
995- ratio_w = img .shape [2 ] / target_w
996- # resize the image so that it fits in a target_h x target_w canvas
997- ratio = max (ratio_h , ratio_w )
998- output_size = (int (img .shape [1 ] / ratio ), int (img .shape [2 ] / ratio ))
999- img = torchvision .transforms .Resize (size = output_size )(img )
1000- return img
1001-
1002- pre_tokens = torch .tensor ([[ 1 , 319 , 13563 , 1546 , 263 , 12758 , 5199 , 322 , 385 , 23116 ,
1003- 21082 , 20255 , 29889 , 450 , 20255 , 4076 , 8444 , 29892 , 13173 , 29892 ,
1004- 322 , 1248 , 568 , 6089 , 304 , 278 , 5199 , 29915 , 29879 , 5155 ,
1005- 29889 , 3148 , 1001 , 29901 , 29871 ]])
1006- img = prepare_image (336 , 336 )
1007- post_tokens = torch .tensor ([[29871 , 13 , 462 , 9651 , 1724 , 526 , 278 , 2712 , 306 , 881 ,
1008- 367 , 274 , 1300 , 2738 , 1048 , 746 , 306 , 6493 , 1244 , 29973 ,
1009- 319 , 1799 , 9047 , 13566 , 29901 ]])
1010-
1011- llava_model = Model .from_params ("/home/gasoonjia/torchchat/torchchat/model_params/llava-1.5.json" )
1012-
1013- llava_model (tokens = pre_tokens , encoder_input = img , post_tokens = post_tokens )
0 commit comments