44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66import argparse
7+ import base64
78import itertools
89import logging
910import os
1213
1314from abc import ABC , abstractmethod
1415from dataclasses import dataclass
16+ from io import BytesIO
1517from os import PathLike
1618from pathlib import Path
1719from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
@@ -600,9 +602,8 @@ def generate(
600602
601603 if len (prompt .shape ) > 1 :
602604 prompt = prompt .squeeze (0 )
603- T = prompt .size (0 )
604- max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - T )
605- T_new = T + max_new_tokens
605+ prompt_length = prompt .size (0 )
606+ max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - prompt_length )
606607 # set up caches only if first inference
607608 if start_pos == 0 :
608609 model = model .to (device = device )
@@ -616,7 +617,7 @@ def generate(
616617 batch_size = 1 ,
617618 dtype = self .dtype ,
618619 encoder_max_seq_len = 6404 ,
619- decoder_max_seq_len = T_new ,
620+ decoder_max_seq_len = max_seq_length ,
620621 )
621622 else :
622623 model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
@@ -629,7 +630,7 @@ def generate(
629630 model .reset_caches ()
630631
631632 input_pos = torch .arange (
632- start_pos , T + start_pos , device = device , dtype = torch .int
633+ start_pos , prompt_length + start_pos , device = device , dtype = torch .int
633634 )
634635
635636 prefill_t0 = time .perf_counter ()
@@ -655,7 +656,9 @@ def generate(
655656 # max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
656657 callback (next_token .clone ().view (- 1 ), done_generating = max_new_tokens <= 2 )
657658
658- input_pos = torch .tensor ([start_pos + T ], device = device , dtype = torch .int )
659+ input_pos = torch .tensor (
660+ [start_pos + prompt_length ], device = device , dtype = torch .int
661+ )
659662 accept_counts = [0 ] * (
660663 speculate_k + 1
661664 ) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
@@ -678,7 +681,7 @@ def generate(
678681 )
679682
680683 accept_counts [len (next_tokens ) - 1 ] += 1
681- num_added = min (T_new - input_pos - 1 , len (next_tokens ))
684+ num_added = min (max_new_tokens - input_pos - 1 , len (next_tokens ))
682685 for token in next_tokens [:num_added ,]:
683686 callback (token )
684687 yield token , None
@@ -741,6 +744,7 @@ def _gen_model_input(
741744 prompt : Union [str | List [Any ]],
742745 image_prompts : Optional [List [str | Image .Image ]] = None ,
743746 max_new_tokens : Optional [int ] = None ,
747+ max_seq_len : Optional [int ] = 2048 ,
744748 ) -> Tuple [torch .Tensor , Optional [Dict [str , Any ]]]:
745749 """
746750 Convert prompt and image prompts into consumable model input args.
@@ -757,7 +761,7 @@ def _gen_model_input(
757761 Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
758762 """
759763
760- # Not Llama 3.2 11B
764+ # Text-Only model
761765 if self .model .config .model_type != ModelType .Flamingo :
762766 # Single String prompt
763767 if isinstance (prompt , str ):
@@ -778,32 +782,69 @@ def _gen_model_input(
778782 assert (
779783 image_prompts is None or len (image_prompts ) == 1
780784 ), "At most one image is supported at the moment"
785+
781786 if image_prompts and isinstance (image_prompts [0 ], str ):
782787 images = [Image .open (image_prompts [0 ])]
783788 else :
784- images = image_prompts
789+ images = None
785790
786791 assert (
787792 max_new_tokens is not None
788793 ), "max_new_tokens must be specified for Flamingo models"
789- assert isinstance (
790- prompt , str
791- ), "(Currently) prompt must be a str for Flamingo models"
792794
793- is_multimodal = images is not None
794- content = [{"type" : "text" , "content" : prompt }]
795+ image_found = False
796+ messages = []
797+ for message in prompt :
798+ if isinstance (message ["content" ], str ):
799+ if not image_found and image_prompts :
800+ messages .append (
801+ Message (
802+ role = message ["role" ],
803+ content = [
804+ {"type" : "image" , "content" : images [0 ]},
805+ {"type" : "text" , "content" : message ["content" ]},
806+ ],
807+ )
808+ )
809+ image_found = True
810+ else :
811+ messages .append (Message (** message ))
812+
813+ elif isinstance (message ["content" ], list ):
814+ images = None
815+ for content_dict in message ["content" ]:
816+ if content_dict ["type" ] == "text" :
817+ prompt_arg = content_dict ["text" ]
818+ elif content_dict ["type" ] == "image_url" :
819+ assert (
820+ images is None
821+ ), "At most one image is supported at the moment"
822+
823+ base64_decoded = base64 .b64decode (
824+ content_dict ["image_url" ].split (";base64," )[1 ]
825+ )
826+ images = [Image .open (BytesIO (base64_decoded ))]
827+ image_found = True
828+
829+ is_multimodal = images is not None
830+ content = [{"type" : "text" , "content" : prompt_arg }]
831+
832+ if is_multimodal :
833+ content = [{"type" : "image" , "content" : images [0 ]}] + content
795834
796- if is_multimodal :
797- content = [{"type" : "image" , "content" : images [0 ]}] + content
835+ messages .append (
836+ Message (
837+ role = message ["role" ],
838+ content = content ,
839+ )
840+ )
798841
799- messages = [
842+ messages . append (
800843 Message (
801- role = "user" ,
802- content = content ,
803- eot = True ,
804- ),
805- Message (role = "assistant" , content = "" ),
806- ]
844+ role = "assistant" ,
845+ content = "" ,
846+ )
847+ )
807848
808849 transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
809850
@@ -812,7 +853,7 @@ def _gen_model_input(
812853 with device , set_default_dtype (self .dtype ):
813854 data = transform ({"messages" : messages }, inference = True )
814855
815- if is_multimodal :
856+ if image_found :
816857 batch = padded_collate_tiled_images_and_mask (
817858 [data ], pad_direction = "left" , pad_max_images = 1
818859 )
@@ -822,17 +863,27 @@ def _gen_model_input(
822863 batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (
823864 self .dtype
824865 )
866+
825867 else :
826868 encoded = torch .tensor (data ["tokens" ], device = device ).view (- 1 )
827869 seq_len = encoded .size (0 )
828870 batch = {}
829871
830872 total_response_length = seq_len + max_new_tokens
831- batch ["causal_mask" ] = torch .tril (
832- torch .ones (
833- size = (total_response_length , total_response_length ),
834- dtype = torch .bool ,
835- )
873+ batch ["causal_mask" ] = torch .nn .functional .pad (
874+ torch .tril (
875+ torch .ones (
876+ size = (total_response_length , total_response_length ),
877+ dtype = torch .bool ,
878+ )
879+ ),
880+ (
881+ 0 ,
882+ max_seq_len - total_response_length ,
883+ 0 ,
884+ max_seq_len - total_response_length ,
885+ ),
886+ value = 0 ,
836887 )
837888
838889 logging .debug (encoded )
@@ -845,12 +896,6 @@ def chat(
845896 if generator_args .chat_mode :
846897 print ("Starting Interactive Chat" )
847898
848- encoded , batch = self ._gen_model_input (
849- generator_args .prompt ,
850- generator_args .image_prompts ,
851- generator_args .max_new_tokens ,
852- )
853-
854899 model_size = sum (
855900 [
856901 p .numel () * p .dtype .itemsize
@@ -896,6 +941,12 @@ def chat(
896941 max_seq_length = (
897942 text_transformer_args .max_seq_length if text_transformer_args else 2048
898943 )
944+ encoded , batch = self ._gen_model_input (
945+ [{"role" : "user" , "content" : generator_args .prompt }],
946+ generator_args .image_prompts ,
947+ generator_args .max_new_tokens ,
948+ max_seq_length ,
949+ )
899950
900951 if generator_args .chat_mode :
901952 print (
@@ -907,16 +958,16 @@ def chat(
907958 if get_system_prompt == "y" or get_system_prompt == "Y" :
908959 self .system_prompt = input ("What is your system prompt? \n " )
909960
910- elif not generator_args .is_torchtune_model :
911- max_seq_length = min (
912- encoded .size (0 ) + generator_args .max_new_tokens ,
913- (
914- text_transformer_args .block_size
915- if text_transformer_args is not None
916- else 2048
917- ),
918- max_seq_length ,
919- )
961+ # elif not generator_args.is_torchtune_model:
962+ # max_seq_length = min(
963+ # encoded.size(0) + generator_args.max_new_tokens,
964+ # (
965+ # text_transformer_args.block_size
966+ # if text_transformer_args is not None
967+ # else 2048
968+ # ),
969+ # max_seq_length,
970+ # )
920971
921972 max_seq_length = (
922973 max_seq_length + self .speculative_builder_args .speculate_k + 1
0 commit comments