@@ -128,6 +128,14 @@ def forward(
128128        decoder_input  =  self ._get_decoder_input (
129129            tokens , encoder_output = encoder_output , post_tokens = post_tokens 
130130        )
131+ 
132+         if  input_pos  is  None :
133+             input_pos  =  torch .arange (
134+                 decoder_input .shape [1 ],
135+                 device = decoder_input .device ,
136+                 dtype = torch .int ,
137+             )
138+ 
131139        return  self .decoder (decoder_input , input_pos = input_pos )
132140
133141    def  setup_caches (self , batch_size , max_seq_len ):
@@ -977,98 +985,3 @@ def setup_caches(self, max_batch_size, max_seq_length):
977985
978986except :
979987    pass 
980- 
981- 
982- if  __name__  ==  "__main__" :
983-     def  prepare_image (target_h : int , target_w : int ) ->  torch .Tensor :
984-         """Read image into a tensor and resize the image so that it fits in 
985-         a target_h x target_w canvas. 
986- 
987-         Args: 
988-             image (Image): An Image object. 
989-             target_h (int): Target height. 
990-             target_w (int): Target width. 
991- 
992-         Returns: 
993-             torch.Tensor: resized image tensor. 
994-         """ 
995-         image  =  Image .open (
996-             requests .get (
997-                 "https://llava-vl.github.io/static/images/view.jpg" , stream = True 
998-             ).raw )
999- 
1000-         img  =  torchvision .transforms .functional .pil_to_tensor (image )
1001-         # height ratio 
1002-         ratio_h  =  img .shape [1 ] /  target_h 
1003-         # width ratio 
1004-         ratio_w  =  img .shape [2 ] /  target_w 
1005-         # resize the image so that it fits in a target_h x target_w canvas 
1006-         ratio  =  max (ratio_h , ratio_w )
1007-         output_size  =  (int (img .shape [1 ] /  ratio ), int (img .shape [2 ] /  ratio ))
1008-         img  =  torchvision .transforms .Resize (size = output_size )(img )
1009-         return  img 
1010-     
1011- 
1012-     def  image_preprocess (img : torch .Tensor , target_h : int , target_w : int , rescale_factor , image_mean , image_std ) ->  torch .Tensor :
1013-         # pad the image with median rgb value, to make a square 
1014-         l_pad  =  (target_w  -  img .shape [2 ]) //  2 
1015-         t_pad  =  (target_h  -  img .shape [1 ]) //  2 
1016-         # ceil division 
1017-         r_pad  =  - ((target_w  -  img .shape [2 ]) //  - 2 )
1018-         b_pad  =  - ((target_h  -  img .shape [1 ]) //  - 2 )
1019- 
1020-         torch ._check (l_pad  >=  0 )
1021-         torch ._check (t_pad  >=  0 )
1022-         torch ._check (r_pad  >=  0 )
1023-         torch ._check (b_pad  >=  0 )
1024- 
1025-         # This is different from the original implementation, due to export limitations. 
1026-         resized  =  torch .nn .functional .pad (
1027-             img ,
1028-             (l_pad , r_pad , t_pad , b_pad ),
1029-         )
1030-         # originally: 
1031-         # resized = F.pad( 
1032-         #     img, 
1033-         #     padding=(l_pad, t_pad, r_pad, b_pad), 
1034-         #     fill=tuple(int(x * 255) for x in self.image_mean), 
1035-         # ) 
1036- 
1037-         # TODO: implement _upsample_bicubic_aa.out in portable kernel library. 
1038-         # here padded shape should be max(h, w) x max(h, w) 
1039-         # skipping resize for now due to missing _upsample_bicubic_aa kernel in portable 
1040-         # resized = resize( 
1041-         #     padded, 
1042-         #     size=[ 
1043-         #         self.image_processor.crop_size["height"], 
1044-         #         self.image_processor.crop_size["width"], 
1045-         #     ], 
1046-         #     interpolation="bicubic", 
1047-         # ) 
1048-         # torch._check(resized.size(1) == self.config.crop_size["height"]) 
1049-         # torch._check(resized.size(2) == self.config.crop_size["width"]) 
1050-         # print(resized.shape) 
1051-         # cropped = F.center_crop(img, output_size=[w, w]) 
1052-         # print(cropped.shape) 
1053-         scaled  =  resized  *  rescale_factor 
1054-         # print(scaled) 
1055-         from  torchvision .transforms .v2  import  functional  as  tvF 
1056-         normed  =  tvF .normalize (
1057-             scaled , image_mean , image_std 
1058-         )
1059-         # print(normed) 
1060-         return  normed .unsqueeze (0 )
1061- 
1062-     pre_tokens  =  torch .tensor ([[    1 ,   319 , 13563 ,  1546 ,   263 , 12758 ,  5199 ,   322 ,   385 , 23116 ,
1063-          21082 , 20255 , 29889 ,   450 , 20255 ,  4076 ,  8444 , 29892 , 13173 , 29892 ,
1064-            322 ,  1248 ,   568 ,  6089 ,   304 ,   278 ,  5199 , 29915 , 29879 ,  5155 ,
1065-          29889 ,  3148 ,  1001 , 29901 , 29871 ]])
1066-     img  =  prepare_image (336 , 336 )
1067-     post_tokens  =  torch .tensor ([[29871 ,    13 ,   462 ,  9651 ,  1724 ,   526 ,   278 ,  2712 ,   306 ,   881 ,
1068-            367 ,   274 ,  1300 ,  2738 ,  1048 ,   746 ,   306 ,  6493 ,  1244 , 29973 ,
1069-            319 ,  1799 ,  9047 , 13566 , 29901 ]])
1070-     
1071-     llava_model  =  Model .from_params ("/home/gasoonjia/torchchat/torchchat/model_params/llava-1.5.json" )
1072-     llava_model .setup_caches (1 , 2048 )
1073-     img  =  image_preprocess (img = img , target_h = 336 , target_w = 336 , image_mean = [0.48145466 , 0.4578275 , 0.40821073 ], image_std = [0.26862954 , 0.26130258 , 0.27577711 ], rescale_factor = 0.00392156862745098 )
1074-     llava_model (tokens = pre_tokens , encoder_input = img , post_tokens = post_tokens )
0 commit comments