@@ -321,9 +321,19 @@ def check_inputs(
321321 width ,
322322 prompt_embeds = None ,
323323 negative_prompt_embeds = None ,
324+ image_embeds = None ,
324325 callback_on_step_end_tensor_inputs = None ,
325326 ):
326- if not isinstance (image , torch .Tensor ) and not isinstance (image , PIL .Image .Image ):
327+ if image is not None and image_embeds is not None :
328+ raise ValueError (
329+ f"Cannot forward both `image`: { image } and `image_embeds`: { image_embeds } . Please make sure to"
330+ " only forward one of the two."
331+ )
332+ if image is None and image_embeds is None :
333+ raise ValueError (
334+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
335+ )
336+ if image is not None and not isinstance (image , torch .Tensor ) and not isinstance (image , PIL .Image .Image ):
327337 raise ValueError ("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" { type (image )} " )
328338 if height % 16 != 0 or width % 16 != 0 :
329339 raise ValueError (f"`height` and `width` have to be divisible by 16 but are { height } and { width } ." )
@@ -463,6 +473,7 @@ def __call__(
463473 latents : Optional [torch .Tensor ] = None ,
464474 prompt_embeds : Optional [torch .Tensor ] = None ,
465475 negative_prompt_embeds : Optional [torch .Tensor ] = None ,
476+ image_embeds : Optional [torch .Tensor ] = None ,
466477 output_type : Optional [str ] = "np" ,
467478 return_dict : bool = True ,
468479 attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -512,6 +523,12 @@ def __call__(
512523 prompt_embeds (`torch.Tensor`, *optional*):
513524 Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
514525 provided, text embeddings are generated from the `prompt` input argument.
526+ negative_prompt_embeds (`torch.Tensor`, *optional*):
527+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
528+ provided, text embeddings are generated from the `negative_prompt` input argument.
529+ image_embeds (`torch.Tensor`, *optional*):
530+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
531+ image embeddings are generated from the `image` input argument.
515532 output_type (`str`, *optional*, defaults to `"pil"`):
516533 The output format of the generated image. Choose between `PIL.Image` or `np.array`.
517534 return_dict (`bool`, *optional*, defaults to `True`):
@@ -556,6 +573,7 @@ def __call__(
556573 width ,
557574 prompt_embeds ,
558575 negative_prompt_embeds ,
576+ image_embeds ,
559577 callback_on_step_end_tensor_inputs ,
560578 )
561579
@@ -599,7 +617,8 @@ def __call__(
599617 if negative_prompt_embeds is not None :
600618 negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
601619
602- image_embeds = self .encode_image (image , device )
620+ if image_embeds is None :
621+ image_embeds = self .encode_image (image , device )
603622 image_embeds = image_embeds .repeat (batch_size , 1 , 1 )
604623 image_embeds = image_embeds .to (transformer_dtype )
605624
0 commit comments