44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66import argparse
7+ import base64
78import itertools
89import logging
910import os
1213
1314from abc import ABC , abstractmethod
1415from dataclasses import dataclass
16+ from io import BytesIO
1517from os import PathLike
1618from pathlib import Path
1719from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
@@ -733,6 +735,9 @@ def _callback(self, x, *, buffer, done_generating):
733735 buffer .clear ()
734736 # print(, end='', flush=True)
735737
738+ def print_m (self , message ):
739+ print (message .role , [t ["type" ] if t ["type" ] != "text" else t for t in message .content ])
740+
736741 def _gen_model_input (
737742 self ,
738743 prompt : Union [str | List [Any ]],
@@ -775,6 +780,7 @@ def _gen_model_input(
775780 assert (
776781 image_prompts is None or len (image_prompts ) == 1
777782 ), "At most one image is supported at the moment"
783+
778784 if image_prompts and isinstance (image_prompts [0 ], str ):
779785 images = [Image .open (image_prompts [0 ])]
780786 else :
@@ -783,24 +789,45 @@ def _gen_model_input(
783789 assert (
784790 max_new_tokens is not None
785791 ), "max_new_tokens must be specified for Flamingo models"
786- assert isinstance (
787- prompt , str
788- ), "(Currently) prompt must be a str for Flamingo models"
789792
790- is_multimodal = images is not None
791- content = [{"type" : "text" , "content" : prompt }]
793+ image_found = False
794+ messages = []
795+ for message in prompt :
796+ if isinstance (message ["content" ], str ):
797+ messages .append (Message (** message ))
798+
799+ elif isinstance (message ["content" ], list ):
800+ images = None
801+ for content_dict in message ["content" ]:
802+ if content_dict ["type" ] == "text" :
803+ prompt_arg = content_dict ["text" ]
804+ elif content_dict ["type" ] == "image_url" :
805+ assert (
806+ images is None
807+ ), "At most one image is supported at the moment"
808+
809+ base64_decoded = base64 .b64decode (
810+ content_dict ["image_url" ].split (";base64," )[1 ]
811+ )
812+ images = [Image .open (BytesIO (base64_decoded ))]
813+ image_found = True
792814
793- if is_multimodal :
794- content = [{"type" : "image " , "content" : images [ 0 ]}] + content
815+ is_multimodal = images is not None
816+ content = [{"type" : "text " , "content" : prompt_arg }]
795817
796- messages = [
797- Message (
798- role = "user" ,
799- content = content ,
800- eot = True ,
801- ),
802- Message (role = "assistant" , content = "" ),
803- ]
818+ if is_multimodal :
819+ content = [{"type" : "image" , "content" : images [0 ]}] + content
820+
821+ messages .append (
822+ Message (
823+ role = "user" ,
824+ content = content ,
825+ )
826+ )
827+
828+ print ("MESSAGE CONTENTS:" )
829+ messages .append (Message (role = "assistant" , content = "" ))
830+ [self .print_m (m ) for m in messages ]
804831
805832 transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
806833
0 commit comments