1414# limitations under the License.
1515
1616import inspect
17- from typing import Callable , Dict , List , Optional , Tuple , Union
17+ from typing import Callable , Dict , List , Optional , Tuple , Union , Any
1818
1919import numpy as np
2020import torch
4343 Examples:
4444 ```python
4545 >>> import torch
46- >>> from diffusers import CogView4Pipeline
46+ >>> from diffusers import CogView4ControlPipeline
4747
4848 >>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16)
4949 >>> control_image = load_image(
@@ -420,6 +420,14 @@ def do_classifier_free_guidance(self):
420420 def num_timesteps (self ):
421421 return self ._num_timesteps
422422
423+ @property
424+ def attention_kwargs (self ):
425+ return self ._attention_kwargs
426+
427+ @property
428+ def current_timestep (self ):
429+ return self ._current_timestep
430+
423431 @property
424432 def interrupt (self ):
425433 return self ._interrupt
@@ -446,6 +454,7 @@ def __call__(
446454 crops_coords_top_left : Tuple [int , int ] = (0 , 0 ),
447455 output_type : str = "pil" ,
448456 return_dict : bool = True ,
457+ attention_kwargs : Optional [Dict [str , Any ]] = None ,
449458 callback_on_step_end : Optional [
450459 Union [Callable [[int , int , Dict ], None ], PipelineCallback , MultiPipelineCallbacks ]
451460 ] = None ,
@@ -559,6 +568,8 @@ def __call__(
559568 negative_prompt_embeds ,
560569 )
561570 self ._guidance_scale = guidance_scale
571+ self ._attention_kwargs = attention_kwargs
572+ self ._current_timestep = None
562573 self ._interrupt = False
563574
564575 # Default call parameters
@@ -652,6 +663,8 @@ def __call__(
652663 for i , t in enumerate (timesteps ):
653664 if self .interrupt :
654665 continue
666+
667+ self ._current_timestep = t
655668 latent_model_input = torch .cat ([latents , control_image ], dim = 1 ).to (transformer_dtype )
656669
657670 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -664,6 +677,7 @@ def __call__(
664677 original_size = original_size ,
665678 target_size = target_size ,
666679 crop_coords = crops_coords_top_left ,
680+ attention_kwargs = attention_kwargs ,
667681 return_dict = False ,
668682 )[0 ]
669683
@@ -676,6 +690,7 @@ def __call__(
676690 original_size = original_size ,
677691 target_size = target_size ,
678692 crop_coords = crops_coords_top_left ,
693+ attention_kwargs = attention_kwargs ,
679694 return_dict = False ,
680695 )[0 ]
681696
@@ -700,6 +715,8 @@ def __call__(
700715 if XLA_AVAILABLE :
701716 xm .mark_step ()
702717
718+ self ._current_timestep = None
719+
703720 if not output_type == "latent" :
704721 latents = latents .to (self .vae .dtype ) / self .vae .config .scaling_factor
705722 image = self .vae .decode (latents , return_dict = False , generator = generator )[0 ]
0 commit comments