@@ -655,7 +655,9 @@ def generate(
655655 # max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
656656 callback (next_token .clone ().view (- 1 ), done_generating = max_new_tokens <= 2 )
657657
658- input_pos = torch .tensor ([start_pos + prompt_length ], device = device , dtype = torch .int )
658+ input_pos = torch .tensor (
659+ [start_pos + prompt_length ], device = device , dtype = torch .int
660+ )
659661 accept_counts = [0 ] * (
660662 speculate_k + 1
661663 ) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
@@ -736,12 +738,6 @@ def _callback(self, x, *, buffer, done_generating):
736738 buffer .clear ()
737739 # print(, end='', flush=True)
738740
739- def print_m (self , message ):
740- print (
741- message .role ,
742- [t ["type" ] if t ["type" ] != "text" else t for t in message .content ],
743- )
744-
745741 def _gen_model_input (
746742 self ,
747743 prompt : Union [str | List [Any ]],
@@ -764,7 +760,7 @@ def _gen_model_input(
764760 Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
765761 """
766762
767- # Not Llama 3.2 11B
763+ # Text-Only model
768764 if self .model .config .model_type != ModelType .Flamingo :
769765 # Single String prompt
770766 if isinstance (prompt , str ):
@@ -819,7 +815,7 @@ def _gen_model_input(
819815
820816 is_multimodal = images is not None
821817 content = [{"type" : "text" , "content" : prompt_arg }]
822-
818+ []
823819 if is_multimodal :
824820 content = [{"type" : "image" , "content" : images [0 ]}] + content
825821
@@ -830,18 +826,14 @@ def _gen_model_input(
830826 )
831827 )
832828
833- print ("MESSAGE CONTENTS:" )
834- messages .append (Message (role = "assistant" , content = "" ))
835- [self .print_m (m ) for m in messages ]
836-
837829 transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
838830
839831 device = torch .device (device = self .builder_args .device )
840832
841833 with device , set_default_dtype (self .dtype ):
842834 data = transform ({"messages" : messages }, inference = True )
843835
844- if is_multimodal :
836+ if image_found :
845837 batch = padded_collate_tiled_images_and_mask (
846838 [data ], pad_direction = "left" , pad_max_images = 1
847839 )
@@ -851,6 +843,7 @@ def _gen_model_input(
851843 batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (
852844 self .dtype
853845 )
846+
854847 else :
855848 encoded = torch .tensor (data ["tokens" ], device = device ).view (- 1 )
856849 seq_len = encoded .size (0 )
@@ -883,13 +876,6 @@ def chat(
883876 if generator_args .chat_mode :
884877 print ("Starting Interactive Chat" )
885878
886- encoded , batch = self ._gen_model_input (
887- generator_args .prompt ,
888- generator_args .image_prompts ,
889- generator_args .max_new_tokens ,
890- generator_args .max_seq_length ,
891- )
892-
893879 model_size = sum (
894880 [
895881 p .numel () * p .dtype .itemsize
@@ -935,6 +921,12 @@ def chat(
935921 max_seq_length = (
936922 text_transformer_args .max_seq_length if text_transformer_args else 2048
937923 )
924+ encoded , batch = self ._gen_model_input (
925+ generator_args .prompt ,
926+ generator_args .image_prompts ,
927+ generator_args .max_new_tokens ,
928+ max_seq_length ,
929+ )
938930
939931 if generator_args .chat_mode :
940932 print (
0 commit comments