2121import  torch ._inductor .config 
2222
2323try :
24-     from  _torchchat_test_script  import  flamingo_transform ,  padded_collate 
24+     from  _torchchat_test_script  import  flamingo_transform 
2525except  ImportError :
2626    pass 
2727
3838from  torchchat .utils .device_info  import  get_device_info 
3939
4040# torchtune model definition dependencies 
41- from  torchtune .data  import  Message 
42- from  torchtune .generation ._generation  import  sample  as  tune_sample 
41+ from  torchtune .data  import  Message , padded_collate_tiled_images_and_mask 
42+ 
43+ from  torchtune .generation  import  sample  as  tune_sample 
4344from  torchtune .models .llama3  import  llama3_tokenizer 
4445from  torchtune .training  import  set_default_dtype 
4546
@@ -357,15 +358,25 @@ def prefill(
357358
358359        if  batch  is  not None :
359360            # TODO: Verify sequential prefill works with multimodal models 
360-             logits  =  model (** batch )[:, - 1 ]
361-             return  tune_sample (logits , 0 , 500 )
361+             tokens  =  batch ["tokens" ]
362+             if  'encoder_input'  in  tokens :
363+                 encoder_input  =  tokens ['encoder_input' ]
364+             else :
365+                 encoder_input  =  None 
366+             
367+             mask  =  batch ["causal_mask" ][None , :seq_len ]
368+             input_pos  =  batch ["input_pos" ][None , :seq_len ]
369+             encoder_mask  =  batch ["encoder_mask" ]
370+ 
371+             logits  =  model (tokens = tokens , mask = mask , encoder_input = encoder_input , input_pos = input_post , encoder_mask = encoder_mask )[:, - 1 ]
372+             return  tune_sample (logits , temperature = 0 , top_k = 500 )
362373        elif  sequential_prefill :
363374            for  i  in  range (width ):
364375                x_sliced , ip_sliced  =  x [:, i ].view (- 1 , 1 ), input_pos [i ].view (- 1 )
365376                # logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}") 
366377                logits  =  model (x_sliced , ip_sliced )  # (x[:, i], input_pos[i]) 
367378        elif  self .model .config .model_type  ==  ModelType .Flamingo :
368-             logits   =   model ( x ) 
379+             assert   False ,  "Flamingo requires batch" 
369380        else :
370381            # input_pos: [B, S] 
371382            logits  =  model (x , input_pos )
@@ -387,10 +398,10 @@ def decode_one_token(
387398        assert  input_pos .shape [- 1 ] ==  1 
388399        x  =  x .view (1 , - 1 )
389400        if  model .config .model_type  ==  ModelType .Flamingo :
390-             if  batch  is  not None : 
391-                  logits   =   model ( x ,  encoder_mask = batch ["encoder_mask " ][:,  - 1 :]) 
392-             else : 
393-                  logits  =  model (x ) 
401+             assert  batch  is  not None ,  "Flamingo requires batch" 
402+             mask   =   batch ["causal_mask " ][None ,  input_pos . item (),  None , :] 
403+             encoder_mask   =   batch [ "encoder_mask" ][:,  - 1 :] 
404+             logits  =  model (x ,  encoder_mask = encoder_mask ,  mask = mask ,  input_pos = input_pos )[:,  - 1 :] 
394405        else :
395406            logits  =  model (x , input_pos )
396407        # print(f"x: {x},\n  input_pos: {input_pos}\n") 
@@ -593,7 +604,7 @@ def generate(
593604                    self .is_torchtune_model 
594605                    or  self .model .config .model_type  ==  ModelType .Flamingo 
595606                ):
596-                     model .setup_caches (max_batch_size = 1 , dtype = self .dtype )
607+                     model .setup_caches (batch_size = 1 , dtype = self .dtype ,  encoder_max_seq_len = 6404 ,  decoder_max_seq_len = T_new )
597608                else :
598609                    model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
599610                if  is_speculative  and  draft_model  is  not model :
@@ -743,8 +754,16 @@ def chat(
743754
744755            transform  =  flamingo_transform (str (self .tokenizer_args .tokenizer_path ))
745756            data  =  transform ({"messages" : messages }, inference = True )
746-             batch  =  padded_collate ([data ], self .builder_args .device )
747-             batch .pop ("mask" )
757+             batch  =  padded_collate_tiled_images_and_mask ([data ], pad_direction = "left" , pad_max_images = 1 )
758+             seq_len  =  len (data ["tokens" ])
759+             total_response_length  =  seq_len  +  generator_args .max_new_tokens 
760+             batch ["causal_mask" ] =  torch .tril (
761+                                         torch .ones (
762+                                             size = (total_response_length , total_response_length ),
763+                                             dtype = torch .bool ,
764+                                         )
765+                                     )
766+             batch ["encoder_mask" ] =  batch ["encoder_mask" ][:, :seq_len ]
748767            encoded  =  batch ["tokens" ]
749768
750769        else :
0 commit comments