@@ -902,20 +902,26 @@ def _encode_prompt_with_clip(
902902 tokenizer ,
903903 prompt : str ,
904904 device = None ,
905+ text_input_ids = None ,
905906 num_images_per_prompt : int = 1 ,
906907):
907908 prompt = [prompt ] if isinstance (prompt , str ) else prompt
908909 batch_size = len (prompt )
909910
910- text_inputs = tokenizer (
911- prompt ,
912- padding = "max_length" ,
913- max_length = 77 ,
914- truncation = True ,
915- return_tensors = "pt" ,
916- )
911+ if tokenizer is not None :
912+ text_inputs = tokenizer (
913+ prompt ,
914+ padding = "max_length" ,
915+ max_length = 77 ,
916+ truncation = True ,
917+ return_tensors = "pt" ,
918+ )
919+
920+ text_input_ids = text_inputs .input_ids
921+ else :
922+ if text_input_ids is None :
923+ raise ValueError ("text_input_ids must be provided when the tokenizer is not specified" )
917924
918- text_input_ids = text_inputs .input_ids
919925 prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = True )
920926
921927 pooled_prompt_embeds = prompt_embeds [0 ]
@@ -937,6 +943,7 @@ def encode_prompt(
937943 max_sequence_length ,
938944 device = None ,
939945 num_images_per_prompt : int = 1 ,
946+ text_input_ids_list = None ,
940947):
941948 prompt = [prompt ] if isinstance (prompt , str ) else prompt
942949
@@ -945,13 +952,14 @@ def encode_prompt(
945952
946953 clip_prompt_embeds_list = []
947954 clip_pooled_prompt_embeds_list = []
948- for tokenizer , text_encoder in zip (clip_tokenizers , clip_text_encoders ):
955+ for i , ( tokenizer , text_encoder ) in enumerate ( zip (clip_tokenizers , clip_text_encoders ) ):
949956 prompt_embeds , pooled_prompt_embeds = _encode_prompt_with_clip (
950957 text_encoder = text_encoder ,
951958 tokenizer = tokenizer ,
952959 prompt = prompt ,
953960 device = device if device is not None else text_encoder .device ,
954961 num_images_per_prompt = num_images_per_prompt ,
962+ text_input_ids = text_input_ids_list [i ] if text_input_ids_list else None ,
955963 )
956964 clip_prompt_embeds_list .append (prompt_embeds )
957965 clip_pooled_prompt_embeds_list .append (pooled_prompt_embeds )
0 commit comments