@@ -786,7 +786,7 @@ def _gen_model_input(
786786 if image_prompts and isinstance (image_prompts [0 ], str ):
787787 images = [Image .open (image_prompts [0 ])]
788788 else :
789- images = image_prompts
789+ images = None
790790
791791 assert (
792792 max_new_tokens is not None
@@ -796,7 +796,19 @@ def _gen_model_input(
796796 messages = []
797797 for message in prompt :
798798 if isinstance (message ["content" ], str ):
799- messages .append (Message (** message ))
799+ if not image_found and image_prompts :
800+ messages .append (
801+ Message (
802+ role = message ["role" ],
803+ content = [
804+ {"type" : "image" , "content" : images [0 ]},
805+ {"type" : "text" , "content" : message ["content" ]},
806+ ],
807+ )
808+ )
809+ image_found = True
810+ else :
811+ messages .append (Message (** message ))
800812
801813 elif isinstance (message ["content" ], list ):
802814 images = None
@@ -816,7 +828,7 @@ def _gen_model_input(
816828
817829 is_multimodal = images is not None
818830 content = [{"type" : "text" , "content" : prompt_arg }]
819-
831+
820832 if is_multimodal :
821833 content = [{"type" : "image" , "content" : images [0 ]}] + content
822834
@@ -826,6 +838,7 @@ def _gen_model_input(
826838 content = content ,
827839 )
828840 )
841+
829842 messages .append (
830843 Message (
831844 role = "assistant" ,
@@ -929,7 +942,7 @@ def chat(
929942 text_transformer_args .max_seq_length if text_transformer_args else 2048
930943 )
931944 encoded , batch = self ._gen_model_input (
932- generator_args .prompt ,
945+ [{ "role" : "user" , "content" : generator_args .prompt }] ,
933946 generator_args .image_prompts ,
934947 generator_args .max_new_tokens ,
935948 max_seq_length ,
@@ -945,16 +958,16 @@ def chat(
945958 if get_system_prompt == "y" or get_system_prompt == "Y" :
946959 self .system_prompt = input ("What is your system prompt? \n " )
947960
948- elif not generator_args .is_torchtune_model :
949- max_seq_length = min (
950- encoded .size (0 ) + generator_args .max_new_tokens ,
951- (
952- text_transformer_args .block_size
953- if text_transformer_args is not None
954- else 2048
955- ),
956- max_seq_length ,
957- )
961+ # elif not generator_args.is_torchtune_model:
962+ # max_seq_length = min(
963+ # encoded.size(0) + generator_args.max_new_tokens,
964+ # (
965+ # text_transformer_args.block_size
966+ # if text_transformer_args is not None
967+ # else 2048
968+ # ),
969+ # max_seq_length,
970+ # )
958971
959972 max_seq_length = (
960973 max_seq_length + self .speculative_builder_args .speculate_k + 1
0 commit comments