2121import  torch ._inductor .config 
2222
2323try :
24-     from  _torchchat_test_script  import  flamingo_transform ,  padded_collate 
24+     from  _torchchat_test_script  import  flamingo_transform 
2525except  ImportError :
2626    pass 
2727
3838from  torchchat .utils .device_info  import  get_device_info 
3939
4040# torchtune model definition dependencies 
41- from  torchtune .data  import  Message 
42- from  torchtune .generation ._generation  import  sample  as  tune_sample 
41+ from  torchtune .data  import  Message , padded_collate_tiled_images_and_mask 
42+ 
43+ from  torchtune .generation  import  sample  as  tune_sample 
4344from  torchtune .models .llama3  import  llama3_tokenizer 
4445from  torchtune .training  import  set_default_dtype 
4546
@@ -357,15 +358,25 @@ def prefill(
357358
358359        if  batch  is  not None :
359360            # TODO: Verify sequential prefill works with multimodal models 
360-             logits  =  model (** batch )[:, - 1 ]
361-             return  tune_sample (logits , 0 , 500 )
361+             tokens  =  batch ["tokens" ]
362+             if  'encoder_input'  in  batch :
363+                 encoder_input  =  batch ['encoder_input' ]
364+             else :
365+                 encoder_input  =  None 
366+ 
367+             seq_len  =  tokens .size (1 )
368+             mask  =  batch ["causal_mask" ][None , :seq_len ]
369+             encoder_mask  =  batch ["encoder_mask" ]
370+             input_pos  =  input_pos .view (1 , - 1 )
371+             logits  =  model (tokens = tokens , mask = mask , encoder_input = encoder_input , input_pos = input_pos , encoder_mask = encoder_mask )[:, - 1 ]
372+             return  tune_sample (logits , temperature = 0 , top_k = 500 )
362373        elif  sequential_prefill :
363374            for  i  in  range (width ):
364375                x_sliced , ip_sliced  =  x [:, i ].view (- 1 , 1 ), input_pos [i ].view (- 1 )
365376                # logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}") 
366377                logits  =  model (x_sliced , ip_sliced )  # (x[:, i], input_pos[i]) 
367378        elif  self .model .config .model_type  ==  ModelType .Flamingo :
368-             logits   =   model ( x ) 
379+             assert   False ,  "Flamingo requires batch" 
369380        else :
370381            # input_pos: [B, S] 
371382            logits  =  model (x , input_pos )
@@ -387,10 +398,10 @@ def decode_one_token(
387398        assert  input_pos .shape [- 1 ] ==  1 
388399        x  =  x .view (1 , - 1 )
389400        if  model .config .model_type  ==  ModelType .Flamingo :
390-             if  batch  is  not None : 
391-                  logits   =   model ( x ,  encoder_mask = batch ["encoder_mask " ][:,  - 1 :]) 
392-             else : 
393-                  logits  =  model (x ) 
401+             assert  batch  is  not None ,  "Flamingo requires batch" 
402+             mask   =   batch ["causal_mask " ][None ,  input_pos . item (),  None , :] 
403+             encoder_mask   =   batch [ "encoder_mask" ][:,  - 1 :] 
404+             logits  =  model (x ,  encoder_mask = encoder_mask ,  mask = mask ,  input_pos = input_pos )[:,  - 1 :] 
394405        else :
395406            logits  =  model (x , input_pos )
396407        # print(f"x: {x},\n  input_pos: {input_pos}\n") 
@@ -593,7 +604,8 @@ def generate(
593604                    self .is_torchtune_model 
594605                    or  self .model .config .model_type  ==  ModelType .Flamingo 
595606                ):
596-                     model .setup_caches (max_batch_size = 1 , dtype = self .dtype )
607+                     # 6404 is one-gpu affordable max_seq_length for single image input 
608+                     model .setup_caches (batch_size = 1 , dtype = self .dtype , encoder_max_seq_len = 6404 , decoder_max_seq_len = T_new )
597609                else :
598610                    model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
599611                if  is_speculative  and  draft_model  is  not model :
@@ -742,10 +754,22 @@ def chat(
742754            ]
743755
744756            transform  =  flamingo_transform (str (self .tokenizer_args .tokenizer_path ))
745-             data  =  transform ({"messages" : messages }, inference = True )
746-             batch  =  padded_collate ([data ], self .builder_args .device )
747-             batch .pop ("mask" )
748-             encoded  =  batch ["tokens" ]
757+ 
758+             with  torch .device (device = self .builder_args .device ), set_default_dtype (self .dtype ):
759+                 data  =  transform ({"messages" : messages }, inference = True )
760+                 batch  =  padded_collate_tiled_images_and_mask ([data ], pad_direction = "left" , pad_max_images = 1 )
761+                 # set_default_dtype can not handle the dtype of the image tensor inside the batch; need to manually cast it 
762+                 batch ["encoder_input" ]["images" ] =  batch ["encoder_input" ]["images" ].to (self .dtype )
763+                 seq_len  =  len (data ["tokens" ])
764+                 total_response_length  =  seq_len  +  generator_args .max_new_tokens 
765+                 batch ["causal_mask" ] =  torch .tril (
766+                                             torch .ones (
767+                                                 size = (total_response_length , total_response_length ),
768+                                                 dtype = torch .bool ,
769+                                             )
770+                                         )
771+                 batch ["encoder_mask" ] =  batch ["encoder_mask" ][:, :seq_len ]
772+                 encoded  =  batch ["tokens" ].view (- 1 )
749773
750774        else :
751775            encoded  =  self .encode_tokens (
0 commit comments