@@ -734,13 +734,7 @@ def _callback(self, x, *, buffer, done_generating):
734734 if len (buffer ) == 4 or done_generating :
735735 print ("" .join (buffer ), end = "" , flush = True )
736736 buffer .clear ()
737- # print(, end='', flush=True)
738-
739- def print_m (self , message ):
740- print (
741- message .role ,
742- [t ["type" ] if t ["type" ] != "text" else t for t in message .content ],
743- )
737+ print (, end = '' , flush = True )
744738
745739 def _gen_model_input (
746740 self ,
@@ -764,7 +758,7 @@ def _gen_model_input(
764758 Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
765759 """
766760
767- # Not Llama 3.2 11B
761+ # Text-Only model
768762 if self .model .config .model_type != ModelType .Flamingo :
769763 # Single String prompt
770764 if isinstance (prompt , str ):
@@ -819,7 +813,7 @@ def _gen_model_input(
819813
820814 is_multimodal = images is not None
821815 content = [{"type" : "text" , "content" : prompt_arg }]
822-
816+ []
823817 if is_multimodal :
824818 content = [{"type" : "image" , "content" : images [0 ]}] + content
825819
@@ -830,27 +824,24 @@ def _gen_model_input(
830824 )
831825 )
832826
833- print ("MESSAGE CONTENTS:" )
834- messages .append (Message (role = "assistant" , content = "" ))
835- [self .print_m (m ) for m in messages ]
836-
837827 transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
838828
839829 device = torch .device (device = self .builder_args .device )
840830
841831 with device , set_default_dtype (self .dtype ):
842832 data = transform ({"messages" : messages }, inference = True )
843833
844- if is_multimodal :
834+ if image_found :
845835 batch = padded_collate_tiled_images_and_mask (
846836 [data ], pad_direction = "left" , pad_max_images = 1
847837 )
848838 encoded = batch .pop ("tokens" ).to (device ).view (- 1 )
849- seq_len = encoded .size (0 )
839+ seq_len = encoded .size (0 )
850840 batch ["encoder_mask" ] = batch ["encoder_mask" ][:, :seq_len ]
851841 batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (
852842 self .dtype
853843 )
844+
854845 else :
855846 encoded = torch .tensor (data ["tokens" ], device = device ).view (- 1 )
856847 seq_len = encoded .size (0 )
0 commit comments