@@ -732,67 +732,106 @@ def _callback(self, x, *, buffer, done_generating):
732732 print ("" .join (buffer ), end = "" , flush = True )
733733 buffer .clear ()
734734 # print(, end='', flush=True)
735-
736- def _gen_model_input (self , prompt : str , image_prompts : Optional [List [str | Image .Image ]] = None , max_new_tokens : Optional [int ] = None ) -> Tuple :
737- assert image_prompts is None or len (image_prompts ) == 1 , "At most one image is supported at the moment"
735+
736+ def _gen_model_input (
737+ self ,
738+ prompt : Union [str | List [Any ]],
739+ image_prompts : Optional [List [str | Image .Image ]] = None ,
740+ max_new_tokens : Optional [int ] = None ,
741+ ) -> Tuple [torch .Tensor , Optional [Dict [str , Any ]]]:
742+ """
743+ Convert prompt and image prompts into consumable model input args.
744+
745+ When prompt is a list, the anticipated format is OpenAI API Inspired:
746+ [ ..., {"role": message["role"], "content": message["content"]}, ...]
747+
748+ Args:
749+ prompt (Union[str, List[Any]]): Prompt or list of dialog.
750+ image_prompts (Optional[List[str | Image.Image]]): List of image prompts. Used only with Llama 3.2 11B.
751+ max_new_tokens (Optional[int]): Maximum number of new tokens to generate. Used only with Llama 3.2 11B.
752+
753+ Returns:
754+ Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
755+ """
756+
757+ # Not Llama 3.2 11B
758+ if self .model .config .model_type != ModelType .Flamingo :
759+ # Single String prompt
760+ if isinstance (prompt , str ):
761+ encoded = self .encode_tokens (
762+ prompt , bos = True , device = self .builder_args .device
763+ )
764+ # List of dialog
765+ else :
766+ tokens = self .chat_formatter .encode_dialog_prompt (prompt )
767+ encoded = torch .tensor (
768+ tokens , dtype = torch .int , device = self .builder_args .device
769+ )
770+
771+ logging .debug (encoded )
772+ return encoded , None
773+
774+ # Llama 3.2 11B
775+ assert (
776+ image_prompts is None or len (image_prompts ) == 1
777+ ), "At most one image is supported at the moment"
738778 if image_prompts and isinstance (image_prompts [0 ], str ):
739779 images = [Image .open (image_prompts [0 ])]
740780 else :
741781 images = image_prompts
742782
743- if self .model .config .model_type == ModelType .Flamingo :
744- assert max_new_tokens is not None , "max_new_tokens must be specified for Flamingo models"
783+ assert (
784+ max_new_tokens is not None
785+ ), "max_new_tokens must be specified for Flamingo models"
786+ assert isinstance (
787+ prompt , str
788+ ), "(Currently) prompt must be a str for Flamingo models"
745789
746- is_multimodal = images is not None
747- content = [{"type" : "text" , "content" : prompt }]
790+ is_multimodal = images is not None
791+ content = [{"type" : "text" , "content" : prompt }]
748792
749- if is_multimodal :
750- content = [{"type" : "image" , "content" : images [0 ]}] + content
793+ if is_multimodal :
794+ content = [{"type" : "image" , "content" : images [0 ]}] + content
751795
752- messages = [
753- Message (
754- role = "user" ,
755- content = content ,
756- eot = True ,
757- ),
758- Message (role = "assistant" , content = "" ),
759- ]
796+ messages = [
797+ Message (
798+ role = "user" ,
799+ content = content ,
800+ eot = True ,
801+ ),
802+ Message (role = "assistant" , content = "" ),
803+ ]
760804
761- transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
805+ transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
762806
763- device = torch .device (device = self .builder_args .device )
807+ device = torch .device (device = self .builder_args .device )
764808
765- with device , set_default_dtype (self .dtype ):
766- data = transform ({"messages" : messages }, inference = True )
809+ with device , set_default_dtype (self .dtype ):
810+ data = transform ({"messages" : messages }, inference = True )
767811
768- if is_multimodal :
769- batch = padded_collate_tiled_images_and_mask (
770- [data ], pad_direction = "left" , pad_max_images = 1
771- )
772- encoded = batch .pop ("tokens" ).to (device ).view (- 1 )
773- seq_len = encoded .size (0 )
774- batch ["encoder_mask" ] = batch ["encoder_mask" ][:, :seq_len ]
775- batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (self .dtype )
776- else :
777- encoded = torch .tensor (
778- data ["tokens" ], device = device
779- ).view (- 1 )
780- seq_len = encoded .size (0 )
781- batch = {}
782-
783- total_response_length = seq_len + max_new_tokens
784- batch ["causal_mask" ] = torch .tril (
785- torch .ones (
786- size = (total_response_length , total_response_length ),
787- dtype = torch .bool ,
788- )
789- )
790- else :
791- encoded = self .encode_tokens (
792- prompt , bos = True , device = self .builder_args .device
812+ if is_multimodal :
813+ batch = padded_collate_tiled_images_and_mask (
814+ [data ], pad_direction = "left" , pad_max_images = 1
815+ )
816+ encoded = batch .pop ("tokens" ).to (device ).view (- 1 )
817+ seq_len = encoded .size (0 )
818+ batch ["encoder_mask" ] = batch ["encoder_mask" ][:, :seq_len ]
819+ batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (
820+ self .dtype
821+ )
822+ else :
823+ encoded = torch .tensor (data ["tokens" ], device = device ).view (- 1 )
824+ seq_len = encoded .size (0 )
825+ batch = {}
826+
827+ total_response_length = seq_len + max_new_tokens
828+ batch ["causal_mask" ] = torch .tril (
829+ torch .ones (
830+ size = (total_response_length , total_response_length ),
831+ dtype = torch .bool ,
832+ )
793833 )
794- batch = None
795-
834+
796835 logging .debug (encoded )
797836 return encoded , batch
798837
0 commit comments