@@ -359,16 +359,16 @@ def prefill(
359359        if  batch  is  not None :
360360            # TODO: Verify sequential prefill works with multimodal models 
361361            tokens  =  batch ["tokens" ]
362-             if  'encoder_input'  in  tokens :
363-                 encoder_input  =  tokens ['encoder_input' ]
362+             if  'encoder_input'  in  batch :
363+                 encoder_input  =  batch ['encoder_input' ]
364364            else :
365365                encoder_input  =  None 
366-             
366+ 
367+             seq_len  =  tokens .size (1 )
367368            mask  =  batch ["causal_mask" ][None , :seq_len ]
368-             input_pos  =  batch ["input_pos" ][None , :seq_len ]
369369            encoder_mask  =  batch ["encoder_mask" ]
370- 
371-             logits  =  model (tokens = tokens , mask = mask , encoder_input = encoder_input , input_pos = input_post , encoder_mask = encoder_mask )[:, - 1 ]
370+              input_pos   =   input_pos . view ( 1 ,  - 1 ) 
371+             logits  =  model (tokens = tokens , mask = mask , encoder_input = encoder_input , input_pos = input_pos , encoder_mask = encoder_mask )[:, - 1 ]
372372            return  tune_sample (logits , temperature = 0 , top_k = 500 )
373373        elif  sequential_prefill :
374374            for  i  in  range (width ):
@@ -604,7 +604,7 @@ def generate(
604604                    self .is_torchtune_model 
605605                    or  self .model .config .model_type  ==  ModelType .Flamingo 
606606                ):
607-                     model .setup_caches (batch_size = 1 , dtype = self .dtype , encoder_max_seq_len = 6404 , decoder_max_seq_len = T_new )
607+                     model .setup_caches (batch_size = 1 , dtype = self .dtype , encoder_max_seq_len = 6404 , decoder_max_seq_len = max_seq_length - 1 )
608608                else :
609609                    model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
610610                if  is_speculative  and  draft_model  is  not model :
@@ -753,18 +753,19 @@ def chat(
753753            ]
754754
755755            transform  =  flamingo_transform (str (self .tokenizer_args .tokenizer_path ))
756-             data  =  transform ({"messages" : messages }, inference = True )
757-             batch  =  padded_collate_tiled_images_and_mask ([data ], pad_direction = "left" , pad_max_images = 1 )
758-             seq_len  =  len (data ["tokens" ])
759-             total_response_length  =  seq_len  +  generator_args .max_new_tokens 
760-             batch ["causal_mask" ] =  torch .tril (
761-                                         torch .ones (
762-                                             size = (total_response_length , total_response_length ),
763-                                             dtype = torch .bool ,
756+ 
757+             with  torch .device (device = self .builder_args .device ):
758+                 data  =  transform ({"messages" : messages }, inference = True )
759+                 batch  =  padded_collate_tiled_images_and_mask ([data ], pad_direction = "left" , pad_max_images = 1 )
760+                 seq_len  =  len (data ["tokens" ])
761+                 batch ["causal_mask" ] =  torch .tril (
762+                                             torch .ones (
763+                                                 size = (generator_args .max_new_tokens , generator_args .max_new_tokens ),
764+                                                 dtype = torch .bool ,
765+                                             )
764766                                        )
765-                                     )
766-             batch ["encoder_mask" ] =  batch ["encoder_mask" ][:, :seq_len ]
767-             encoded  =  batch ["tokens" ]
767+                 batch ["encoder_mask" ] =  batch ["encoder_mask" ][:, :seq_len ]
768+                 encoded  =  batch ["tokens" ]
768769
769770        else :
770771            encoded  =  self .encode_tokens (
0 commit comments