@@ -733,66 +733,80 @@ def _callback(self, x, *, buffer, done_generating):
733733 buffer .clear ()
734734 # print(, end='', flush=True)
735735
736- def _gen_model_input (self , prompt : str , image_prompts : Optional [List [str | Image .Image ]] = None , max_new_tokens : Optional [int ] = None ) -> Tuple :
736+ def _gen_model_input (self , prompt : Union [str | List [Any ]], image_prompts : Optional [List [str | Image .Image ]] = None , max_new_tokens : Optional [int ] = None ) -> Tuple :
737+
738+ # Not Llama 3.2 11B
739+ if self .model .config .model_type != ModelType .Flamingo :
740+ # Single String prompt
741+ if isinstance (prompt , str ):
742+ encoded = self .encode_tokens (
743+ prompt , bos = True , device = self .builder_args .device
744+ )
745+ # List of dialog
746+ else :
747+ tokens = self .chat_formatter .encode_dialog_prompt (prompt )
748+ encoded = torch .tensor (
749+ tokens , dtype = torch .int , device = self .builder_args .device
750+ )
751+
752+ logging .debug (encoded )
753+ return encoded , None
754+
755+ # Llama 3.2 11B
737756 assert image_prompts is None or len (image_prompts ) == 1 , "At most one image is supported at the moment"
738757 if image_prompts and isinstance (image_prompts [0 ], str ):
739758 images = [Image .open (image_prompts [0 ])]
740759 else :
741760 images = image_prompts
742761
743- if self . model . config . model_type == ModelType . Flamingo :
744- assert max_new_tokens is not None , "max_new_tokens must be specified for Flamingo models"
762+ assert max_new_tokens is not None , "max_new_tokens must be specified for Flamingo models"
763+ assert isinstance ( prompt , str ) , "(Currently) prompt must be a str for Flamingo models"
745764
746- is_multimodal = images is not None
747- content = [{"type" : "text" , "content" : prompt }]
765+ is_multimodal = images is not None
766+ content = [{"type" : "text" , "content" : prompt }]
748767
749- if is_multimodal :
750- content = [{"type" : "image" , "content" : images [0 ]}] + content
768+ if is_multimodal :
769+ content = [{"type" : "image" , "content" : images [0 ]}] + content
751770
752- messages = [
753- Message (
754- role = "user" ,
755- content = content ,
756- eot = True ,
757- ),
758- Message (role = "assistant" , content = "" ),
759- ]
771+ messages = [
772+ Message (
773+ role = "user" ,
774+ content = content ,
775+ eot = True ,
776+ ),
777+ Message (role = "assistant" , content = "" ),
778+ ]
760779
761- transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
780+ transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
762781
763- device = torch .device (device = self .builder_args .device )
782+ device = torch .device (device = self .builder_args .device )
764783
765- with device , set_default_dtype (self .dtype ):
766- data = transform ({"messages" : messages }, inference = True )
784+ with device , set_default_dtype (self .dtype ):
785+ data = transform ({"messages" : messages }, inference = True )
767786
768- if is_multimodal :
769- batch = padded_collate_tiled_images_and_mask (
770- [data ], pad_direction = "left" , pad_max_images = 1
771- )
772- encoded = batch .pop ("tokens" ).to (device ).view (- 1 )
773- seq_len = encoded .size (0 )
774- batch ["encoder_mask" ] = batch ["encoder_mask" ][:, :seq_len ]
775- batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (self .dtype )
776- else :
777- encoded = torch .tensor (
778- data ["tokens" ], device = device
779- ).view (- 1 )
780- seq_len = encoded .size (0 )
781- batch = {}
782-
783- total_response_length = seq_len + max_new_tokens
784- batch ["causal_mask" ] = torch .tril (
785- torch .ones (
786- size = (total_response_length , total_response_length ),
787- dtype = torch .bool ,
788- )
787+ if is_multimodal :
788+ batch = padded_collate_tiled_images_and_mask (
789+ [data ], pad_direction = "left" , pad_max_images = 1
790+ )
791+ encoded = batch .pop ("tokens" ).to (device ).view (- 1 )
792+ seq_len = encoded .size (0 )
793+ batch ["encoder_mask" ] = batch ["encoder_mask" ][:, :seq_len ]
794+ batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (self .dtype )
795+ else :
796+ encoded = torch .tensor (
797+ data ["tokens" ], device = device
798+ ).view (- 1 )
799+ seq_len = encoded .size (0 )
800+ batch = {}
801+
802+ total_response_length = seq_len + max_new_tokens
803+ batch ["causal_mask" ] = torch .tril (
804+ torch .ones (
805+ size = (total_response_length , total_response_length ),
806+ dtype = torch .bool ,
789807 )
790- else :
791- encoded = self .encode_tokens (
792- prompt , bos = True , device = self .builder_args .device
793- )
794- batch = None
795-
808+ )
809+
796810 logging .debug (encoded )
797811 return encoded , batch
798812
0 commit comments