@@ -732,8 +732,27 @@ def _callback(self, x, *, buffer, done_generating):
732732 print ("" .join (buffer ), end = "" , flush = True )
733733 buffer .clear ()
734734 # print(, end='', flush=True)
735-
736- def _gen_model_input (self , prompt : Union [str | List [Any ]], image_prompts : Optional [List [str | Image .Image ]] = None , max_new_tokens : Optional [int ] = None ) -> Tuple :
735+
736+ def _gen_model_input (
737+ self ,
738+ prompt : Union [str | List [Any ]],
739+ image_prompts : Optional [List [str | Image .Image ]] = None ,
740+ max_new_tokens : Optional [int ] = None ,
741+ ) -> Tuple [torch .Tensor , Optional [Dict [str , Any ]]]:
742+ """
743+ Convert prompt and image prompts into consumable model input args.
744+
745+ When prompt is a list, the anticipated format is OpenAI API Inspired:
746+ [ ..., {"role": message["role"], "content": message["content"]}, ...]
747+
748+ Args:
749+ prompt (Union[str, List[Any]]): Prompt or list of dialog.
750+ image_prompts (Optional[List[str | Image.Image]]): List of image prompts. Used only with Llama 3.2 11B.
751+ max_new_tokens (Optional[int]): Maximum number of new tokens to generate. Used only with Llama 3.2 11B.
752+
753+ Returns:
754+ Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
755+ """
737756
738757 # Not Llama 3.2 11B
739758 if self .model .config .model_type != ModelType .Flamingo :
@@ -753,14 +772,20 @@ def _gen_model_input(self, prompt: Union[str | List[Any]], image_prompts: Option
753772 return encoded , None
754773
755774 # Llama 3.2 11B
756- assert image_prompts is None or len (image_prompts ) == 1 , "At most one image is supported at the moment"
775+ assert (
776+ image_prompts is None or len (image_prompts ) == 1
777+ ), "At most one image is supported at the moment"
757778 if image_prompts and isinstance (image_prompts [0 ], str ):
758779 images = [Image .open (image_prompts [0 ])]
759780 else :
760781 images = image_prompts
761782
762- assert max_new_tokens is not None , "max_new_tokens must be specified for Flamingo models"
763- assert isinstance (prompt , str ), "(Currently) prompt must be a str for Flamingo models"
783+ assert (
784+ max_new_tokens is not None
785+ ), "max_new_tokens must be specified for Flamingo models"
786+ assert isinstance (
787+ prompt , str
788+ ), "(Currently) prompt must be a str for Flamingo models"
764789
765790 is_multimodal = images is not None
766791 content = [{"type" : "text" , "content" : prompt }]
@@ -791,21 +816,21 @@ def _gen_model_input(self, prompt: Union[str | List[Any]], image_prompts: Option
791816 encoded = batch .pop ("tokens" ).to (device ).view (- 1 )
792817 seq_len = encoded .size (0 )
793818 batch ["encoder_mask" ] = batch ["encoder_mask" ][:, :seq_len ]
794- batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (self .dtype )
819+ batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (
820+ self .dtype
821+ )
795822 else :
796- encoded = torch .tensor (
797- data ["tokens" ], device = device
798- ).view (- 1 )
823+ encoded = torch .tensor (data ["tokens" ], device = device ).view (- 1 )
799824 seq_len = encoded .size (0 )
800825 batch = {}
801826
802827 total_response_length = seq_len + max_new_tokens
803828 batch ["causal_mask" ] = torch .tril (
804- torch .ones (
805- size = (total_response_length , total_response_length ),
806- dtype = torch .bool ,
807- )
808- )
829+ torch .ones (
830+ size = (total_response_length , total_response_length ),
831+ dtype = torch .bool ,
832+ )
833+ )
809834
810835 logging .debug (encoded )
811836 return encoded , batch
0 commit comments