44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66import argparse
7+ import base64
78import itertools
89import logging
910import os
1213
1314from abc import ABC , abstractmethod
1415from dataclasses import dataclass
16+ from io import BytesIO
1517from os import PathLike
1618from pathlib import Path
1719from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
@@ -101,7 +103,11 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
101103 tokens = self .tokenizer .encode (f"{ B_INST } " )
102104 first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
103105 for message in dialog :
104- content = message ["content" ].strip ()
106+ if isinstance (message ["content" ], list ):
107+ content = message ["content" ][0 ]["text" ]
108+ else :
109+ content = message ["content" ]
110+ content = content .strip ()
105111 if message ["role" ] == "system" :
106112 encoded = self .tokenizer .encode (f"{ B_SYS } \n { content } \n { E_SYS } " )
107113 first_message = False
@@ -138,6 +144,7 @@ class GeneratorArgs:
138144 speculate_k : int = 5
139145 sequential_prefill : bool = False
140146 max_autotune : bool = False
147+ # (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
141148 is_torchtune_model : bool = False
142149
143150 def __post_init__ (self ):
@@ -600,9 +607,8 @@ def generate(
600607
601608 if len (prompt .shape ) > 1 :
602609 prompt = prompt .squeeze (0 )
603- T = prompt .size (0 )
604- max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - T )
605- T_new = T + max_new_tokens
610+ prompt_length = prompt .size (0 )
611+ max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - prompt_length )
606612 # set up caches only if first inference
607613 if start_pos == 0 :
608614 model = model .to (device = device )
@@ -616,7 +622,7 @@ def generate(
616622 batch_size = 1 ,
617623 dtype = self .dtype ,
618624 encoder_max_seq_len = 6404 ,
619- decoder_max_seq_len = T_new ,
625+ decoder_max_seq_len = max_seq_length ,
620626 )
621627 else :
622628 model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
@@ -629,7 +635,7 @@ def generate(
629635 model .reset_caches ()
630636
631637 input_pos = torch .arange (
632- start_pos , T + start_pos , device = device , dtype = torch .int
638+ start_pos , prompt_length + start_pos , device = device , dtype = torch .int
633639 )
634640
635641 prefill_t0 = time .perf_counter ()
@@ -655,7 +661,9 @@ def generate(
655661 # max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
656662 callback (next_token .clone ().view (- 1 ), done_generating = max_new_tokens <= 2 )
657663
658- input_pos = torch .tensor ([start_pos + T ], device = device , dtype = torch .int )
664+ input_pos = torch .tensor (
665+ [start_pos + prompt_length ], device = device , dtype = torch .int
666+ )
659667 accept_counts = [0 ] * (
660668 speculate_k + 1
661669 ) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
@@ -678,7 +686,7 @@ def generate(
678686 )
679687
680688 accept_counts [len (next_tokens ) - 1 ] += 1
681- num_added = min (T_new - input_pos - 1 , len (next_tokens ))
689+ num_added = min (max_new_tokens - input_pos - 1 , len (next_tokens ))
682690 for token in next_tokens [:num_added ,]:
683691 callback (token )
684692 yield token , None
@@ -741,6 +749,7 @@ def _gen_model_input(
741749 prompt : Union [str | List [Any ]],
742750 image_prompts : Optional [List [str | Image .Image ]] = None ,
743751 max_new_tokens : Optional [int ] = None ,
752+ max_seq_len : Optional [int ] = 2048 ,
744753 ) -> Tuple [torch .Tensor , Optional [Dict [str , Any ]]]:
745754 """
746755 Convert prompt and image prompts into consumable model input args.
@@ -757,7 +766,7 @@ def _gen_model_input(
757766 Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
758767 """
759768
760- # Not Llama 3.2 11B
769+ # Text-Only model
761770 if self .model .config .model_type != ModelType .Flamingo :
762771 # Single String prompt
763772 if isinstance (prompt , str ):
@@ -778,32 +787,69 @@ def _gen_model_input(
778787 assert (
779788 image_prompts is None or len (image_prompts ) == 1
780789 ), "At most one image is supported at the moment"
790+
781791 if image_prompts and isinstance (image_prompts [0 ], str ):
782792 images = [Image .open (image_prompts [0 ])]
783793 else :
784- images = image_prompts
794+ images = None
785795
786796 assert (
787797 max_new_tokens is not None
788798 ), "max_new_tokens must be specified for Flamingo models"
789- assert isinstance (
790- prompt , str
791- ), "(Currently) prompt must be a str for Flamingo models"
792799
793- is_multimodal = images is not None
794- content = [{"type" : "text" , "content" : prompt }]
800+ image_found = False
801+ messages = []
802+ for message in prompt :
803+ if isinstance (message ["content" ], str ):
804+ if not image_found and image_prompts :
805+ messages .append (
806+ Message (
807+ role = message ["role" ],
808+ content = [
809+ {"type" : "image" , "content" : images [0 ]},
810+ {"type" : "text" , "content" : message ["content" ]},
811+ ],
812+ )
813+ )
814+ image_found = True
815+ else :
816+ messages .append (Message (** message ))
817+
818+ elif isinstance (message ["content" ], list ):
819+ images = None
820+ for content_dict in message ["content" ]:
821+ if content_dict ["type" ] == "text" :
822+ prompt_arg = content_dict ["text" ]
823+ elif content_dict ["type" ] == "image_url" :
824+ assert (
825+ images is None
826+ ), "At most one image is supported at the moment"
827+
828+ base64_decoded = base64 .b64decode (
829+ content_dict ["image_url" ].split (";base64," )[1 ]
830+ )
831+ images = [Image .open (BytesIO (base64_decoded ))]
832+ image_found = True
833+
834+ is_multimodal = images is not None
835+ content = [{"type" : "text" , "content" : prompt_arg }]
795836
796- if is_multimodal :
797- content = [{"type" : "image" , "content" : images [0 ]}] + content
837+ if is_multimodal :
838+ content = [{"type" : "image" , "content" : images [0 ]}] + content
798839
799- messages = [
840+ messages .append (
841+ Message (
842+ role = message ["role" ],
843+ content = content ,
844+ )
845+ )
846+
847+ messages .append (
800848 Message (
801- role = "user" ,
802- content = content ,
803- eot = True ,
804- ),
805- Message (role = "assistant" , content = "" ),
806- ]
849+ role = "assistant" ,
850+ content = "" ,
851+ )
852+ )
807853
808854 transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
809855
@@ -812,7 +858,7 @@ def _gen_model_input(
812858 with device , set_default_dtype (self .dtype ):
813859 data = transform ({"messages" : messages }, inference = True )
814860
815- if is_multimodal :
861+ if image_found :
816862 batch = padded_collate_tiled_images_and_mask (
817863 [data ], pad_direction = "left" , pad_max_images = 1
818864 )
@@ -822,17 +868,27 @@ def _gen_model_input(
822868 batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (
823869 self .dtype
824870 )
871+
825872 else :
826873 encoded = torch .tensor (data ["tokens" ], device = device ).view (- 1 )
827874 seq_len = encoded .size (0 )
828875 batch = {}
829876
830877 total_response_length = seq_len + max_new_tokens
831- batch ["causal_mask" ] = torch .tril (
832- torch .ones (
833- size = (total_response_length , total_response_length ),
834- dtype = torch .bool ,
835- )
878+ batch ["causal_mask" ] = torch .nn .functional .pad (
879+ torch .tril (
880+ torch .ones (
881+ size = (total_response_length , total_response_length ),
882+ dtype = torch .bool ,
883+ )
884+ ),
885+ (
886+ 0 ,
887+ max_seq_len - total_response_length ,
888+ 0 ,
889+ max_seq_len - total_response_length ,
890+ ),
891+ value = 0 ,
836892 )
837893
838894 logging .debug (encoded )
@@ -845,12 +901,6 @@ def chat(
845901 if generator_args .chat_mode :
846902 print ("Starting Interactive Chat" )
847903
848- encoded , batch = self ._gen_model_input (
849- generator_args .prompt ,
850- generator_args .image_prompts ,
851- generator_args .max_new_tokens ,
852- )
853-
854904 model_size = sum (
855905 [
856906 p .numel () * p .dtype .itemsize
@@ -896,6 +946,12 @@ def chat(
896946 max_seq_length = (
897947 text_transformer_args .max_seq_length if text_transformer_args else 2048
898948 )
949+ encoded , batch = self ._gen_model_input (
950+ [{"role" : "user" , "content" : generator_args .prompt }],
951+ generator_args .image_prompts ,
952+ generator_args .max_new_tokens ,
953+ max_seq_length ,
954+ )
899955
900956 if generator_args .chat_mode :
901957 print (
@@ -907,7 +963,10 @@ def chat(
907963 if get_system_prompt == "y" or get_system_prompt == "Y" :
908964 self .system_prompt = input ("What is your system prompt? \n " )
909965
910- elif not generator_args .is_torchtune_model :
966+ # `is_torchtune_model` is a misnomer since it doesn't capture all
967+ # torchtune models (i.e. Flamingo)
968+ # See Issue: https://github.com/pytorch/torchchat/issues/1273
969+ elif not generator_args .is_torchtune_model and self .model .config .model_type != ModelType .Flamingo :
911970 max_seq_length = min (
912971 encoded .size (0 ) + generator_args .max_new_tokens ,
913972 (
0 commit comments